Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c755758

Browse files
committed
Support all the new stuff supported by the new pickle code:
- subclasses of list or dict - __reduce__ returning a 4-tuple or 5-tuple - slots
1 parent 0189266 commit c755758

2 files changed

Lines changed: 109 additions & 13 deletions

File tree

Lib/copy.py

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
x = copy.copy(y) # make a shallow copy of y
88
x = copy.deepcopy(y) # make a deep copy of y
99
10-
For module specific errors, copy.error is raised.
10+
For module specific errors, copy.Error is raised.
1111
1212
The difference between shallow and deep copying is only relevant for
1313
compound objects (objects that contain other objects, like lists or
@@ -51,6 +51,7 @@ class instances).
5151
# XXX need to support copy_reg here too...
5252

5353
import types
54+
from pickle import _slotnames
5455

5556
class Error(Exception):
5657
pass
@@ -61,7 +62,7 @@ class Error(Exception):
6162
except ImportError:
6263
PyStringMap = None
6364

64-
__all__ = ["Error", "error", "copy", "deepcopy"]
65+
__all__ = ["Error", "copy", "deepcopy"]
6566

6667
def copy(x):
6768
"""Shallow copy operation on arbitrary Python objects.
@@ -76,18 +77,60 @@ def copy(x):
7677
copier = x.__copy__
7778
except AttributeError:
7879
try:
79-
reductor = x.__reduce__
80+
reductor = x.__class__.__reduce__
81+
if reductor == object.__reduce__:
82+
reductor = _better_reduce
8083
except AttributeError:
81-
raise error, \
82-
"un(shallow)copyable object of type %s" % type(x)
84+
raise Error("un(shallow)copyable object of type %s" % type(x))
8385
else:
84-
y = _reconstruct(x, reductor(), 0)
86+
y = _reconstruct(x, reductor(x), 0)
8587
else:
8688
y = copier()
8789
else:
8890
y = copierfunction(x)
8991
return y
9092

93+
def __newobj__(cls, *args):
94+
return cls.__new__(cls, *args)
95+
96+
def _better_reduce(obj):
97+
cls = obj.__class__
98+
getnewargs = getattr(obj, "__getnewargs__", None)
99+
if getnewargs:
100+
args = getnewargs()
101+
else:
102+
args = ()
103+
getstate = getattr(obj, "__getstate__", None)
104+
if getstate:
105+
try:
106+
state = getstate()
107+
except TypeError, err:
108+
# XXX Catch generic exception caused by __slots__
109+
if str(err) != ("a class that defines __slots__ "
110+
"without defining __getstate__ "
111+
"cannot be pickled"):
112+
raise # Not that specific exception
113+
getstate = None
114+
if not getstate:
115+
state = getattr(obj, "__dict__", None)
116+
names = _slotnames(cls)
117+
if names:
118+
slots = {}
119+
nil = []
120+
for name in names:
121+
value = getattr(obj, name, nil)
122+
if value is not nil:
123+
slots[name] = value
124+
if slots:
125+
state = (state, slots)
126+
listitems = dictitems = None
127+
if isinstance(obj, list):
128+
listitems = iter(obj)
129+
elif isinstance(obj, dict):
130+
dictitems = obj.iteritems()
131+
return __newobj__, (cls, args), state, listitems, dictitems
132+
133+
91134
_copy_dispatch = d = {}
92135

93136
def _copy_atomic(x):
@@ -175,12 +218,14 @@ def deepcopy(x, memo = None):
175218
copier = x.__deepcopy__
176219
except AttributeError:
177220
try:
178-
reductor = x.__reduce__
221+
reductor = x.__class__.__reduce__
222+
if reductor == object.__reduce__:
223+
reductor = _better_reduce
179224
except AttributeError:
180-
raise error, \
181-
"un-deep-copyable object of type %s" % type(x)
225+
raise Error("un(shallow)copyable object of type %s" %
226+
type(x))
182227
else:
183-
y = _reconstruct(x, reductor(), 1, memo)
228+
y = _reconstruct(x, reductor(x), 1, memo)
184229
else:
185230
y = copier(memo)
186231
else:
@@ -331,7 +376,15 @@ def _reconstruct(x, info, deep, memo=None):
331376
if hasattr(y, '__setstate__'):
332377
y.__setstate__(state)
333378
else:
334-
y.__dict__.update(state)
379+
if isinstance(state, tuple) and len(state) == 2:
380+
state, slotstate = state
381+
else:
382+
slotstate = None
383+
if state is not None:
384+
y.__dict__.update(state)
385+
if slotstate is not None:
386+
for key, value in slotstate.iteritems():
387+
setattr(y, key, value)
335388
return y
336389

337390
del d

Lib/test/test_copy.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,13 @@ def __reduce__(self):
4141
self.assert_(y is x)
4242

4343
def test_copy_cant(self):
44-
class C(object):
44+
class Meta(type):
4545
def __getattribute__(self, name):
4646
if name == "__reduce__":
4747
raise AttributeError, name
4848
return object.__getattribute__(self, name)
49+
class C:
50+
__metaclass__ = Meta
4951
x = C()
5052
self.assertRaises(copy.Error, copy.copy, x)
5153

@@ -189,11 +191,13 @@ def __reduce__(self):
189191
self.assert_(y is x)
190192

191193
def test_deepcopy_cant(self):
192-
class C(object):
194+
class Meta(type):
193195
def __getattribute__(self, name):
194196
if name == "__reduce__":
195197
raise AttributeError, name
196198
return object.__getattribute__(self, name)
199+
class C:
200+
__metaclass__ = Meta
197201
x = C()
198202
self.assertRaises(copy.Error, copy.deepcopy, x)
199203

@@ -411,6 +415,45 @@ def __cmp__(self, other):
411415
self.assert_(x is not y)
412416
self.assert_(x["foo"] is not y["foo"])
413417

418+
def test_copy_slots(self):
419+
class C(object):
420+
__slots__ = ["foo"]
421+
x = C()
422+
x.foo = [42]
423+
y = copy.copy(x)
424+
self.assert_(x.foo is y.foo)
425+
426+
def test_deepcopy_slots(self):
427+
class C(object):
428+
__slots__ = ["foo"]
429+
x = C()
430+
x.foo = [42]
431+
y = copy.deepcopy(x)
432+
self.assertEqual(x.foo, y.foo)
433+
self.assert_(x.foo is not y.foo)
434+
435+
def test_copy_list_subclass(self):
436+
class C(list):
437+
pass
438+
x = C([[1, 2], 3])
439+
x.foo = [4, 5]
440+
y = copy.copy(x)
441+
self.assertEqual(list(x), list(y))
442+
self.assertEqual(x.foo, y.foo)
443+
self.assert_(x[0] is y[0])
444+
self.assert_(x.foo is y.foo)
445+
446+
def test_deepcopy_list_subclass(self):
447+
class C(list):
448+
pass
449+
x = C([[1, 2], 3])
450+
x.foo = [4, 5]
451+
y = copy.deepcopy(x)
452+
self.assertEqual(list(x), list(y))
453+
self.assertEqual(x.foo, y.foo)
454+
self.assert_(x[0] is not y[0])
455+
self.assert_(x.foo is not y.foo)
456+
414457
def test_main():
415458
suite = unittest.TestSuite()
416459
suite.addTest(unittest.makeSuite(TestCopy))

0 commit comments

Comments
 (0)