1
1
# Adapted with permission from the EdgeDB project;
2
2
# license: PSFL.
3
3
4
+ import weakref
5
+ import sys
4
6
import gc
5
7
import asyncio
6
8
import contextvars
@@ -27,7 +29,25 @@ def get_error_types(eg):
27
29
return {type (exc ) for exc in eg .exceptions }
28
30
29
31
30
- class TestTaskGroup (unittest .IsolatedAsyncioTestCase ):
32
+ def set_gc_state (enabled ):
33
+ was_enabled = gc .isenabled ()
34
+ if enabled :
35
+ gc .enable ()
36
+ else :
37
+ gc .disable ()
38
+ return was_enabled
39
+
40
+
41
+ @contextlib .contextmanager
42
+ def disable_gc ():
43
+ was_enabled = set_gc_state (enabled = False )
44
+ try :
45
+ yield
46
+ finally :
47
+ set_gc_state (enabled = was_enabled )
48
+
49
+
50
+ class BaseTestTaskGroup :
31
51
32
52
async def test_taskgroup_01 (self ):
33
53
@@ -880,6 +900,30 @@ async def coro_fn():
880
900
self .assertIsInstance (exc , _Done )
881
901
self .assertListEqual (gc .get_referrers (exc ), [])
882
902
903
+
904
+ async def test_exception_refcycles_parent_task_wr (self ):
905
+ """Test that TaskGroup deletes self._parent_task and create_task() deletes task"""
906
+ tg = asyncio .TaskGroup ()
907
+ exc = None
908
+
909
+ class _Done (Exception ):
910
+ pass
911
+
912
+ async def coro_fn ():
913
+ async with tg :
914
+ raise _Done
915
+
916
+ with disable_gc ():
917
+ try :
918
+ async with asyncio .TaskGroup () as tg2 :
919
+ task_wr = weakref .ref (tg2 .create_task (coro_fn ()))
920
+ except* _Done as excs :
921
+ exc = excs .exceptions [0 ].exceptions [0 ]
922
+
923
+ self .assertIsNone (task_wr ())
924
+ self .assertIsInstance (exc , _Done )
925
+ self .assertListEqual (gc .get_referrers (exc ), [])
926
+
883
927
async def test_exception_refcycles_propagate_cancellation_error (self ):
884
928
"""Test that TaskGroup deletes propagate_cancellation_error"""
885
929
tg = asyncio .TaskGroup ()
@@ -912,6 +956,81 @@ class MyKeyboardInterrupt(KeyboardInterrupt):
912
956
self .assertIsNotNone (exc )
913
957
self .assertListEqual (gc .get_referrers (exc ), [])
914
958
959
+ async def test_cancels_task_if_created_during_creation (self ):
960
+ # regression test for gh-128550
961
+ ran = False
962
+ class MyError (Exception ):
963
+ pass
964
+
965
+ exc = None
966
+ try :
967
+ async with asyncio .TaskGroup () as tg :
968
+ async def third_task ():
969
+ raise MyError ("third task failed" )
970
+
971
+ async def second_task ():
972
+ nonlocal ran
973
+ tg .create_task (third_task ())
974
+ with self .assertRaises (asyncio .CancelledError ):
975
+ await asyncio .sleep (0 ) # eager tasks cancel here
976
+ await asyncio .sleep (0 ) # lazy tasks cancel here
977
+ ran = True
978
+
979
+ tg .create_task (second_task ())
980
+ except* MyError as excs :
981
+ exc = excs .exceptions [0 ]
982
+
983
+ self .assertTrue (ran )
984
+ self .assertIsInstance (exc , MyError )
985
+
986
+ async def test_cancellation_does_not_leak_out_of_tg (self ):
987
+ class MyError (Exception ):
988
+ pass
989
+
990
+ async def throw_error ():
991
+ raise MyError
992
+
993
+ try :
994
+ async with asyncio .TaskGroup () as tg :
995
+ tg .create_task (throw_error ())
996
+ except* MyError :
997
+ pass
998
+ else :
999
+ self .fail ("should have raised one MyError in group" )
1000
+
1001
+ # if this test fails this current task will be cancelled
1002
+ # outside the task group and inside unittest internals
1003
+ # we yield to the event loop with sleep(0) so that
1004
+ # cancellation happens here and error is more understandable
1005
+ await asyncio .sleep (0 )
1006
+
1007
+
1008
+ if sys .platform == "win32" :
1009
+ EventLoop = asyncio .ProactorEventLoop
1010
+ else :
1011
+ EventLoop = asyncio .SelectorEventLoop
1012
+
1013
+
1014
+ class IsolatedAsyncioTestCase (unittest .IsolatedAsyncioTestCase ):
1015
+ loop_factory = None
1016
+
1017
+ def _setupAsyncioRunner (self ):
1018
+ assert self ._asyncioRunner is None , 'asyncio runner is already initialized'
1019
+ runner = asyncio .Runner (debug = True , loop_factory = self .loop_factory )
1020
+ self ._asyncioRunner = runner
1021
+
1022
+
1023
+ class TestTaskGroup (BaseTestTaskGroup , IsolatedAsyncioTestCase ):
1024
+ loop_factory = EventLoop
1025
+
1026
+
1027
+ class TestEagerTaskTaskGroup (BaseTestTaskGroup , IsolatedAsyncioTestCase ):
1028
+ @staticmethod
1029
+ def loop_factory ():
1030
+ loop = EventLoop ()
1031
+ loop .set_task_factory (asyncio .eager_task_factory )
1032
+ return loop
1033
+
915
1034
916
1035
if __name__ == "__main__" :
917
1036
unittest .main ()
0 commit comments