Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 3a63e5e

Browse files
committed
Make all transforms copiable (and thus scales, too).
1 parent ee2bee3 commit 3a63e5e

File tree

4 files changed

+75
-5
lines changed

4 files changed

+75
-5
lines changed

lib/matplotlib/tests/test_colors.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import matplotlib.colorbar as mcolorbar
1717
import matplotlib.cbook as cbook
1818
import matplotlib.pyplot as plt
19+
import matplotlib.scale as mscale
1920
from matplotlib.testing.decorators import image_comparison
2021

2122

@@ -1320,3 +1321,16 @@ def test_2d_to_rgba():
13201321
rgba_1d = mcolors.to_rgba(color.reshape(-1))
13211322
rgba_2d = mcolors.to_rgba(color.reshape((1, -1)))
13221323
assert rgba_1d == rgba_2d
1324+
1325+
1326+
def test_norm_deepcopy():
1327+
norm = mcolors.LogNorm()
1328+
norm.vmin = 0.0002
1329+
norm2 = copy.deepcopy(norm)
1330+
assert norm2.vmin == norm.vmin
1331+
assert isinstance(norm2._scale, mscale.LogScale)
1332+
norm = mcolors.Normalize()
1333+
norm.vmin = 0.0002
1334+
norm2 = copy.deepcopy(norm)
1335+
assert isinstance(norm2._scale, mscale.LinearScale)
1336+
assert norm2.vmin == norm.vmin

lib/matplotlib/tests/test_scale.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import copy
2+
13
import matplotlib.pyplot as plt
24
from matplotlib.scale import (
35
LogTransform, InvertedLogTransform,
@@ -210,3 +212,9 @@ def test_pass_scale():
210212
ax.set_yscale(scale)
211213
assert ax.xaxis.get_scale() == 'log'
212214
assert ax.yaxis.get_scale() == 'log'
215+
216+
217+
def test_scale_deepcopy():
218+
sc = mscale.LogScale(axis='x', base=10)
219+
sc2 = copy.deepcopy(sc)
220+
assert str(sc.get_transform()) == str(sc2.get_transform())

lib/matplotlib/tests/test_transforms.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import copy
2+
13
import numpy as np
24
from numpy.testing import (assert_allclose, assert_almost_equal,
35
assert_array_equal, assert_array_almost_equal)
@@ -696,3 +698,37 @@ def test_lockable_bbox(locked_element):
696698
assert getattr(locked, 'locked_' + locked_element) == 3
697699
for elem in other_elements:
698700
assert getattr(locked, elem) == getattr(orig, elem)
701+
702+
703+
def test_copy():
704+
a = mtransforms.Affine2D()
705+
b = mtransforms.Affine2D()
706+
s = a + b
707+
s.get_matrix() # resolve it.
708+
s1 = copy.copy(s)
709+
assert not s._invalid and not s1._invalid
710+
a.translate(1, 2)
711+
assert s._invalid and s1._invalid
712+
assert (s1.get_matrix() == a.get_matrix()).all()
713+
s.get_matrix() # resolve it.
714+
b1 = copy.copy(b)
715+
b1.translate(3, 4)
716+
assert not s._invalid
717+
assert (s.get_matrix() == a.get_matrix()).all()
718+
719+
720+
def test_deepcopy():
721+
a = mtransforms.Affine2D()
722+
b = mtransforms.Affine2D()
723+
s = a + b
724+
s.get_matrix() # resolve it.
725+
s1 = copy.deepcopy(s)
726+
assert not s._invalid and not s1._invalid
727+
a.translate(1, 2)
728+
assert s._invalid and not s1._invalid
729+
assert (s1.get_matrix() == mtransforms.Affine2D().get_matrix()).all()
730+
s.get_matrix() # resolve it.
731+
b1 = copy.deepcopy(b)
732+
b1.translate(3, 4)
733+
assert not s._invalid
734+
assert (s.get_matrix() == a.get_matrix()).all()

lib/matplotlib/transforms.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
# `np.minimum` instead of the builtin `min`, and likewise for `max`. This is
3434
# done so that `nan`s are propagated, instead of being silently dropped.
3535

36+
import copy
3637
import functools
3738
import textwrap
3839
import weakref
@@ -139,11 +140,22 @@ def __setstate__(self, data_dict):
139140
k: weakref.ref(v, lambda _, pop=self._parents.pop, k=k: pop(k))
140141
for k, v in self._parents.items() if v is not None}
141142

142-
def __copy__(self, *args):
143-
raise NotImplementedError(
144-
"TransformNode instances can not be copied. "
145-
"Consider using frozen() instead.")
146-
__deepcopy__ = __copy__
143+
def __copy__(self):
144+
other = copy.copy(super())
145+
# If `c = a + b; a1 = copy(a)`, then modifications to `a1` do not
146+
# propagate back to `c`, i.e. we need to clear the parents of `a1`.
147+
other._parents = {}
148+
# If `c = a + b; c1 = copy(c)`, then modifications to `a` also need to
149+
# be propagated to `c1`.
150+
for key, val in vars(self).items():
151+
if isinstance(val, TransformNode) and id(self) in val._parents:
152+
other.set_children(val) # val == getattr(other, key)
153+
return other
154+
155+
def __deepcopy__(self, memo):
156+
# We could deepcopy the entire transform tree, but nothing except
157+
# `self` is accessible publicly, so we may as well just freeze `self`.
158+
return self.frozen()
147159

148160
def invalidate(self):
149161
"""

0 commit comments

Comments
 (0)