diff --git a/CHANGELOG.md b/CHANGELOG.md index a95f31ed..b8d8b628 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,16 @@ - Add `__orig_bases__` to non-generic TypedDicts, call-based TypedDicts, and call-based NamedTuples. Other TypedDicts and NamedTuples already had the attribute. Patch by Adrian Garcia Badaracco. +- Add `typing_extensions.get_original_bases`, a backport of + [`types.get_original_bases`](https://docs.python.org/3.12/library/types.html#types.get_original_bases), + introduced in Python 3.12 (CPython PR + https://github.com/python/cpython/pull/101827, originally by James + Hilton-Balfe). Patch by Alex Waygood. + + This function should always produce correct results when called on classes + constructed using features from `typing_extensions`. However, it may + produce incorrect results when called on some `NamedTuple` or `TypedDict` + classes that use `typing.{NamedTuple,TypedDict}` on Python <=3.11. - Constructing a call-based `TypedDict` using keyword arguments for the fields now causes a `DeprecationWarning` to be emitted. This matches the behaviour of `typing.TypedDict` on 3.11 and 3.12. diff --git a/README.md b/README.md index 59d5d3d0..94d157f4 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,14 @@ This module currently contains the following: - `override` (equivalent to `typing.override`; see [PEP 698](https://peps.python.org/pep-0698/)) - `Buffer` (equivalent to `collections.abc.Buffer`; see [PEP 688](https://peps.python.org/pep-0688/)) + - `get_original_bases` (equivalent to + [`types.get_original_bases`](https://docs.python.org/3.12/library/types.html#types.get_original_bases) + on 3.12+). + + This function should always produce correct results when called on classes + constructed using features from `typing_extensions`. However, it may + produce incorrect results when called on some `NamedTuple` or `TypedDict` + classes that use `typing.{NamedTuple,TypedDict}` on Python <=3.11. - In `typing` since Python 3.11 diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index 77171100..7f3c0ef2 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -31,7 +31,7 @@ from typing_extensions import Awaitable, AsyncIterator, AsyncContextManager, Required, NotRequired from typing_extensions import Protocol, runtime, runtime_checkable, Annotated, final, is_typeddict from typing_extensions import TypeVarTuple, Unpack, dataclass_transform, reveal_type, Never, assert_never, LiteralString -from typing_extensions import assert_type, get_type_hints, get_origin, get_args +from typing_extensions import assert_type, get_type_hints, get_origin, get_args, get_original_bases from typing_extensions import clear_overloads, get_overloads, overload from typing_extensions import NamedTuple from typing_extensions import override, deprecated, Buffer @@ -4286,5 +4286,90 @@ def __buffer__(self, flags: int) -> memoryview: self.assertIsSubclass(MySubclassedBuffer, Buffer) +class GetOriginalBasesTests(BaseTestCase): + def test_basics(self): + T = TypeVar('T') + class A: pass + class B(Generic[T]): pass + class C(B[int]): pass + class D(B[str], float): pass + self.assertEqual(get_original_bases(A), (object,)) + self.assertEqual(get_original_bases(B), (Generic[T],)) + self.assertEqual(get_original_bases(C), (B[int],)) + self.assertEqual(get_original_bases(int), (object,)) + self.assertEqual(get_original_bases(D), (B[str], float)) + + with self.assertRaisesRegex(TypeError, "Expected an instance of type"): + get_original_bases(object()) + + @skipUnless(TYPING_3_9_0, "PEP 585 is yet to be") + def test_builtin_generics(self): + class E(list[T]): pass + class F(list[int]): pass + + self.assertEqual(get_original_bases(E), (list[T],)) + self.assertEqual(get_original_bases(F), (list[int],)) + + def test_namedtuples(self): + # On 3.12, this should work well with typing.NamedTuple and typing_extensions.NamedTuple + # On lower versions, it will only work fully with typing_extensions.NamedTuple + if sys.version_info >= (3, 12): + namedtuple_classes = (typing.NamedTuple, typing_extensions.NamedTuple) + else: + namedtuple_classes = (typing_extensions.NamedTuple,) + + for NamedTuple in namedtuple_classes: # noqa: F402 + with self.subTest(cls=NamedTuple): + class ClassBasedNamedTuple(NamedTuple): + x: int + + class GenericNamedTuple(NamedTuple, Generic[T]): + x: T + + CallBasedNamedTuple = NamedTuple("CallBasedNamedTuple", [("x", int)]) + + self.assertIs( + get_original_bases(ClassBasedNamedTuple)[0], NamedTuple + ) + self.assertEqual( + get_original_bases(GenericNamedTuple), + (NamedTuple, Generic[T]) + ) + self.assertIs( + get_original_bases(CallBasedNamedTuple)[0], NamedTuple + ) + + def test_typeddicts(self): + # On 3.12, this should work well with typing.TypedDict and typing_extensions.TypedDict + # On lower versions, it will only work fully with typing_extensions.TypedDict + if sys.version_info >= (3, 12): + typeddict_classes = (typing.TypedDict, typing_extensions.TypedDict) + else: + typeddict_classes = (typing_extensions.TypedDict,) + + for TypedDict in typeddict_classes: # noqa: F402 + with self.subTest(cls=TypedDict): + class ClassBasedTypedDict(TypedDict): + x: int + + class GenericTypedDict(TypedDict, Generic[T]): + x: T + + CallBasedTypedDict = TypedDict("CallBasedTypedDict", {"x": int}) + + self.assertIs( + get_original_bases(ClassBasedTypedDict)[0], + TypedDict + ) + self.assertEqual( + get_original_bases(GenericTypedDict), + (TypedDict, Generic[T]) + ) + self.assertIs( + get_original_bases(CallBasedTypedDict)[0], + TypedDict + ) + + if __name__ == '__main__': main() diff --git a/src/typing_extensions.py b/src/typing_extensions.py index 411ccd42..2ee36eb3 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -59,6 +59,7 @@ 'final', 'get_args', 'get_origin', + 'get_original_bases', 'get_type_hints', 'IntVar', 'is_typeddict', @@ -2440,3 +2441,39 @@ class Buffer(abc.ABC): Buffer.register(memoryview) Buffer.register(bytearray) Buffer.register(bytes) + + +# Backport of types.get_original_bases, available on 3.12+ in CPython +if hasattr(_types, "get_original_bases"): + get_original_bases = _types.get_original_bases +else: + def get_original_bases(__cls): + """Return the class's "original" bases prior to modification by `__mro_entries__`. + + Examples:: + + from typing import TypeVar, Generic + from typing_extensions import NamedTuple, TypedDict + + T = TypeVar("T") + class Foo(Generic[T]): ... + class Bar(Foo[int], float): ... + class Baz(list[str]): ... + Eggs = NamedTuple("Eggs", [("a", int), ("b", str)]) + Spam = TypedDict("Spam", {"a": int, "b": str}) + + assert get_original_bases(Bar) == (Foo[int], float) + assert get_original_bases(Baz) == (list[str],) + assert get_original_bases(Eggs) == (NamedTuple,) + assert get_original_bases(Spam) == (TypedDict,) + assert get_original_bases(int) == (object,) + """ + try: + return __cls.__orig_bases__ + except AttributeError: + try: + return __cls.__bases__ + except AttributeError: + raise TypeError( + f'Expected an instance of type, not {type(__cls).__name__!r}' + ) from None