@@ -994,11 +994,40 @@ impl ExecutingFrame<'_> {
994
994
}
995
995
bytecode:: Instruction :: GetANext => {
996
996
let aiter = self . top_value ( ) ;
997
- let awaitable = vm. call_special_method ( aiter, identifier ! ( vm, __anext__) , ( ) ) ?;
998
- let awaitable = if awaitable. payload_is :: < PyCoroutine > ( ) {
999
- awaitable
997
+ let awaitable = if aiter. class ( ) . is ( vm. ctx . types . async_generator ) {
998
+ vm. call_special_method ( aiter, identifier ! ( vm, __anext__) , ( ) ) ?
1000
999
} else {
1001
- vm. call_special_method ( & awaitable, identifier ! ( vm, __await__) , ( ) ) ?
1000
+ if !aiter. has_attr ( "__anext__" , vm) . unwrap_or ( false ) {
1001
+ // TODO: __anext__ must be protocol
1002
+ let msg = format ! (
1003
+ "'async for' requires an iterator with __anext__ method, got {:.100}" ,
1004
+ aiter. class( ) . name( )
1005
+ ) ;
1006
+ return Err ( vm. new_type_error ( msg) ) ;
1007
+ }
1008
+ let next_iter =
1009
+ vm. call_special_method ( aiter, identifier ! ( vm, __anext__) , ( ) ) ?;
1010
+
1011
+ // _PyCoro_GetAwaitableIter in CPython
1012
+ fn get_awaitable_iter ( next_iter : & PyObject , vm : & VirtualMachine ) -> PyResult {
1013
+ let gen_is_coroutine = |_| {
1014
+ // TODO: cpython gen_is_coroutine
1015
+ true
1016
+ } ;
1017
+ if next_iter. class ( ) . is ( vm. ctx . types . coroutine_type )
1018
+ || gen_is_coroutine ( next_iter)
1019
+ {
1020
+ return Ok ( next_iter. to_owned ( ) ) ;
1021
+ }
1022
+ // TODO: error handling
1023
+ vm. call_special_method ( next_iter, identifier ! ( vm, __await__) , ( ) )
1024
+ }
1025
+ get_awaitable_iter ( & next_iter, vm) . map_err ( |_| {
1026
+ vm. new_type_error ( format ! (
1027
+ "'async for' received an invalid object from __anext__: {:.200}" ,
1028
+ next_iter. class( ) . name( )
1029
+ ) )
1030
+ } ) ?
1002
1031
} ;
1003
1032
self . push_value ( awaitable) ;
1004
1033
Ok ( None )
0 commit comments