Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 8bfbeeb

Browse files
[3.11] gh-112281: Allow Union with unhashable Annotated metadata (GH-112283) (#116288)
Co-authored-by: Alex Waygood <[email protected]>
1 parent 6c2484b commit 8bfbeeb

File tree

4 files changed

+155
-15
lines changed

4 files changed

+155
-15
lines changed

Lib/test/test_types.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,26 @@ def test_hash(self):
709709
self.assertEqual(hash(int | str), hash(str | int))
710710
self.assertEqual(hash(int | str), hash(typing.Union[int, str]))
711711

712+
def test_union_of_unhashable(self):
713+
class UnhashableMeta(type):
714+
__hash__ = None
715+
716+
class A(metaclass=UnhashableMeta): ...
717+
class B(metaclass=UnhashableMeta): ...
718+
719+
self.assertEqual((A | B).__args__, (A, B))
720+
union1 = A | B
721+
with self.assertRaises(TypeError):
722+
hash(union1)
723+
724+
union2 = int | B
725+
with self.assertRaises(TypeError):
726+
hash(union2)
727+
728+
union3 = A | int
729+
with self.assertRaises(TypeError):
730+
hash(union3)
731+
712732
def test_instancecheck_and_subclasscheck(self):
713733
for x in (int | str, typing.Union[int, str]):
714734
with self.subTest(x=x):

Lib/test/test_typing.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import contextlib
22
import collections
33
from collections import defaultdict
4-
from functools import lru_cache, wraps
4+
from functools import lru_cache, wraps, reduce
55
import inspect
66
import itertools
77
import gc
8+
import operator
89
import pickle
910
import re
1011
import sys
@@ -1705,6 +1706,26 @@ def test_union_union(self):
17051706
v = Union[u, Employee]
17061707
self.assertEqual(v, Union[int, float, Employee])
17071708

1709+
def test_union_of_unhashable(self):
1710+
class UnhashableMeta(type):
1711+
__hash__ = None
1712+
1713+
class A(metaclass=UnhashableMeta): ...
1714+
class B(metaclass=UnhashableMeta): ...
1715+
1716+
self.assertEqual(Union[A, B].__args__, (A, B))
1717+
union1 = Union[A, B]
1718+
with self.assertRaises(TypeError):
1719+
hash(union1)
1720+
1721+
union2 = Union[int, B]
1722+
with self.assertRaises(TypeError):
1723+
hash(union2)
1724+
1725+
union3 = Union[A, int]
1726+
with self.assertRaises(TypeError):
1727+
hash(union3)
1728+
17081729
def test_repr(self):
17091730
self.assertEqual(repr(Union), 'typing.Union')
17101731
u = Union[Employee, int]
@@ -7374,6 +7395,76 @@ def test_flatten(self):
73747395
self.assertEqual(A.__metadata__, (4, 5))
73757396
self.assertEqual(A.__origin__, int)
73767397

7398+
def test_deduplicate_from_union(self):
7399+
# Regular:
7400+
self.assertEqual(get_args(Annotated[int, 1] | int),
7401+
(Annotated[int, 1], int))
7402+
self.assertEqual(get_args(Union[Annotated[int, 1], int]),
7403+
(Annotated[int, 1], int))
7404+
self.assertEqual(get_args(Annotated[int, 1] | Annotated[int, 2] | int),
7405+
(Annotated[int, 1], Annotated[int, 2], int))
7406+
self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[int, 2], int]),
7407+
(Annotated[int, 1], Annotated[int, 2], int))
7408+
self.assertEqual(get_args(Annotated[int, 1] | Annotated[str, 1] | int),
7409+
(Annotated[int, 1], Annotated[str, 1], int))
7410+
self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[str, 1], int]),
7411+
(Annotated[int, 1], Annotated[str, 1], int))
7412+
7413+
# Duplicates:
7414+
self.assertEqual(Annotated[int, 1] | Annotated[int, 1] | int,
7415+
Annotated[int, 1] | int)
7416+
self.assertEqual(Union[Annotated[int, 1], Annotated[int, 1], int],
7417+
Union[Annotated[int, 1], int])
7418+
7419+
# Unhashable metadata:
7420+
self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[int, set()] | int),
7421+
(str, Annotated[int, {}], Annotated[int, set()], int))
7422+
self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[int, set()], int]),
7423+
(str, Annotated[int, {}], Annotated[int, set()], int))
7424+
self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[str, {}] | int),
7425+
(str, Annotated[int, {}], Annotated[str, {}], int))
7426+
self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[str, {}], int]),
7427+
(str, Annotated[int, {}], Annotated[str, {}], int))
7428+
7429+
self.assertEqual(get_args(Annotated[int, 1] | str | Annotated[str, {}] | int),
7430+
(Annotated[int, 1], str, Annotated[str, {}], int))
7431+
self.assertEqual(get_args(Union[Annotated[int, 1], str, Annotated[str, {}], int]),
7432+
(Annotated[int, 1], str, Annotated[str, {}], int))
7433+
7434+
import dataclasses
7435+
@dataclasses.dataclass
7436+
class ValueRange:
7437+
lo: int
7438+
hi: int
7439+
v = ValueRange(1, 2)
7440+
self.assertEqual(get_args(Annotated[int, v] | None),
7441+
(Annotated[int, v], types.NoneType))
7442+
self.assertEqual(get_args(Union[Annotated[int, v], None]),
7443+
(Annotated[int, v], types.NoneType))
7444+
self.assertEqual(get_args(Optional[Annotated[int, v]]),
7445+
(Annotated[int, v], types.NoneType))
7446+
7447+
# Unhashable metadata duplicated:
7448+
self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int,
7449+
Annotated[int, {}] | int)
7450+
self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int,
7451+
int | Annotated[int, {}])
7452+
self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int],
7453+
Union[Annotated[int, {}], int])
7454+
self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int],
7455+
Union[int, Annotated[int, {}]])
7456+
7457+
def test_order_in_union(self):
7458+
expr1 = Annotated[int, 1] | str | Annotated[str, {}] | int
7459+
for args in itertools.permutations(get_args(expr1)):
7460+
with self.subTest(args=args):
7461+
self.assertEqual(expr1, reduce(operator.or_, args))
7462+
7463+
expr2 = Union[Annotated[int, 1], str, Annotated[str, {}], int]
7464+
for args in itertools.permutations(get_args(expr2)):
7465+
with self.subTest(args=args):
7466+
self.assertEqual(expr2, Union[args])
7467+
73777468
def test_specialize(self):
73787469
L = Annotated[List[T], "my decoration"]
73797470
LI = Annotated[List[int], "my decoration"]
@@ -7394,6 +7485,16 @@ def test_hash_eq(self):
73947485
{Annotated[int, 4, 5], Annotated[int, 4, 5], Annotated[T, 4, 5]},
73957486
{Annotated[int, 4, 5], Annotated[T, 4, 5]}
73967487
)
7488+
# Unhashable `metadata` raises `TypeError`:
7489+
a1 = Annotated[int, []]
7490+
with self.assertRaises(TypeError):
7491+
hash(a1)
7492+
7493+
class A:
7494+
__hash__ = None
7495+
a2 = Annotated[int, A()]
7496+
with self.assertRaises(TypeError):
7497+
hash(a2)
73977498

