Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 25e479a

Browse files
committed
Fix incorrect dims in CompositeAffine [skip ci]
1 parent 2804271 commit 25e479a

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

lib/matplotlib/tests/test_transforms.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,6 @@ def test_rotate_around(self):
437437
assert_array_almost_equal(r90[2].transform(self.multiple_points), [
438438
[2, 2, 0], [-1, 0, 0], [2, 0, 4], [-3, 5, 0], [-4, 6, 6]])
439439

440-
441440
r_pi = [Affine3D().rotate_around(*self.pivot, np.pi, dim) for dim in range(3)]
442441
r180 = [Affine3D().rotate_deg_around(*self.pivot, 180, dim) for dim in range(3)]
443442

@@ -448,7 +447,6 @@ def test_rotate_around(self):
448447
assert_array_almost_equal(r180[2].transform(self.multiple_points), [
449448
[0, 2, 0], [2, -1, 0], [2, 2, 4], [-3, -3, 0], [-4, -4, 6]])
450449

451-
452450
r_pi_3_2 = [Affine3D().rotate_around(*self.pivot, 3 * np.pi / 2, dim)
453451
for dim in range(3)]
454452
r270 = [Affine3D().rotate_deg_around(*self.pivot, 270, dim) for dim in range(3)]
@@ -472,6 +470,16 @@ def test_rotate_around(self):
472470
assert_array_almost_equal(
473471
(r90[dim] + r180[dim]).get_matrix(), r270[dim].get_matrix())
474472

473+
def test_scale(self):
474+
sx = Affine3D().scale(3, 1, 1)
475+
sy = Affine3D().scale(1, -2, 1)
476+
sz = Affine3D().scale(1, 1, 4)
477+
trans = Affine3D().scale(3, -2, 4)
478+
assert_array_equal((sx + sy + sz).get_matrix(), trans.get_matrix())
479+
assert_array_equal(trans.transform(self.single_point), [3, -2, 4])
480+
assert_array_equal(trans.transform(self.multiple_points), [
481+
[6, 0, 0], [0, -6, 0], [0, 0, 16], [15, -10, 0], [18, -12, 24]])
482+
475483

476484
def test_non_affine_caching():
477485
class AssertingNonAffineTransform(mtransforms.Transform):

lib/matplotlib/transforms.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2833,10 +2833,8 @@ def __init__(self, a, b, **kwargs):
28332833
if a.output_dims != b.input_dims:
28342834
raise ValueError("The output dimension of 'a' must be equal to "
28352835
"the input dimensions of 'b'")
2836-
self.input_dims = a.input_dims
2837-
self.output_dims = b.output_dims
2836+
super().__init__(dims=a.output_dims, **kwargs)
28382837

2839-
super().__init__(**kwargs)
28402838
self._a = a
28412839
self._b = b
28422840
self.set_children(a, b)

0 commit comments

Comments
 (0)