Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0d20e52

Browse files
authored
Better match narrowing for unions of type objects (#20905)
Sequel to #20872 that has a slightly thrashy diff
1 parent 0cf9c02 commit 0d20e52

2 files changed

Lines changed: 70 additions & 38 deletions

File tree

mypy/checkpattern.py

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -553,40 +553,18 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
553553
# Check class type
554554
#
555555
type_info = o.class_ref.node
556-
typ = self.chk.expr_checker.accept(o.class_ref)
557-
p_typ = get_proper_type(typ)
558556
if isinstance(type_info, TypeAlias) and not type_info.no_args:
559557
self.msg.fail(message_registry.CLASS_PATTERN_GENERIC_TYPE_ALIAS, o)
560558
return self.early_non_match()
561-
elif isinstance(p_typ, FunctionLike) and p_typ.is_type_obj():
562-
typ = fill_typevars_with_any(p_typ.type_object())
563-
type_range = TypeRange(typ, is_upper_bound=False)
564-
elif (
565-
isinstance(type_info, Var)
566-
and type_info.type is not None
567-
and type_info.fullname == "typing.Callable"
568-
):
569-
# Create a `Callable[..., Any]`
570-
fallback = self.chk.named_type("builtins.function")
571-
any_type = AnyType(TypeOfAny.unannotated)
572-
typ = callable_with_ellipsis(any_type, ret_type=any_type, fallback=fallback)
573-
type_range = TypeRange(typ, is_upper_bound=False)
574-
elif isinstance(p_typ, TypeType):
575-
typ = p_typ.item
576-
type_range = TypeRange(p_typ.item, is_upper_bound=True)
577-
elif not isinstance(p_typ, AnyType):
578-
self.msg.fail(
579-
message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(
580-
typ.str_with_options(self.options)
581-
),
582-
o,
583-
)
559+
560+
typ = self.chk.expr_checker.accept(o.class_ref)
561+
type_ranges = self.get_class_pattern_type_ranges(typ, o)
562+
if type_ranges is None:
584563
return self.early_non_match()
585-
else:
586-
type_range = get_type_range(typ)
564+
typ = UnionType.make_union([t.item for t in type_ranges])
587565

588566
new_type, rest_type = self.chk.conditional_types_with_intersection(
589-
current_type, [type_range], o, default=current_type
567+
current_type, type_ranges, o, default=current_type
590568
)
591569
if is_uninhabited(new_type):
592570
return self.early_non_match()
@@ -717,6 +695,46 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
717695
new_type = UninhabitedType()
718696
return PatternType(new_type, rest_type, captures)
719697

698+
def get_class_pattern_type_ranges(self, typ: Type, o: ClassPattern) -> list[TypeRange] | None:
699+
p_typ = get_proper_type(typ)
700+
701+
if isinstance(p_typ, UnionType):
702+
type_ranges = []
703+
for item in p_typ.items:
704+
type_range = self.get_class_pattern_type_ranges(item, o)
705+
if type_range is not None:
706+
type_ranges.extend(type_range)
707+
if not type_ranges:
708+
return None
709+
return type_ranges
710+
711+
if isinstance(p_typ, FunctionLike) and p_typ.is_type_obj():
712+
typ = fill_typevars_with_any(p_typ.type_object())
713+
return [TypeRange(typ, is_upper_bound=False)]
714+
if (
715+
isinstance(o.class_ref.node, Var)
716+
and o.class_ref.node.type is not None
717+
and o.class_ref.node.fullname == "typing.Callable"
718+
):
719+
# Create a `Callable[..., Any]`
720+
fallback = self.chk.named_type("builtins.function")
721+
any_type = AnyType(TypeOfAny.unannotated)
722+
typ = callable_with_ellipsis(any_type, ret_type=any_type, fallback=fallback)
723+
return [TypeRange(typ, is_upper_bound=False)]
724+
if isinstance(p_typ, TypeType):
725+
typ = p_typ.item
726+
return [TypeRange(p_typ.item, is_upper_bound=True)]
727+
if isinstance(p_typ, AnyType):
728+
return [TypeRange(p_typ, is_upper_bound=False)]
729+
730+
self.msg.fail(
731+
message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(
732+
typ.str_with_options(self.options)
733+
),
734+
o,
735+
)
736+
return None
737+
720738
def should_self_match(self, typ: Type) -> bool:
721739
typ = get_proper_type(typ)
722740
if isinstance(typ, TupleType):

test-data/unit/check-python310.test

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,13 +1091,27 @@ match m:
10911091
[builtins fixtures/tuple.pyi]
10921092

10931093
[case testMatchClassPatternIsNotType]
1094-
a = 1
1095-
m: object
1094+
# flags: --strict-equality --warn-unreachable
1095+
from typing import Any
10961096

1097-
match m:
1098-
case a(i, j): # E: Expected type in class pattern; found "builtins.int"
1099-
reveal_type(i)
1100-
reveal_type(j)
1097+
def match_int(m: object, a: int):
1098+
match m:
1099+
case a(i, j): # E: Expected type in class pattern; found "builtins.int"
1100+
reveal_type(i) # E: Statement is unreachable
1101+
reveal_type(j)
1102+
1103+
def match_int_str(m: object, a: int | str):
1104+
match m:
1105+
case a(i, j): # E: Expected type in class pattern; found "builtins.int" \
1106+
# E: Expected type in class pattern; found "builtins.str"
1107+
reveal_type(i) # E: Statement is unreachable
1108+
reveal_type(j)
1109+
1110+
def match_int_any(m: object, a: int | Any):
1111+
match m:
1112+
case a(i, j): # E: Expected type in class pattern; found "builtins.int"
1113+
reveal_type(i) # N: Revealed type is "Any"
1114+
reveal_type(j) # N: Revealed type is "Any"
11011115

11021116
[case testMatchClassPatternAny]
11031117
from typing import Any
@@ -1300,15 +1314,15 @@ def f4(T: type[Example | Example2]) -> None:
13001314

13011315
def f5(T: type[Example | Example2]) -> None:
13021316
match Example("a"):
1303-
case T(value): # E: Expected type in class pattern; found "type[__main__.Example] | type[__main__.Example2]"
1304-
reveal_type(value) # E: Statement is unreachable
1317+
case T(value):
1318+
reveal_type(value) # N: Revealed type is "builtins.str"
13051319
case anything:
13061320
reveal_type(anything) # N: Revealed type is "__main__.Example"
13071321

13081322
def f6(T: type[Example | Example2]) -> None:
13091323
match T("a"):
1310-
case T(value): # E: Expected type in class pattern; found "type[__main__.Example] | type[__main__.Example2]"
1311-
reveal_type(value) # E: Statement is unreachable
1324+
case T(value):
1325+
reveal_type(value) # N: Revealed type is "builtins.str"
13121326
case anything:
13131327
reveal_type(anything) # N: Revealed type is "__main__.Example | __main__.Example2"
13141328

0 commit comments

Comments
 (0)