From 17f0d96ed54a577868795bcc8d1ee0ca77f1fcde Mon Sep 17 00:00:00 2001 From: wookie184 Date: Mon, 1 May 2023 13:05:49 +0000 Subject: [PATCH] Update __class__ in function closures for dataclasses with slots=True --- Lib/dataclasses.py | 40 +++++++++++++-- Lib/test/test_dataclasses.py | 94 ++++++++++++++++++++++++++++++++++++ 2 files changed, 129 insertions(+), 5 deletions(-) diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index a73cdc22a5f4b3..75e1894ef48bb1 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1221,16 +1221,46 @@ def _add_slots(cls, is_frozen, weakref_slot): # And finally create the class. qualname = getattr(cls, '__qualname__', None) - cls = type(cls)(cls.__name__, cls.__bases__, cls_dict) + new_cls = type(cls)(cls.__name__, cls.__bases__, cls_dict) if qualname is not None: - cls.__qualname__ = qualname + new_cls.__qualname__ = qualname if is_frozen: # Need this for pickling frozen classes with slots. - cls.__getstate__ = _dataclass_getstate - cls.__setstate__ = _dataclass_setstate + new_cls.__getstate__ = _dataclass_getstate + new_cls.__setstate__ = _dataclass_setstate + + # References to the old class may still be stored in class function __closure__s. + # Try and replace them with the new class so, for example, super() will work. + for cls_attribute in cls_dict.values(): + # Unwrap classmethods, staticmethods, and decorators made using @functools.wraps. + cls_attribute = inspect.unwrap(cls_attribute) + + if isinstance(cls_attribute, property): + # Special case to support property. + for func in (cls_attribute.fget, cls_attribute.fset, cls_attribute.fdel): + _update_class_cell(cls, new_cls, func) + else: + _update_class_cell(cls, new_cls, cls_attribute) - return cls + return new_cls + + +def _update_class_cell(old_cls, new_cls, function): + # Ignore if the object isn't actually a function. + if not inspect.isfunction(function): + return + + # Function does not use __class__ or super(), ignore. + if function.__closure__ is None: + return + + # Look through function closure to find the __class__ cell, and update + # it with the new class if it previously contained the old class. + for name, cell in zip(function.__code__.co_freevars, function.__closure__): + if name == "__class__" and cell.cell_contents is old_cls: + cell.cell_contents = new_cls + break def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False, diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py index 7b48b26f9e7743..c90d354fdfdfb9 100644 --- a/Lib/test/test_dataclasses.py +++ b/Lib/test/test_dataclasses.py @@ -3331,6 +3331,100 @@ class A(Base): a_ref = weakref.ref(a) self.assertIs(a.__weakref__, a_ref) + def test_slots_super(self): + @dataclass(slots=True) + class Base: + a: int + def b(self): + return 2 + + def c(self): + return 3 + + @dataclass(slots=True) + class A(Base): + def b(self): + return super().b() + + def d(self): + return 4 + + a = A(1) + self.assertEqual(a.a, 1) + self.assertEqual(a.b(), 2) + self.assertEqual(a.c(), 3) + self.assertEqual(a.d(), 4) + + def test_slots_dunder_class(self): + class Base: + def foo(self): + return 1 + + class OtherClass(Base): + def foo(self): + return __class__ + + @dataclass(slots=True) + class A: + foo = OtherClass.foo + + def bar(self): + return __class__ + + a = A() + self.assertEqual(a.foo(), OtherClass) + self.assertEqual(a.bar(), A) + + def test_slots_dunder_class_decorated(self): + import functools + + def deco(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + return f(*args, **kwargs) + return wrapper + + @dataclass(slots=True) + class A(): + @classmethod + @deco + def b(cls): + return __class__ + + a = A() + self.assertEqual(a.b(), A) + + def test_slots_dunder_class_property_getter(self): + @dataclass(slots=True) + class A: + @property + def foo(slf): + return __class__ + + a = A() + self.assertEqual(a.foo, A) + + def test_slots_dunder_class_property_setter(self): + @dataclass(slots=True) + class A: + foo = property() + @foo.setter + def foo(slf, val): + self.assertEqual(__class__, type(slf)) + + a = A() + a.foo = 4 + + def test_slots_dunder_class_property_deleter(self): + @dataclass(slots=True) + class A: + foo = property() + @foo.deleter + def foo(slf): + self.assertEqual(__class__, type(slf)) + + a = A() + del a.foo class TestDescriptors(unittest.TestCase): def test_set_name(self):