73987499
def test_instantiate(self):
73997500
class C:

Lib/typing.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -303,19 +303,33 @@ def _unpack_args(args):
303303
newargs.append(arg)
304304
return newargs
305305

306-
def _deduplicate(params):
306+
def _deduplicate(params, *, unhashable_fallback=False):
307307
# Weed out strict duplicates, preserving the first of each occurrence.
308-
all_params = set(params)
309-
if len(all_params) < len(params):
310-
new_params = []
311-
for t in params:
312-
if t in all_params:
313-
new_params.append(t)
314-
all_params.remove(t)
315-
params = new_params
316-
assert not all_params, all_params
317-
return params
318-
308+
try:
309+
return dict.fromkeys(params)
310+
except TypeError:
311+
if not unhashable_fallback:
312+
raise
313+
# Happens for cases like `Annotated[dict, {'x': IntValidator()}]`
314+
return _deduplicate_unhashable(params)
315+
316+
def _deduplicate_unhashable(unhashable_params):
317+
new_unhashable = []
318+
for t in unhashable_params:
319+
if t not in new_unhashable:
320+
new_unhashable.append(t)
321+
return new_unhashable
322+
323+
def _compare_args_orderless(first_args, second_args):
324+
first_unhashable = _deduplicate_unhashable(first_args)
325+
second_unhashable = _deduplicate_unhashable(second_args)
326+
t = list(second_unhashable)
327+
try:
328+
for elem in first_unhashable:
329+
t.remove(elem)
330+
except ValueError:
331+
return False
332+
return not t
319333

320334
def _remove_dups_flatten(parameters):
321335
"""Internal helper for Union creation and substitution.
@@ -330,7 +344,7 @@ def _remove_dups_flatten(parameters):
330344
else:
331345
params.append(p)
332346

333-
return tuple(_deduplicate(params))
347+
return tuple(_deduplicate(params, unhashable_fallback=True))
334348

335349

336350
def _flatten_literal_params(parameters):
@@ -1673,7 +1687,10 @@ def copy_with(self, params):
16731687
def __eq__(self, other):
16741688
if not isinstance(other, (_UnionGenericAlias, types.UnionType)):
16751689
return NotImplemented
1676-
return set(self.__args__) == set(other.__args__)
1690+
try: # fast path
1691+
return set(self.__args__) == set(other.__args__)
1692+
except TypeError: # not hashable, slow path
1693+
return _compare_args_orderless(self.__args__, other.__args__)
16771694

16781695
def __hash__(self):
16791696
return hash(frozenset(self.__args__))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Allow creating :ref:`union of types<types-union>` for
2+
:class:`typing.Annotated` with unhashable metadata.

0 commit comments

Comments
 (0)