From 1eb783523eac4a78553f721608a2138fc6bd8dab Mon Sep 17 00:00:00 2001 From: Yurii Karabas <1998uriyyo@gmail.com> Date: Thu, 2 Sep 2021 12:44:52 +0300 Subject: [PATCH 1/5] bpo-45081: Fix dataclass __init__ method generation when inherit from Protocol --- Lib/dataclasses.py | 5 +++++ Lib/test/test_dataclasses.py | 19 ++++++++++++++++++- .../2021-09-02-12-42-25.bpo-45081.tOjJ1k.rst | 2 ++ 3 files changed, 25 insertions(+), 1 deletion(-) create mode 100644 Misc/NEWS.d/next/Library/2021-09-02-12-42-25.bpo-45081.tOjJ1k.rst diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index 95ff39287bed61..6b4ab589e35eec 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1015,6 +1015,11 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, # Does this class have a post-init function? has_post_init = hasattr(cls, _POST_INIT_NAME) + # typing.Protocol can override __init__ method to object.__init__ + inherits_from_protocol = any(getattr(c, '_is_protocol', False) for c in cls.__bases__) + if inherits_from_protocol and cls.__dict__.get('__init__') is object.__init__: + del cls.__init__ + _set_new_attribute(cls, '__init__', _init_fn(all_init_fields, std_init_fields, diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py index 8e645aeb4a7503..f1e0df3e529105 100644 --- a/Lib/test/test_dataclasses.py +++ b/Lib/test/test_dataclasses.py @@ -10,7 +10,7 @@ import builtins import unittest from unittest.mock import Mock -from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional +from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol from typing import get_type_hints from collections import deque, OrderedDict, namedtuple from functools import total_ordering @@ -2150,6 +2150,23 @@ def __init__(self, x): self.x = 2 * x self.assertEqual(C(5).x, 10) + def test_inherit_from_protocol(self): + class P(Protocol): + a: int + + @dataclass + class C(P): + a: int + + self.assertEqual(C(5).a, 5) + + @dataclass + class D(P): + def __init__(self, a): + self.a = a * 2 + + self.assertEqual(D(5).a, 10) + class TestRepr(unittest.TestCase): def test_repr(self): diff --git a/Misc/NEWS.d/next/Library/2021-09-02-12-42-25.bpo-45081.tOjJ1k.rst b/Misc/NEWS.d/next/Library/2021-09-02-12-42-25.bpo-45081.tOjJ1k.rst new file mode 100644 index 00000000000000..86d7182003bb93 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-09-02-12-42-25.bpo-45081.tOjJ1k.rst @@ -0,0 +1,2 @@ +Fix issue when dataclasses that inherit from ``typing.Protocol`` subclasses +have wrong ``__init__``. Patch provided by Yurii Karabas. From 084840a29bf70abbea70ef4c11da3ef14fc042f7 Mon Sep 17 00:00:00 2001 From: Yurii Karabas <1998uriyyo@gmail.com> Date: Thu, 2 Sep 2021 15:02:34 +0300 Subject: [PATCH 2/5] Lazly override __init__ method of a Protocol subclasses --- Lib/dataclasses.py | 5 ----- Lib/test/test_dataclasses.py | 2 ++ Lib/typing.py | 29 +++++++++++++++++++---------- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index 6b4ab589e35eec..95ff39287bed61 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1015,11 +1015,6 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, # Does this class have a post-init function? has_post_init = hasattr(cls, _POST_INIT_NAME) - # typing.Protocol can override __init__ method to object.__init__ - inherits_from_protocol = any(getattr(c, '_is_protocol', False) for c in cls.__bases__) - if inherits_from_protocol and cls.__dict__.get('__init__') is object.__init__: - del cls.__init__ - _set_new_attribute(cls, '__init__', _init_fn(all_init_fields, std_init_fields, diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py index f1e0df3e529105..6663d31ce6df50 100644 --- a/Lib/test/test_dataclasses.py +++ b/Lib/test/test_dataclasses.py @@ -2151,6 +2151,8 @@ def __init__(self, x): self.assertEqual(C(5).x, 10) def test_inherit_from_protocol(self): + # See bpo-45081. + class P(Protocol): a: int diff --git a/Lib/typing.py b/Lib/typing.py index 35c57c21b37c21..4beefdb4be223b 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -1401,7 +1401,25 @@ def _is_callable_members_only(cls): def _no_init(self, *args, **kwargs): - raise TypeError('Protocols cannot be instantiated') + cls = type(self) + + if cls._is_protocol: + raise TypeError('Protocols cannot be instantiated') + + # set correct __init__ method on a first initialization + # so all further initialization will call it directly + # see bpo-45081 + for base in cls.__mro__: + init = base.__dict__.get('__init__', _no_init) + if init is not _no_init: + cls.__init__ = init + break + else: + # should not happen + cls.__init__ = object.__init__ + + cls.__init__(self, *args, **kwargs) + def _caller(depth=1, default='__main__'): try: @@ -1541,15 +1559,6 @@ def _proto_hook(other): # We have nothing more to do for non-protocols... if not cls._is_protocol: - if cls.__init__ == _no_init: - for base in cls.__mro__: - init = base.__dict__.get('__init__', _no_init) - if init != _no_init: - cls.__init__ = init - break - else: - # should not happen - cls.__init__ = object.__init__ return # ... otherwise check consistency of bases, and prohibit instantiation. From 6d1be7435c470e8114b0fba9c83c3ca35309d14e Mon Sep 17 00:00:00 2001 From: Yurii Karabas <1998uriyyo@gmail.com> Date: Thu, 2 Sep 2021 16:30:14 +0300 Subject: [PATCH 3/5] Update Lib/test/test_dataclasses.py Co-authored-by: Ken Jin <28750310+Fidget-Spinner@users.noreply.github.com> --- Lib/test/test_dataclasses.py | 1 + 1 file changed, 1 insertion(+) diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py index 6663d31ce6df50..33c9fcd1656219 100644 --- a/Lib/test/test_dataclasses.py +++ b/Lib/test/test_dataclasses.py @@ -2151,6 +2151,7 @@ def __init__(self, x): self.assertEqual(C(5).x, 10) def test_inherit_from_protocol(self): + # Dataclasses inheriting from protocol should preserve their own `__init__`. # See bpo-45081. class P(Protocol): From 91b011edbc1f6754abd781f42459f6c0ef400083 Mon Sep 17 00:00:00 2001 From: Yurii Karabas <1998uriyyo@gmail.com> Date: Thu, 2 Sep 2021 16:32:33 +0300 Subject: [PATCH 4/5] Update Ptocol init method --- Lib/typing.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/Lib/typing.py b/Lib/typing.py index 4beefdb4be223b..71cfe66b4587d2 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -1400,18 +1400,21 @@ def _is_callable_members_only(cls): return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls)) -def _no_init(self, *args, **kwargs): +def _no_init_or_replace_init(self, *args, **kwargs): cls = type(self) if cls._is_protocol: raise TypeError('Protocols cannot be instantiated') - # set correct __init__ method on a first initialization - # so all further initialization will call it directly - # see bpo-45081 + # Initially, `__init__` of a protocol subclass is set to `_no_init`. + # The first instantiation of the subclass will call `_no_init` which + # searches for a proper new `__init__` in the MRO. The new `__init__` + # replaces the subclass' old `__init__` (ie `_no_init`). Subsequent + # instantiation of the protocol subclass will thus use the new + # `__init__` and no longer call `_no_init`. for base in cls.__mro__: - init = base.__dict__.get('__init__', _no_init) - if init is not _no_init: + init = base.__dict__.get('__init__', _no_init_or_replace_init) + if init is not _no_init_or_replace_init: cls.__init__ = init break else: @@ -1569,7 +1572,7 @@ def _proto_hook(other): issubclass(base, Generic) and base._is_protocol): raise TypeError('Protocols can only inherit from other' ' protocols, got %r' % base) - cls.__init__ = _no_init + cls.__init__ = _no_init_or_replace_init class _AnnotatedAlias(_GenericAlias, _root=True): From 8dc6a2ed074be0d157ca021d5290523c7277ce56 Mon Sep 17 00:00:00 2001 From: Yurii Karabas <1998uriyyo@gmail.com> Date: Thu, 2 Sep 2021 16:49:49 +0300 Subject: [PATCH 5/5] Update comment --- Lib/typing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Lib/typing.py b/Lib/typing.py index 71cfe66b4587d2..892f1b3506851d 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -1406,12 +1406,12 @@ def _no_init_or_replace_init(self, *args, **kwargs): if cls._is_protocol: raise TypeError('Protocols cannot be instantiated') - # Initially, `__init__` of a protocol subclass is set to `_no_init`. - # The first instantiation of the subclass will call `_no_init` which + # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. + # The first instantiation of the subclass will call `_no_init_or_replace_init` which # searches for a proper new `__init__` in the MRO. The new `__init__` - # replaces the subclass' old `__init__` (ie `_no_init`). Subsequent + # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent # instantiation of the protocol subclass will thus use the new - # `__init__` and no longer call `_no_init`. + # `__init__` and no longer call `_no_init_or_replace_init`. for base in cls.__mro__: init = base.__dict__.get('__init__', _no_init_or_replace_init) if init is not _no_init_or_replace_init: