22
33from collections import OrderedDict
44from contextlib import contextmanager
5+ import itertools
56from typing import (
67 cast , Dict , Set , List , Tuple , Callable , Union , Optional , Sequence , Iterator
78)
@@ -2554,15 +2555,18 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr,
25542555 if isinstance (index , SliceExpr ):
25552556 return self .visit_tuple_slice_helper (left_type , index )
25562557
2557- n = self ._get_value (index )
2558- if n is not None :
2559- if n < 0 :
2560- n += len (left_type .items )
2561- if 0 <= n < len (left_type .items ):
2562- return left_type .items [n ]
2563- else :
2564- self .chk .fail (message_registry .TUPLE_INDEX_OUT_OF_RANGE , e )
2565- return AnyType (TypeOfAny .from_error )
2558+ ns = self .try_getting_int_literals (index )
2559+ if ns is not None :
2560+ out = []
2561+ for n in ns :
2562+ if n < 0 :
2563+ n += len (left_type .items )
2564+ if 0 <= n < len (left_type .items ):
2565+ out .append (left_type .items [n ])
2566+ else :
2567+ self .chk .fail (message_registry .TUPLE_INDEX_OUT_OF_RANGE , e )
2568+ return AnyType (TypeOfAny .from_error )
2569+ return UnionType .make_simplified_union (out )
25662570 else :
25672571 return self .nonliteral_tuple_index_helper (left_type , index )
25682572 elif isinstance (left_type , TypedDictType ):
@@ -2578,26 +2582,66 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr,
25782582 return result
25792583
25802584 def visit_tuple_slice_helper (self , left_type : TupleType , slic : SliceExpr ) -> Type :
2581- begin = None
2582- end = None
2583- stride = None
2585+ begin = [ None ] # type: Sequence[Optional[int]]
2586+ end = [ None ] # type: Sequence[Optional[int]]
2587+ stride = [ None ] # type: Sequence[Optional[int]]
25842588
25852589 if slic .begin_index :
2586- begin = self ._get_value (slic .begin_index )
2587- if begin is None :
2590+ begin_raw = self .try_getting_int_literals (slic .begin_index )
2591+ if begin_raw is None :
25882592 return self .nonliteral_tuple_index_helper (left_type , slic )
2593+ begin = begin_raw
25892594
25902595 if slic .end_index :
2591- end = self ._get_value (slic .end_index )
2592- if end is None :
2596+ end_raw = self .try_getting_int_literals (slic .end_index )
2597+ if end_raw is None :
25932598 return self .nonliteral_tuple_index_helper (left_type , slic )
2599+ end = end_raw
25942600
25952601 if slic .stride :
2596- stride = self ._get_value (slic .stride )
2597- if stride is None :
2602+ stride_raw = self .try_getting_int_literals (slic .stride )
2603+ if stride_raw is None :
25982604 return self .nonliteral_tuple_index_helper (left_type , slic )
2605+ stride = stride_raw
2606+
2607+ items = [] # type: List[Type]
2608+ for b , e , s in itertools .product (begin , end , stride ):
2609+ items .append (left_type .slice (b , e , s ))
2610+ return UnionType .make_simplified_union (items )
25992611
2600- return left_type .slice (begin , stride , end )
2612+ def try_getting_int_literals (self , index : Expression ) -> Optional [List [int ]]:
2613+ """If the given expression or type corresponds to an int literal
2614+ or a union of int literals, returns a list of the underlying ints.
2615+ Otherwise, returns None.
2616+
2617+ Specifically, this function is guaranteed to return a list with
2618+ one or more ints if one one the following is true:
2619+
2620+ 1. 'expr' is a IntExpr or a UnaryExpr backed by an IntExpr
2621+ 2. 'typ' is a LiteralType containing an int
2622+ 3. 'typ' is a UnionType containing only LiteralType of ints
2623+ """
2624+ if isinstance (index , IntExpr ):
2625+ return [index .value ]
2626+ elif isinstance (index , UnaryExpr ):
2627+ if index .op == '-' :
2628+ operand = index .expr
2629+ if isinstance (operand , IntExpr ):
2630+ return [- 1 * operand .value ]
2631+ typ = self .accept (index )
2632+ if isinstance (typ , Instance ) and typ .last_known_value is not None :
2633+ typ = typ .last_known_value
2634+ if isinstance (typ , LiteralType ) and isinstance (typ .value , int ):
2635+ return [typ .value ]
2636+ if isinstance (typ , UnionType ):
2637+ out = []
2638+ for item in typ .items :
2639+ if isinstance (item , LiteralType ) and isinstance (item .value , int ):
2640+ out .append (item .value )
2641+ else :
2642+ return None
2643+ return out
2644+ return None
26012645
26022646 def nonliteral_tuple_index_helper (self , left_type : TupleType , index : Expression ) -> Type :
26032647 index_type = self .accept (index )
@@ -2614,40 +2658,36 @@ def nonliteral_tuple_index_helper(self, left_type: TupleType, index: Expression)
26142658 else :
26152659 return union
26162660
2617- def _get_value (self , index : Expression ) -> Optional [int ]:
2618- if isinstance (index , IntExpr ):
2619- return index .value
2620- elif isinstance (index , UnaryExpr ):
2621- if index .op == '-' :
2622- operand = index .expr
2623- if isinstance (operand , IntExpr ):
2624- return - 1 * operand .value
2625- typ = self .accept (index )
2626- if isinstance (typ , Instance ) and typ .last_known_value is not None :
2627- typ = typ .last_known_value
2628- if isinstance (typ , LiteralType ) and isinstance (typ .value , int ):
2629- return typ .value
2630- return None
2631-
26322661 def visit_typeddict_index_expr (self , td_type : TypedDictType , index : Expression ) -> Type :
26332662 if isinstance (index , (StrExpr , UnicodeExpr )):
2634- item_name = index .value
2663+ key_names = [ index .value ]
26352664 else :
26362665 typ = self .accept (index )
2637- if isinstance (typ , Instance ) and typ .last_known_value is not None :
2638- typ = typ .last_known_value
2639-
2640- if isinstance (typ , LiteralType ) and isinstance (typ .value , str ):
2641- item_name = typ .value
2666+ if isinstance (typ , UnionType ):
2667+ key_types = typ .items
26422668 else :
2643- self .msg .typeddict_key_must_be_string_literal (td_type , index )
2644- return AnyType (TypeOfAny .from_error )
2669+ key_types = [typ ]
26452670
2646- item_type = td_type .items .get (item_name )
2647- if item_type is None :
2648- self .msg .typeddict_key_not_found (td_type , item_name , index )
2649- return AnyType (TypeOfAny .from_error )
2650- return item_type
2671+ key_names = []
2672+ for key_type in key_types :
2673+ if isinstance (key_type , Instance ) and key_type .last_known_value is not None :
2674+ key_type = key_type .last_known_value
2675+
2676+ if isinstance (key_type , LiteralType ) and isinstance (key_type .value , str ):
2677+ key_names .append (key_type .value )
2678+ else :
2679+ self .msg .typeddict_key_must_be_string_literal (td_type , index )
2680+ return AnyType (TypeOfAny .from_error )
2681+
2682+ value_types = []
2683+ for key_name in key_names :
2684+ value_type = td_type .items .get (key_name )
2685+ if value_type is None :
2686+ self .msg .typeddict_key_not_found (td_type , key_name , index )
2687+ return AnyType (TypeOfAny .from_error )
2688+ else :
2689+ value_types .append (value_type )
2690+ return UnionType .make_simplified_union (value_types )
26512691
26522692 def visit_enum_index_expr (self , enum_type : TypeInfo , index : Expression ,
26532693 context : Context ) -> Type :
0 commit comments