diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py index 45dfebc65904fce..6e678e40e7727ac 100644 --- a/Lib/asyncio/taskgroups.py +++ b/Lib/asyncio/taskgroups.py @@ -30,6 +30,7 @@ def __init__(self): self._entered = False self._exiting = False self._aborting = False + self._explicitly_cancelled = False self._loop = None self._parent_task = None self._parent_cancel_requested = False @@ -196,7 +197,7 @@ def create_task(self, coro, **kwargs): if self._exiting and not self._tasks: coro.close() raise RuntimeError(f"TaskGroup {self!r} is finished") - if self._aborting: + if self._aborting and not self._explicitly_cancelled: coro.close() raise RuntimeError(f"TaskGroup {self!r} is shutting down") task = self._loop.create_task(coro, **kwargs) @@ -209,6 +210,12 @@ def create_task(self, coro, **kwargs): # the current task too early. gh-128550, gh-128588 self._tasks.add(task) task.add_done_callback(self._on_task_done) + if self._aborting and self._explicitly_cancelled: + def _cancel_later(task=task): + if not task.done(): + task.cancel() + self._loop.call_soon(_cancel_later) + try: return task finally: @@ -307,6 +314,7 @@ def cancel(self): if self._exiting and not self._tasks: return if not self._aborting: + self._explicitly_cancelled = True self._abort() if self._parent_task and not self._parent_cancel_requested: self._parent_cancel_requested = True diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py index 8925884b9dcf731..3f6632c9c8834b6 100644 --- a/Lib/test/test_asyncio/test_taskgroups.py +++ b/Lib/test/test_asyncio/test_taskgroups.py @@ -1149,10 +1149,9 @@ async def test_taskgroup_cancel_before_enter(self): async def test_taskgroup_cancel_before_create_task(self): async with asyncio.TaskGroup() as tg: tg.cancel() - # TODO: This behavior is not ideal. We'd rather have no exception - # raised, and the child task run until the first await. - with self.assertRaises(RuntimeError): - tg.create_task(asyncio.sleep(1)) + t = tg.create_task(asyncio.sleep(1)) + await asyncio.sleep(0) + self.assertTrue(t.cancelled()) async def test_taskgroup_cancel_before_exception(self): async def raise_exc(parent_tg: asyncio.TaskGroup):