Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit cba1283

Browse files
mdboomtacaswell
authored andcommitted
Merge pull request #4915 from QuLogic/transformwrapper-pickles
TransformWrapper pickling fixes
1 parent e99f1d5 commit cba1283

File tree

2 files changed

+43
-4
lines changed

2 files changed

+43
-4
lines changed

lib/matplotlib/tests/test_pickle.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from matplotlib.testing.decorators import cleanup, image_comparison
1414
import matplotlib.pyplot as plt
15+
import matplotlib.transforms as mtransforms
1516

1617

1718
def depth_getter(obj,
@@ -252,6 +253,34 @@ def test_polar():
252253
plt.draw()
253254

254255

256+
class TransformBlob(object):
257+
def __init__(self):
258+
self.identity = mtransforms.IdentityTransform()
259+
self.identity2 = mtransforms.IdentityTransform()
260+
# Force use of the more complex composition.
261+
self.composite = mtransforms.CompositeGenericTransform(
262+
self.identity,
263+
self.identity2)
264+
# Check parent -> child links of TransformWrapper.
265+
self.wrapper = mtransforms.TransformWrapper(self.composite)
266+
# Check child -> parent links of TransformWrapper.
267+
self.composite2 = mtransforms.CompositeGenericTransform(
268+
self.wrapper,
269+
self.identity)
270+
271+
272+
def test_transform():
273+
obj = TransformBlob()
274+
pf = pickle.dumps(obj)
275+
del obj
276+
277+
obj = pickle.loads(pf)
278+
# Check parent -> child links of TransformWrapper.
279+
assert_equal(obj.wrapper._child, obj.composite)
280+
# Check child -> parent links of TransformWrapper.
281+
assert_equal(list(obj.wrapper._parents.values()), [obj.composite2])
282+
283+
255284
if __name__ == '__main__':
256285
import nose
257286
nose.runmodule(argv=['-s'])

lib/matplotlib/transforms.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,6 +1534,10 @@ def __init__(self, child):
15341534
msg = ("'child' must be an instance of"
15351535
" 'matplotlib.transform.Transform'")
15361536
raise ValueError(msg)
1537+
self._init(child)
1538+
self.set_children(child)
1539+
1540+
def _init(self, child):
15371541
Transform.__init__(self)
15381542
self.input_dims = child.input_dims
15391543
self.output_dims = child.output_dims
@@ -1549,12 +1553,18 @@ def __str__(self):
15491553
return str(self._child)
15501554

15511555
def __getstate__(self):
1552-
# only store the child
1553-
return {'child': self._child}
1556+
# only store the child and parents
1557+
return {
1558+
'child': self._child,
1559+
# turn the weakkey dictionary into a normal dictionary
1560+
'parents': dict(six.iteritems(self._parents))
1561+
}
15541562

15551563
def __setstate__(self, state):
15561564
# re-initialise the TransformWrapper with the state's child
1557-
self.__init__(state['child'])
1565+
self._init(state['child'])
1566+
# turn the normal dictionary back into a WeakValueDictionary
1567+
self._parents = WeakValueDictionary(state['parents'])
15581568

15591569
def __repr__(self):
15601570
return "TransformWrapper(%r)" % self._child
@@ -1565,7 +1575,6 @@ def frozen(self):
15651575

15661576
def _set(self, child):
15671577
self._child = child
1568-
self.set_children(child)
15691578

15701579
self.transform = child.transform
15711580
self.transform_affine = child.transform_affine
@@ -1594,6 +1603,7 @@ def set(self, child):
15941603
" output dimensions as the current child.")
15951604
raise ValueError(msg)
15961605

1606+
self.set_children(child)
15971607
self._set(child)
15981608

15991609
self._invalid = 0

0 commit comments

Comments
 (0)