Thanks to visit codestin.com
Credit goes to github.com

Skip to content

gh-112319: Allow special protocol members #112340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
40 changes: 40 additions & 0 deletions Lib/test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4091,6 +4091,46 @@ def method(self) -> None: ...
self.assertIsInstance(Foo(), ProtocolWithMixedMembers)
self.assertNotIsInstance(42, ProtocolWithMixedMembers)

def test_protocol_special_members(self):
# See https://github.com/python/cpython/issues/112319

T_co = TypeVar("T_co", covariant=True)

@runtime_checkable
class GenericIterable(Protocol[T_co]):
def __class_getitem__(cls, item): ...
def __iter__(self): ...

self.assertIsInstance([1,2,3], GenericIterable)
self.assertNotIsInstance("123", GenericIterable) # str is not a generic type!

class TakesKWARGS(Protocol):
def __init__(self, **kwargs): ... # NOTE: For static checking.

self.assertEqual(TakesKWARGS.__protocol_attrs__, {"__init__"})

def test_protocol_special_attributes(self):
class Documented(Protocol):
"""Matches classes that have a docstring."""
__doc__: str # NOTE: For static checking, undocumented classes have __doc__ = None.

self.assertEqual(Documented.__protocol_attrs__, {"__doc__"})

@runtime_checkable
class Slotted(Protocol):
"""Matches classes that have a __slots__ attribute."""
__slots__: tuple

class Unslotted:
pass

class WithSlots:
__slots__ = ("foo", "bar")

self.assertEqual(Slotted.__protocol_attrs__, {"__slots__"})
self.assertNotIsInstance(Unslotted(), Slotted)
self.assertIsInstance(WithSlots(), Slotted)


class GenericTests(BaseTestCase):

Expand Down
72 changes: 64 additions & 8 deletions Lib/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,13 +1675,65 @@ class _TypingEllipsis:

_SPECIAL_NAMES = frozenset({
'__abstractmethods__', '__annotations__', '__dict__', '__doc__',
'__init__', '__module__', '__new__', '__slots__',
'__subclasshook__', '__weakref__', '__class_getitem__',
'__match_args__',
'__module__', '__slots__', '__match_args__', '__qualname__',
})
_SPECIAL_CALLABLE_NAMES = frozenset({
'__init__', '__new__', '__subclasshook__','__class_getitem__', '__weakref__',
})

# These special attributes will be not collected as protocol members.
EXCLUDED_ATTRIBUTES = _TYPING_INTERNALS | _SPECIAL_NAMES | {'_MutableMapping__marker'}
EXCLUDED_MEMBERS = EXCLUDED_ATTRIBUTES | _SPECIAL_CALLABLE_NAMES
Comment on lines -1678 to +1686
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please keep the diff as minimal as possible. As far as I can tell, splitting the _SPECIAL_NAMES frozenset into a _SPECIAL_NAMES set and a _SPECIAL_CALLABLE_NAMES set shouldn't make any difference here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my solution to work, I need both, because for things defined directly inside the protocol body, I only exclude EXCLUDED_ATTRIBUTES, but not _SPECIAL_CALLABLE_NAMES. (Hence, we can test for __class_getitem__). However, for general classes we need to exclude it (otherwise test.test_typing.ProtocolTests.test_collections_protocols_allowed fails).

See the difference in function body between _get_local_members and _get_local_protocol_members.



def _get_local_members(namespace):
"""Collect the specified attributes from a classes' namespace."""
annotations = namespace.get("__annotations__", {})
attrs = (namespace.keys() | annotations.keys()) - EXCLUDED_MEMBERS
# exclude special "_abc_" attributes
return {attr for attr in attrs if not attr.startswith('_abc_')}


def _get_local_protocol_members(namespace):
"""Collect the specified attributes from the protocols' namespace."""
# annotated attributes are always considered protocol members
annotations = namespace.get("__annotations__", {})
# only namespace members outside the excluded set are considered protocol members
return (namespace.keys() - EXCLUDED_ATTRIBUTES) | annotations.keys()


