diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 6ef633e4545aef..d763f09f060655 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -3,6 +3,7 @@ import collections import collections.abc from collections import defaultdict +from collections.abc import Callable as ABCallable from functools import lru_cache, wraps, reduce import gc import inspect @@ -42,6 +43,7 @@ from typing import TypeAlias from typing import ParamSpec, Concatenate, ParamSpecArgs, ParamSpecKwargs from typing import TypeGuard, TypeIs, NoDefault +from typing import _eval_type import abc import textwrap import typing @@ -10668,9 +10670,126 @@ def test_eq(self): with self.assertWarns(DeprecationWarning): self.assertNotEqual(int, typing._UnionGenericAlias) +class MyType: + pass + +class TestGenericAliasHandling(BaseTestCase): + + def test_forward_ref(self): + fwd_ref = ForwardRef('MyType') + + def func(arg: fwd_ref): + pass + + result = get_type_hints(func) + self.assertEqual(result['arg'], MyType, f"Expected MyType, got {result['arg']}") + + def test_generic_alias(self): + fwd_ref = ForwardRef('MyType') + generic_list = List[fwd_ref] + + def func(arg: generic_list): + pass + + result = get_type_hints(func) + self.assertEqual(result['arg'], List[MyType], f"Expected List[MyType], got {result['arg']}") + + def test_union(self): + fwd_ref_1 = ForwardRef('MyType') + fwd_ref_2 = ForwardRef('int') + union_type = Union[fwd_ref_1, fwd_ref_2] + + def func(arg: union_type): + pass + + result = get_type_hints(func) + self.assertEqual(result['arg'], Union[MyType, int], f"Expected Union[MyType, int], got {result['arg']}") + + def test_recursive_forward_ref(self): + recursive_ref = ForwardRef('RecursiveType') + globals()['RecursiveType'] = recursive_ref + recursive_type = Dict[str, List[recursive_ref]] + + def func(arg: recursive_type): + pass + + result = get_type_hints(func) + self.assertEqual(result['arg'], Dict[str, List[recursive_ref]], f"Expected Dict[str, List[RecursiveType]], got {result['arg']}") + + def test_callable_unpacking(self): + fwd_ref = ForwardRef('MyType') + callable_type = Callable[[fwd_ref, int], str] + + def func(arg1: fwd_ref, arg2: int) -> str: + return "test" + + result = get_type_hints(func) + self.assertEqual(result['arg1'], MyType, f"Expected MyType for arg1, got {result['arg1']}") + self.assertEqual(result['arg2'], int, f"Expected int for arg2, got {result['arg2']}") + self.assertEqual(result['return'], str, f"Expected str for return, got {result['return']}") + + def test_unpacked_generic(self): + fwd_ref = ForwardRef('MyType') + generic_type = Tuple[fwd_ref, int] + + def func(arg: generic_type): + pass + + result = get_type_hints(func) + self.assertEqual(result['arg'], Tuple[MyType, int], f"Expected Tuple[MyType, int], got {result['arg']}") + + def test_preservation_of_type(self): + fwd_ref_1 = ForwardRef('MyType') + fwd_ref_2 = ForwardRef('int') + complex_type = Dict[str, Union[fwd_ref_1, fwd_ref_2]] + + def func(arg: complex_type): + pass + + result = get_type_hints(func) + self.assertEqual(result['arg'], Dict[str, Union[MyType, int]], f"Expected Dict[str, Union[MyType, int]], got {result['arg']}") + + def test_callable_unflattening(self): + callable_type = Callable[[int, str], bool] + + def func(arg1: int, arg2: str) -> bool: + return True + + result = get_type_hints(func) + self.assertEqual(result['arg1'], int, f"Expected int for arg1, got {result['arg1']}") + self.assertEqual(result['arg2'], str, f"Expected str for arg2, got {result['arg2']}") + self.assertEqual(result['return'], bool, f"Expected bool for return, got {result['return']}") + + callable_type_packed = Callable[[int, str], bool] + + def func_packed(arg1: int, arg2: str) -> bool: + return True + + result = get_type_hints(func_packed) + self.assertEqual(result['arg1'], int, f"Expected int for arg1, got {result['arg1']}") + self.assertEqual(result['arg2'], str, f"Expected str for arg2, got {result['arg2']}") + self.assertEqual(result['return'], bool, f"Expected bool for return, got {result['return']}") + def test_hashable(self): self.assertEqual(hash(typing._UnionGenericAlias), hash(Union)) +class TestCallableAlias(BaseTestCase): + def test_callable_alias_preserves_subclass(self): + C = ABCallable[[str, ForwardRef('int')], int] + class A: + c: C + # Explicitly pass global namespace to ensure correct resolution + hints = get_type_hints(A, globalns=globals()) + + # Ensure evaluated type retains the correct subclass (_CallableGenericAlias) + self.assertEqual(hints['c'].__class__, C.__class__) + + # Ensure evaluated type retains correct origin + self.assertEqual(hints['c'].__origin__, C.__origin__) + + # Instead of comparing raw ForwardRef, check if the resolution is correct + expected_args = tuple(int if isinstance(arg, ForwardRef) else arg for arg in C.__args__) + self.assertEqual(hints['c'].__args__, expected_args) def load_tests(loader, tests, pattern): import doctest diff --git a/Lib/typing.py b/Lib/typing.py index 3d64480e1431c1..7c546ec7fb8c10 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -472,7 +472,9 @@ def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=f if ev_args == t.__args__: return t if isinstance(t, GenericAlias): - return GenericAlias(t.__origin__, ev_args) + if _should_unflatten_callable_args(t, ev_args): + return t.__class__(t.__origin__, (ev_args[:-1], ev_args[-1])) + return t.__class__(t.__origin__, ev_args) if isinstance(t, Union): return functools.reduce(operator.or_, ev_args) else: diff --git a/Misc/NEWS.d/next/Library/2025-03-05-21-48-22.gh-issue-130870.uDz6AQ.rst b/Misc/NEWS.d/next/Library/2025-03-05-21-48-22.gh-issue-130870.uDz6AQ.rst new file mode 100644 index 00000000000000..e52af134eeff63 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-03-05-21-48-22.gh-issue-130870.uDz6AQ.rst @@ -0,0 +1 @@ +Ensure that typing.Callable retains its subclass (_CallableGenericAlias) instead of being incorrectly converted to GenericAlias.