From 396e8ee6b95f9761cc480cb0e281675e00f6841c Mon Sep 17 00:00:00 2001 From: James Hilton-Balfe Date: Sat, 15 Jul 2023 23:04:06 +0200 Subject: [PATCH 1/8] Allow returning Literals in __new__ Unblocks https://github.com/python/typeshed/pull/10465 --- mypy/checker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index f2873c7d58e4..06d770665ea9 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1475,7 +1475,8 @@ def check___new___signature(self, fdef: FuncDef, typ: CallableType) -> None: "but must return a subtype of", ) elif not isinstance( - get_proper_type(bound_type.ret_type), (AnyType, Instance, TupleType, UninhabitedType) + get_proper_type(bound_type.ret_type), + (AnyType, Instance, TupleType, UninhabitedType, LiteralType), ): self.fail( message_registry.NON_INSTANCE_NEW_TYPE.format( From abbad4ede09495b9b4472c32cac11fd702f624e1 Mon Sep 17 00:00:00 2001 From: Gobot1234 Date: Wed, 19 Jul 2023 10:25:51 +0100 Subject: [PATCH 2/8] Add tests --- test-data/unit/check-classes.test | 22 ++++++++++++++++++++ test-data/unit/fixtures/__new__-literals.pyi | 6 ++++++ 2 files changed, 28 insertions(+) create mode 100644 test-data/unit/fixtures/__new__-literals.pyi diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index 957eb9214d7c..ed18b859e9df 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -473,6 +473,28 @@ class B(A): def __new__(cls) -> B: pass +[case testOverride__new__WithLiteralReturnPassing] +[file builtins.py] +from typing import Literal + +class int: pass +class str: pass +class dict: pass +class float: pass + +class bool: + def __new__(cls) -> Literal[True]: pass +[typing fixtures/typing-medium.pyi] + +[case testOverride__new__WithLiteralReturnFailing] +from typing import Literal + +class Foo: + def __new__(cls) -> Literal[1]: pass # E: Incompatible return type for "__new__" (returns "Literal[1]", but must return a subtype of "Foo") + +[builtins fixtures/__new__-literals.pyi] +[typing fixtures/typing-medium.pyi] + [case testOverride__new__AndCallObject] from typing import TypeVar, Generic diff --git a/test-data/unit/fixtures/__new__-literals.pyi b/test-data/unit/fixtures/__new__-literals.pyi new file mode 100644 index 000000000000..414f0d2d858a --- /dev/null +++ b/test-data/unit/fixtures/__new__-literals.pyi @@ -0,0 +1,6 @@ +# builtins stub for testing __new__ literals + +class float: pass +class int: pass +class str: pass +class dict: pass From 90b37d4cf82ec863b6942b96437154e87b688485 Mon Sep 17 00:00:00 2001 From: Gobot1234 Date: Wed, 19 Jul 2023 11:19:16 +0100 Subject: [PATCH 3/8] Update tests --- test-data/unit/check-classes.test | 23 +++++++++++++++----- test-data/unit/fixtures/__new__-literals.pyi | 6 ----- test-data/unit/fixtures/__new__.pyi | 1 + 3 files changed, 19 insertions(+), 11 deletions(-) delete mode 100644 test-data/unit/fixtures/__new__-literals.pyi diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index ed18b859e9df..d9f0e4a7967a 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -475,15 +475,28 @@ class B(A): [case testOverride__new__WithLiteralReturnPassing] [file builtins.py] -from typing import Literal +from typing import Literal, Protocol, overload -class int: pass class str: pass class dict: pass class float: pass -class bool: - def __new__(cls) -> Literal[True]: pass +class _Truthy(Protocol): + def __bool__(self) -> Literal[True]: pass + +class _Falsy(Protocol): + def __bool__(self) -> Literal[False]: pass + +class bool(int): + @overload + def __new__(cls, __o: _Truthy) -> Literal[True]: pass + @overload + def __new__(cls, __o: _Falsy) -> Literal[False]: pass + def __new__(cls, __o: object): + pass + +class int: + def __new__(cls) -> Literal[0]: pass [typing fixtures/typing-medium.pyi] [case testOverride__new__WithLiteralReturnFailing] @@ -492,7 +505,7 @@ from typing import Literal class Foo: def __new__(cls) -> Literal[1]: pass # E: Incompatible return type for "__new__" (returns "Literal[1]", but must return a subtype of "Foo") -[builtins fixtures/__new__-literals.pyi] +[builtins fixtures/__new__.pyi] [typing fixtures/typing-medium.pyi] [case testOverride__new__AndCallObject] diff --git a/test-data/unit/fixtures/__new__-literals.pyi b/test-data/unit/fixtures/__new__-literals.pyi deleted file mode 100644 index 414f0d2d858a..000000000000 --- a/test-data/unit/fixtures/__new__-literals.pyi +++ /dev/null @@ -1,6 +0,0 @@ -# builtins stub for testing __new__ literals - -class float: pass -class int: pass -class str: pass -class dict: pass diff --git a/test-data/unit/fixtures/__new__.pyi b/test-data/unit/fixtures/__new__.pyi index 401de6fb9cd1..57d3624ce92c 100644 --- a/test-data/unit/fixtures/__new__.pyi +++ b/test-data/unit/fixtures/__new__.pyi @@ -12,6 +12,7 @@ class object: class type: def __init__(self, x) -> None: pass +class float: pass class int: pass class bool: pass class str: pass From 79cb39ef6862181f509abbcee4790e2f0fb968a0 Mon Sep 17 00:00:00 2001 From: James Hilton-Balfe Date: Wed, 19 Jul 2023 11:33:01 +0100 Subject: [PATCH 4/8] Temporarily update typeshed --- mypy/typeshed/stdlib/builtins.pyi | 116 +++++++++++++++++++++++++++++- 1 file changed, 114 insertions(+), 2 deletions(-) diff --git a/mypy/typeshed/stdlib/builtins.pyi b/mypy/typeshed/stdlib/builtins.pyi index d6ca39049c77..e11b06025b6d 100644 --- a/mypy/typeshed/stdlib/builtins.pyi +++ b/mypy/typeshed/stdlib/builtins.pyi @@ -56,6 +56,7 @@ from typing import ( # noqa: Y022 from typing_extensions import ( Concatenate, Literal, + LiteralString, ParamSpec, Self, SupportsIndex, @@ -213,6 +214,8 @@ _NegativeInteger: TypeAlias = Literal[-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, - _LiteralInteger = _PositiveInteger | _NegativeInteger | Literal[0] # noqa: Y026 # TODO: Use TypeAlias once mypy bugs are fixed class int: + @overload + def __new__(cls) -> Literal[0]: ... @overload def __new__(cls, __x: str | ReadableBuffer | SupportsInt | SupportsIndex | SupportsTrunc = ...) -> Self: ... @overload @@ -431,12 +434,23 @@ class _TranslateTable(Protocol): def __getitem__(self, __key: int) -> str | int | None: ... class str(Sequence[str]): + @overload + def __new__(cls) -> Literal[""]: ... @overload def __new__(cls, object: object = ...) -> Self: ... @overload def __new__(cls, object: ReadableBuffer, encoding: str = ..., errors: str = ...) -> Self: ... + @overload + def capitalize(self: LiteralString) -> LiteralString: ... + @overload def capitalize(self) -> str: ... # type: ignore[misc] + @overload + def casefold(self: LiteralString) -> LiteralString: ... + @overload def casefold(self) -> str: ... # type: ignore[misc] + @overload + def center(self: LiteralString, __width: SupportsIndex, __fillchar: LiteralString = " ") -> LiteralString: ... + @overload def center(self, __width: SupportsIndex, __fillchar: str = " ") -> str: ... # type: ignore[misc] def count(self, x: str, __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...) -> int: ... def encode(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: ... @@ -444,11 +458,20 @@ class str(Sequence[str]): self, __suffix: str | tuple[str, ...], __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ... ) -> bool: ... if sys.version_info >= (3, 8): + @overload + def expandtabs(self: LiteralString, tabsize: SupportsIndex = 8) -> LiteralString: ... + @overload def expandtabs(self, tabsize: SupportsIndex = 8) -> str: ... # type: ignore[misc] else: + @overload + def expandtabs(self: LiteralString, tabsize: int = 8) -> LiteralString: ... + @overload def expandtabs(self, tabsize: int = 8) -> str: ... # type: ignore[misc] def find(self, __sub: str, __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...) -> int: ... + @overload + def format(self: LiteralString, *args: LiteralString, **kwargs: LiteralString) -> LiteralString: ... + @overload def format(self, *args: object, **kwargs: object) -> str: ... def format_map(self, map: _FormatMapMapping) -> str: ... def index(self, __sub: str, __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...) -> int: ... @@ -464,32 +487,91 @@ class str(Sequence[str]): def isspace(self) -> bool: ... def istitle(self) -> bool: ... def isupper(self) -> bool: ... + @overload + def join(self: LiteralString, __iterable: Iterable[LiteralString]) -> LiteralString: ... + @overload def join(self, __iterable: Iterable[str]) -> str: ... # type: ignore[misc] + @overload + def ljust(self: LiteralString, __width: SupportsIndex, __fillchar: LiteralString = " ") -> LiteralString: ... + @overload def ljust(self, __width: SupportsIndex, __fillchar: str = " ") -> str: ... # type: ignore[misc] + @overload + def lower(self: LiteralString) -> LiteralString: ... + @overload def lower(self) -> str: ... # type: ignore[misc] + @overload + def lstrip(self: LiteralString, __chars: LiteralString | None = None) -> LiteralString: ... + @overload def lstrip(self, __chars: str | None = None) -> str: ... # type: ignore[misc] + @overload + def partition(self: LiteralString, __sep: LiteralString) -> tuple[LiteralString, LiteralString, LiteralString]: ... + @overload def partition(self, __sep: str) -> tuple[str, str, str]: ... # type: ignore[misc] + @overload + def replace( + self: LiteralString, __old: LiteralString, __new: LiteralString, __count: SupportsIndex = -1 + ) -> LiteralString: ... + @overload def replace(self, __old: str, __new: str, __count: SupportsIndex = -1) -> str: ... # type: ignore[misc] if sys.version_info >= (3, 9): + @overload + def removeprefix(self: LiteralString, __prefix: LiteralString) -> LiteralString: ... + @overload def removeprefix(self, __prefix: str) -> str: ... # type: ignore[misc] + @overload + def removesuffix(self: LiteralString, __suffix: LiteralString) -> LiteralString: ... + @overload def removesuffix(self, __suffix: str) -> str: ... # type: ignore[misc] def rfind(self, __sub: str, __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...) -> int: ... def rindex(self, __sub: str, __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...) -> int: ... + @overload + def rjust(self: LiteralString, __width: SupportsIndex, __fillchar: LiteralString = " ") -> LiteralString: ... + @overload def rjust(self, __width: SupportsIndex, __fillchar: str = " ") -> str: ... # type: ignore[misc] + @overload + def rpartition(self: LiteralString, __sep: LiteralString) -> tuple[LiteralString, LiteralString, LiteralString]: ... + @overload def rpartition(self, __sep: str) -> tuple[str, str, str]: ... # type: ignore[misc] + @overload + def rsplit(self: LiteralString, sep: LiteralString | None = None, maxsplit: SupportsIndex = -1) -> list[LiteralString]: ... + @overload def rsplit(self, sep: str | None = None, maxsplit: SupportsIndex = -1) -> list[str]: ... # type: ignore[misc] + @overload + def rstrip(self: LiteralString, __chars: LiteralString | None = None) -> LiteralString: ... + @overload def rstrip(self, __chars: str | None = None) -> str: ... # type: ignore[misc] + @overload + def split(self: LiteralString, sep: LiteralString | None = None, maxsplit: SupportsIndex = -1) -> list[LiteralString]: ... + @overload def split(self, sep: str | None = None, maxsplit: SupportsIndex = -1) -> list[str]: ... # type: ignore[misc] + @overload + def splitlines(self: LiteralString, keepends: bool = False) -> list[LiteralString]: ... + @overload def splitlines(self, keepends: bool = False) -> list[str]: ... # type: ignore[misc] def startswith( self, __prefix: str | tuple[str, ...], __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ... ) -> bool: ... + @overload + def strip(self: LiteralString, __chars: LiteralString | None = None) -> LiteralString: ... + @overload def strip(self, __chars: str | None = None) -> str: ... # type: ignore[misc] + @overload + def swapcase(self: LiteralString) -> LiteralString: ... + @overload def swapcase(self) -> str: ... # type: ignore[misc] + @overload + def title(self: LiteralString) -> LiteralString: ... + @overload def title(self) -> str: ... # type: ignore[misc] def translate(self, __table: _TranslateTable) -> str: ... + @overload + def upper(self: LiteralString) -> LiteralString: ... + @overload def upper(self) -> str: ... # type: ignore[misc] + @overload + def zfill(self: LiteralString, __width: SupportsIndex) -> LiteralString: ... + @overload def zfill(self, __width: SupportsIndex) -> str: ... # type: ignore[misc] @staticmethod @overload @@ -500,6 +582,9 @@ class str(Sequence[str]): @staticmethod @overload def maketrans(__x: str, __y: str, __z: str) -> dict[int, int | None]: ... + @overload + def __add__(self: LiteralString, __value: LiteralString) -> LiteralString: ... + @overload def __add__(self, __value: str) -> str: ... # type: ignore[misc] # Incompatible with Sequence.__contains__ def __contains__(self, __key: str) -> bool: ... # type: ignore[override] @@ -508,17 +593,31 @@ class str(Sequence[str]): def __getitem__(self, __key: SupportsIndex | slice) -> str: ... def __gt__(self, __value: str) -> bool: ... def __hash__(self) -> int: ... + @overload + def __iter__(self: LiteralString) -> Iterator[LiteralString]: ... + @overload def __iter__(self) -> Iterator[str]: ... # type: ignore[misc] def __le__(self, __value: str) -> bool: ... def __len__(self) -> int: ... def __lt__(self, __value: str) -> bool: ... + @overload + def __mod__(self: LiteralString, __value: LiteralString | tuple[LiteralString, ...]) -> LiteralString: ... + @overload def __mod__(self, __value: Any) -> str: ... + @overload + def __mul__(self: LiteralString, __value: SupportsIndex) -> LiteralString: ... + @overload def __mul__(self, __value: SupportsIndex) -> str: ... # type: ignore[misc] def __ne__(self, __value: object) -> bool: ... + @overload + def __rmul__(self: LiteralString, __value: SupportsIndex) -> LiteralString: ... + @overload def __rmul__(self, __value: SupportsIndex) -> str: ... # type: ignore[misc] def __getnewargs__(self) -> tuple[str]: ... class bytes(Sequence[int]): + @overload + def __new__(cls) -> Literal[b""]: ... @overload def __new__(cls, __o: Iterable[SupportsIndex] | SupportsIndex | SupportsBytes | ReadableBuffer) -> Self: ... @overload @@ -805,8 +904,21 @@ class memoryview(Sequence[int]): def __buffer__(self, __flags: int) -> memoryview: ... def __release_buffer__(self, __buffer: memoryview) -> None: ... +class _Truthy(Protocol): + def __bool__(self) -> Literal[True]: ... + +class _Falsy(Protocol): + def __bool__(self) -> Literal[False]: ... + @final class bool(int): + @overload + def __new__(cls) -> Literal[False]: ... + @overload + def __new__(cls, __o: _Truthy) -> Literal[True]: ... + @overload + def __new__(cls, __o: _Falsy) -> Literal[False]: ... + @overload def __new__(cls, __o: object = ...) -> Self: ... # The following overloads could be represented more elegantly with a TypeVar("_B", bool, int), # however mypy has a bug regarding TypeVar constraints (https://github.com/python/mypy/issues/11880). @@ -1676,11 +1788,11 @@ _SupportsSumNoDefaultT = TypeVar("_SupportsSumNoDefaultT", bound=_SupportsSumWit # Instead, we special-case the most common examples of this: bool and literal integers. if sys.version_info >= (3, 8): @overload - def sum(__iterable: Iterable[bool], start: int = 0) -> int: ... # type: ignore[misc] + def sum(__iterable: Iterable[bool | _LiteralInteger], start: int = 0) -> int: ... # type: ignore[misc] else: @overload - def sum(__iterable: Iterable[bool], __start: int = 0) -> int: ... # type: ignore[misc] + def sum(__iterable: Iterable[bool | _LiteralInteger], __start: int = 0) -> int: ... # type: ignore[misc] @overload def sum(__iterable: Iterable[_SupportsSumNoDefaultT]) -> _SupportsSumNoDefaultT | Literal[0]: ... From 427987cb9be198033c2582485cecfa8584bb7cca Mon Sep 17 00:00:00 2001 From: Gobot1234 Date: Wed, 19 Jul 2023 11:54:14 +0100 Subject: [PATCH 5/8] Revert "Temporarily update typeshed" This reverts commit 79cb39ef6862181f509abbcee4790e2f0fb968a0. --- mypy/typeshed/stdlib/builtins.pyi | 116 +----------------------------- 1 file changed, 2 insertions(+), 114 deletions(-) diff --git a/mypy/typeshed/stdlib/builtins.pyi b/mypy/typeshed/stdlib/builtins.pyi index e11b06025b6d..d6ca39049c77 100644 --- a/mypy/typeshed/stdlib/builtins.pyi +++ b/mypy/typeshed/stdlib/builtins.pyi @@ -56,7 +56,6 @@ from typing import ( # noqa: Y022 from typing_extensions import ( Concatenate, Literal, - LiteralString, ParamSpec, Self, SupportsIndex, @@ -214,8 +213,6 @@ _NegativeInteger: TypeAlias = Literal[-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, - _LiteralInteger = _PositiveInteger | _NegativeInteger | Literal[0] # noqa: Y026 # TODO: Use TypeAlias once mypy bugs are fixed class int: - @overload - def __new__(cls) -> Literal[0]: ... @overload def __new__(cls, __x: str | ReadableBuffer | SupportsInt | SupportsIndex | SupportsTrunc = ...) -> Self: ... @overload @@ -434,23 +431,12 @@ class _TranslateTable(Protocol): def __getitem__(self, __key: int) -> str | int | None: ... class str(Sequence[str]): - @overload - def __new__(cls) -> Literal[""]: ... @overload def __new__(cls, object: object = ...) -> Self: ... @overload def __new__(cls, object: ReadableBuffer, encoding: str = ..., errors: str = ...) -> Self: ... - @overload - def capitalize(self: LiteralString) -> LiteralString: ... - @overload def capitalize(self) -> str: ... # type: ignore[misc] - @overload - def casefold(self: LiteralString) -> LiteralString: ... - @overload def casefold(self) -> str: ... # type: ignore[misc] - @overload - def center(self: LiteralString, __width: SupportsIndex, __fillchar: LiteralString = " ") -> LiteralString: ... - @overload def center(self, __width: SupportsIndex, __fillchar: str = " ") -> str: ... # type: ignore[misc] def count(self, x: str, __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...) -> int: ... def encode(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: ... @@ -458,20 +444,11 @@ class str(Sequence[str]): self, __suffix: str | tuple[str, ...], __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ... ) -> bool: ... if sys.version_info >= (3, 8): - @overload - def expandtabs(self: LiteralString, tabsize: SupportsIndex = 8) -> LiteralString: ... - @overload def expandtabs(self, tabsize: SupportsIndex = 8) -> str: ... # type: ignore[misc] else: - @overload - def expandtabs(self: LiteralString, tabsize: int = 8) -> LiteralString: ... - @overload def expandtabs(self, tabsize: int = 8) -> str: ... # type: ignore[misc] def find(self, __sub: str, __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...) -> int: ... - @overload - def format(self: LiteralString, *args: LiteralString, **kwargs: LiteralString) -> LiteralString: ... - @overload def format(self, *args: object, **kwargs: object) -> str: ... def format_map(self, map: _FormatMapMapping) -> str: ... def index(self, __sub: str, __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...) -> int: ... @@ -487,91 +464,32 @@ class str(Sequence[str]): def isspace(self) -> bool: ... def istitle(self) -> bool: ... def isupper(self) -> bool: ... - @overload - def join(self: LiteralString, __iterable: Iterable[LiteralString]) -> LiteralString: ... - @overload def join(self, __iterable: Iterable[str]) -> str: ... # type: ignore[misc] - @overload - def ljust(self: LiteralString, __width: SupportsIndex, __fillchar: LiteralString = " ") -> LiteralString: ... - @overload def ljust(self, __width: SupportsIndex, __fillchar: str = " ") -> str: ... # type: ignore[misc] - @overload - def lower(self: LiteralString) -> LiteralString: ... - @overload def lower(self) -> str: ... # type: ignore[misc] - @overload - def lstrip(self: LiteralString, __chars: LiteralString | None = None) -> LiteralString: ... - @overload def lstrip(self, __chars: str | None = None) -> str: ... # type: ignore[misc] - @overload - def partition(self: LiteralString, __sep: LiteralString) -> tuple[LiteralString, LiteralString, LiteralString]: ... - @overload def partition(self, __sep: str) -> tuple[str, str, str]: ... # type: ignore[misc] - @overload - def replace( - self: LiteralString, __old: LiteralString, __new: LiteralString, __count: SupportsIndex = -1 - ) -> LiteralString: ... - @overload def replace(self, __old: str, __new: str, __count: SupportsIndex = -1) -> str: ... # type: ignore[misc] if sys.version_info >= (3, 9): - @overload - def removeprefix(self: LiteralString, __prefix: LiteralString) -> LiteralString: ... - @overload def removeprefix(self, __prefix: str) -> str: ... # type: ignore[misc] - @overload - def removesuffix(self: LiteralString, __suffix: LiteralString) -> LiteralString: ... - @overload def removesuffix(self, __suffix: str) -> str: ... # type: ignore[misc] def rfind(self, __sub: str, __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...) -> int: ... def rindex(self, __sub: str, __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...) -> int: ... - @overload - def rjust(self: LiteralString, __width: SupportsIndex, __fillchar: LiteralString = " ") -> LiteralString: ... - @overload def rjust(self, __width: SupportsIndex, __fillchar: str = " ") -> str: ... # type: ignore[misc] - @overload - def rpartition(self: LiteralString, __sep: LiteralString) -> tuple[LiteralString, LiteralString, LiteralString]: ... - @overload def rpartition(self, __sep: str) -> tuple[str, str, str]: ... # type: ignore[misc] - @overload - def rsplit(self: LiteralString, sep: LiteralString | None = None, maxsplit: SupportsIndex = -1) -> list[LiteralString]: ... - @overload def rsplit(self, sep: str | None = None, maxsplit: SupportsIndex = -1) -> list[str]: ... # type: ignore[misc] - @overload - def rstrip(self: LiteralString, __chars: LiteralString | None = None) -> LiteralString: ... - @overload def rstrip(self, __chars: str | None = None) -> str: ... # type: ignore[misc] - @overload - def split(self: LiteralString, sep: LiteralString | None = None, maxsplit: SupportsIndex = -1) -> list[LiteralString]: ... - @overload def split(self, sep: str | None = None, maxsplit: SupportsIndex = -1) -> list[str]: ... # type: ignore[misc] - @overload - def splitlines(self: LiteralString, keepends: bool = False) -> list[LiteralString]: ... - @overload def splitlines(self, keepends: bool = False) -> list[str]: ... # type: ignore[misc] def startswith( self, __prefix: str | tuple[str, ...], __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ... ) -> bool: ... - @overload - def strip(self: LiteralString, __chars: LiteralString | None = None) -> LiteralString: ... - @overload def strip(self, __chars: str | None = None) -> str: ... # type: ignore[misc] - @overload - def swapcase(self: LiteralString) -> LiteralString: ... - @overload def swapcase(self) -> str: ... # type: ignore[misc] - @overload - def title(self: LiteralString) -> LiteralString: ... - @overload def title(self) -> str: ... # type: ignore[misc] def translate(self, __table: _TranslateTable) -> str: ... - @overload - def upper(self: LiteralString) -> LiteralString: ... - @overload def upper(self) -> str: ... # type: ignore[misc] - @overload - def zfill(self: LiteralString, __width: SupportsIndex) -> LiteralString: ... - @overload def zfill(self, __width: SupportsIndex) -> str: ... # type: ignore[misc] @staticmethod @overload @@ -582,9 +500,6 @@ class str(Sequence[str]): @staticmethod @overload def maketrans(__x: str, __y: str, __z: str) -> dict[int, int | None]: ... - @overload - def __add__(self: LiteralString, __value: LiteralString) -> LiteralString: ... - @overload def __add__(self, __value: str) -> str: ... # type: ignore[misc] # Incompatible with Sequence.__contains__ def __contains__(self, __key: str) -> bool: ... # type: ignore[override] @@ -593,31 +508,17 @@ class str(Sequence[str]): def __getitem__(self, __key: SupportsIndex | slice) -> str: ... def __gt__(self, __value: str) -> bool: ... def __hash__(self) -> int: ... - @overload - def __iter__(self: LiteralString) -> Iterator[LiteralString]: ... - @overload def __iter__(self) -> Iterator[str]: ... # type: ignore[misc] def __le__(self, __value: str) -> bool: ... def __len__(self) -> int: ... def __lt__(self, __value: str) -> bool: ... - @overload - def __mod__(self: LiteralString, __value: LiteralString | tuple[LiteralString, ...]) -> LiteralString: ... - @overload def __mod__(self, __value: Any) -> str: ... - @overload - def __mul__(self: LiteralString, __value: SupportsIndex) -> LiteralString: ... - @overload def __mul__(self, __value: SupportsIndex) -> str: ... # type: ignore[misc] def __ne__(self, __value: object) -> bool: ... - @overload - def __rmul__(self: LiteralString, __value: SupportsIndex) -> LiteralString: ... - @overload def __rmul__(self, __value: SupportsIndex) -> str: ... # type: ignore[misc] def __getnewargs__(self) -> tuple[str]: ... class bytes(Sequence[int]): - @overload - def __new__(cls) -> Literal[b""]: ... @overload def __new__(cls, __o: Iterable[SupportsIndex] | SupportsIndex | SupportsBytes | ReadableBuffer) -> Self: ... @overload @@ -904,21 +805,8 @@ class memoryview(Sequence[int]): def __buffer__(self, __flags: int) -> memoryview: ... def __release_buffer__(self, __buffer: memoryview) -> None: ... -class _Truthy(Protocol): - def __bool__(self) -> Literal[True]: ... - -class _Falsy(Protocol): - def __bool__(self) -> Literal[False]: ... - @final class bool(int): - @overload - def __new__(cls) -> Literal[False]: ... - @overload - def __new__(cls, __o: _Truthy) -> Literal[True]: ... - @overload - def __new__(cls, __o: _Falsy) -> Literal[False]: ... - @overload def __new__(cls, __o: object = ...) -> Self: ... # The following overloads could be represented more elegantly with a TypeVar("_B", bool, int), # however mypy has a bug regarding TypeVar constraints (https://github.com/python/mypy/issues/11880). @@ -1788,11 +1676,11 @@ _SupportsSumNoDefaultT = TypeVar("_SupportsSumNoDefaultT", bound=_SupportsSumWit # Instead, we special-case the most common examples of this: bool and literal integers. if sys.version_info >= (3, 8): @overload - def sum(__iterable: Iterable[bool | _LiteralInteger], start: int = 0) -> int: ... # type: ignore[misc] + def sum(__iterable: Iterable[bool], start: int = 0) -> int: ... # type: ignore[misc] else: @overload - def sum(__iterable: Iterable[bool | _LiteralInteger], __start: int = 0) -> int: ... # type: ignore[misc] + def sum(__iterable: Iterable[bool], __start: int = 0) -> int: ... # type: ignore[misc] @overload def sum(__iterable: Iterable[_SupportsSumNoDefaultT]) -> _SupportsSumNoDefaultT | Literal[0]: ... From bcaa4c92db14c506df73848a2cabfd7c367c199c Mon Sep 17 00:00:00 2001 From: Gobot1234 Date: Fri, 4 Aug 2023 11:01:54 +0100 Subject: [PATCH 6/8] Add reveal_types --- test-data/unit/check-classes.test | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index d9f0e4a7967a..8a5369d9bf34 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -495,8 +495,15 @@ class bool(int): def __new__(cls, __o: object): pass +class Falsey: + def __bool__(self) -> Literal[False]: pass + +reveal_type(bool(Falsey())) # N: Revealed type is "Literal[False]" + class int: def __new__(cls) -> Literal[0]: pass + +reveal_type(int()) # N: Revealed type is "Literal[0]" [typing fixtures/typing-medium.pyi] [case testOverride__new__WithLiteralReturnFailing] From ec940176e574ebe3e21b9f6f61b7df324a70694b Mon Sep 17 00:00:00 2001 From: Gobot1234 Date: Thu, 10 Aug 2023 22:20:30 +0100 Subject: [PATCH 7/8] Make things actually work --- mypy/checkmember.py | 2 ++ mypy/typeops.py | 2 +- mypy/types.py | 2 ++ 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 343dfe3de243..b9fc60f6dbfd 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -367,6 +367,8 @@ def analyze_type_callable_member_access(name: str, typ: FunctionLike, mx: Member ret_type = tuple_fallback(ret_type) if isinstance(ret_type, TypedDictType): ret_type = ret_type.fallback + if isinstance(ret_type, LiteralType): + ret_type = ret_type.fallback if isinstance(ret_type, Instance): if not mx.is_operator: # When Python sees an operator (eg `3 == 4`), it automatically translates that diff --git a/mypy/typeops.py b/mypy/typeops.py index 519d3de995f5..4e0b3665c2c0 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -193,7 +193,7 @@ def class_callable( default_ret_type = fill_typevars(info) explicit_type = init_ret_type if is_new else orig_self_type if ( - isinstance(explicit_type, (Instance, TupleType, UninhabitedType)) + isinstance(explicit_type, (Instance, TupleType, UninhabitedType, LiteralType)) # We have to skip protocols, because it can be a subtype of a return type # by accident. Like `Hashable` is a subtype of `object`. See #11799 and isinstance(default_ret_type, Instance) diff --git a/mypy/types.py b/mypy/types.py index ba629a3553cf..c3a3d22a496c 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1916,6 +1916,8 @@ def type_object(self) -> mypy.nodes.TypeInfo: ret = ret.partial_fallback if isinstance(ret, TypedDictType): ret = ret.fallback + if isinstance(ret, LiteralType): + ret = ret.fallback assert isinstance(ret, Instance) return ret.type From dad9b05dabd6ab233a6c6c0624a9943503b2d4f0 Mon Sep 17 00:00:00 2001 From: Gobot1234 Date: Fri, 11 Aug 2023 00:25:19 +0100 Subject: [PATCH 8/8] Push what I've got --- test-data/unit/check-classes.test | 31 ++++------------------ test-data/unit/fixtures/literal__new__.pyi | 19 +++++++++++++ 2 files changed, 24 insertions(+), 26 deletions(-) create mode 100644 test-data/unit/fixtures/literal__new__.pyi diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index 8a5369d9bf34..62e79e76c95f 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -474,36 +474,15 @@ class B(A): pass [case testOverride__new__WithLiteralReturnPassing] -[file builtins.py] -from typing import Literal, Protocol, overload - -class str: pass -class dict: pass -class float: pass - -class _Truthy(Protocol): - def __bool__(self) -> Literal[True]: pass - -class _Falsy(Protocol): - def __bool__(self) -> Literal[False]: pass - -class bool(int): - @overload - def __new__(cls, __o: _Truthy) -> Literal[True]: pass - @overload - def __new__(cls, __o: _Falsy) -> Literal[False]: pass - def __new__(cls, __o: object): - pass +from typing import Literal -class Falsey: +class Falsy: def __bool__(self) -> Literal[False]: pass -reveal_type(bool(Falsey())) # N: Revealed type is "Literal[False]" - -class int: - def __new__(cls) -> Literal[0]: pass - +reveal_type(bool(Falsy())) # N: Revealed type is "Literal[False]" reveal_type(int()) # N: Revealed type is "Literal[0]" + +[builtins fixtures/literal__new__.pyi] [typing fixtures/typing-medium.pyi] [case testOverride__new__WithLiteralReturnFailing] diff --git a/test-data/unit/fixtures/literal__new__.pyi b/test-data/unit/fixtures/literal__new__.pyi new file mode 100644 index 000000000000..df2fbdd628d7 --- /dev/null +++ b/test-data/unit/fixtures/literal__new__.pyi @@ -0,0 +1,19 @@ +from typing import Literal, Protocol, overload + +class str: pass +class dict: pass +class float: pass +class int: + def __new__(cls) -> Literal[0]: pass + +class _Truthy(Protocol): + def __bool__(self) -> Literal[True]: pass + +class _Falsy(Protocol): + def __bool__(self) -> Literal[False]: pass + +class bool(int): + @overload + def __new__(cls, __o: _Truthy) -> Literal[True]: pass + @overload + def __new__(cls, __o: _Falsy) -> Literal[False]: pass