Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 08fcdf8

Browse files
JukkaLmsullivan
authored andcommitted
Support union types (mypyc/mypyc#340)
Add almost-minimal support for union types. Add new rtype `RUnion` which replaces `ROptional`. Optimize attribute get and method calls for union types by generating `isinstance()` checks and casts. Main limitations: * Attribute set and various other operations aren't optimized yet. * If a single operation could support all union items, we still generate multiple operations and type checks. * Only `RInstance` union item types generate efficient code. Everything else falls back to generic operations. Closes mypyc/mypyc#135.
1 parent 473f7e2 commit 08fcdf8

10 files changed

Lines changed: 600 additions & 95 deletions

File tree

mypyc/emit.py

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
"""Utilities for emitting C code."""
22

33
from collections import OrderedDict
4-
from typing import List, Set, Dict, Optional, List, Callable
4+
from typing import List, Set, Dict, Optional, List, Callable, Union
55

66
from mypyc.common import REG_PREFIX, STATIC_PREFIX, TYPE_PREFIX, NATIVE_PREFIX
77
from mypyc.ops import (
88
Any, AssignmentTarget, Environment, BasicBlock, Value, Register, RType, RTuple, RInstance,
9-
ROptional, RPrimitive, is_int_rprimitive, is_float_rprimitive, is_bool_rprimitive,
9+
RUnion, RPrimitive, RUnion, is_int_rprimitive, is_float_rprimitive, is_bool_rprimitive,
1010
short_name, is_list_rprimitive, is_dict_rprimitive, is_set_rprimitive, is_tuple_rprimitive,
1111
is_none_rprimitive, is_object_rprimitive, object_rprimitive, is_str_rprimitive, ClassIR,
12-
FuncIR, FuncDecl, int_rprimitive
12+
FuncIR, FuncDecl, int_rprimitive, is_optional_type, optional_value_type
1313
)
1414
from mypyc.namegen import NameGenerator
1515
from mypyc.sametype import is_same_type
@@ -79,9 +79,13 @@ def emit_lines(self, *lines: str) -> None:
7979
for line in lines:
8080
self.emit_line(line)
8181

82-
def emit_label(self, label: BasicBlock) -> None:
82+
def emit_label(self, label: Union[BasicBlock, str]) -> None:
83+
if isinstance(label, str):
84+
text = label
85+
else:
86+
text = self.label(label)
8387
# Extra semicolon prevents an error when the next line declares a tempvar
84-
self.fragments.append('{}: ;\n'.format(self.label(label)))
88+
self.fragments.append('{}: ;\n'.format(text))
8589

8690
def emit_from_emitter(self, emitter: 'Emitter') -> None:
8791
self.fragments.extend(emitter.fragments)
@@ -95,6 +99,10 @@ def temp_name(self) -> str:
9599
self.context.temp_counter += 1
96100
return '__tmp%d' % self.context.temp_counter
97101

102+
def new_label(self) -> str:
103+
self.context.temp_counter += 1
104+
return '__LL%d' % self.context.temp_counter
105+
98106
def static_name(self, id: str, module: Optional[str], prefix: str = STATIC_PREFIX) -> str:
99107
"""Create name of a C static variable.
100108
@@ -262,8 +270,9 @@ def emit_dec_ref(self, dest: str, rtype: RType) -> None:
262270

263271
def pretty_name(self, typ: RType) -> str:
264272
pretty_name = typ.name
265-
if isinstance(typ, ROptional):
266-
pretty_name = '%s or None' % self.pretty_name(typ.value_type)
273+
value_type = optional_value_type(typ)
274+
if value_type is not None:
275+
pretty_name = '%s or None' % self.pretty_name(value_type)
267276
return short_name(pretty_name)
268277

269278
def emit_cast(self, src: str, dest: str, typ: RType, declare_dest: bool = False,
@@ -293,17 +302,19 @@ def emit_cast(self, src: str, dest: str, typ: RType, declare_dest: bool = False,
293302
self.pretty_name(typ))
294303

295304
# Special case casting *from* optional
296-
if (src_type and isinstance(src_type, ROptional) and not is_object_rprimitive(typ)
297-
and is_same_type(src_type.value_type, typ)):
298-
if declare_dest:
299-
self.emit_line('PyObject *{};'.format(dest))
300-
self.emit_arg_check(src, dest, typ, '({} != Py_None)'.format(src), optional)
301-
self.emit_lines(
302-
' {} = {};'.format(dest, src),
303-
'else {',
304-
err,
305-
'{} = NULL;'.format(dest),
306-
'}')
305+
if src_type and is_optional_type(src_type) and not is_object_rprimitive(typ):
306+
value_type = optional_value_type(src_type)
307+
assert value_type is not None
308+
if is_same_type(value_type, typ):
309+
if declare_dest:
310+
self.emit_line('PyObject *{};'.format(dest))
311+
self.emit_arg_check(src, dest, typ, '({} != Py_None)'.format(src), optional)
312+
self.emit_lines(
313+
' {} = {};'.format(dest, src),
314+
'else {',
315+
err,
316+
'{} = NULL;'.format(dest),
317+
'}')
307318

308319
# TODO: Verify refcount handling.
309320
elif (is_list_rprimitive(typ) or is_dict_rprimitive(typ) or is_set_rprimitive(typ) or
@@ -372,18 +383,32 @@ def emit_cast(self, src: str, dest: str, typ: RType, declare_dest: bool = False,
372383
self.emit_line('{} = {};'.format(dest, src))
373384
if optional:
374385
self.emit_line('}')
375-
elif isinstance(typ, ROptional):
376-
if declare_dest:
377-
self.emit_line('PyObject *{};'.format(dest))
378-
self.emit_arg_check(src, dest, typ, '({} == Py_None)'.format(src), optional)
379-
self.emit_lines(
380-
' {} = {};'.format(dest, src),
381-
'else {')
382-
self.emit_cast(src, dest, typ.value_type, custom_message=err)
383-
self.emit_line('}')
386+
elif isinstance(typ, RUnion):
387+
self.emit_union_cast(src, dest, typ, declare_dest, err, optional, src_type)
384388
else:
385389
assert False, 'Cast not implemented: %s' % typ
386390

391+
def emit_union_cast(self, src: str, dest: str, typ: RUnion, declare_dest: bool,
392+
err: str, optional: bool, src_type: Optional[RType]) -> None:
393+
"""Emit cast to a union type.
394+
395+
The arguments are similar to emit_cast.
396+
"""
397+
if declare_dest:
398+
self.emit_line('PyObject *{};'.format(dest))
399+
good_label = self.new_label()
400+
for item in typ.items:
401+
self.emit_cast(src,
402+
dest,
403+
item,
404+
declare_dest=False,
405+
custom_message='',
406+
optional=optional)
407+
self.emit_line('if ({} != NULL) goto {};'.format(dest, good_label))
408+
# Handle cast failure.
409+
self.emit_line(err)
410+
self.emit_label(good_label)
411+
387412
def emit_arg_check(self, src: str, dest: str, typ: RType, check: str, optional: bool) -> None:
388413
if optional:
389414
self.emit_line('if ({} == NULL) {{'.format(src))

mypyc/genops.py

Lines changed: 138 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,15 @@ def f(x: int) -> int:
5151
BasicBlock, AssignmentTarget, AssignmentTargetRegister, AssignmentTargetIndex,
5252
AssignmentTargetAttr, AssignmentTargetTuple, Environment, Op, LoadInt, RType, Value, Register,
5353
Return, FuncIR, Assign, Branch, Goto, RuntimeArg, Call, Box, Unbox, Cast, RTuple, Unreachable,
54-
TupleGet, TupleSet, ClassIR, RInstance, ModuleIR, GetAttr, SetAttr, LoadStatic, ROptional,
54+
TupleGet, TupleSet, ClassIR, RInstance, ModuleIR, GetAttr, SetAttr, LoadStatic,
5555
MethodCall, INVALID_FUNC_DEF, int_rprimitive, float_rprimitive, bool_rprimitive,
5656
list_rprimitive, is_list_rprimitive, dict_rprimitive, set_rprimitive, str_rprimitive,
5757
tuple_rprimitive, none_rprimitive, is_none_rprimitive, object_rprimitive, exc_rtuple,
5858
PrimitiveOp, ControlOp, LoadErrorValue, ERR_FALSE, OpDescription, RegisterOp,
5959
is_object_rprimitive, LiteralsMap, FuncSignature, VTableAttr, VTableMethod, VTableEntries,
6060
NAMESPACE_TYPE, RaiseStandardError, LoadErrorValue, NO_TRACEBACK_LINE_NO, FuncDecl,
6161
FUNC_NORMAL, FUNC_STATICMETHOD, FUNC_CLASSMETHOD,
62+
RUnion, is_optional_type, optional_value_type
6263
)
6364
from mypyc.ops_primitive import binary_ops, unary_ops, func_ops, method_ops, name_ref_ops
6465
from mypyc.ops_list import (
@@ -266,12 +267,8 @@ def type_to_rtype(self, typ: Type) -> RType:
266267
elif isinstance(typ, NoneTyp):
267268
return none_rprimitive
268269
elif isinstance(typ, UnionType):
269-
assert len(typ.items) == 2 and any(isinstance(it, NoneTyp) for it in typ.items)
270-
if isinstance(typ.items[0], NoneTyp):
271-
value_type = typ.items[1]
272-
else:
273-
value_type = typ.items[0]
274-
return ROptional(self.type_to_rtype(value_type))
270+
return RUnion([self.type_to_rtype(item)
271+
for item in typ.items])
275272
elif isinstance(typ, AnyType):
276273
return object_rprimitive
277274
elif isinstance(typ, TypeType):
@@ -770,7 +767,7 @@ def generate_attr_defaults(self, cdef: ClassDef) -> None:
770767
# don't initialize it to anything.
771768
if isinstance(stmt.rvalue, RefExpr) and stmt.rvalue.fullname == 'builtins.None':
772769
attr_type = cls.attr_type(lvalue.name)
773-
if (not isinstance(attr_type, ROptional) and not is_object_rprimitive(attr_type)
770+
if (not is_optional_type(attr_type) and not is_object_rprimitive(attr_type)
774771
and not is_none_rprimitive(attr_type)):
775772
continue
776773

@@ -1664,10 +1661,93 @@ def visit_member_expr(self, expr: MemberExpr) -> Value:
16641661
return self.load_module_attr(expr)
16651662
else:
16661663
obj = self.accept(expr.expr)
1667-
if isinstance(obj.type, RInstance):
1668-
return self.add(GetAttr(obj, expr.name, expr.line))
1664+
return self.get_attr(obj, expr.name, self.node_type(expr), expr.line)
1665+
1666+
def get_attr(self, obj: Value, attr: str, result_type: RType, line: int) -> Value:
1667+
if isinstance(obj.type, RInstance):
1668+
return self.add(GetAttr(obj, attr, line))
1669+
elif isinstance(obj.type, RUnion):
1670+
return self.union_get_attr(obj, obj.type, attr, result_type, line)
1671+
else:
1672+
return self.py_get_attr(obj, attr, line)
1673+
1674+
def union_get_attr(self,
1675+
obj: Value,
1676+
rtype: RUnion,
1677+
attr: str,
1678+
result_type: RType,
1679+
line: int) -> Value:
1680+
def get_item_attr(value: Value) -> Value:
1681+
return self.get_attr(value, attr, result_type, line)
1682+
1683+
return self.decompose_union_helper(obj, rtype, result_type, get_item_attr, line)
1684+
1685+
def decompose_union_helper(self,
1686+
obj: Value,
1687+
rtype: RUnion,
1688+
result_type: RType,
1689+
process_item: Callable[[Value], Value],
1690+
line: int) -> Value:
1691+
"""Generate isinstance() + specialized operations for union items.
1692+
1693+
Say, for Union[A, B] generate ops resembling this (pseudocode):
1694+
1695+
if isinstance(obj, A):
1696+
result = <result of process_item(cast(A, obj)>
16691697
else:
1670-
return self.py_get_attr(obj, expr.name, expr.line)
1698+
result = <result of process_item(cast(B, obj)>
1699+
1700+
Args:
1701+
obj: value with a union type
1702+
rtype: the union type
1703+
result_type: result of the operation
1704+
process_item: callback to generate op for a single union item (arg is coerced
1705+
to union item type)
1706+
line: line number
1707+
"""
1708+
# TODO: Optimize cases where a single operation can handle multiple union items
1709+
# (say a method is implemented in a common base class)
1710+
fast_items = []
1711+
rest_items = []
1712+
for item in rtype.items:
1713+
if isinstance(item, RInstance):
1714+
fast_items.append(item)
1715+
else:
1716+
# For everything but RInstance we fall back to C API
1717+
rest_items.append(item)
1718+
exit_block = BasicBlock()
1719+
result = self.alloc_temp(result_type)
1720+
for i, item in enumerate(fast_items):
1721+
more_types = i < len(fast_items) - 1 or rest_items
1722+
if more_types:
1723+
# We are not at the final item so we need one more branch
1724+
op = self.isinstance(obj, item, line)
1725+
true_block, false_block = BasicBlock(), BasicBlock()
1726+
self.add_bool_branch(op, true_block, false_block)
1727+
self.activate_block(true_block)
1728+
coerced = self.coerce(obj, item, line)
1729+
temp = process_item(coerced)
1730+
temp2 = self.coerce(temp, result_type, line)
1731+
self.add(Assign(result, temp2))
1732+
self.goto(exit_block)
1733+
if more_types:
1734+
self.activate_block(false_block)
1735+
if rest_items:
1736+
# For everything else we use generic operation. Use force=True to drop the
1737+
# union type.
1738+
coerced = self.coerce(obj, object_rprimitive, line, force=True)
1739+
temp = process_item(coerced)
1740+
temp2 = self.coerce(temp, result_type, line)
1741+
self.add(Assign(result, temp2))
1742+
self.goto(exit_block)
1743+
self.activate_block(exit_block)
1744+
return result
1745+
1746+
def isinstance(self, obj: Value, rtype: RInstance, line: int) -> Value:
1747+
class_ir = rtype.class_ir
1748+
fullname = '%s.%s' % (class_ir.module_name, class_ir.name)
1749+
type_obj = self.load_native_type_object(fullname)
1750+
return self.primitive_op(fast_isinstance_op, [obj, type_obj], line)
16711751

16721752
def py_get_attr(self, obj: Value, attr: str, line: int) -> Value:
16731753
key = self.load_static_unicode(attr)
@@ -1937,6 +2017,9 @@ def gen_method_call(self,
19372017
arg_values = self.coerce_native_call_args(arg_values, decl.bound_sig, base.line)
19382018

19392019
return self.add(MethodCall(base, name, arg_values, line))
2020+
elif isinstance(base.type, RUnion):
2021+
return self.union_method_call(base, base.type, name, arg_values, return_rtype, line,
2022+
arg_kinds, arg_names)
19402023

19412024
# Try to do a special-cased method call
19422025
target = self.translate_special_method_call(base, name, arg_values, return_rtype, line)
@@ -1946,6 +2029,21 @@ def gen_method_call(self,
19462029
# Fall back to Python method call
19472030
return self.py_method_call(base, name, arg_values, base.line, arg_kinds, arg_names)
19482031

2032+
def union_method_call(self,
2033+
base: Value,
2034+
obj_type: RUnion,
2035+
name: str,
2036+
arg_values: List[Value],
2037+
return_rtype: RType,
2038+
line: int,
2039+
arg_kinds: Optional[List[int]],
2040+
arg_names: Optional[List[Optional[str]]]) -> Value:
2041+
def call_union_item(value: Value) -> Value:
2042+
return self.gen_method_call(value, name, arg_values, return_rtype, line,
2043+
arg_kinds, arg_names)
2044+
2045+
return self.decompose_union_helper(base, obj_type, return_rtype, call_union_item, line)
2046+
19492047
def translate_cast_expr(self, expr: CastExpr) -> Value:
19502048
src = self.accept(expr.expr)
19512049
target_type = self.type_to_rtype(expr.type)
@@ -2109,26 +2207,27 @@ def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) ->
21092207
length = self.primitive_op(list_len_op, [value], value.line)
21102208
zero = self.add(LoadInt(0))
21112209
value = self.binary_op(length, zero, '!=', value.line)
2112-
elif isinstance(value.type, ROptional):
2113-
is_none = self.binary_op(value, self.add(PrimitiveOp([], none_op, value.line)),
2114-
'is not', value.line)
2115-
branch = Branch(is_none, true, false, Branch.BOOL_EXPR)
2116-
self.add(branch)
2117-
value_type = value.type.value_type
2118-
if isinstance(value_type, RInstance):
2119-
# Optional[X] where X is always truthy
2120-
# TODO: Support __bool__
2121-
pass
2122-
else:
2123-
# Optional[X] where X may be falsey and requires a check
2124-
branch.true = self.new_block()
2125-
# unbox_or_cast instead of coerce because we want the
2126-
# type to change even if it is a subtype.
2127-
remaining = self.unbox_or_cast(value, value.type.value_type, value.line)
2128-
self.add_bool_branch(remaining, true, false)
2129-
return
2130-
elif not is_same_type(value.type, bool_rprimitive):
2131-
value = self.primitive_op(bool_op, [value], value.line)
2210+
else:
2211+
value_type = optional_value_type(value.type)
2212+
if value_type is not None:
2213+
is_none = self.binary_op(value, self.add(PrimitiveOp([], none_op, value.line)),
2214+
'is not', value.line)
2215+
branch = Branch(is_none, true, false, Branch.BOOL_EXPR)
2216+
self.add(branch)
2217+
if isinstance(value_type, RInstance):
2218+
# Optional[X] where X is always truthy
2219+
# TODO: Support __bool__
2220+
pass
2221+
else:
2222+
# Optional[X] where X may be falsey and requires a check
2223+
branch.true = self.new_block()
2224+
# unbox_or_cast instead of coerce because we want the
2225+
# type to change even if it is a subtype.
2226+
remaining = self.unbox_or_cast(value, value_type, value.line)
2227+
self.add_bool_branch(remaining, true, false)
2228+
return
2229+
elif not is_same_type(value.type, bool_rprimitive):
2230+
value = self.primitive_op(bool_op, [value], value.line)
21322231
self.add(Branch(value, true, false, Branch.BOOL_EXPR))
21332232

21342233
def visit_nonlocal_decl(self, o: NonlocalDecl) -> None:
@@ -3300,12 +3399,15 @@ def load_native_type_object(self, fullname: str) -> Value:
33003399
module, name = fullname.rsplit('.', 1)
33013400
return self.add(LoadStatic(object_rprimitive, name, module, NAMESPACE_TYPE))
33023401

3303-
def coerce(self, src: Value, target_type: RType, line: int) -> Value:
3402+
def coerce(self, src: Value, target_type: RType, line: int, force: bool = False) -> Value:
33043403
"""Generate a coercion/cast from one type to other (only if needed).
33053404
33063405
For example, int -> object boxes the source int; int -> int emits nothing;
33073406
object -> int unboxes the object. All conversions preserve object value.
33083407
3408+
If force is true, always generate an op (even if it is just an assingment) so
3409+
that the result will have exactly target_type as the type.
3410+
33093411
Returns the register with the converted value (may be same as src).
33103412
"""
33113413
if src.type.is_unboxed and not target_type.is_unboxed:
@@ -3319,6 +3421,10 @@ def coerce(self, src: Value, target_type: RType, line: int) -> Value:
33193421
if ((not src.type.is_unboxed and target_type.is_unboxed)
33203422
or not is_subtype(src.type, target_type)):
33213423
return self.unbox_or_cast(src, target_type, line)
3424+
elif force:
3425+
tmp = self.alloc_temp(target_type)
3426+
self.add(Assign(tmp, src))
3427+
return tmp
33223428
return src
33233429

33243430
def keyword_args_to_positional(self,

0 commit comments

Comments
 (0)