Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7903f74

Browse files
committed
gh-130870: Preserve GenericAlias subclasses in typing.get_type_hints()
1 parent 18249d9 commit 7903f74

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

Lib/test/test_typing.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7178,6 +7178,24 @@ def func(x: undefined) -> undefined: ...
71787178
self.assertEqual(get_type_hints(func, format=annotationlib.Format.STRING),
71797179
{'x': 'undefined', 'return': 'undefined'})
71807180

7181+
def test_get_type_hints_preserve_generic_alias_subclasses(self):
7182+
# https://github.com/python/cpython/issues/130870
7183+
# A real world example of this is `collections.abc.Callable`. When parameterized,
7184+
# the result is a subclass of `types.GenericAlias`.
7185+
class MyAlias(types.GenericAlias):
7186+
pass
7187+
7188+
class MyClass:
7189+
def __class_getitem__(cls, args):
7190+
return MyAlias(cls, args)
7191+
7192+
# Using a forward reference is important, otherwise it works as expected.
7193+
# `y` tests that the `GenericAlias` subclass is preserved when stripping `Annotated`.
7194+
def func(x: MyClass['int'], y: MyClass[Annotated[int, ...]]): ...
7195+
7196+
assert isinstance(get_type_hints(func)['x'], MyAlias)
7197+
assert isinstance(get_type_hints(func)['y'], MyAlias)
7198+
71817199

71827200
class GetUtilitiesTestCase(TestCase):
71837201
def test_get_origin(self):

Lib/typing.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,17 @@ def inner(*args, **kwds):
421421
return decorator
422422

423423

424+
def _rebuild_generic_alias(alias: GenericAlias, args: tuple[object, ...]) -> GenericAlias:
425+
is_unpacked = alias.__unpacked__
426+
if _should_unflatten_callable_args(alias, args):
427+
t = alias.__origin__[(args[:-1], args[-1])]
428+
else:
429+
t = alias.__origin__[args]
430+
if is_unpacked:
431+
t = Unpack[t]
432+
return t
433+
434+
424435
def _deprecation_warning_for_no_type_params_passed(funcname: str) -> None:
425436
import warnings
426437

@@ -468,25 +479,20 @@ def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=f
468479
_make_forward_ref(arg) if isinstance(arg, str) else arg
469480
for arg in t.__args__
470481
)
471-
is_unpacked = t.__unpacked__
472-
if _should_unflatten_callable_args(t, args):
473-
t = t.__origin__[(args[:-1], args[-1])]
474-
else:
475-
t = t.__origin__[args]
476-
if is_unpacked:
477-
t = Unpack[t]
482+
else:
483+
args = t.__args__
478484

479485
ev_args = tuple(
480486
_eval_type(
481487
a, globalns, localns, type_params, recursive_guard=recursive_guard,
482488
format=format, owner=owner,
483489
)
484-
for a in t.__args__
490+
for a in args
485491
)
486492
if ev_args == t.__args__:
487493
return t
488494
if isinstance(t, GenericAlias):
489-
return GenericAlias(t.__origin__, ev_args)
495+
return _rebuild_generic_alias(t, ev_args)
490496
if isinstance(t, Union):
491497
return functools.reduce(operator.or_, ev_args)
492498
else:
@@ -2400,7 +2406,7 @@ def _strip_annotations(t):
24002406
stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
24012407
if stripped_args == t.__args__:
24022408
return t
2403-
return GenericAlias(t.__origin__, stripped_args)
2409+
return _rebuild_generic_alias(t, stripped_args)
24042410
if isinstance(t, Union):
24052411
stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
24062412
if stripped_args == t.__args__:

0 commit comments

Comments
 (0)