@@ -1210,7 +1210,7 @@ def visit_return_stmt(self, s: ReturnStmt) -> Type:
12101210 self .fail (messages .NO_RETURN_VALUE_EXPECTED , s )
12111211 else :
12121212 if self .function_stack [- 1 ].is_coroutine : # Something similar will be needed to mix return and yield
1213- #If the function is a coroutine, wrap the return type in a Future
1213+ # If the function is a coroutine, wrap the return type in a Future
12141214 typ = self .wrap_generic_type (typ , self .return_types [- 1 ], 'asyncio.futures.Future' , s )
12151215 self .check_subtype (
12161216 typ , self .return_types [- 1 ], s ,
@@ -1225,20 +1225,21 @@ def visit_return_stmt(self, s: ReturnStmt) -> Type:
12251225 self .fail (messages .RETURN_VALUE_EXPECTED , s )
12261226
12271227 def wrap_generic_type (self , typ : Type , rtyp : Type , check_type : str , context : Context ) -> Type :
1228- n_diff = self .count_concatenated_types (rtyp , check_type ) - self .count_concatenated_types (typ , check_type )
1229- if n_diff > = 1 :
1228+ n_diff = self .count_nested_types (rtyp , check_type ) - self .count_nested_types (typ , check_type )
1229+ if n_diff = = 1 :
12301230 return self .named_generic_type (check_type , [typ ])
1231- elif n_diff == 0 :
1231+ elif n_diff == 0 or n_diff > 1 :
12321232 self .fail (messages .INCOMPATIBLE_RETURN_VALUE_TYPE
12331233 + ": expected {}, got {}" .format (rtyp , typ ), context )
12341234 return typ
12351235 return typ
12361236
1237- def count_concatenated_types (self , typ : Type , check_type : str ) -> int :
1237+ def count_nested_types (self , typ : Type , check_type : str ) -> int :
12381238 c = 0
12391239 while is_subtype (typ , self .named_type (check_type )):
12401240 c += 1
1241- if hasattr (typ , 'args' ) and typ .args :
1241+ typ = map_instance_to_supertype (typ , self .lookup_typeinfo (check_type ))
1242+ if typ .args :
12421243 typ = typ .args [0 ]
12431244 else :
12441245 return c
@@ -1268,7 +1269,7 @@ def visit_yield_from_stmt(self, s: YieldFromStmt) -> Type:
12681269 return_type = self .return_types [- 1 ]
12691270 type_func = self .accept (s .expr , return_type )
12701271 if isinstance (type_func , Instance ):
1271- if hasattr ( type_func , 'type' ) and hasattr ( type_func . type , 'fullname' ) and type_func .type .fullname () == 'asyncio.futures.Future' :
1272+ if type_func .type .fullname () == 'asyncio.futures.Future' :
12721273 # if is a Future, in stmt don't need to do nothing
12731274 # because the type Future[Some] jus matters to the main loop
12741275 # that python executes, in statement we shouldn't get the Future,
@@ -1277,15 +1278,15 @@ def visit_yield_from_stmt(self, s: YieldFromStmt) -> Type:
12771278 elif is_subtype (type_func , self .named_type ('typing.Iterable' )):
12781279 # If it's and Iterable-Like, let's check the types.
12791280 # Maybe just check if have __iter__? (like in analyse_iterable)
1280- self .check_iterable_yf (s )
1281+ self .check_iterable_yield_from (s )
12811282 else :
1282- self .msg .yield_from_not_valid_applied (type_func , s )
1283+ self .msg .yield_from_invalid_operand_type (type_func , s )
12831284 elif isinstance (type_func , AnyType ):
1284- self .check_iterable_yf (s )
1285+ self .check_iterable_yield_from (s )
12851286 else :
1286- self .msg .yield_from_not_valid_applied (type_func , s )
1287+ self .msg .yield_from_invalid_operand_type (type_func , s )
12871288
1288- def check_iterable_yf (self , s : YieldFromStmt ) -> Type :
1289+ def check_iterable_yield_from (self , s : YieldFromStmt ) -> Type :
12891290 """
12901291 Check that return type is super type of Iterable (Maybe just check if have __iter__?)
12911292 and compare it with the type of the expression
@@ -1295,9 +1296,9 @@ def check_iterable_yf(self, s: YieldFromStmt) -> Type:
12951296 if not is_subtype (expected_item_type , self .named_type ('typing.Iterable' )):
12961297 self .fail (messages .INVALID_RETURN_TYPE_FOR_YIELD_FROM , s )
12971298 return None
1298- elif hasattr (expected_item_type , 'args' ) and expected_item_type .args :
1299+ elif expected_item_type .args :
1300+ expected_item_type = map_instance_to_supertype (expected_item_type , self .lookup_typeinfo ('typing.Iterable' ))
12991301 expected_item_type = expected_item_type .args [0 ] # Take the item inside the iterator
1300- # expected_item_type = expected_item_type
13011302 elif isinstance (expected_item_type , AnyType ):
13021303 expected_item_type = AnyType ()
13031304 else :
@@ -1308,6 +1309,7 @@ def check_iterable_yf(self, s: YieldFromStmt) -> Type:
13081309 else :
13091310 actual_item_type = self .accept (s .expr , expected_item_type )
13101311 if hasattr (actual_item_type , 'args' ) and actual_item_type .args :
1312+ actual_item_type = map_instance_to_supertype (actual_item_type , self .lookup_typeinfo ('typing.Iterable' ))
13111313 actual_item_type = actual_item_type .args [0 ] # Take the item inside the iterator
13121314 self .check_subtype (actual_item_type , expected_item_type , s ,
13131315 messages .INCOMPATIBLE_TYPES_IN_YIELD_FROM ,
@@ -1625,7 +1627,7 @@ def visit_call_expr(self, e: CallExpr) -> Type:
16251627
16261628 def visit_yield_from_expr (self , e : YieldFromExpr ) -> Type :
16271629 result = self .expr_checker .visit_yield_from_expr (e )
1628- if hasattr ( result , 'type' ) and result .type .fullname () == "asyncio.futures.Future" :
1630+ if result .type .fullname () == "asyncio.futures.Future" :
16291631 self .function_stack [- 1 ].is_coroutine = True # Set the function as coroutine
16301632 result = result .args [0 ] # Set the return type as the type inside
16311633 elif is_subtype (result , self .named_type ('typing.Iterable' )):
@@ -1634,8 +1636,7 @@ def visit_yield_from_expr(self, e: YieldFromExpr) -> Type:
16341636 # Maybe set result like in the Future
16351637 pass
16361638 else :
1637- self .msg .yield_from_not_valid_applied (e .expr , e )
1638- self .breaking_out = False
1639+ self .msg .yield_from_invalid_operand_type (e .expr , e )
16391640 return result
16401641
16411642 def visit_member_expr (self , e : MemberExpr ) -> Type :
0 commit comments