Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Fix union simplification performance regression #12519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions mypy/test/testtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,12 @@ def test_simplified_union(self) -> None:
self.assert_simplified_union([fx.a, UnionType([fx.a])], fx.a)
self.assert_simplified_union([fx.b, UnionType([fx.c, UnionType([fx.d])])],
UnionType([fx.b, fx.c, fx.d]))

def test_simplified_union_with_literals(self) -> None:
fx = self.fx

self.assert_simplified_union([fx.lit1, fx.a], fx.a)
self.assert_simplified_union([fx.lit1, fx.lit2, fx.a], fx.a)
self.assert_simplified_union([fx.lit1, fx.lit1], fx.lit1)
self.assert_simplified_union([fx.lit1, fx.lit2], UnionType([fx.lit1, fx.lit2]))
self.assert_simplified_union([fx.lit1, fx.lit3], UnionType([fx.lit1, fx.lit3]))
Expand All @@ -481,6 +486,40 @@ def test_simplified_union(self) -> None:
self.assert_simplified_union([fx.lit1, fx.lit2_inst], UnionType([fx.lit1, fx.lit2_inst]))
self.assert_simplified_union([fx.lit1, fx.lit3_inst], UnionType([fx.lit1, fx.lit3_inst]))

def test_simplified_union_with_str_literals(self) -> None:
fx = self.fx

self.assert_simplified_union([fx.lit_str1, fx.lit_str2, fx.str_type], fx.str_type)
self.assert_simplified_union([fx.lit_str1, fx.lit_str1, fx.lit_str1], fx.lit_str1)
self.assert_simplified_union([fx.lit_str1, fx.lit_str2, fx.lit_str3],
UnionType([fx.lit_str1, fx.lit_str2, fx.lit_str3]))
self.assert_simplified_union([fx.lit_str1, fx.lit_str2, fx.uninhabited],
UnionType([fx.lit_str1, fx.lit_str2]))

def test_simplified_union_with_str_instance_literals(self) -> None:
fx = self.fx

self.assert_simplified_union([fx.lit_str1_inst, fx.lit_str2_inst, fx.str_type],
fx.str_type)
self.assert_simplified_union([fx.lit_str1_inst, fx.lit_str1_inst, fx.lit_str1_inst],
fx.lit_str1_inst)
self.assert_simplified_union([fx.lit_str1_inst, fx.lit_str2_inst, fx.lit_str3_inst],
UnionType([fx.lit_str1_inst,
fx.lit_str2_inst,
fx.lit_str3_inst]))
self.assert_simplified_union([fx.lit_str1_inst, fx.lit_str2_inst, fx.uninhabited],
UnionType([fx.lit_str1_inst, fx.lit_str2_inst]))

def test_simplified_union_with_mixed_str_literals(self) -> None:
fx = self.fx

self.assert_simplified_union([fx.lit_str1, fx.lit_str2, fx.lit_str3_inst],
UnionType([fx.lit_str1,
fx.lit_str2,
fx.lit_str3_inst]))
self.assert_simplified_union([fx.lit_str1, fx.lit_str1, fx.lit_str1_inst],
UnionType([fx.lit_str1, fx.lit_str1_inst]))

