@@ -372,11 +372,8 @@ def tearDown(self):
372372 self .loop = None
373373 asyncio .set_event_loop_policy (None )
374374
375- def test_async_gen_anext (self ):
376- async def gen ():
377- yield 1
378- yield 2
379- g = gen ()
375+ def check_async_iterator_anext (self , ait_class ):
376+ g = ait_class ()
380377 async def consume ():
381378 results = []
382379 results .append (await anext (g ))
@@ -388,6 +385,66 @@ async def consume():
388385 with self .assertRaises (StopAsyncIteration ):
389386 self .loop .run_until_complete (consume ())
390387
388+ async def test_2 ():
389+ g1 = ait_class ()
390+ self .assertEqual (await anext (g1 ), 1 )
391+ self .assertEqual (await anext (g1 ), 2 )
392+ with self .assertRaises (StopAsyncIteration ):
393+ await anext (g1 )
394+ with self .assertRaises (StopAsyncIteration ):
395+ await anext (g1 )
396+
397+ g2 = ait_class ()
398+ self .assertEqual (await anext (g2 , "default" ), 1 )
399+ self .assertEqual (await anext (g2 , "default" ), 2 )
400+ self .assertEqual (await anext (g2 , "default" ), "default" )
401+ self .assertEqual (await anext (g2 , "default" ), "default" )
402+
403+ return "completed"
404+
405+ result = self .loop .run_until_complete (test_2 ())
406+ self .assertEqual (result , "completed" )
407+
408+ def test_async_generator_anext (self ):
409+ async def agen ():
410+ yield 1
411+ yield 2
412+ self .check_async_iterator_anext (agen )
413+
414+ def test_python_async_iterator_anext (self ):
415+ class MyAsyncIter :
416+ """Asynchronously yield 1, then 2."""
417+ def __init__ (self ):
418+ self .yielded = 0
419+ def __aiter__ (self ):
420+ return self
421+ async def __anext__ (self ):
422+ if self .yielded >= 2 :
423+ raise StopAsyncIteration ()
424+ else :
425+ self .yielded += 1
426+ return self .yielded
427+ self .check_async_iterator_anext (MyAsyncIter )
428+
429+ def test_python_async_iterator_types_coroutine_anext (self ):
430+ import types
431+ class MyAsyncIterWithTypesCoro :
432+ """Asynchronously yield 1, then 2."""
433+ def __init__ (self ):
434+ self .yielded = 0
435+ def __aiter__ (self ):
436+ return self
437+ @types .coroutine
438+ def __anext__ (self ):
439+ if False :
440+ yield "this is a generator-based coroutine"
441+ if self .yielded >= 2 :
442+ raise StopAsyncIteration ()
443+ else :
444+ self .yielded += 1
445+ return self .yielded
446+ self .check_async_iterator_anext (MyAsyncIterWithTypesCoro )
447+
391448 def test_async_gen_aiter (self ):
392449 async def gen ():
393450 yield 1
@@ -431,12 +488,85 @@ async def call_with_too_many_args():
431488 await anext (gen (), 1 , 3 )
432489 async def call_with_wrong_type_args ():
433490 await anext (1 , gen ())
491+ async def call_with_kwarg ():
492+ await anext (aiterator = gen ())
434493 with self .assertRaises (TypeError ):
435494 self .loop .run_until_complete (call_with_too_few_args ())
436495 with self .assertRaises (TypeError ):
437496 self .loop .run_until_complete (call_with_too_many_args ())
438497 with self .assertRaises (TypeError ):
439498 self .loop .run_until_complete (call_with_wrong_type_args ())
499+ with self .assertRaises (TypeError ):
500+ self .loop .run_until_complete (call_with_kwarg ())
501+
502+ def test_anext_bad_await (self ):
503+ async def bad_awaitable ():
504+ class BadAwaitable :
505+ def __await__ (self ):
506+ return 42
507+ class MyAsyncIter :
508+ def __aiter__ (self ):
509+ return self
510+ def __anext__ (self ):
511+ return BadAwaitable ()
512+ regex = r"__await__.*iterator"
513+ awaitable = anext (MyAsyncIter (), "default" )
514+ with self .assertRaisesRegex (TypeError , regex ):
515+ await awaitable
516+ awaitable = anext (MyAsyncIter ())
517+ with self .assertRaisesRegex (TypeError , regex ):
518+ await awaitable
519+ return "completed"
520+ result = self .loop .run_until_complete (bad_awaitable ())
521+ self .assertEqual (result , "completed" )
522+
523+ async def check_anext_returning_iterator (self , aiter_class ):
524+ awaitable = anext (aiter_class (), "default" )
525+ with self .assertRaises (TypeError ):
526+ await awaitable
527+ awaitable = anext (aiter_class ())
528+ with self .assertRaises (TypeError ):
529+ await awaitable
530+ return "completed"
531+
532+ def test_anext_return_iterator (self ):
533+ class WithIterAnext :
534+ def __aiter__ (self ):
535+ return self
536+ def __anext__ (self ):
537+ return iter ("abc" )
538+ result = self .loop .run_until_complete (self .check_anext_returning_iterator (WithIterAnext ))
539+ self .assertEqual (result , "completed" )
540+
541+ def test_anext_return_generator (self ):
542+ class WithGenAnext :
543+ def __aiter__ (self ):
544+ return self
545+ def __anext__ (self ):
546+ yield
547+ result = self .loop .run_until_complete (self .check_anext_returning_iterator (WithGenAnext ))
548+ self .assertEqual (result , "completed" )
549+
550+ def test_anext_await_raises (self ):
551+ class RaisingAwaitable :
552+ def __await__ (self ):
553+ raise ZeroDivisionError ()
554+ yield
555+ class WithRaisingAwaitableAnext :
556+ def __aiter__ (self ):
557+ return self
558+ def __anext__ (self ):
559+ return RaisingAwaitable ()
560+ async def do_test ():
561+ awaitable = anext (WithRaisingAwaitableAnext ())
562+ with self .assertRaises (ZeroDivisionError ):
563+ await awaitable
564+ awaitable = anext (WithRaisingAwaitableAnext (), "default" )
565+ with self .assertRaises (ZeroDivisionError ):
566+ await awaitable
567+ return "completed"
568+ result = self .loop .run_until_complete (do_test ())
569+ self .assertEqual (result , "completed" )
440570
441571 def test_aiter_bad_args (self ):
442572 async def gen ():
0 commit comments