diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py index 6af21f3a15d93a..7cda05f218c524 100644 --- a/Lib/asyncio/taskgroups.py +++ b/Lib/asyncio/taskgroups.py @@ -8,6 +8,12 @@ from . import events from . import exceptions from . import tasks +from .futures import Future + + +class PyCoroEagerResult: + def __init__(self, value): + self.value = value class TaskGroup: @@ -153,6 +159,22 @@ def create_task(self, coro, *, name=None, context=None): self._tasks.add(task) return task + def enqueue(self, coro, no_future=True): + if not self._entered: + raise RuntimeError(f"TaskGroup {self!r} has not been entered") + + try: + fut = coro.send(None) + task = self.create_task(coro) + task._set_fut_awaiter(fut) + return task + except StopIteration as e: + # The co-routine has completed synchronously and we've got + # our result. + res = Future(loop=self._loop) + res.set_result(e.args[0] if e.args else None) + return res + # Since Python 3.8 Tasks propagate all exceptions correctly, # except for KeyboardInterrupt and SystemExit which are # still considered special. diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py index 3952b5f2a7743d..9ad8c88fd6772f 100644 --- a/Lib/asyncio/tasks.py +++ b/Lib/asyncio/tasks.py @@ -117,7 +117,9 @@ def __init__(self, coro, *, loop=None, name=None, context=None): else: self._context = context - self._loop.call_soon(self.__step, context=self._context) + if not getattr(coro, "cr_suspended", False): + self._loop.call_soon(self.__step, context=self._context) + _register_task(self) def __del__(self): @@ -293,55 +295,58 @@ def __step(self, exc=None): except BaseException as exc: super().set_exception(exc) else: - blocking = getattr(result, '_asyncio_future_blocking', None) - if blocking is not None: - # Yielded Future must come from Future.__iter__(). - if not self._check_future(result): - new_exc = RuntimeError( - f'Task {self!r} got Future ' - f'{result!r} attached to a different loop') - self._loop.call_soon( - self.__step, new_exc, context=self._context) - elif blocking: - if result is self: - new_exc = RuntimeError( - f'Task cannot await on itself: {self!r}') - self._loop.call_soon( - self.__step, new_exc, context=self._context) - else: - result._asyncio_future_blocking = False - result.add_done_callback( - self.__wakeup, context=self._context) - self._fut_waiter = result - if self._must_cancel: - if self._fut_waiter.cancel( - msg=self._cancel_message): - self._must_cancel = False - else: - new_exc = RuntimeError( - f'yield was used instead of yield from ' - f'in task {self!r} with {result!r}') - self._loop.call_soon( - self.__step, new_exc, context=self._context) + self._set_fut_awaiter(result) + finally: + _leave_task(self._loop, self) + self = None # Needed to break cycles when an exception occurs. - elif result is None: - # Bare yield relinquishes control for one event loop iteration. - self._loop.call_soon(self.__step, context=self._context) - elif inspect.isgenerator(result): - # Yielding a generator is just wrong. + def _set_fut_awaiter(self, result): + blocking = getattr(result, '_asyncio_future_blocking', None) + if blocking is not None: + # Yielded Future must come from Future.__iter__(). + if not self._check_future(result): new_exc = RuntimeError( - f'yield was used instead of yield from for ' - f'generator in task {self!r} with {result!r}') + f'Task {self!r} got Future ' + f'{result!r} attached to a different loop') self._loop.call_soon( self.__step, new_exc, context=self._context) + elif blocking: + if result is self: + new_exc = RuntimeError( + f'Task cannot await on itself: {self!r}') + self._loop.call_soon( + self.__step, new_exc, context=self._context) + else: + result._asyncio_future_blocking = False + result.add_done_callback( + self.__wakeup, context=self._context) + self._fut_waiter = result + if self._must_cancel: + if self._fut_waiter.cancel( + msg=self._cancel_message): + self._must_cancel = False else: - # Yielding something else is an error. - new_exc = RuntimeError(f'Task got bad yield: {result!r}') + new_exc = RuntimeError( + f'yield was used instead of yield from ' + f'in task {self!r} with {result!r}') self._loop.call_soon( self.__step, new_exc, context=self._context) - finally: - _leave_task(self._loop, self) - self = None # Needed to break cycles when an exception occurs. + + elif result is None: + # Bare yield relinquishes control for one event loop iteration. + self._loop.call_soon(self.__step, context=self._context) + elif inspect.isgenerator(result): + # Yielding a generator is just wrong. + new_exc = RuntimeError( + f'yield was used instead of yield from for ' + f'generator in task {self!r} with {result!r}') + self._loop.call_soon( + self.__step, new_exc, context=self._context) + else: + # Yielding something else is an error. + new_exc = RuntimeError(f'Task got bad yield: {result!r}') + self._loop.call_soon( + self.__step, new_exc, context=self._context) def __wakeup(self, future): try: @@ -369,8 +374,8 @@ def __wakeup(self, future): pass else: # _CTask is needed for tests. - Task = _CTask = _asyncio.Task - + #Task = _CTask = _asyncio.Task + pass def create_task(coro, *, name=None, context=None): """Schedule the execution of a coroutine object in a spawn task. diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py index 69369a6100a8fd..3517c0775f0a5e 100644 --- a/Lib/test/test_asyncio/test_taskgroups.py +++ b/Lib/test/test_asyncio/test_taskgroups.py @@ -722,6 +722,111 @@ async def coro(val): await t2 self.assertEqual(2, ctx.get(cvar)) + async def test_taskgroup_enqueue_01(self): + + async def foo1(): + await asyncio.sleep(0.1) + return 42 + + async def eager(): + return 11 + + async with taskgroups.TaskGroup() as g: + t1 = g.enqueue(foo1(), no_future=False) + t2 = g.enqueue(eager(), no_future=False) + + self.assertEqual(t1.result(), 42) + self.assertEqual(t2.result(), 11) + + async def test_taskgroup_enqueue_02(self): + + async def eager1(): + return 42 + + async def eager2(): + return 11 + + async with taskgroups.TaskGroup() as g: + t1 = g.enqueue(eager1(), no_future=False) + t2 = g.enqueue(eager2(), no_future=False) + + self.assertEqual(t1.result(), 42) + self.assertEqual(t2.result(), 11) + + async def test_taskgroup_enqueue_exception(self): + async def foo1(): + 1 / 0 + + with self.assertRaises(ExceptionGroup) as cm: + async with taskgroups.TaskGroup() as g: + g.enqueue(foo1()) + + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + + async def test_taskgroup_fanout_task(self): + async def step(i): + if i == 0: + return + async with taskgroups.TaskGroup() as g: + for _ in range(6): + g.create_task(step(i - 1)) + + import time + s = time.perf_counter() + await step(6) + e = time.perf_counter() + print(e-s) + + async def test_taskgroup_fanout_enqueue(self): + async def step(i): + if i == 0: + return + async with taskgroups.TaskGroup() as g: + for _ in range(6): + g.enqueue(step(i - 1)) + + import time + s = time.perf_counter() + await step(6) + e = time.perf_counter() + print(e-s) + + async def test_taskgroup_fanout_enqueue_02(self): + async def intermediate2(i): + return await intermediate(i) + + async def intermediate(i): + async with taskgroups.TaskGroup() as g: + for _ in range(6): + g.enqueue(step(i - 1)) + + async def step(i): + if i == 0: + return + + return await intermediate2(i) + + + import time + s = time.perf_counter() + await step(6) + e = time.perf_counter() + print(e-s) + + + async def test_taskgroup_fanout_enqueue_future(self): + async def step(i): + if i == 0: + return + async with taskgroups.TaskGroup() as g: + for _ in range(6): + g.enqueue(step(i - 1), no_future=False) + + import time + s = time.perf_counter() + await step(6) + e = time.perf_counter() + print(e-s) if __name__ == "__main__": unittest.main()