@@ -553,40 +553,18 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
553553 # Check class type
554554 #
555555 type_info = o .class_ref .node
556- typ = self .chk .expr_checker .accept (o .class_ref )
557- p_typ = get_proper_type (typ )
558556 if isinstance (type_info , TypeAlias ) and not type_info .no_args :
559557 self .msg .fail (message_registry .CLASS_PATTERN_GENERIC_TYPE_ALIAS , o )
560558 return self .early_non_match ()
561- elif isinstance (p_typ , FunctionLike ) and p_typ .is_type_obj ():
562- typ = fill_typevars_with_any (p_typ .type_object ())
563- type_range = TypeRange (typ , is_upper_bound = False )
564- elif (
565- isinstance (type_info , Var )
566- and type_info .type is not None
567- and type_info .fullname == "typing.Callable"
568- ):
569- # Create a `Callable[..., Any]`
570- fallback = self .chk .named_type ("builtins.function" )
571- any_type = AnyType (TypeOfAny .unannotated )
572- typ = callable_with_ellipsis (any_type , ret_type = any_type , fallback = fallback )
573- type_range = TypeRange (typ , is_upper_bound = False )
574- elif isinstance (p_typ , TypeType ):
575- typ = p_typ .item
576- type_range = TypeRange (p_typ .item , is_upper_bound = True )
577- elif not isinstance (p_typ , AnyType ):
578- self .msg .fail (
579- message_registry .CLASS_PATTERN_TYPE_REQUIRED .format (
580- typ .str_with_options (self .options )
581- ),
582- o ,
583- )
559+
560+ typ = self .chk .expr_checker .accept (o .class_ref )
561+ type_ranges = self .get_class_pattern_type_ranges (typ , o )
562+ if type_ranges is None :
584563 return self .early_non_match ()
585- else :
586- type_range = get_type_range (typ )
564+ typ = UnionType .make_union ([t .item for t in type_ranges ])
587565
588566 new_type , rest_type = self .chk .conditional_types_with_intersection (
589- current_type , [ type_range ] , o , default = current_type
567+ current_type , type_ranges , o , default = current_type
590568 )
591569 if is_uninhabited (new_type ):
592570 return self .early_non_match ()
@@ -717,6 +695,46 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
717695 new_type = UninhabitedType ()
718696 return PatternType (new_type , rest_type , captures )
719697
698+ def get_class_pattern_type_ranges (self , typ : Type , o : ClassPattern ) -> list [TypeRange ] | None :
699+ p_typ = get_proper_type (typ )
700+
701+ if isinstance (p_typ , UnionType ):
702+ type_ranges = []
703+ for item in p_typ .items :
704+ type_range = self .get_class_pattern_type_ranges (item , o )
705+ if type_range is not None :
706+ type_ranges .extend (type_range )
707+ if not type_ranges :
708+ return None
709+ return type_ranges
710+
711+ if isinstance (p_typ , FunctionLike ) and p_typ .is_type_obj ():
712+ typ = fill_typevars_with_any (p_typ .type_object ())
713+ return [TypeRange (typ , is_upper_bound = False )]
714+ if (
715+ isinstance (o .class_ref .node , Var )
716+ and o .class_ref .node .type is not None
717+ and o .class_ref .node .fullname == "typing.Callable"
718+ ):
719+ # Create a `Callable[..., Any]`
720+ fallback = self .chk .named_type ("builtins.function" )
721+ any_type = AnyType (TypeOfAny .unannotated )
722+ typ = callable_with_ellipsis (any_type , ret_type = any_type , fallback = fallback )
723+ return [TypeRange (typ , is_upper_bound = False )]
724+ if isinstance (p_typ , TypeType ):
725+ typ = p_typ .item
726+ return [TypeRange (p_typ .item , is_upper_bound = True )]
727+ if isinstance (p_typ , AnyType ):
728+ return [TypeRange (p_typ , is_upper_bound = False )]
729+
730+ self .msg .fail (
731+ message_registry .CLASS_PATTERN_TYPE_REQUIRED .format (
732+ typ .str_with_options (self .options )
733+ ),
734+ o ,
735+ )
736+ return None
737+
720738 def should_self_match (self , typ : Type ) -> bool :
721739 typ = get_proper_type (typ )
722740 if isinstance (typ , TupleType ):
0 commit comments