3
3
import abc
4
4
import contextlib
5
5
import collections
6
+ from collections import defaultdict
6
7
import collections .abc
7
8
from functools import lru_cache
8
9
import inspect
9
10
import pickle
10
11
import subprocess
11
12
import types
12
13
from unittest import TestCase , main , skipUnless , skipIf
14
+ from unittest .mock import patch
13
15
from test import ann_module , ann_module2 , ann_module3
14
16
import typing
15
17
from typing import TypeVar , Optional , Union , Any , AnyStr
21
23
from typing_extensions import NoReturn , ClassVar , Final , IntVar , Literal , Type , NewType , TypedDict , Self
22
24
from typing_extensions import TypeAlias , ParamSpec , Concatenate , ParamSpecArgs , ParamSpecKwargs , TypeGuard
23
25
from typing_extensions import Awaitable , AsyncIterator , AsyncContextManager , Required , NotRequired
24
- from typing_extensions import Protocol , runtime , runtime_checkable , Annotated , overload , final , is_typeddict
26
+ from typing_extensions import Protocol , runtime , runtime_checkable , Annotated , final , is_typeddict
25
27
from typing_extensions import TypeVarTuple , Unpack , dataclass_transform , reveal_type , Never , assert_never , LiteralString
26
28
from typing_extensions import assert_type , get_type_hints , get_origin , get_args
29
+ from typing_extensions import clear_overloads , get_overloads , overload
27
30
28
31
# Flags used to mark tests that only apply after a specific
29
32
# version of the typing module.
@@ -403,6 +406,20 @@ def test_no_multiple_subscripts(self):
403
406
Literal [1 ][1 ]
404
407
405
408
409
+ class MethodHolder :
410
+ @classmethod
411
+ def clsmethod (cls ): ...
412
+ @staticmethod
413
+ def stmethod (): ...
414
+ def method (self ): ...
415
+
416
+
417
+ if TYPING_3_11_0 :
418
+ registry_holder = typing
419
+ else :
420
+ registry_holder = typing_extensions
421
+
422
+
406
423
class OverloadTests (BaseTestCase ):
407
424
408
425
def test_overload_fails (self ):
@@ -424,6 +441,61 @@ def blah():
424
441
425
442
blah ()
426
443
444
+ def set_up_overloads (self ):
445
+ def blah ():
446
+ pass
447
+
448
+ overload1 = blah
449
+ overload (blah )
450
+
451
+ def blah ():
452
+ pass
453
+
454
+ overload2 = blah
455
+ overload (blah )
456
+
457
+ def blah ():
458
+ pass
459
+
460
+ return blah , [overload1 , overload2 ]
461
+
462
+ # Make sure we don't clear the global overload registry
463
+ @patch (
464
+ f"{ registry_holder .__name__ } ._overload_registry" ,
465
+ defaultdict (lambda : defaultdict (dict ))
466
+ )
467
+ def test_overload_registry (self ):
468
+ registry = registry_holder ._overload_registry
469
+ # The registry starts out empty
470
+ self .assertEqual (registry , {})
471
+
472
+ impl , overloads = self .set_up_overloads ()
473
+ self .assertNotEqual (registry , {})
474
+ self .assertEqual (list (get_overloads (impl )), overloads )
475
+
476
+ def some_other_func (): pass
477
+ overload (some_other_func )
478
+ other_overload = some_other_func
479
+ def some_other_func (): pass
480
+ self .assertEqual (list (get_overloads (some_other_func )), [other_overload ])
481
+
482
+ # Make sure that after we clear all overloads, the registry is
483
+ # completely empty.
484
+ clear_overloads ()
485
+ self .assertEqual (registry , {})
486
+ self .assertEqual (get_overloads (impl ), [])
487
+
488
+ # Querying a function with no overloads shouldn't change the registry.
489
+ def the_only_one (): pass
490
+ self .assertEqual (get_overloads (the_only_one ), [])
491
+ self .assertEqual (registry , {})
492
+
493
+ def test_overload_registry_repeated (self ):
494
+ for _ in range (2 ):
495
+ impl , overloads = self .set_up_overloads ()
496
+
497
+ self .assertEqual (list (get_overloads (impl )), overloads )
498
+
427
499
428
500
class AssertTypeTests (BaseTestCase ):
429
501
0 commit comments