Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Flatten TypeAliasType when it is aliased as a Union #8146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue,
is_named_instance, FunctionLike,
StarType, is_optional, remove_optional, is_generic_instance, get_proper_type, ProperType,
get_proper_types
get_proper_types, flatten_nested_unions
)
from mypy.nodes import (
NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
Expand Down Expand Up @@ -2521,7 +2521,9 @@ def check_op(self, method: str, base_type: Type,
left_variants = [base_type]
base_type = get_proper_type(base_type)
if isinstance(base_type, UnionType):
left_variants = [item for item in base_type.relevant_items()]
left_variants = [item for item in
flatten_nested_unions(base_type.relevant_items(),
handle_type_alias_type=True)]
right_type = self.accept(arg)

# Step 1: We first try leaving the right arguments alone and destructure
Expand Down Expand Up @@ -2564,8 +2566,8 @@ def check_op(self, method: str, base_type: Type,
right_type = get_proper_type(right_type)
if isinstance(right_type, UnionType):
right_variants = [(item, TempNode(item, context=context))
for item in right_type.relevant_items()]

for item in flatten_nested_unions(right_type.relevant_items(),
handle_type_alias_type=True)]
msg = self.msg.clean_copy()
msg.disable_count = 0
all_results = []
Expand Down
8 changes: 6 additions & 2 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2231,15 +2231,19 @@ def has_type_vars(typ: Type) -> bool:
return typ.accept(HasTypeVars())


def flatten_nested_unions(types: Iterable[Type]) -> List[Type]:
def flatten_nested_unions(types: Iterable[Type],
handle_type_alias_type: bool = False) -> List[Type]:
"""Flatten nested unions in a type list."""
# This and similar functions on unions can cause infinite recursion
# if passed a "pathological" alias like A = Union[int, A] or similar.
# TODO: ban such aliases in semantic analyzer.
flat_items = [] # type: List[Type]
if handle_type_alias_type:
types = get_proper_types(types)
for tp in types:
if isinstance(tp, ProperType) and isinstance(tp, UnionType):
flat_items.extend(flatten_nested_unions(tp.items))
flat_items.extend(flatten_nested_unions(tp.items,
handle_type_alias_type=handle_type_alias_type))
else:
flat_items.append(tp)
return flat_items
Expand Down
15 changes: 15 additions & 0 deletions test-data/unit/check-unions.test
Original file line number Diff line number Diff line change
Expand Up @@ -1013,3 +1013,18 @@ y: Union[int, Dict[int, int]] = 1 if bool() else {}
u: Union[int, List[int]] = [] if bool() else 1
v: Union[int, Dict[int, int]] = {} if bool() else 1
[builtins fixtures/isinstancelist.pyi]

[case testFlattenTypeAliasWhenAliasedAsUnion]
from typing import Union

T1 = int
T2 = Union[T1, float]
T3 = Union[T2, complex]
T4 = Union[T3, int]

def foo(a: T2, b: T2) -> T2:
return a + b

def bar(a: T4, b: T4) -> T4: # test multi-level alias
return a + b
[builtins fixtures/ops.pyi]
4 changes: 4 additions & 0 deletions test-data/unit/fixtures/ops.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class float:
def __truediv__(self, x: 'float') -> 'float': pass
def __rtruediv__(self, x: 'float') -> 'float': pass

class complex:
def __add__(self, x: complex) -> complex: pass
def __radd__(self, x: complex) -> complex: pass

class BaseException: pass

def __print(a1=None, a2=None, a3=None, a4=None): pass
Expand Down