def assert_simplified_union(self, original: List[Type], union: Type) -> None:
assert_equal(make_simplified_union(original), union)
assert_equal(make_simplified_union(list(reversed(original))), union)
Expand Down
9 changes: 9 additions & 0 deletions mypy/test/typefixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def make_type_var(name: str, id: int, values: List[Type], upper_bound: Type,
variances=[COVARIANT]) # class tuple
self.type_typei = self.make_type_info('builtins.type') # class type
self.bool_type_info = self.make_type_info('builtins.bool')
self.str_type_info = self.make_type_info('builtins.str')
self.functioni = self.make_type_info('builtins.function') # function TODO
self.ai = self.make_type_info('A', mro=[self.oi]) # class A
self.bi = self.make_type_info('B', mro=[self.ai, self.oi]) # class B(A)
Expand Down Expand Up @@ -109,6 +110,7 @@ def make_type_var(name: str, id: int, values: List[Type], upper_bound: Type,
self.std_tuple = Instance(self.std_tuplei, [self.anyt]) # tuple
self.type_type = Instance(self.type_typei, []) # type
self.function = Instance(self.functioni, []) # function TODO
self.str_type = Instance(self.str_type_info, [])
self.a = Instance(self.ai, []) # A
self.b = Instance(self.bi, []) # B
self.c = Instance(self.ci, []) # C
Expand Down Expand Up @@ -163,6 +165,13 @@ def make_type_var(name: str, id: int, values: List[Type], upper_bound: Type,
self.lit3_inst = Instance(self.di, [], last_known_value=self.lit3)
self.lit4_inst = Instance(self.ai, [], last_known_value=self.lit4)

self.lit_str1 = LiteralType("x", self.str_type)
self.lit_str2 = LiteralType("y", self.str_type)
self.lit_str3 = LiteralType("z", self.str_type)
self.lit_str1_inst = Instance(self.str_type_info, [], last_known_value=self.lit_str1)
self.lit_str2_inst = Instance(self.str_type_info, [], last_known_value=self.lit_str2)
self.lit_str3_inst = Instance(self.str_type_info, [], last_known_value=self.lit_str3)

self.type_a = TypeType.make_normalized(self.a)
self.type_b = TypeType.make_normalized(self.b)
self.type_c = TypeType.make_normalized(self.c)
Expand Down
50 changes: 30 additions & 20 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,23 @@ def callable_corresponding_argument(typ: CallableType,
return by_name if by_name is not None else by_pos


def is_simple_literal(t: ProperType) -> bool:
"""
Whether a type is a simple enough literal to allow for fast Union simplification
def simple_literal_value_key(t: ProperType) -> Optional[Tuple[str, ...]]:
"""Return a hashable description of simple literal type.

Return None if not a simple literal type.

For now this means enum or string
The return value can be used to simplify away duplicate types in
unions by comparing keys for equality. For now enum, string or
Instance with string last_known_value are supported.
"""
return isinstance(t, LiteralType) and (
t.fallback.type.is_enum or t.fallback.type.fullname == 'builtins.str'
)
if isinstance(t, LiteralType):
if t.fallback.type.is_enum or t.fallback.type.fullname == 'builtins.str':
assert isinstance(t.value, str)
return 'literal', t.value, t.fallback.type.fullname
if isinstance(t, Instance):
if t.last_known_value is not None and isinstance(t.last_known_value.value, str):
return 'instance', t.last_known_value.value, t.type.fullname
return None


def make_simplified_union(items: Sequence[Type],
Expand Down Expand Up @@ -341,10 +349,20 @@ def make_simplified_union(items: Sequence[Type],
all_items.append(typ)
items = all_items

simplified_set = _remove_redundant_union_items(items, keep_erased)

# If more than one literal exists in the union, try to simplify
if (contract_literals and sum(isinstance(item, LiteralType) for item in simplified_set) > 1):
simplified_set = try_contracting_literals_in_union(simplified_set)

return UnionType.make_union(simplified_set, line, column)


def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) -> List[ProperType]:
from mypy.subtypes import is_proper_subtype

removed: Set[int] = set()
seen: Set[Tuple[str, str]] = set()
seen: Set[Tuple[str, ...]] = set()

# NB: having a separate fast path for Union of Literal and slow path for other things
# would arguably be cleaner, however it breaks down when simplifying the Union of two
Expand All @@ -354,10 +372,8 @@ def make_simplified_union(items: Sequence[Type],
if i in removed:
continue
# Avoid slow nested for loop for Union of Literal of strings/enums (issue #9169)
if is_simple_literal(item):
assert isinstance(item, LiteralType)
assert isinstance(item.value, str)
k = (item.value, item.fallback.type.fullname)
k = simple_literal_value_key(item)
if k is not None:
if k in seen:
removed.add(i)
continue
Expand All @@ -373,13 +389,13 @@ def make_simplified_union(items: Sequence[Type],
seen.add(k)
if safe_skip:
continue

# Keep track of the truishness info for deleted subtypes which can be relevant
cbt = cbf = False
for j, tj in enumerate(items):
# NB: we don't need to check literals as the fast path above takes care of that
if (
i != j
and not is_simple_literal(tj)
and is_proper_subtype(tj, item, keep_erased_types=keep_erased)
and is_redundant_literal_instance(item, tj) # XXX?
):
Expand All @@ -393,13 +409,7 @@ def make_simplified_union(items: Sequence[Type],
elif not item.can_be_false and cbf:
items[i] = true_or_false(item)

simplified_set = [items[i] for i in range(len(items)) if i not in removed]

# If more than one literal exists in the union, try to simplify
if (contract_literals and sum(isinstance(item, LiteralType) for item in simplified_set) > 1):
simplified_set = try_contracting_literals_in_union(simplified_set)

return UnionType.make_union(simplified_set, line, column)
return [items[i] for i in range(len(items)) if i not in removed]


def _get_type_special_method_bool_ret_type(t: Type) -> Optional[Type]:
Expand Down