|
1 | 1 | # Adapted with permission from the EdgeDB project;
|
2 | 2 | # license: PSFL.
|
3 | 3 |
|
| 4 | +import weakref |
4 | 5 | import sys
|
5 | 6 | import gc
|
6 | 7 | import asyncio
|
@@ -38,7 +39,25 @@ def no_other_refs():
|
38 | 39 | return [coro]
|
39 | 40 |
|
40 | 41 |
|
41 |
| -class TestTaskGroup(unittest.IsolatedAsyncioTestCase): |
| 42 | +def set_gc_state(enabled): |
| 43 | + was_enabled = gc.isenabled() |
| 44 | + if enabled: |
| 45 | + gc.enable() |
| 46 | + else: |
| 47 | + gc.disable() |
| 48 | + return was_enabled |
| 49 | + |
| 50 | + |
| 51 | +@contextlib.contextmanager |
| 52 | +def disable_gc(): |
| 53 | + was_enabled = set_gc_state(enabled=False) |
| 54 | + try: |
| 55 | + yield |
| 56 | + finally: |
| 57 | + set_gc_state(enabled=was_enabled) |
| 58 | + |
| 59 | + |
| 60 | +class BaseTestTaskGroup: |
42 | 61 |
|
43 | 62 | async def test_taskgroup_01(self):
|
44 | 63 |
|
@@ -832,15 +851,15 @@ async def test_taskgroup_without_parent_task(self):
|
832 | 851 | with self.assertRaisesRegex(RuntimeError, "has not been entered"):
|
833 | 852 | tg.create_task(coro)
|
834 | 853 |
|
835 |
| - def test_coro_closed_when_tg_closed(self): |
| 854 | + async def test_coro_closed_when_tg_closed(self): |
836 | 855 | async def run_coro_after_tg_closes():
|
837 | 856 | async with taskgroups.TaskGroup() as tg:
|
838 | 857 | pass
|
839 | 858 | coro = asyncio.sleep(0)
|
840 | 859 | with self.assertRaisesRegex(RuntimeError, "is finished"):
|
841 | 860 | tg.create_task(coro)
|
842 |
| - loop = asyncio.get_event_loop() |
843 |
| - loop.run_until_complete(run_coro_after_tg_closes()) |
| 861 | + |
| 862 | + await run_coro_after_tg_closes() |
844 | 863 |
|
845 | 864 | async def test_cancelling_level_preserved(self):
|
846 | 865 | async def raise_after(t, e):
|
@@ -965,6 +984,30 @@ async def coro_fn():
|
965 | 984 | self.assertIsInstance(exc, _Done)
|
966 | 985 | self.assertListEqual(gc.get_referrers(exc), no_other_refs())
|
967 | 986 |
|
| 987 | + |
| 988 | + async def test_exception_refcycles_parent_task_wr(self): |
| 989 | + """Test that TaskGroup deletes self._parent_task and create_task() deletes task""" |
| 990 | + tg = asyncio.TaskGroup() |
| 991 | + exc = None |
| 992 | + |
| 993 | + class _Done(Exception): |
| 994 | + pass |
| 995 | + |
| 996 | + async def coro_fn(): |
| 997 | + async with tg: |
| 998 | + raise _Done |
| 999 | + |
| 1000 | + with disable_gc(): |
| 1001 | + try: |
| 1002 | + async with asyncio.TaskGroup() as tg2: |
| 1003 | + task_wr = weakref.ref(tg2.create_task(coro_fn())) |
| 1004 | + except* _Done as excs: |
| 1005 | + exc = excs.exceptions[0].exceptions[0] |
| 1006 | + |
| 1007 | + self.assertIsNone(task_wr()) |
| 1008 | + self.assertIsInstance(exc, _Done) |
| 1009 | + self.assertListEqual(gc.get_referrers(exc), no_other_refs()) |
| 1010 | + |
968 | 1011 | async def test_exception_refcycles_propagate_cancellation_error(self):
|
969 | 1012 | """Test that TaskGroup deletes propagate_cancellation_error"""
|
970 | 1013 | tg = asyncio.TaskGroup()
|
@@ -998,5 +1041,16 @@ class MyKeyboardInterrupt(KeyboardInterrupt):
|
998 | 1041 | self.assertListEqual(gc.get_referrers(exc), no_other_refs())
|
999 | 1042 |
|
1000 | 1043 |
|
| 1044 | +class TestTaskGroup(BaseTestTaskGroup, unittest.IsolatedAsyncioTestCase): |
| 1045 | + loop_factory = asyncio.EventLoop |
| 1046 | + |
| 1047 | +class TestEagerTaskTaskGroup(BaseTestTaskGroup, unittest.IsolatedAsyncioTestCase): |
| 1048 | + @staticmethod |
| 1049 | + def loop_factory(): |
| 1050 | + loop = asyncio.EventLoop() |
| 1051 | + loop.set_task_factory(asyncio.eager_task_factory) |
| 1052 | + return loop |
| 1053 | + |
| 1054 | + |
1001 | 1055 | if __name__ == "__main__":
|
1002 | 1056 | unittest.main()
|
0 commit comments