@@ -67,11 +67,11 @@ def silence_coro_gc():
6767class AsyncBadSyntaxTest (unittest .TestCase ):
6868
6969 def test_badsyntax_1 (self ):
70- with self .assertRaisesRegex (SyntaxError , 'invalid syntax' ):
70+ with self .assertRaisesRegex (SyntaxError , "'await' outside" ):
7171 import test .badsyntax_async1
7272
7373 def test_badsyntax_2 (self ):
74- with self .assertRaisesRegex (SyntaxError , 'invalid syntax' ):
74+ with self .assertRaisesRegex (SyntaxError , "'await' outside" ):
7575 import test .badsyntax_async2
7676
7777 def test_badsyntax_3 (self ):
@@ -103,10 +103,6 @@ def test_badsyntax_8(self):
103103 import test .badsyntax_async8
104104
105105 def test_badsyntax_9 (self ):
106- with self .assertRaisesRegex (SyntaxError , 'invalid syntax' ):
107- import test .badsyntax_async9
108-
109- def test_badsyntax_10 (self ):
110106 ns = {}
111107 for comp in {'(await a for a in b)' ,
112108 '[await a for a in b]' ,
@@ -116,6 +112,221 @@ def test_badsyntax_10(self):
116112 with self .assertRaisesRegex (SyntaxError , 'await.*in comprehen' ):
117113 exec ('async def f():\n \t {}' .format (comp ), ns , ns )
118114
115+ def test_badsyntax_10 (self ):
116+ # Tests for issue 24619
117+
118+ samples = [
119+ """async def foo():
120+ def bar(): pass
121+ await = 1
122+ """ ,
123+
124+ """async def foo():
125+
126+ def bar(): pass
127+ await = 1
128+ """ ,
129+
130+ """async def foo():
131+ def bar(): pass
132+ if 1:
133+ await = 1
134+ """ ,
135+
136+ """def foo():
137+ async def bar(): pass
138+ if 1:
139+ await a
140+ """ ,
141+
142+ """def foo():
143+ async def bar(): pass
144+ await a
145+ """ ,
146+
147+ """def foo():
148+ def baz(): pass
149+ async def bar(): pass
150+ await a
151+ """ ,
152+
153+ """def foo():
154+ def baz(): pass
155+ # 456
156+ async def bar(): pass
157+ # 123
158+ await a
159+ """ ,
160+
161+ """async def foo():
162+ def baz(): pass
163+ # 456
164+ async def bar(): pass
165+ # 123
166+ await = 2
167+ """ ,
168+
169+ """def foo():
170+
171+ def baz(): pass
172+
173+ async def bar(): pass
174+
175+ await a
176+ """ ,
177+
178+ """async def foo():
179+
180+ def baz(): pass
181+
182+ async def bar(): pass
183+
184+ await = 2
185+ """ ,
186+
187+ """async def foo():
188+ def async(): pass
189+ """ ,
190+
191+ """async def foo():
192+ def await(): pass
193+ """ ,
194+
195+ """async def foo():
196+ def bar():
197+ await
198+ """ ,
199+
200+ """async def foo():
201+ return lambda async: await
202+ """ ,
203+
204+ """async def foo():
205+ return lambda a: await
206+ """ ,
207+
208+ """async def foo(a: await b):
209+ pass
210+ """ ,
211+
212+ """def baz():
213+ async def foo(a: await b):
214+ pass
215+ """ ,
216+
217+ """async def foo(async):
218+ pass
219+ """ ,
220+
221+ """async def foo():
222+ def bar():
223+ def baz():
224+ async = 1
225+ """ ,
226+
227+ """async def foo():
228+ def bar():
229+ def baz():
230+ pass
231+ async = 1
232+ """ ,
233+
234+ """def foo():
235+ async def bar():
236+
237+ async def baz():
238+ pass
239+
240+ def baz():
241+ 42
242+
243+ async = 1
244+ """ ,
245+
246+ """async def foo():
247+ def bar():
248+ def baz():
249+ pass\n await foo()
250+ """ ,
251+
252+ """def foo():
253+ def bar():
254+ async def baz():
255+ pass\n await foo()
256+ """ ,
257+
258+ """async def foo(await):
259+ pass
260+ """ ,
261+
262+ """def foo():
263+
264+ async def bar(): pass
265+
266+ await a
267+ """ ,
268+
269+ """def foo():
270+ async def bar():
271+ pass\n await a
272+ """ ]
273+
274+ ns = {}
275+ for code in samples :
276+ with self .subTest (code = code ), self .assertRaises (SyntaxError ):
277+ exec (code , ns , ns )
278+
279+ def test_goodsyntax_1 (self ):
280+ # Tests for issue 24619
281+
282+ def foo (await ):
283+ async def foo (): pass
284+ async def foo ():
285+ pass
286+ return await + 1
287+ self .assertEqual (foo (10 ), 11 )
288+
289+ def foo (await ):
290+ async def foo (): pass
291+ async def foo (): pass
292+ return await + 2
293+ self .assertEqual (foo (20 ), 22 )
294+
295+ def foo (await ):
296+
297+ async def foo (): pass
298+
299+ async def foo (): pass
300+
301+ return await + 2
302+ self .assertEqual (foo (20 ), 22 )
303+
304+ def foo (await ):
305+ """spam"""
306+ async def foo (): \
307+ pass
308+ # 123
309+ async def foo (): pass
310+ # 456
311+ return await + 2
312+ self .assertEqual (foo (20 ), 22 )
313+
314+ def foo (await ):
315+ def foo (): pass
316+ def foo (): pass
317+ async def bar (): return await_
318+ await_ = await
319+ try :
320+ bar ().send (None )
321+ except StopIteration as ex :
322+ return ex .args [0 ]
323+ self .assertEqual (foo (42 ), 42 )
324+
325+ async def f ():
326+ async def g (): pass
327+ await z
328+ self .assertTrue (inspect .iscoroutinefunction (f ))
329+
119330
120331class TokenizerRegrTest (unittest .TestCase ):
121332
@@ -461,8 +672,7 @@ def test_await_8(self):
461672 class Awaitable :
462673 pass
463674
464- async def foo ():
465- return (await Awaitable ())
675+ async def foo (): return await Awaitable ()
466676
467677 with self .assertRaisesRegex (
468678 TypeError , "object Awaitable can't be used in 'await' expression" ):
0 commit comments