def _get_parent_members(cls):
"""Collect protocol members from parents of arbitrary class object.

This includes names actually defined in the class dictionary, as well
as names that appear in annotations. Special names (above) are skipped.
"""
attrs = set()
for base in cls.__mro__[1:-1]: # without self and object
if base.__name__ in {'Protocol', 'Generic'}:
continue
elif getattr(base, "_is_protocol", False):
attrs |= getattr(base, "__protocol_attrs__", set())
else: # get from annotations
attrs |= _get_local_members(base.__dict__)
return attrs


def _get_cls_members(cls):
"""Collect protocol members from an arbitrary class object.

This includes names actually defined in the class dictionary, as well
as names that appear in annotations. Special names (above) are skipped.
"""
attrs = set()
for base in cls.__mro__[:-1]: # without object
if base.__name__ in {'Protocol', 'Generic'}:
continue
elif getattr(base, "_is_protocol", False):
attrs |= getattr(base, "__protocol_attrs__", set())
else:
attrs |= _get_local_members(base.__dict__)
return attrs


def _get_protocol_attrs(cls):
Copy link
Contributor Author

@randolf-scholz randolf-scholz Nov 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the modifications here, _get_protocol_attrs is not used anymore, and it will even give some wrong results (e.g. if the Protocol defines __class_getitem__). So maybe just delete it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's not used anywhere anymore, then yes, it should be deleted. It's an undocumented, private function; if anybody's using it and their code is broken by us deleting it, that's on them. I can't find any uses of it on grep.app anyway, so I don't think anybody will have their code broken by us deleting it.

Expand All @@ -1696,7 +1748,7 @@ def _get_protocol_attrs(cls):
continue
annotations = getattr(base, '__annotations__', {})
for attr in (*base.__dict__, *annotations):
if not attr.startswith('_abc_') and attr not in EXCLUDED_ATTRIBUTES:
if not attr.startswith('_abc_') and attr not in EXCLUDED_MEMBERS:
attrs.add(attr)
return attrs

Expand Down Expand Up @@ -1804,10 +1856,14 @@ def __new__(mcls, name, bases, namespace, /, **kwargs):
)
return super().__new__(mcls, name, bases, namespace, **kwargs)

def __init__(cls, *args, **kwargs):
super().__init__(*args, **kwargs)
if getattr(cls, "_is_protocol", False):
cls.__protocol_attrs__ = _get_protocol_attrs(cls)
def __init__(cls, name, bases, namespace, **kwds):
super().__init__(name, bases, namespace, **kwds)
if getattr(cls, "_is_protocol", False) and cls.__name__ != "Protocol":
cls.__protocol_attrs__ = (
_get_local_protocol_members(namespace) | _get_parent_members(cls)
)
# local_attrs = _get_local_protocol_members(namespace)
# cls.__protocol_attrs__ = local_attrs.union(*map(_get_cls_members, bases))
Comment on lines +1862 to +1866
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure which variant is better...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only difference between _get_cls_members and _get_parent_members is that the former checks cls.__mro__[:-1] and the latter only cls.__mro__[1:-1].

I would expect the _get_parent_members to be better for deeply inherited classes, and the _get_cls_members to be better for classes with many disjoint parents.

# PEP 544 prohibits using issubclass()
# with protocols that have non-method members.
cls.__callable_proto_members_only__ = all(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
:class:`Protocol` classes now allow specification of previously excluded attributes and methods.

Example::

class Documented(Protocol):
"""A Protocol for documented classes."""
__doc__: str

class Slotted(Protocol):
"""A Protocol for classes with __slots__."""
__slots__: tuple[str, ...]

@runtime_checkable
class GenericIterable(Protocol):
"""An iterable that must also be a generic type."""
def __class_getitem__(cls, item): ...
def __iter__(self): ...

assert isinstance(["a", "b", "c"], GenericIterable) # ✅
assert not isinstance("abc", GenericIterable) # ✅

Patch by Randolf Scholz.