Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 72 additions & 14 deletions Lib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,11 +1222,6 @@ def _get_slots(cls):


def _update_func_cell_for__class__(f, oldcls, newcls):
# Returns True if we update a cell, else False.
if f is None:
# f will be None in the case of a property where not all of
# fget, fset, and fdel are used. Nothing to do in that case.
return False
try:
idx = f.__code__.co_freevars.index("__class__")
except ValueError:
Expand All @@ -1235,13 +1230,54 @@ def _update_func_cell_for__class__(f, oldcls, newcls):
# Fix the cell to point to the new class, if it's already pointing
# at the old class. I'm not convinced that the "is oldcls" test
# is needed, but other than performance can't hurt.
closure = f.__closure__[idx]
if closure.cell_contents is oldcls:
closure.cell_contents = newcls
cell = f.__closure__[idx]
if cell.cell_contents is oldcls:
cell.cell_contents = newcls
return True
return False


def _safe_get_attributes(obj):
# we should avoid triggering any user-defined code
# when inspecting attributes if possible

# look for __slots__ descriptors
type_dict = object.__getattribute__(type(obj), "__dict__")
for value in type_dict.values():
if isinstance(value, types.MemberDescriptorType):
yield value.__get__(obj)

instance_dict_descriptor = type_dict.get("__dict__", None)
if not isinstance(instance_dict_descriptor, types.GetSetDescriptorType):
# __dict__ is either not present, or redefined by user
# as custom descriptor, either way, we're done here
return

yield from instance_dict_descriptor.__get__(obj).values()


def _find_inner_functions(obj, seen=None, depth=0):
if seen is None:
seen = set()
if id(obj) in seen:
return None
seen.add(id(obj))

depth += 1
# Normally just an inspection of a descriptor object itself should be enough,
# and we should encounter the function as its attribute,
# but in case function was wrapped (e.g. functools.partial was used),
# we want to dive at least one level deeper.
if depth > 2:
return None

for value in _safe_get_attributes(obj):
if isinstance(value, types.FunctionType):
yield inspect.unwrap(value)
return
yield from _find_inner_functions(value, seen, depth)


def _create_slots(defined_fields, inherited_slots, field_names, weakref_slot):
# The slots for our class. Remove slots from our base classes. Add
# '__weakref__' if weakref_slot was given, unless it is already present.
Expand Down Expand Up @@ -1317,19 +1353,41 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields):
# (the newly created one, which we're returning) and not the
# original class. We can break out of this loop as soon as we
# make an update, since all closures for a class will share a
# given cell.
# given cell. First we try to find a pure function or a property,
# and then fallback to inspecting custom descriptors
# if no pure function or property is found.

custom_descriptors_to_check = []
for member in newcls.__dict__.values():
# If this is a wrapped function, unwrap it.
member = inspect.unwrap(member)

if isinstance(member, types.FunctionType):
if _update_func_cell_for__class__(member, cls, newcls):
break
elif isinstance(member, property):
if (_update_func_cell_for__class__(member.fget, cls, newcls)
or _update_func_cell_for__class__(member.fset, cls, newcls)
or _update_func_cell_for__class__(member.fdel, cls, newcls)):
break
elif isinstance(member, property) and (
any(
# Unwrap once more in case function
# was wrapped before it became property.
_update_func_cell_for__class__(inspect.unwrap(f), cls, newcls)
for f in (member.fget, member.fset, member.fdel)
if f is not None
)
):
break
elif hasattr(member, "__get__") and not inspect.ismemberdescriptor(
member
):
# We don't want to inspect custom descriptors just yet
# there's still a chance we'll encounter a pure function
# or a property and won't have to use slower recursive search.
custom_descriptors_to_check.append(member)
else:
# Now let's ensure custom descriptors won't be left out.
for descriptor in custom_descriptors_to_check:
for f in _find_inner_functions(descriptor):
if _update_func_cell_for__class__(f, cls, newcls):
break

return newcls

Expand Down
150 changes: 150 additions & 0 deletions Lib/test/test_dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import weakref
import traceback
import unittest
from functools import partial, update_wrapper
from unittest.mock import Mock
from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol, DefaultDict
from typing import get_type_hints
Expand Down Expand Up @@ -5031,6 +5032,155 @@ def foo(self):

A().foo()

def test_wrapped_property(self):
def mydecorator(f):
@wraps(f)
def wrapper(*args, **kwargs):
return f(*args, **kwargs)
return wrapper

class B:
@property
def foo(self):
return "bar"

@dataclass(slots=True)
class A(B):
@property
@mydecorator
def foo(self):
return super().foo

self.assertEqual(A().foo, "bar")

def test_custom_descriptor(self):
class CustomDescriptor:
def __init__(self, f):
self._f = f

def __get__(self, instance, owner):
return self._f(instance)

class B:
def foo(self):
return "bar"

@dataclass(slots=True)
class A(B):
@CustomDescriptor
def foo(cls):
return super().foo()

self.assertEqual(A().foo, "bar")

def test_custom_descriptor_wrapped(self):
class CustomDescriptor:
def __init__(self, f):
self._f = update_wrapper(lambda *args, **kwargs: f(*args, **kwargs), f)

def __get__(self, instance, owner):
return self._f(instance)

class B:
def foo(self):
return "bar"

@dataclass(slots=True)
class A(B):
@CustomDescriptor
def foo(cls):
return super().foo()

self.assertEqual(A().foo, "bar")

def test_custom_nested_descriptor(self):
class CustomFunctionWrapper:
def __init__(self, f):
self._f = f

def __call__(self, *args, **kwargs):
return self._f(*args, **kwargs)

class CustomDescriptor:
def __init__(self, f):
self._wrapper = CustomFunctionWrapper(f)

def __get__(self, instance, owner):
return self._wrapper(instance)

class B:
def foo(self):
return "bar"

@dataclass(slots=True)
class A(B):
@CustomDescriptor
def foo(cls):
return super().foo()

self.assertEqual(A().foo, "bar")

def test_custom_nested_descriptor_with_partial(self):
class CustomDescriptor:
def __init__(self, f):
self._wrapper = partial(f, value="bar")

def __get__(self, instance, owner):
return self._wrapper(instance)

class B:
def foo(self, value):
return value

@dataclass(slots=True)
class A(B):
@CustomDescriptor
def foo(self, value):
return super().foo(value)

self.assertEqual(A().foo, "bar")

def test_custom_too_nested_descriptor(self):
class UnnecessaryNestedWrapper:
def __init__(self, wrapper):
self._wrapper = wrapper

def __call__(self, *args, **kwargs):
return self._wrapper(*args, **kwargs)

class CustomFunctionWrapper:
def __init__(self, f):
self._f = f

def __call__(self, *args, **kwargs):
return self._f(*args, **kwargs)

class CustomDescriptor:
def __init__(self, f):
self._wrapper = UnnecessaryNestedWrapper(CustomFunctionWrapper(f))

def __get__(self, instance, owner):
return self._wrapper(instance)

class B:
def foo(self):
return "bar"

@dataclass(slots=True)
class A(B):
@CustomDescriptor
def foo(cls):
return super().foo()

with self.assertRaises(TypeError) as context:
A().foo

expected_error_message = (
'super(type, obj): obj (instance of A) is not '
'an instance or subtype of type (A).'
)
self.assertEqual(context.exception.args, (expected_error_message,))

def test_remembered_class(self):
# Apply the dataclass decorator manually (not when the class
# is created), so that we can keep a reference to the
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Modify dataclasses to enable zero argument support for ``super()`` when ``slots=True`` is
specified and custom descriptor is used or ``property`` function is wrapped.
Loading