@@ -51,14 +51,15 @@ def f(x: int) -> int:
5151 BasicBlock , AssignmentTarget , AssignmentTargetRegister , AssignmentTargetIndex ,
5252 AssignmentTargetAttr , AssignmentTargetTuple , Environment , Op , LoadInt , RType , Value , Register ,
5353 Return , FuncIR , Assign , Branch , Goto , RuntimeArg , Call , Box , Unbox , Cast , RTuple , Unreachable ,
54- TupleGet , TupleSet , ClassIR , RInstance , ModuleIR , GetAttr , SetAttr , LoadStatic , ROptional ,
54+ TupleGet , TupleSet , ClassIR , RInstance , ModuleIR , GetAttr , SetAttr , LoadStatic ,
5555 MethodCall , INVALID_FUNC_DEF , int_rprimitive , float_rprimitive , bool_rprimitive ,
5656 list_rprimitive , is_list_rprimitive , dict_rprimitive , set_rprimitive , str_rprimitive ,
5757 tuple_rprimitive , none_rprimitive , is_none_rprimitive , object_rprimitive , exc_rtuple ,
5858 PrimitiveOp , ControlOp , LoadErrorValue , ERR_FALSE , OpDescription , RegisterOp ,
5959 is_object_rprimitive , LiteralsMap , FuncSignature , VTableAttr , VTableMethod , VTableEntries ,
6060 NAMESPACE_TYPE , RaiseStandardError , LoadErrorValue , NO_TRACEBACK_LINE_NO , FuncDecl ,
6161 FUNC_NORMAL , FUNC_STATICMETHOD , FUNC_CLASSMETHOD ,
62+ RUnion , is_optional_type , optional_value_type
6263)
6364from mypyc .ops_primitive import binary_ops , unary_ops , func_ops , method_ops , name_ref_ops
6465from mypyc .ops_list import (
@@ -266,12 +267,8 @@ def type_to_rtype(self, typ: Type) -> RType:
266267 elif isinstance (typ , NoneTyp ):
267268 return none_rprimitive
268269 elif isinstance (typ , UnionType ):
269- assert len (typ .items ) == 2 and any (isinstance (it , NoneTyp ) for it in typ .items )
270- if isinstance (typ .items [0 ], NoneTyp ):
271- value_type = typ .items [1 ]
272- else :
273- value_type = typ .items [0 ]
274- return ROptional (self .type_to_rtype (value_type ))
270+ return RUnion ([self .type_to_rtype (item )
271+ for item in typ .items ])
275272 elif isinstance (typ , AnyType ):
276273 return object_rprimitive
277274 elif isinstance (typ , TypeType ):
@@ -770,7 +767,7 @@ def generate_attr_defaults(self, cdef: ClassDef) -> None:
770767 # don't initialize it to anything.
771768 if isinstance (stmt .rvalue , RefExpr ) and stmt .rvalue .fullname == 'builtins.None' :
772769 attr_type = cls .attr_type (lvalue .name )
773- if (not isinstance (attr_type , ROptional ) and not is_object_rprimitive (attr_type )
770+ if (not is_optional_type (attr_type ) and not is_object_rprimitive (attr_type )
774771 and not is_none_rprimitive (attr_type )):
775772 continue
776773
@@ -1664,10 +1661,93 @@ def visit_member_expr(self, expr: MemberExpr) -> Value:
16641661 return self .load_module_attr (expr )
16651662 else :
16661663 obj = self .accept (expr .expr )
1667- if isinstance (obj .type , RInstance ):
1668- return self .add (GetAttr (obj , expr .name , expr .line ))
1664+ return self .get_attr (obj , expr .name , self .node_type (expr ), expr .line )
1665+
1666+ def get_attr (self , obj : Value , attr : str , result_type : RType , line : int ) -> Value :
1667+ if isinstance (obj .type , RInstance ):
1668+ return self .add (GetAttr (obj , attr , line ))
1669+ elif isinstance (obj .type , RUnion ):
1670+ return self .union_get_attr (obj , obj .type , attr , result_type , line )
1671+ else :
1672+ return self .py_get_attr (obj , attr , line )
1673+
1674+ def union_get_attr (self ,
1675+ obj : Value ,
1676+ rtype : RUnion ,
1677+ attr : str ,
1678+ result_type : RType ,
1679+ line : int ) -> Value :
1680+ def get_item_attr (value : Value ) -> Value :
1681+ return self .get_attr (value , attr , result_type , line )
1682+
1683+ return self .decompose_union_helper (obj , rtype , result_type , get_item_attr , line )
1684+
1685+ def decompose_union_helper (self ,
1686+ obj : Value ,
1687+ rtype : RUnion ,
1688+ result_type : RType ,
1689+ process_item : Callable [[Value ], Value ],
1690+ line : int ) -> Value :
1691+ """Generate isinstance() + specialized operations for union items.
1692+
1693+ Say, for Union[A, B] generate ops resembling this (pseudocode):
1694+
1695+ if isinstance(obj, A):
1696+ result = <result of process_item(cast(A, obj)>
16691697 else:
1670- return self .py_get_attr (obj , expr .name , expr .line )
1698+ result = <result of process_item(cast(B, obj)>
1699+
1700+ Args:
1701+ obj: value with a union type
1702+ rtype: the union type
1703+ result_type: result of the operation
1704+ process_item: callback to generate op for a single union item (arg is coerced
1705+ to union item type)
1706+ line: line number
1707+ """
1708+ # TODO: Optimize cases where a single operation can handle multiple union items
1709+ # (say a method is implemented in a common base class)
1710+ fast_items = []
1711+ rest_items = []
1712+ for item in rtype .items :
1713+ if isinstance (item , RInstance ):
1714+ fast_items .append (item )
1715+ else :
1716+ # For everything but RInstance we fall back to C API
1717+ rest_items .append (item )
1718+ exit_block = BasicBlock ()
1719+ result = self .alloc_temp (result_type )
1720+ for i , item in enumerate (fast_items ):
1721+ more_types = i < len (fast_items ) - 1 or rest_items
1722+ if more_types :
1723+ # We are not at the final item so we need one more branch
1724+ op = self .isinstance (obj , item , line )
1725+ true_block , false_block = BasicBlock (), BasicBlock ()
1726+ self .add_bool_branch (op , true_block , false_block )
1727+ self .activate_block (true_block )
1728+ coerced = self .coerce (obj , item , line )
1729+ temp = process_item (coerced )
1730+ temp2 = self .coerce (temp , result_type , line )
1731+ self .add (Assign (result , temp2 ))
1732+ self .goto (exit_block )
1733+ if more_types :
1734+ self .activate_block (false_block )
1735+ if rest_items :
1736+ # For everything else we use generic operation. Use force=True to drop the
1737+ # union type.
1738+ coerced = self .coerce (obj , object_rprimitive , line , force = True )
1739+ temp = process_item (coerced )
1740+ temp2 = self .coerce (temp , result_type , line )
1741+ self .add (Assign (result , temp2 ))
1742+ self .goto (exit_block )
1743+ self .activate_block (exit_block )
1744+ return result
1745+
1746+ def isinstance (self , obj : Value , rtype : RInstance , line : int ) -> Value :
1747+ class_ir = rtype .class_ir
1748+ fullname = '%s.%s' % (class_ir .module_name , class_ir .name )
1749+ type_obj = self .load_native_type_object (fullname )
1750+ return self .primitive_op (fast_isinstance_op , [obj , type_obj ], line )
16711751
16721752 def py_get_attr (self , obj : Value , attr : str , line : int ) -> Value :
16731753 key = self .load_static_unicode (attr )
@@ -1937,6 +2017,9 @@ def gen_method_call(self,
19372017 arg_values = self .coerce_native_call_args (arg_values , decl .bound_sig , base .line )
19382018
19392019 return self .add (MethodCall (base , name , arg_values , line ))
2020+ elif isinstance (base .type , RUnion ):
2021+ return self .union_method_call (base , base .type , name , arg_values , return_rtype , line ,
2022+ arg_kinds , arg_names )
19402023
19412024 # Try to do a special-cased method call
19422025 target = self .translate_special_method_call (base , name , arg_values , return_rtype , line )
@@ -1946,6 +2029,21 @@ def gen_method_call(self,
19462029 # Fall back to Python method call
19472030 return self .py_method_call (base , name , arg_values , base .line , arg_kinds , arg_names )
19482031
2032+ def union_method_call (self ,
2033+ base : Value ,
2034+ obj_type : RUnion ,
2035+ name : str ,
2036+ arg_values : List [Value ],
2037+ return_rtype : RType ,
2038+ line : int ,
2039+ arg_kinds : Optional [List [int ]],
2040+ arg_names : Optional [List [Optional [str ]]]) -> Value :
2041+ def call_union_item (value : Value ) -> Value :
2042+ return self .gen_method_call (value , name , arg_values , return_rtype , line ,
2043+ arg_kinds , arg_names )
2044+
2045+ return self .decompose_union_helper (base , obj_type , return_rtype , call_union_item , line )
2046+
19492047 def translate_cast_expr (self , expr : CastExpr ) -> Value :
19502048 src = self .accept (expr .expr )
19512049 target_type = self .type_to_rtype (expr .type )
@@ -2109,26 +2207,27 @@ def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) ->
21092207 length = self .primitive_op (list_len_op , [value ], value .line )
21102208 zero = self .add (LoadInt (0 ))
21112209 value = self .binary_op (length , zero , '!=' , value .line )
2112- elif isinstance (value .type , ROptional ):
2113- is_none = self .binary_op (value , self .add (PrimitiveOp ([], none_op , value .line )),
2114- 'is not' , value .line )
2115- branch = Branch (is_none , true , false , Branch .BOOL_EXPR )
2116- self .add (branch )
2117- value_type = value .type .value_type
2118- if isinstance (value_type , RInstance ):
2119- # Optional[X] where X is always truthy
2120- # TODO: Support __bool__
2121- pass
2122- else :
2123- # Optional[X] where X may be falsey and requires a check
2124- branch .true = self .new_block ()
2125- # unbox_or_cast instead of coerce because we want the
2126- # type to change even if it is a subtype.
2127- remaining = self .unbox_or_cast (value , value .type .value_type , value .line )
2128- self .add_bool_branch (remaining , true , false )
2129- return
2130- elif not is_same_type (value .type , bool_rprimitive ):
2131- value = self .primitive_op (bool_op , [value ], value .line )
2210+ else :
2211+ value_type = optional_value_type (value .type )
2212+ if value_type is not None :
2213+ is_none = self .binary_op (value , self .add (PrimitiveOp ([], none_op , value .line )),
2214+ 'is not' , value .line )
2215+ branch = Branch (is_none , true , false , Branch .BOOL_EXPR )
2216+ self .add (branch )
2217+ if isinstance (value_type , RInstance ):
2218+ # Optional[X] where X is always truthy
2219+ # TODO: Support __bool__
2220+ pass
2221+ else :
2222+ # Optional[X] where X may be falsey and requires a check
2223+ branch .true = self .new_block ()
2224+ # unbox_or_cast instead of coerce because we want the
2225+ # type to change even if it is a subtype.
2226+ remaining = self .unbox_or_cast (value , value_type , value .line )
2227+ self .add_bool_branch (remaining , true , false )
2228+ return
2229+ elif not is_same_type (value .type , bool_rprimitive ):
2230+ value = self .primitive_op (bool_op , [value ], value .line )
21322231 self .add (Branch (value , true , false , Branch .BOOL_EXPR ))
21332232
21342233 def visit_nonlocal_decl (self , o : NonlocalDecl ) -> None :
@@ -3300,12 +3399,15 @@ def load_native_type_object(self, fullname: str) -> Value:
33003399 module , name = fullname .rsplit ('.' , 1 )
33013400 return self .add (LoadStatic (object_rprimitive , name , module , NAMESPACE_TYPE ))
33023401
3303- def coerce (self , src : Value , target_type : RType , line : int ) -> Value :
3402+ def coerce (self , src : Value , target_type : RType , line : int , force : bool = False ) -> Value :
33043403 """Generate a coercion/cast from one type to other (only if needed).
33053404
33063405 For example, int -> object boxes the source int; int -> int emits nothing;
33073406 object -> int unboxes the object. All conversions preserve object value.
33083407
3408+ If force is true, always generate an op (even if it is just an assingment) so
3409+ that the result will have exactly target_type as the type.
3410+
33093411 Returns the register with the converted value (may be same as src).
33103412 """
33113413 if src .type .is_unboxed and not target_type .is_unboxed :
@@ -3319,6 +3421,10 @@ def coerce(self, src: Value, target_type: RType, line: int) -> Value:
33193421 if ((not src .type .is_unboxed and target_type .is_unboxed )
33203422 or not is_subtype (src .type , target_type )):
33213423 return self .unbox_or_cast (src , target_type , line )
3424+ elif force :
3425+ tmp = self .alloc_temp (target_type )
3426+ self .add (Assign (tmp , src ))
3427+ return tmp
33223428 return src
33233429
33243430 def keyword_args_to_positional (self ,
0 commit comments