From b3aafa7fd78cb661257dc336680c6dd6b9fd5352 Mon Sep 17 00:00:00 2001 From: TH3CHARLie Date: Sat, 14 Dec 2019 16:29:20 +0800 Subject: [PATCH 1/4] flatten TypeAliasType when it is aliased as a Union --- mypy/checkexpr.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 35c58478ce1e..4a216f619fe8 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -19,7 +19,7 @@ PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue, is_named_instance, FunctionLike, StarType, is_optional, remove_optional, is_generic_instance, get_proper_type, ProperType, - get_proper_types + get_proper_types, flatten_nested_unions, TypeAliasType ) from mypy.nodes import ( NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr, @@ -2521,7 +2521,10 @@ def check_op(self, method: str, base_type: Type, left_variants = [base_type] base_type = get_proper_type(base_type) if isinstance(base_type, UnionType): - left_variants = [item for item in base_type.relevant_items()] + items = [get_proper_type(item) if isinstance(item, TypeAliasType) + else item for item in base_type.relevant_items()] + left_variants = [item for item in + flatten_nested_unions(items)] right_type = self.accept(arg) # Step 1: We first try leaving the right arguments alone and destructure @@ -2563,8 +2566,10 @@ def check_op(self, method: str, base_type: Type, right_variants = [(right_type, arg)] right_type = get_proper_type(right_type) if isinstance(right_type, UnionType): + items = [get_proper_type(item) if isinstance(item, TypeAliasType) + else item for item in right_type.relevant_items()] right_variants = [(item, TempNode(item, context=context)) - for item in right_type.relevant_items()] + for item in flatten_nested_unions(items)] msg = self.msg.clean_copy() msg.disable_count = 0 From 0a12cfd26c1bd52fbe99b89886ffaecec70b4aac Mon Sep 17 00:00:00 2001 From: TH3CHARLie Date: Mon, 16 Dec 2019 09:56:59 +0800 Subject: [PATCH 2/4] support recursive flattening --- mypy/checkexpr.py | 13 +++++-------- mypy/types.py | 9 +++++++-- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 4a216f619fe8..9b15075da85e 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -19,7 +19,7 @@ PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue, is_named_instance, FunctionLike, StarType, is_optional, remove_optional, is_generic_instance, get_proper_type, ProperType, - get_proper_types, flatten_nested_unions, TypeAliasType + get_proper_types, flatten_nested_unions ) from mypy.nodes import ( NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr, @@ -2521,10 +2521,9 @@ def check_op(self, method: str, base_type: Type, left_variants = [base_type] base_type = get_proper_type(base_type) if isinstance(base_type, UnionType): - items = [get_proper_type(item) if isinstance(item, TypeAliasType) - else item for item in base_type.relevant_items()] left_variants = [item for item in - flatten_nested_unions(items)] + flatten_nested_unions(base_type.relevant_items(), + handle_type_alias_type=True)] right_type = self.accept(arg) # Step 1: We first try leaving the right arguments alone and destructure @@ -2566,11 +2565,9 @@ def check_op(self, method: str, base_type: Type, right_variants = [(right_type, arg)] right_type = get_proper_type(right_type) if isinstance(right_type, UnionType): - items = [get_proper_type(item) if isinstance(item, TypeAliasType) - else item for item in right_type.relevant_items()] right_variants = [(item, TempNode(item, context=context)) - for item in flatten_nested_unions(items)] - + for item in flatten_nested_unions(right_type.relevant_items(), + handle_type_alias_type=True)] msg = self.msg.clean_copy() msg.disable_count = 0 all_results = [] diff --git a/mypy/types.py b/mypy/types.py index ae678acedb3a..ba0c8af0048a 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -2231,15 +2231,20 @@ def has_type_vars(typ: Type) -> bool: return typ.accept(HasTypeVars()) -def flatten_nested_unions(types: Iterable[Type]) -> List[Type]: +def flatten_nested_unions(types: Iterable[Type], + handle_type_alias_type: bool = False) -> List[Type]: """Flatten nested unions in a type list.""" # This and similar functions on unions can cause infinite recursion # if passed a "pathological" alias like A = Union[int, A] or similar. # TODO: ban such aliases in semantic analyzer. flat_items = [] # type: List[Type] + if handle_type_alias_type: + types = [get_proper_type(item) if isinstance(item, TypeAliasType) + else item for item in types] for tp in types: if isinstance(tp, ProperType) and isinstance(tp, UnionType): - flat_items.extend(flatten_nested_unions(tp.items)) + flat_items.extend(flatten_nested_unions(tp.items, + handle_type_alias_type=handle_type_alias_type)) else: flat_items.append(tp) return flat_items From eeb18b26a47ff496530addd6d565c1e15e666003 Mon Sep 17 00:00:00 2001 From: TH3CHARLie Date: Fri, 20 Dec 2019 10:16:33 +0800 Subject: [PATCH 3/4] minor fixes --- mypy/types.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mypy/types.py b/mypy/types.py index ba0c8af0048a..2890cc35b22b 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -2239,12 +2239,11 @@ def flatten_nested_unions(types: Iterable[Type], # TODO: ban such aliases in semantic analyzer. flat_items = [] # type: List[Type] if handle_type_alias_type: - types = [get_proper_type(item) if isinstance(item, TypeAliasType) - else item for item in types] + types = get_proper_types(types) for tp in types: if isinstance(tp, ProperType) and isinstance(tp, UnionType): flat_items.extend(flatten_nested_unions(tp.items, - handle_type_alias_type=handle_type_alias_type)) + handle_type_alias_type=handle_type_alias_type)) else: flat_items.append(tp) return flat_items From 053914b131aede26099007be7caef5f388b911c4 Mon Sep 17 00:00:00 2001 From: TH3CHARLie Date: Fri, 20 Dec 2019 10:47:14 +0800 Subject: [PATCH 4/4] add test --- test-data/unit/check-unions.test | 15 +++++++++++++++ test-data/unit/fixtures/ops.pyi | 4 ++++ 2 files changed, 19 insertions(+) diff --git a/test-data/unit/check-unions.test b/test-data/unit/check-unions.test index ed2b415e8f99..92e886fee419 100644 --- a/test-data/unit/check-unions.test +++ b/test-data/unit/check-unions.test @@ -1013,3 +1013,18 @@ y: Union[int, Dict[int, int]] = 1 if bool() else {} u: Union[int, List[int]] = [] if bool() else 1 v: Union[int, Dict[int, int]] = {} if bool() else 1 [builtins fixtures/isinstancelist.pyi] + +[case testFlattenTypeAliasWhenAliasedAsUnion] +from typing import Union + +T1 = int +T2 = Union[T1, float] +T3 = Union[T2, complex] +T4 = Union[T3, int] + +def foo(a: T2, b: T2) -> T2: + return a + b + +def bar(a: T4, b: T4) -> T4: # test multi-level alias + return a + b +[builtins fixtures/ops.pyi] diff --git a/test-data/unit/fixtures/ops.pyi b/test-data/unit/fixtures/ops.pyi index 34cfb176243e..0c3497b1667f 100644 --- a/test-data/unit/fixtures/ops.pyi +++ b/test-data/unit/fixtures/ops.pyi @@ -64,6 +64,10 @@ class float: def __truediv__(self, x: 'float') -> 'float': pass def __rtruediv__(self, x: 'float') -> 'float': pass +class complex: + def __add__(self, x: complex) -> complex: pass + def __radd__(self, x: complex) -> complex: pass + class BaseException: pass def __print(a1=None, a2=None, a3=None, a4=None): pass