Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Make all transforms copiable (and thus scales, too). #19281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions lib/matplotlib/tests/test_colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import matplotlib.colorbar as mcolorbar
import matplotlib.cbook as cbook
import matplotlib.pyplot as plt
import matplotlib.scale as mscale
from matplotlib.testing.decorators import image_comparison


Expand Down Expand Up @@ -1320,3 +1321,17 @@ def test_2d_to_rgba():
rgba_1d = mcolors.to_rgba(color.reshape(-1))
rgba_2d = mcolors.to_rgba(color.reshape((1, -1)))
assert rgba_1d == rgba_2d


def test_norm_deepcopy():
norm = mcolors.LogNorm()
norm.vmin = 0.0002
norm2 = copy.deepcopy(norm)
assert norm2.vmin == norm.vmin
assert isinstance(norm2._scale, mscale.LogScale)
norm = mcolors.Normalize()
norm.vmin = 0.0002
norm2 = copy.deepcopy(norm)
assert isinstance(norm2._scale, mscale.LinearScale)
assert norm2.vmin == norm.vmin
assert norm2._scale is not norm._scale
9 changes: 9 additions & 0 deletions lib/matplotlib/tests/test_scale.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

import matplotlib.pyplot as plt
from matplotlib.scale import (
LogTransform, InvertedLogTransform,
Expand Down Expand Up @@ -210,3 +212,10 @@ def test_pass_scale():
ax.set_yscale(scale)
assert ax.xaxis.get_scale() == 'log'
assert ax.yaxis.get_scale() == 'log'


def test_scale_deepcopy():
sc = mscale.LogScale(axis='x', base=10)
sc2 = copy.deepcopy(sc)
assert str(sc.get_transform()) == str(sc2.get_transform())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a rough check that there is no shared state, I'd add

assert sc._transform is not sc2._transform

or, if you're uncomfortable with testing private attributes:

assert sc.get_transform() is not sc2.get_transform()

But personally, I'd test the internal state here, because get_transform() could be written to return copies, so that testing get_transform() include an additional assumption.

Copy link
Contributor Author

@anntzer anntzer Jan 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes (although this revealed the need to separately implement deepcopy for transforms that just return self in frozen()).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 The test found an implementation flaw. It's working 😄

assert sc._transform is not sc2._transform
40 changes: 40 additions & 0 deletions lib/matplotlib/tests/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

import numpy as np
from numpy.testing import (assert_allclose, assert_almost_equal,
assert_array_equal, assert_array_almost_equal)
Expand Down Expand Up @@ -696,3 +698,41 @@ def test_lockable_bbox(locked_element):
assert getattr(locked, 'locked_' + locked_element) == 3
for elem in other_elements:
assert getattr(locked, elem) == getattr(orig, elem)


def test_copy():
a = mtransforms.Affine2D()
b = mtransforms.Affine2D()
s = a + b
# Updating a dependee should invalidate a copy of the dependent.
s.get_matrix() # resolve it.
s1 = copy.copy(s)
assert not s._invalid and not s1._invalid
a.translate(1, 2)
assert s._invalid and s1._invalid
assert (s1.get_matrix() == a.get_matrix()).all()
# Updating a copy of a dependee shouldn't invalidate a dependent.
s.get_matrix() # resolve it.
b1 = copy.copy(b)
b1.translate(3, 4)
assert not s._invalid
assert (s.get_matrix() == a.get_matrix()).all()


def test_deepcopy():
a = mtransforms.Affine2D()
b = mtransforms.Affine2D()
s = a + b
# Updating a dependee shouldn't invalidate a deepcopy of the dependent.
s.get_matrix() # resolve it.
s1 = copy.deepcopy(s)
assert not s._invalid and not s1._invalid
a.translate(1, 2)
assert s._invalid and not s1._invalid
assert (s1.get_matrix() == mtransforms.Affine2D().get_matrix()).all()
# Updating a deepcopy of a dependee shouldn't invalidate a dependent.
s.get_matrix() # resolve it.
b1 = copy.deepcopy(b)
b1.translate(3, 4)
assert not s._invalid
assert (s.get_matrix() == a.get_matrix()).all()
33 changes: 28 additions & 5 deletions lib/matplotlib/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
# `np.minimum` instead of the builtin `min`, and likewise for `max`. This is
# done so that `nan`s are propagated, instead of being silently dropped.

import copy
import functools
import textwrap
import weakref
Expand Down Expand Up @@ -139,11 +140,33 @@ def __setstate__(self, data_dict):
k: weakref.ref(v, lambda _, pop=self._parents.pop, k=k: pop(k))
for k, v in self._parents.items() if v is not None}

def __copy__(self, *args):
raise NotImplementedError(
"TransformNode instances can not be copied. "
"Consider using frozen() instead.")
__deepcopy__ = __copy__
def __copy__(self):
other = copy.copy(super())
# If `c = a + b; a1 = copy(a)`, then modifications to `a1` do not
# propagate back to `c`, i.e. we need to clear the parents of `a1`.
other._parents = {}
# If `c = a + b; c1 = copy(c)`, then modifications to `a` also need to
# be propagated to `c1`.
for key, val in vars(self).items():
if isinstance(val, TransformNode) and id(self) in val._parents:
other.set_children(val) # val == getattr(other, key)
return other

def __deepcopy__(self, memo):
# We could deepcopy the entire transform tree, but nothing except
# `self` is accessible publicly, so we may as well just freeze `self`.
other = self.frozen()
if other is not self:
return other
# Some classes implement frozen() as returning self, which is not
# acceptable for deepcopying, so we need to handle them separately.
other = copy.deepcopy(super(), memo)
# If `c = a + b; a1 = copy(a)`, then modifications to `a1` do not
# propagate back to `c`, i.e. we need to clear the parents of `a1`.
other._parents = {}
# If `c = a + b; c1 = copy(c)`, this creates a separate tree
# (`c1 = a1 + b1`) so nothing needs to be done.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so with a deepcopy the tranform is frozen, or at least to the point where someone would need to dig out the children? Does the copy.deepcopy above really copy the children over?

return other

def invalidate(self):
"""
Expand Down