Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e62d61e

Browse files
committed
Make all transforms copiable (and thus scales, too).
1 parent ee2bee3 commit e62d61e

File tree

3 files changed

+43
-5
lines changed

3 files changed

+43
-5
lines changed

lib/matplotlib/tests/test_colors.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import matplotlib.colorbar as mcolorbar
1717
import matplotlib.cbook as cbook
1818
import matplotlib.pyplot as plt
19+
import matplotlib.scale as mscale
1920
from matplotlib.testing.decorators import image_comparison
2021

2122

@@ -1320,3 +1321,16 @@ def test_2d_to_rgba():
13201321
rgba_1d = mcolors.to_rgba(color.reshape(-1))
13211322
rgba_2d = mcolors.to_rgba(color.reshape((1, -1)))
13221323
assert rgba_1d == rgba_2d
1324+
1325+
1326+
def test_norm_deepcopy():
1327+
norm = mcolors.LogNorm()
1328+
norm.vmin = 0.0002
1329+
norm2 = copy.deepcopy(norm)
1330+
assert norm2.vmin == norm.vmin
1331+
assert isinstance(norm2._scale, mscale.LogScale)
1332+
norm = mcolors.Normalize()
1333+
norm.vmin = 0.0002
1334+
norm2 = copy.deepcopy(norm)
1335+
assert isinstance(norm2._scale, mscale.LinearScale)
1336+
assert norm2.vmin == norm.vmin

lib/matplotlib/tests/test_scale.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import copy
2+
13
import matplotlib.pyplot as plt
24
from matplotlib.scale import (
35
LogTransform, InvertedLogTransform,
@@ -210,3 +212,9 @@ def test_pass_scale():
210212
ax.set_yscale(scale)
211213
assert ax.xaxis.get_scale() == 'log'
212214
assert ax.yaxis.get_scale() == 'log'
215+
216+
217+
def test_scale_deepcopy():
218+
sc = mscale.LogScale(axis='x', base=10)
219+
sc2 = copy.deepcopy(sc)
220+
assert str(sc.get_transform()) == str(sc2.get_transform())

lib/matplotlib/transforms.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
# `np.minimum` instead of the builtin `min`, and likewise for `max`. This is
3434
# done so that `nan`s are propagated, instead of being silently dropped.
3535

36+
import copy
3637
import functools
3738
import textwrap
3839
import weakref
@@ -139,11 +140,26 @@ def __setstate__(self, data_dict):
139140
k: weakref.ref(v, lambda _, pop=self._parents.pop, k=k: pop(k))
140141
for k, v in self._parents.items() if v is not None}
141142

142-
def __copy__(self, *args):
143-
raise NotImplementedError(
144-
"TransformNode instances can not be copied. "
145-
"Consider using frozen() instead.")
146-
__deepcopy__ = __copy__
143+
def __copy__(self):
144+
other = copy.copy(super())
145+
# If `c = a + b; a1 = copy(a)`, then modifications to `a1` do not
146+
# propagate back to `c`, i.e. we need to clear the parents of `a1`.
147+
other._parents = {}
148+
# If `c = a + b; c1 = copy(c)`, then modifications to `a` also need to
149+
# be propagated to `c1`.
150+
for key, val in vars(self).items():
151+
if isinstance(val, TransformNode) and id(self) in val._parents:
152+
other.set_children(val) # val == getattr(other, key)
153+
return other
154+
155+
def __deepcopy__(self, memo):
156+
other = copy.deepcopy(super(), memo)
157+
# If `c = a + b; a1 = copy(a)`, then modifications to `a1` do not
158+
# propagate back to `c`, i.e. we need to clear the parents of `a1`.
159+
other._parents = {}
160+
# If `c = a + b; c1 = copy(c)`, this creates a separate tree
161+
# (`c1 = a1 + b1`) so nothing needs to be done.
162+
return other
147163

148164
def invalidate(self):
149165
"""

0 commit comments

Comments
 (0)