@@ -610,3 +610,247 @@ class SomeEnum(Enum):
610610main:2: note: Revealed type is 'builtins.int'
611611[out2]
612612main:2: note: Revealed type is 'builtins.str'
613+
614+ [case testEnumReachabilityChecksBasic]
615+ from enum import Enum
616+ from typing_extensions import Literal
617+
618+ class Foo(Enum):
619+ A = 1
620+ B = 2
621+ C = 3
622+
623+ x: Literal[Foo.A, Foo.B, Foo.C]
624+ if x is Foo.A:
625+ reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
626+ elif x is Foo.B:
627+ reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
628+ elif x is Foo.C:
629+ reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
630+ else:
631+ reveal_type(x) # No output here: this branch is unreachable
632+
633+ if Foo.A is x:
634+ reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
635+ elif Foo.B is x:
636+ reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
637+ elif Foo.C is x:
638+ reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
639+ else:
640+ reveal_type(x) # No output here: this branch is unreachable
641+
642+ y: Foo
643+ if y is Foo.A:
644+ reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
645+ elif y is Foo.B:
646+ reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
647+ elif y is Foo.C:
648+ reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
649+ else:
650+ reveal_type(y) # No output here: this branch is unreachable
651+
652+ if Foo.A is y:
653+ reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
654+ elif Foo.B is y:
655+ reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
656+ elif Foo.C is y:
657+ reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
658+ else:
659+ reveal_type(y) # No output here: this branch is unreachable
660+ [builtins fixtures/bool.pyi]
661+
662+ [case testEnumReachabilityChecksIndirect]
663+ from enum import Enum
664+ from typing_extensions import Literal, Final
665+
666+ class Foo(Enum):
667+ A = 1
668+ B = 2
669+ C = 3
670+
671+ def accepts_foo_a(x: Literal[Foo.A]) -> None: ...
672+
673+ x: Foo
674+ y: Literal[Foo.A]
675+ z: Final = Foo.A
676+
677+ if x is y:
678+ reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
679+ reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
680+ else:
681+ reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
682+ reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
683+ if y is x:
684+ reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
685+ reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
686+ else:
687+ reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
688+ reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
689+
690+ if x is z:
691+ reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
692+ reveal_type(z) # N: Revealed type is '__main__.Foo'
693+ accepts_foo_a(z)
694+ else:
695+ reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
696+ reveal_type(z) # N: Revealed type is '__main__.Foo'
697+ accepts_foo_a(z)
698+ if z is x:
699+ reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
700+ reveal_type(z) # N: Revealed type is '__main__.Foo'
701+ accepts_foo_a(z)
702+ else:
703+ reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
704+ reveal_type(z) # N: Revealed type is '__main__.Foo'
705+ accepts_foo_a(z)
706+
707+ if y is z:
708+ reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
709+ reveal_type(z) # N: Revealed type is '__main__.Foo'
710+ accepts_foo_a(z)
711+ else:
712+ reveal_type(y) # No output: this branch is unreachable
713+ reveal_type(z) # No output: this branch is unreachable
714+ if z is y:
715+ reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
716+ reveal_type(z) # N: Revealed type is '__main__.Foo'
717+ accepts_foo_a(z)
718+ else:
719+ reveal_type(y) # No output: this branch is unreachable
720+ reveal_type(z) # No output: this branch is unreachable
721+ [builtins fixtures/bool.pyi]
722+
723+ [case testEnumReachabilityNoNarrowingForUnionMessiness]
724+ from enum import Enum
725+ from typing_extensions import Literal
726+
727+ class Foo(Enum):
728+ A = 1
729+ B = 2
730+ C = 3
731+
732+ x: Foo
733+ y: Literal[Foo.A, Foo.B]
734+ z: Literal[Foo.B, Foo.C]
735+
736+ # For the sake of simplicity, no narrowing is done when the narrower type is a Union.
737+ if x is y:
738+ reveal_type(x) # N: Revealed type is '__main__.Foo'
739+ reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
740+ else:
741+ reveal_type(x) # N: Revealed type is '__main__.Foo'
742+ reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
743+
744+ if y is z:
745+ reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
746+ reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
747+ else:
748+ reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
749+ reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
750+ [builtins fixtures/bool.pyi]
751+
752+ [case testEnumReachabilityWithNone]
753+ # flags: --strict-optional
754+ from enum import Enum
755+ from typing import Optional
756+
757+ class Foo(Enum):
758+ A = 1
759+ B = 2
760+ C = 3
761+
762+ x: Optional[Foo]
763+ if x:
764+ reveal_type(x) # N: Revealed type is '__main__.Foo'
765+ else:
766+ reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]'
767+
768+ if x is not None:
769+ reveal_type(x) # N: Revealed type is '__main__.Foo'
770+ else:
771+ reveal_type(x) # N: Revealed type is 'None'
772+
773+ if x is Foo.A:
774+ reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
775+ else:
776+ reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C], None]'
777+ [builtins fixtures/bool.pyi]
778+
779+ [case testEnumReachabilityWithMultipleEnums]
780+ from enum import Enum
781+ from typing import Union
782+ from typing_extensions import Literal
783+
784+ class Foo(Enum):
785+ A = 1
786+ B = 2
787+ class Bar(Enum):
788+ A = 1
789+ B = 2
790+
791+ x1: Union[Foo, Bar]
792+ if x1 is Foo.A:
793+ reveal_type(x1) # N: Revealed type is 'Literal[__main__.Foo.A]'
794+ else:
795+ reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], __main__.Bar]'
796+
797+ x2: Union[Foo, Bar]
798+ if x2 is Bar.A:
799+ reveal_type(x2) # N: Revealed type is 'Literal[__main__.Bar.A]'
800+ else:
801+ reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, Literal[__main__.Bar.B]]'
802+
803+ x3: Union[Foo, Bar]
804+ if x3 is Foo.A or x3 is Bar.A:
805+ reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Bar.A]]'
806+ else:
807+ reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Bar.B]]'
808+
809+ [builtins fixtures/bool.pyi]
810+
811+ [case testEnumReachabilityPEP484Example1]
812+ # flags: --strict-optional
813+ from typing import Union
814+ from typing_extensions import Final
815+ from enum import Enum
816+
817+ class Empty(Enum):
818+ token = 0
819+ _empty: Final = Empty.token
820+
821+ def func(x: Union[int, None, Empty] = _empty) -> int:
822+ boom = x + 42 # E: Unsupported left operand type for + ("None") \
823+ # E: Unsupported left operand type for + ("Empty") \
824+ # N: Left operand is of type "Union[int, None, Empty]"
825+ if x is _empty:
826+ reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]'
827+ return 0
828+ elif x is None:
829+ reveal_type(x) # N: Revealed type is 'None'
830+ return 1
831+ else: # At this point typechecker knows that x can only have type int
832+ reveal_type(x) # N: Revealed type is 'builtins.int'
833+ return x + 2
834+ [builtins fixtures/primitives.pyi]
835+
836+ [case testEnumReachabilityPEP484Example2]
837+ from typing import Union
838+ from enum import Enum
839+
840+ class Reason(Enum):
841+ timeout = 1
842+ error = 2
843+
844+ def process(response: Union[str, Reason] = '') -> str:
845+ if response is Reason.timeout:
846+ reveal_type(response) # N: Revealed type is 'Literal[__main__.Reason.timeout]'
847+ return 'TIMEOUT'
848+ elif response is Reason.error:
849+ reveal_type(response) # N: Revealed type is 'Literal[__main__.Reason.error]'
850+ return 'ERROR'
851+ else:
852+ # response can be only str, all other possible values exhausted
853+ reveal_type(response) # N: Revealed type is 'builtins.str'
854+ return 'PROCESSED: ' + response
855+
856+ [builtins fixtures/primitives.pyi]
0 commit comments