diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 2b5f34b4b92e0c..10e3a83edb524c 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -4091,6 +4091,46 @@ def method(self) -> None: ... self.assertIsInstance(Foo(), ProtocolWithMixedMembers) self.assertNotIsInstance(42, ProtocolWithMixedMembers) + def test_protocol_special_members(self): + # See https://github.com/python/cpython/issues/112319 + + T_co = TypeVar("T_co", covariant=True) + + @runtime_checkable + class GenericIterable(Protocol[T_co]): + def __class_getitem__(cls, item): ... + def __iter__(self): ... + + self.assertIsInstance([1,2,3], GenericIterable) + self.assertNotIsInstance("123", GenericIterable) # str is not a generic type! + + class TakesKWARGS(Protocol): + def __init__(self, **kwargs): ... # NOTE: For static checking. + + self.assertEqual(TakesKWARGS.__protocol_attrs__, {"__init__"}) + + def test_protocol_special_attributes(self): + class Documented(Protocol): + """Matches classes that have a docstring.""" + __doc__: str # NOTE: For static checking, undocumented classes have __doc__ = None. + + self.assertEqual(Documented.__protocol_attrs__, {"__doc__"}) + + @runtime_checkable + class Slotted(Protocol): + """Matches classes that have a __slots__ attribute.""" + __slots__: tuple + + class Unslotted: + pass + + class WithSlots: + __slots__ = ("foo", "bar") + + self.assertEqual(Slotted.__protocol_attrs__, {"__slots__"}) + self.assertNotIsInstance(Unslotted(), Slotted) + self.assertIsInstance(WithSlots(), Slotted) + class GenericTests(BaseTestCase): diff --git a/Lib/typing.py b/Lib/typing.py index a96c7083eb785e..35865933bd199c 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -1675,13 +1675,65 @@ class _TypingEllipsis: _SPECIAL_NAMES = frozenset({ '__abstractmethods__', '__annotations__', '__dict__', '__doc__', - '__init__', '__module__', '__new__', '__slots__', - '__subclasshook__', '__weakref__', '__class_getitem__', - '__match_args__', + '__module__', '__slots__', '__match_args__', '__qualname__', +}) +_SPECIAL_CALLABLE_NAMES = frozenset({ + '__init__', '__new__', '__subclasshook__','__class_getitem__', '__weakref__', }) # These special attributes will be not collected as protocol members. EXCLUDED_ATTRIBUTES = _TYPING_INTERNALS | _SPECIAL_NAMES | {'_MutableMapping__marker'} +EXCLUDED_MEMBERS = EXCLUDED_ATTRIBUTES | _SPECIAL_CALLABLE_NAMES + + +def _get_local_members(namespace): + """Collect the specified attributes from a classes' namespace.""" + annotations = namespace.get("__annotations__", {}) + attrs = (namespace.keys() | annotations.keys()) - EXCLUDED_MEMBERS + # exclude special "_abc_" attributes + return {attr for attr in attrs if not attr.startswith('_abc_')} + + +def _get_local_protocol_members(namespace): + """Collect the specified attributes from the protocols' namespace.""" + # annotated attributes are always considered protocol members + annotations = namespace.get("__annotations__", {}) + # only namespace members outside the excluded set are considered protocol members + return (namespace.keys() - EXCLUDED_ATTRIBUTES) | annotations.keys() + + +def _get_parent_members(cls): + """Collect protocol members from parents of arbitrary class object. + + This includes names actually defined in the class dictionary, as well + as names that appear in annotations. Special names (above) are skipped. + """ + attrs = set() + for base in cls.__mro__[1:-1]: # without self and object + if base.__name__ in {'Protocol', 'Generic'}: + continue + elif getattr(base, "_is_protocol", False): + attrs |= getattr(base, "__protocol_attrs__", set()) + else: # get from annotations + attrs |= _get_local_members(base.__dict__) + return attrs + + +def _get_cls_members(cls): + """Collect protocol members from an arbitrary class object. + + This includes names actually defined in the class dictionary, as well + as names that appear in annotations. Special names (above) are skipped. + """ + attrs = set() + for base in cls.__mro__[:-1]: # without object + if base.__name__ in {'Protocol', 'Generic'}: + continue + elif getattr(base, "_is_protocol", False): + attrs |= getattr(base, "__protocol_attrs__", set()) + else: + attrs |= _get_local_members(base.__dict__) + return attrs def _get_protocol_attrs(cls): @@ -1696,7 +1748,7 @@ def _get_protocol_attrs(cls): continue annotations = getattr(base, '__annotations__', {}) for attr in (*base.__dict__, *annotations): - if not attr.startswith('_abc_') and attr not in EXCLUDED_ATTRIBUTES: + if not attr.startswith('_abc_') and attr not in EXCLUDED_MEMBERS: attrs.add(attr) return attrs @@ -1804,10 +1856,14 @@ def __new__(mcls, name, bases, namespace, /, **kwargs): ) return super().__new__(mcls, name, bases, namespace, **kwargs) - def __init__(cls, *args, **kwargs): - super().__init__(*args, **kwargs) - if getattr(cls, "_is_protocol", False): - cls.__protocol_attrs__ = _get_protocol_attrs(cls) + def __init__(cls, name, bases, namespace, **kwds): + super().__init__(name, bases, namespace, **kwds) + if getattr(cls, "_is_protocol", False) and cls.__name__ != "Protocol": + cls.__protocol_attrs__ = ( + _get_local_protocol_members(namespace) | _get_parent_members(cls) + ) + # local_attrs = _get_local_protocol_members(namespace) + # cls.__protocol_attrs__ = local_attrs.union(*map(_get_cls_members, bases)) # PEP 544 prohibits using issubclass() # with protocols that have non-method members. cls.__callable_proto_members_only__ = all( diff --git a/Misc/NEWS.d/next/Library/2023-11-23-17-43-28.gh-issue-112319.MdsRx7.rst b/Misc/NEWS.d/next/Library/2023-11-23-17-43-28.gh-issue-112319.MdsRx7.rst new file mode 100644 index 00000000000000..f84b9a3c117e51 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2023-11-23-17-43-28.gh-issue-112319.MdsRx7.rst @@ -0,0 +1,22 @@ +:class:`Protocol` classes now allow specification of previously excluded attributes and methods. + +Example:: + + class Documented(Protocol): + """A Protocol for documented classes.""" + __doc__: str + + class Slotted(Protocol): + """A Protocol for classes with __slots__.""" + __slots__: tuple[str, ...] + + @runtime_checkable + class GenericIterable(Protocol): + """An iterable that must also be a generic type.""" + def __class_getitem__(cls, item): ... + def __iter__(self): ... + + assert isinstance(["a", "b", "c"], GenericIterable) # ✅ + assert not isinstance("abc", GenericIterable) # ✅ + +Patch by Randolf Scholz.