diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index 6af1c18145cf..923e39571ad9 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -465,7 +465,12 @@ def test_simplified_union(self) -> None: self.assert_simplified_union([fx.a, UnionType([fx.a])], fx.a) self.assert_simplified_union([fx.b, UnionType([fx.c, UnionType([fx.d])])], UnionType([fx.b, fx.c, fx.d])) + + def test_simplified_union_with_literals(self) -> None: + fx = self.fx + self.assert_simplified_union([fx.lit1, fx.a], fx.a) + self.assert_simplified_union([fx.lit1, fx.lit2, fx.a], fx.a) self.assert_simplified_union([fx.lit1, fx.lit1], fx.lit1) self.assert_simplified_union([fx.lit1, fx.lit2], UnionType([fx.lit1, fx.lit2])) self.assert_simplified_union([fx.lit1, fx.lit3], UnionType([fx.lit1, fx.lit3])) @@ -481,6 +486,40 @@ def test_simplified_union(self) -> None: self.assert_simplified_union([fx.lit1, fx.lit2_inst], UnionType([fx.lit1, fx.lit2_inst])) self.assert_simplified_union([fx.lit1, fx.lit3_inst], UnionType([fx.lit1, fx.lit3_inst])) + def test_simplified_union_with_str_literals(self) -> None: + fx = self.fx + + self.assert_simplified_union([fx.lit_str1, fx.lit_str2, fx.str_type], fx.str_type) + self.assert_simplified_union([fx.lit_str1, fx.lit_str1, fx.lit_str1], fx.lit_str1) + self.assert_simplified_union([fx.lit_str1, fx.lit_str2, fx.lit_str3], + UnionType([fx.lit_str1, fx.lit_str2, fx.lit_str3])) + self.assert_simplified_union([fx.lit_str1, fx.lit_str2, fx.uninhabited], + UnionType([fx.lit_str1, fx.lit_str2])) + + def test_simplified_union_with_str_instance_literals(self) -> None: + fx = self.fx + + self.assert_simplified_union([fx.lit_str1_inst, fx.lit_str2_inst, fx.str_type], + fx.str_type) + self.assert_simplified_union([fx.lit_str1_inst, fx.lit_str1_inst, fx.lit_str1_inst], + fx.lit_str1_inst) + self.assert_simplified_union([fx.lit_str1_inst, fx.lit_str2_inst, fx.lit_str3_inst], + UnionType([fx.lit_str1_inst, + fx.lit_str2_inst, + fx.lit_str3_inst])) + self.assert_simplified_union([fx.lit_str1_inst, fx.lit_str2_inst, fx.uninhabited], + UnionType([fx.lit_str1_inst, fx.lit_str2_inst])) + + def test_simplified_union_with_mixed_str_literals(self) -> None: + fx = self.fx + + self.assert_simplified_union([fx.lit_str1, fx.lit_str2, fx.lit_str3_inst], + UnionType([fx.lit_str1, + fx.lit_str2, + fx.lit_str3_inst])) + self.assert_simplified_union([fx.lit_str1, fx.lit_str1, fx.lit_str1_inst], + UnionType([fx.lit_str1, fx.lit_str1_inst])) + def assert_simplified_union(self, original: List[Type], union: Type) -> None: assert_equal(make_simplified_union(original), union) assert_equal(make_simplified_union(list(reversed(original))), union) diff --git a/mypy/test/typefixture.py b/mypy/test/typefixture.py index 7d4faeccf432..c8bbf67510a6 100644 --- a/mypy/test/typefixture.py +++ b/mypy/test/typefixture.py @@ -65,6 +65,7 @@ def make_type_var(name: str, id: int, values: List[Type], upper_bound: Type, variances=[COVARIANT]) # class tuple self.type_typei = self.make_type_info('builtins.type') # class type self.bool_type_info = self.make_type_info('builtins.bool') + self.str_type_info = self.make_type_info('builtins.str') self.functioni = self.make_type_info('builtins.function') # function TODO self.ai = self.make_type_info('A', mro=[self.oi]) # class A self.bi = self.make_type_info('B', mro=[self.ai, self.oi]) # class B(A) @@ -109,6 +110,7 @@ def make_type_var(name: str, id: int, values: List[Type], upper_bound: Type, self.std_tuple = Instance(self.std_tuplei, [self.anyt]) # tuple self.type_type = Instance(self.type_typei, []) # type self.function = Instance(self.functioni, []) # function TODO + self.str_type = Instance(self.str_type_info, []) self.a = Instance(self.ai, []) # A self.b = Instance(self.bi, []) # B self.c = Instance(self.ci, []) # C @@ -163,6 +165,13 @@ def make_type_var(name: str, id: int, values: List[Type], upper_bound: Type, self.lit3_inst = Instance(self.di, [], last_known_value=self.lit3) self.lit4_inst = Instance(self.ai, [], last_known_value=self.lit4) + self.lit_str1 = LiteralType("x", self.str_type) + self.lit_str2 = LiteralType("y", self.str_type) + self.lit_str3 = LiteralType("z", self.str_type) + self.lit_str1_inst = Instance(self.str_type_info, [], last_known_value=self.lit_str1) + self.lit_str2_inst = Instance(self.str_type_info, [], last_known_value=self.lit_str2) + self.lit_str3_inst = Instance(self.str_type_info, [], last_known_value=self.lit_str3) + self.type_a = TypeType.make_normalized(self.a) self.type_b = TypeType.make_normalized(self.b) self.type_c = TypeType.make_normalized(self.c) diff --git a/mypy/typeops.py b/mypy/typeops.py index 007a54e17b95..fa0523110941 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -299,15 +299,23 @@ def callable_corresponding_argument(typ: CallableType, return by_name if by_name is not None else by_pos -def is_simple_literal(t: ProperType) -> bool: - """ - Whether a type is a simple enough literal to allow for fast Union simplification +def simple_literal_value_key(t: ProperType) -> Optional[Tuple[str, ...]]: + """Return a hashable description of simple literal type. + + Return None if not a simple literal type. - For now this means enum or string + The return value can be used to simplify away duplicate types in + unions by comparing keys for equality. For now enum, string or + Instance with string last_known_value are supported. """ - return isinstance(t, LiteralType) and ( - t.fallback.type.is_enum or t.fallback.type.fullname == 'builtins.str' - ) + if isinstance(t, LiteralType): + if t.fallback.type.is_enum or t.fallback.type.fullname == 'builtins.str': + assert isinstance(t.value, str) + return 'literal', t.value, t.fallback.type.fullname + if isinstance(t, Instance): + if t.last_known_value is not None and isinstance(t.last_known_value.value, str): + return 'instance', t.last_known_value.value, t.type.fullname + return None def make_simplified_union(items: Sequence[Type], @@ -341,10 +349,20 @@ def make_simplified_union(items: Sequence[Type], all_items.append(typ) items = all_items + simplified_set = _remove_redundant_union_items(items, keep_erased) + + # If more than one literal exists in the union, try to simplify + if (contract_literals and sum(isinstance(item, LiteralType) for item in simplified_set) > 1): + simplified_set = try_contracting_literals_in_union(simplified_set) + + return UnionType.make_union(simplified_set, line, column) + + +def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) -> List[ProperType]: from mypy.subtypes import is_proper_subtype removed: Set[int] = set() - seen: Set[Tuple[str, str]] = set() + seen: Set[Tuple[str, ...]] = set() # NB: having a separate fast path for Union of Literal and slow path for other things # would arguably be cleaner, however it breaks down when simplifying the Union of two @@ -354,10 +372,8 @@ def make_simplified_union(items: Sequence[Type], if i in removed: continue # Avoid slow nested for loop for Union of Literal of strings/enums (issue #9169) - if is_simple_literal(item): - assert isinstance(item, LiteralType) - assert isinstance(item.value, str) - k = (item.value, item.fallback.type.fullname) + k = simple_literal_value_key(item) + if k is not None: if k in seen: removed.add(i) continue @@ -373,13 +389,13 @@ def make_simplified_union(items: Sequence[Type], seen.add(k) if safe_skip: continue + # Keep track of the truishness info for deleted subtypes which can be relevant cbt = cbf = False for j, tj in enumerate(items): # NB: we don't need to check literals as the fast path above takes care of that if ( i != j - and not is_simple_literal(tj) and is_proper_subtype(tj, item, keep_erased_types=keep_erased) and is_redundant_literal_instance(item, tj) # XXX? ): @@ -393,13 +409,7 @@ def make_simplified_union(items: Sequence[Type], elif not item.can_be_false and cbf: items[i] = true_or_false(item) - simplified_set = [items[i] for i in range(len(items)) if i not in removed] - - # If more than one literal exists in the union, try to simplify - if (contract_literals and sum(isinstance(item, LiteralType) for item in simplified_set) > 1): - simplified_set = try_contracting_literals_in_union(simplified_set) - - return UnionType.make_union(simplified_set, line, column) + return [items[i] for i in range(len(items)) if i not in removed] def _get_type_special_method_bool_ret_type(t: Type) -> Optional[Type]: