diff --git a/lib/matplotlib/tests/test_colors.py b/lib/matplotlib/tests/test_colors.py index c17edfa1bf0a..f40f6d20e846 100644 --- a/lib/matplotlib/tests/test_colors.py +++ b/lib/matplotlib/tests/test_colors.py @@ -16,6 +16,7 @@ import matplotlib.colorbar as mcolorbar import matplotlib.cbook as cbook import matplotlib.pyplot as plt +import matplotlib.scale as mscale from matplotlib.testing.decorators import image_comparison @@ -1320,3 +1321,17 @@ def test_2d_to_rgba(): rgba_1d = mcolors.to_rgba(color.reshape(-1)) rgba_2d = mcolors.to_rgba(color.reshape((1, -1))) assert rgba_1d == rgba_2d + + +def test_norm_deepcopy(): + norm = mcolors.LogNorm() + norm.vmin = 0.0002 + norm2 = copy.deepcopy(norm) + assert norm2.vmin == norm.vmin + assert isinstance(norm2._scale, mscale.LogScale) + norm = mcolors.Normalize() + norm.vmin = 0.0002 + norm2 = copy.deepcopy(norm) + assert isinstance(norm2._scale, mscale.LinearScale) + assert norm2.vmin == norm.vmin + assert norm2._scale is not norm._scale diff --git a/lib/matplotlib/tests/test_scale.py b/lib/matplotlib/tests/test_scale.py index e4577a6b017c..8fba86d2e82e 100644 --- a/lib/matplotlib/tests/test_scale.py +++ b/lib/matplotlib/tests/test_scale.py @@ -1,3 +1,5 @@ +import copy + import matplotlib.pyplot as plt from matplotlib.scale import ( LogTransform, InvertedLogTransform, @@ -210,3 +212,10 @@ def test_pass_scale(): ax.set_yscale(scale) assert ax.xaxis.get_scale() == 'log' assert ax.yaxis.get_scale() == 'log' + + +def test_scale_deepcopy(): + sc = mscale.LogScale(axis='x', base=10) + sc2 = copy.deepcopy(sc) + assert str(sc.get_transform()) == str(sc2.get_transform()) + assert sc._transform is not sc2._transform diff --git a/lib/matplotlib/tests/test_transforms.py b/lib/matplotlib/tests/test_transforms.py index a7ddbf770d7e..ad572fe5287a 100644 --- a/lib/matplotlib/tests/test_transforms.py +++ b/lib/matplotlib/tests/test_transforms.py @@ -1,3 +1,5 @@ +import copy + import numpy as np from numpy.testing import (assert_allclose, assert_almost_equal, assert_array_equal, assert_array_almost_equal) @@ -696,3 +698,41 @@ def test_lockable_bbox(locked_element): assert getattr(locked, 'locked_' + locked_element) == 3 for elem in other_elements: assert getattr(locked, elem) == getattr(orig, elem) + + +def test_copy(): + a = mtransforms.Affine2D() + b = mtransforms.Affine2D() + s = a + b + # Updating a dependee should invalidate a copy of the dependent. + s.get_matrix() # resolve it. + s1 = copy.copy(s) + assert not s._invalid and not s1._invalid + a.translate(1, 2) + assert s._invalid and s1._invalid + assert (s1.get_matrix() == a.get_matrix()).all() + # Updating a copy of a dependee shouldn't invalidate a dependent. + s.get_matrix() # resolve it. + b1 = copy.copy(b) + b1.translate(3, 4) + assert not s._invalid + assert (s.get_matrix() == a.get_matrix()).all() + + +def test_deepcopy(): + a = mtransforms.Affine2D() + b = mtransforms.Affine2D() + s = a + b + # Updating a dependee shouldn't invalidate a deepcopy of the dependent. + s.get_matrix() # resolve it. + s1 = copy.deepcopy(s) + assert not s._invalid and not s1._invalid + a.translate(1, 2) + assert s._invalid and not s1._invalid + assert (s1.get_matrix() == mtransforms.Affine2D().get_matrix()).all() + # Updating a deepcopy of a dependee shouldn't invalidate a dependent. + s.get_matrix() # resolve it. + b1 = copy.deepcopy(b) + b1.translate(3, 4) + assert not s._invalid + assert (s.get_matrix() == a.get_matrix()).all() diff --git a/lib/matplotlib/transforms.py b/lib/matplotlib/transforms.py index bfc2915872c4..b2491b8cf8cc 100644 --- a/lib/matplotlib/transforms.py +++ b/lib/matplotlib/transforms.py @@ -33,6 +33,7 @@ # `np.minimum` instead of the builtin `min`, and likewise for `max`. This is # done so that `nan`s are propagated, instead of being silently dropped. +import copy import functools import textwrap import weakref @@ -139,11 +140,33 @@ def __setstate__(self, data_dict): k: weakref.ref(v, lambda _, pop=self._parents.pop, k=k: pop(k)) for k, v in self._parents.items() if v is not None} - def __copy__(self, *args): - raise NotImplementedError( - "TransformNode instances can not be copied. " - "Consider using frozen() instead.") - __deepcopy__ = __copy__ + def __copy__(self): + other = copy.copy(super()) + # If `c = a + b; a1 = copy(a)`, then modifications to `a1` do not + # propagate back to `c`, i.e. we need to clear the parents of `a1`. + other._parents = {} + # If `c = a + b; c1 = copy(c)`, then modifications to `a` also need to + # be propagated to `c1`. + for key, val in vars(self).items(): + if isinstance(val, TransformNode) and id(self) in val._parents: + other.set_children(val) # val == getattr(other, key) + return other + + def __deepcopy__(self, memo): + # We could deepcopy the entire transform tree, but nothing except + # `self` is accessible publicly, so we may as well just freeze `self`. + other = self.frozen() + if other is not self: + return other + # Some classes implement frozen() as returning self, which is not + # acceptable for deepcopying, so we need to handle them separately. + other = copy.deepcopy(super(), memo) + # If `c = a + b; a1 = copy(a)`, then modifications to `a1` do not + # propagate back to `c`, i.e. we need to clear the parents of `a1`. + other._parents = {} + # If `c = a + b; c1 = copy(c)`, this creates a separate tree + # (`c1 = a1 + b1`) so nothing needs to be done. + return other def invalidate(self): """