Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7cf95f1

Browse files
committed
Add ability to blend any number of transforms
1 parent 12e7d01 commit 7cf95f1

File tree

3 files changed

+116
-92
lines changed

3 files changed

+116
-92
lines changed

lib/matplotlib/tests/test_transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ def test_str_transform():
850850
CompositeGenericTransform(
851851
CompositeGenericTransform(
852852
TransformWrapper(
853-
BlendedAffine2D(
853+
BlendedAffine(
854854
IdentityTransform(),
855855
IdentityTransform())),
856856
CompositeAffine2D(
@@ -864,7 +864,7 @@ def test_str_transform():
864864
CompositeGenericTransform(
865865
PolarAffine(
866866
TransformWrapper(
867-
BlendedAffine2D(
867+
BlendedAffine(
868868
IdentityTransform(),
869869
IdentityTransform())),
870870
LockableBbox(

lib/matplotlib/transforms.py

Lines changed: 106 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -2216,182 +2216,207 @@ def inverted(self):
22162216

22172217

22182218
class _BlendedMixin:
2219-
"""Common methods for `BlendedGenericTransform` and `BlendedAffine2D`."""
2219+
"""Common methods for `BlendedGenericTransform` and `BlendedAffine`."""
22202220

22212221
def __eq__(self, other):
2222-
if isinstance(other, (BlendedAffine2D, BlendedGenericTransform)):
2223-
return (self._x == other._x) and (self._y == other._y)
2224-
elif self._x == self._y:
2225-
return self._x == other
2222+
num_transforms = len(self._transforms)
2223+
2224+
if (isinstance(other, (BlendedGenericTransform, BlendedAffine))
2225+
and num_transforms == len(other._transforms)):
2226+
return all(self._transforms[i] == other._transforms[i]
2227+
for i in range(num_transforms))
22262228
else:
22272229
return NotImplemented
22282230

22292231
def contains_branch_seperately(self, transform):
2230-
return (self._x.contains_branch(transform),
2231-
self._y.contains_branch(transform))
2232+
return tuple(branch.contains_branch(transform) for branch in self._transforms)
22322233

2233-
__str__ = _make_str_method("_x", "_y")
2234+
def __str__(self):
2235+
indent = functools.partial(textwrap.indent, prefix=" " * 4)
2236+
return (
2237+
type(self).__name__ + "("
2238+
+ ",".join([*(indent("\n" + transform.__str__())
2239+
for transform in self._transforms)])
2240+
+ ")")
22342241

22352242

22362243
class BlendedGenericTransform(_BlendedMixin, Transform):
22372244
"""
2238-
A "blended" transform uses one transform for the *x*-direction, and
2239-
another transform for the *y*-direction.
2245+
A "blended" transform uses one transform for each direction
22402246
2241-
This "generic" version can handle any given child transform in the
2242-
*x*- and *y*-directions.
2247+
This "generic" version can handle any number of given child transforms, each
2248+
handling a different axis.
22432249
"""
2244-
input_dims = 2
2245-
output_dims = 2
22462250
is_separable = True
22472251
pass_through = True
22482252

2249-
def __init__(self, x_transform, y_transform, **kwargs):
2253+
def __init__(self, *args, **kwargs):
22502254
"""
2251-
Create a new "blended" transform using *x_transform* to transform the
2252-
*x*-axis and *y_transform* to transform the *y*-axis.
2255+
Create a new "blended" transform, with the first argument providing
2256+
a transform for the *x*-axis, the second argument providing a transform
2257+
for the *y*-axis, etc.
22532258
22542259
You will generally not call this constructor directly but use the
22552260
`blended_transform_factory` function instead, which can determine
22562261
automatically which kind of blended transform to create.
22572262
"""
2263+
self.input_dims = self.output_dims = len(args)
2264+
2265+
for i in range(self.input_dims):
2266+
transform = args[i]
2267+
if transform.input_dims > 1 and transform.input_dims <= i:
2268+
raise TypeError("Invalid transform provided to"
2269+
"`BlendedGenericTransform`")
2270+
22582271
Transform.__init__(self, **kwargs)
2259-
self._x = x_transform
2260-
self._y = y_transform
2261-
self.set_children(x_transform, y_transform)
2272+
self.set_children(*args)
2273+
self._transforms = args
22622274
self._affine = None
22632275

22642276
@property
22652277
def depth(self):
2266-
return max(self._x.depth, self._y.depth)
2278+
return max(transform.depth for transform in self._transforms)
22672279

22682280
def contains_branch(self, other):
22692281
# A blended transform cannot possibly contain a branch from two
22702282
# different transforms.
22712283
return False
22722284

2273-
is_affine = property(lambda self: self._x.is_affine and self._y.is_affine)
2274-
has_inverse = property(
2275-
lambda self: self._x.has_inverse and self._y.has_inverse)
2285+
is_affine = property(lambda self: all(transform.is_affine
2286+
for transform in self._transforms))
2287+
has_inverse = property(lambda self: all(transform.has_inverse
2288+
for transform in self._transforms))
22762289

22772290
def frozen(self):
22782291
# docstring inherited
2279-
return blended_transform_factory(self._x.frozen(), self._y.frozen())
2292+
return blended_transform_factory(*(transform.frozen()
2293+
for transform in self._transforms))
22802294

22812295
@_api.rename_parameter("3.8", "points", "values")
22822296
def transform_non_affine(self, values):
22832297
# docstring inherited
2284-
if self._x.is_affine and self._y.is_affine:
2298+
if self.is_affine:
22852299
return values
2286-
x = self._x
2287-
y = self._y
22882300

2289-
if x == y and x.input_dims == 2:
2290-
return x.transform_non_affine(values)
2301+
if all(transform == self._transforms[0]
2302+
for transform in self._transforms) and self.input_dims >= 2:
2303+
return self._transforms[0].transform_non_affine(values)
22912304

2292-
if x.input_dims == 2:
2293-
x_points = x.transform_non_affine(values)[:, 0:1]
2294-
else:
2295-
x_points = x.transform_non_affine(values[:, 0])
2296-
x_points = x_points.reshape((len(x_points), 1))
2305+
all_points = []
2306+
masked = False
22972307

2298-
if y.input_dims == 2:
2299-
y_points = y.transform_non_affine(values)[:, 1:]
2300-
else:
2301-
y_points = y.transform_non_affine(values[:, 1])
2302-
y_points = y_points.reshape((len(y_points), 1))
2308+
for dim in range(self.input_dims):
2309+
transform = self._transforms[dim]
2310+
if transform.input_dims == 1:
2311+
points = transform.transform_non_affine(values[:, dim])
2312+
points = points.reshape((len(points), 1))
2313+
else:
2314+
points = transform.transform_non_affine(values)[:, dim:dim+1]
23032315

2304-
if (isinstance(x_points, np.ma.MaskedArray) or
2305-
isinstance(y_points, np.ma.MaskedArray)):
2306-
return np.ma.concatenate((x_points, y_points), 1)
2316+
masked = masked or isinstance(points, np.ma.MaskedArray)
2317+
all_points.append(points)
2318+
2319+
if masked:
2320+
return np.ma.concatenate(tuple(all_points), 1)
23072321
else:
2308-
return np.concatenate((x_points, y_points), 1)
2322+
return np.concatenate(tuple(all_points), 1)
23092323

23102324
def inverted(self):
23112325
# docstring inherited
2312-
return BlendedGenericTransform(self._x.inverted(), self._y.inverted())
2326+
return BlendedGenericTransform(*(transform.inverted()
2327+
for transform in self._transforms))
23132328

23142329
def get_affine(self):
23152330
# docstring inherited
23162331
if self._invalid or self._affine is None:
2317-
if self._x == self._y:
2318-
self._affine = self._x.get_affine()
2332+
if all(transform == self._transforms[0] for transform in self._transforms):
2333+
self._affine = self._transforms[0].get_affine()
23192334
else:
2320-
x_mtx = self._x.get_affine().get_matrix()
2321-
y_mtx = self._y.get_affine().get_matrix()
2322-
# We already know the transforms are separable, so we can skip
2323-
# setting b and c to zero.
2324-
mtx = np.array([x_mtx[0], y_mtx[1], [0.0, 0.0, 1.0]])
2325-
self._affine = Affine2D(mtx)
2335+
mtx = np.identity(self.input_dims + 1)
2336+
for i in range(self.input_dims):
2337+
transform = self._transforms[i]
2338+
if transform.output_dims > 1:
2339+
mtx[i] = transform.get_affine().get_matrix()[i]
2340+
2341+
self._affine = _affine_factory(mtx, dims=self.input_dims)
23262342
self._invalid = 0
23272343
return self._affine
23282344

23292345

2330-
class BlendedAffine2D(_BlendedMixin, Affine2DBase):
2346+
class BlendedAffine(_BlendedMixin, AffineImmutable):
23312347
"""
23322348
A "blended" transform uses one transform for the *x*-direction, and
23332349
another transform for the *y*-direction.
23342350
23352351
This version is an optimization for the case where both child
2336-
transforms are of type `Affine2DBase`.
2352+
transforms are of type `AffineImmutable`.
23372353
"""
23382354

23392355
is_separable = True
23402356

2341-
def __init__(self, x_transform, y_transform, **kwargs):
2357+
def __init__(self, *args, **kwargs):
23422358
"""
2343-
Create a new "blended" transform using *x_transform* to transform the
2344-
*x*-axis and *y_transform* to transform the *y*-axis.
2359+
Create a new "blended" transform, with the first argument providing
2360+
a transform for the *x*-axis, the second argument providing a transform
2361+
for the *y*-axis, etc.
23452362
2346-
Both *x_transform* and *y_transform* must be 2D affine transforms.
2363+
All provided transforms must be affine transforms.
23472364
23482365
You will generally not call this constructor directly but use the
23492366
`blended_transform_factory` function instead, which can determine
23502367
automatically which kind of blended transform to create.
23512368
"""
2352-
is_affine = x_transform.is_affine and y_transform.is_affine
2353-
is_separable = x_transform.is_separable and y_transform.is_separable
2354-
is_correct = is_affine and is_separable
2355-
if not is_correct:
2356-
raise ValueError("Both *x_transform* and *y_transform* must be 2D "
2357-
"affine transforms")
2358-
23592369
Transform.__init__(self, **kwargs)
2360-
self._x = x_transform
2361-
self._y = y_transform
2362-
self.set_children(x_transform, y_transform)
2370+
AffineImmutable.__init__(self, **kwargs)
2371+
2372+
if not all(transform.is_affine and transform.is_separable
2373+
for transform in args):
2374+
raise ValueError("Given transforms must be affine")
2375+
2376+
for i in range(self.input_dims):
2377+
transform = args[i]
2378+
if transform.input_dims > 1 and transform.input_dims <= i:
2379+
raise TypeError("Invalid transform provided to"
2380+
"`BlendedGenericTransform`")
2381+
2382+
self._transforms = args
2383+
self.set_children(*args)
23632384

2364-
Affine2DBase.__init__(self)
23652385
self._mtx = None
23662386

23672387
def get_matrix(self):
23682388
# docstring inherited
23692389
if self._invalid:
2370-
if self._x == self._y:
2371-
self._mtx = self._x.get_matrix()
2390+
if all(transform == self._transforms[0] for transform in self._transforms):
2391+
self._mtx = self._transforms[0].get_matrix()
23722392
else:
2373-
x_mtx = self._x.get_matrix()
2374-
y_mtx = self._y.get_matrix()
23752393
# We already know the transforms are separable, so we can skip
2376-
# setting b and c to zero.
2377-
self._mtx = np.array([x_mtx[0], y_mtx[1], [0.0, 0.0, 1.0]])
2394+
# setting non-diagonal values to zero.
2395+
self._mtx = np.array(
2396+
[self._transforms[i].get_affine().get_matrix()[i]
2397+
for i in range(self.input_dims)] +
2398+
[[0.0] * self.input_dims + [1.0]])
23782399
self._inverted = None
23792400
self._invalid = 0
23802401
return self._mtx
23812402

23822403

2383-
def blended_transform_factory(x_transform, y_transform):
2404+
@_api.deprecated("3.9", alternative="BlendedAffine")
2405+
class BlendedAffine2D(BlendedAffine):
2406+
pass
2407+
2408+
2409+
def blended_transform_factory(*args):
23842410
"""
23852411
Create a new "blended" transform using *x_transform* to transform
23862412
the *x*-axis and *y_transform* to transform the *y*-axis.
23872413
23882414
A faster version of the blended transform is returned for the case
23892415
where both child transforms are affine.
23902416
"""
2391-
if (isinstance(x_transform, Affine2DBase) and
2392-
isinstance(y_transform, Affine2DBase)):
2393-
return BlendedAffine2D(x_transform, y_transform)
2394-
return BlendedGenericTransform(x_transform, y_transform)
2417+
if all(isinstance(transform, AffineImmutable) for transform in args):
2418+
return BlendedAffine(*args)
2419+
return BlendedGenericTransform(*args)
23952420

23962421

23972422
class CompositeGenericTransform(Transform):

lib/matplotlib/transforms.pyi

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -258,26 +258,25 @@ class _BlendedMixin:
258258
def contains_branch_seperately(self, transform: Transform) -> Sequence[bool]: ...
259259

260260
class BlendedGenericTransform(_BlendedMixin, Transform):
261-
input_dims: Literal[2]
262-
output_dims: Literal[2]
263261
pass_through: bool
264262
def __init__(
265-
self, x_transform: Transform, y_transform: Transform, **kwargs
263+
self, *args: Transform, **kwargs
266264
) -> None: ...
267265
@property
268266
def depth(self) -> int: ...
269267
def contains_branch(self, other: Transform) -> Literal[False]: ...
270268
@property
271269
def is_affine(self) -> bool: ...
272270

273-
class BlendedAffine2D(_BlendedMixin, Affine2DBase):
274-
def __init__(
275-
self, x_transform: Transform, y_transform: Transform, **kwargs
276-
) -> None: ...
271+
class BlendedAffine(_BlendedMixin, AffineImmutable):
272+
def __init__(self, *args: Transform, **kwargs) -> None: ...
273+
274+
class BlendedAffine2D(BlendedAffine):
275+
pass
277276

278277
def blended_transform_factory(
279-
x_transform: Transform, y_transform: Transform
280-
) -> BlendedGenericTransform | BlendedAffine2D: ...
278+
*args: Transform
279+
) -> BlendedGenericTransform | BlendedAffine: ...
281280

282281
class CompositeGenericTransform(Transform):
283282
pass_through: bool

0 commit comments

Comments
 (0)