12
12
the other half simulate the same workload as the
13
13
"memoization" variant.
14
14
15
- All variants also have an "eager" flavor that uses
16
- the asyncio eager task factory (if available).
15
+ All variants also have an "eager" flavor that uses the asyncio eager task
16
+ factory (if available), and a "tg" variant that uses TaskGroups .
17
17
"""
18
18
19
19
34
34
35
35
36
36
class AsyncTree :
37
- def __init__ (self ):
37
+ def __init__ (self , use_task_groups = False ):
38
38
self .cache = {}
39
+ self .use_task_groups = use_task_groups
39
40
# set to deterministic random, so that the results are reproducible
40
41
random .seed (RANDOM_SEED )
41
42
@@ -47,17 +48,31 @@ async def workload_func(self):
47
48
"To be implemented by each variant's derived class."
48
49
)
49
50
50
- async def recurse (self , recurse_level ):
51
+ async def recurse_with_gather (self , recurse_level ):
51
52
if recurse_level == 0 :
52
53
await self .workload_func ()
53
54
return
54
55
55
56
await asyncio .gather (
56
- * [self .recurse (recurse_level - 1 ) for _ in range (NUM_RECURSE_BRANCHES )]
57
+ * [self .recurse_with_gather (recurse_level - 1 )
58
+ for _ in range (NUM_RECURSE_BRANCHES )]
57
59
)
58
60
61
+ async def recurse_with_task_group (self , recurse_level ):
62
+ if recurse_level == 0 :
63
+ await self .workload_func ()
64
+ return
65
+
66
+ async with asyncio .TaskGroup () as tg :
67
+ for _ in range (NUM_RECURSE_BRANCHES ):
68
+ tg .create_task (
69
+ self .recurse_with_task_group (recurse_level - 1 ))
70
+
59
71
async def run (self ):
60
- await self .recurse (NUM_RECURSE_LEVELS )
72
+ if self .use_task_groups :
73
+ await self .recurse_with_task_group (NUM_RECURSE_LEVELS )
74
+ else :
75
+ await self .recurse_with_gather (NUM_RECURSE_LEVELS )
61
76
62
77
63
78
class EagerMixin :
@@ -132,6 +147,8 @@ def add_metadata(runner):
132
147
133
148
def add_cmdline_args (cmd , args ):
134
149
cmd .append (args .benchmark )
150
+ if args .task_groups :
151
+ cmd .append ("--task-groups" )
135
152
136
153
137
154
def add_parser_args (parser ):
@@ -149,6 +166,12 @@ def add_parser_args(parser):
149
166
"memoization" variant.
150
167
""" ,
151
168
)
169
+ parser .add_argument (
170
+ "--task-groups" ,
171
+ action = "store_true" ,
172
+ default = False ,
173
+ help = "Use TaskGroups instead of gather." ,
174
+ )
152
175
153
176
154
177
BENCHMARKS = {
@@ -171,5 +194,8 @@ def add_parser_args(parser):
171
194
benchmark = args .benchmark
172
195
173
196
async_tree_class = BENCHMARKS [benchmark ]
174
- async_tree = async_tree_class ()
175
- runner .bench_async_func (f"async_tree_{ benchmark } " , async_tree .run )
197
+ async_tree = async_tree_class (use_task_groups = args .task_groups )
198
+ bench_name = f"async_tree_{ benchmark } "
199
+ if args .task_groups :
200
+ bench_name += "_tg"
201
+ runner .bench_async_func (bench_name , async_tree .run )
0 commit comments