Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 91d3329

Browse files
committed
Add ability to blend any number of transforms
1 parent f683fc7 commit 91d3329

File tree

3 files changed

+120
-94
lines changed

3 files changed

+120
-94
lines changed

lib/matplotlib/tests/test_transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ def test_str_transform():
850850
CompositeGenericTransform(
851851
CompositeGenericTransform(
852852
TransformWrapper(
853-
BlendedAffine2D(
853+
BlendedAffine(
854854
IdentityTransform(),
855855
IdentityTransform())),
856856
CompositeAffine2D(
@@ -864,7 +864,7 @@ def test_str_transform():
864864
CompositeGenericTransform(
865865
PolarAffine(
866866
TransformWrapper(
867-
BlendedAffine2D(
867+
BlendedAffine(
868868
IdentityTransform(),
869869
IdentityTransform())),
870870
LockableBbox(

lib/matplotlib/transforms.py

Lines changed: 110 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -2203,182 +2203,208 @@ def inverted(self):
22032203

22042204

22052205
class _BlendedMixin:
2206-
"""Common methods for `BlendedGenericTransform` and `BlendedAffine2D`."""
2206+
"""Common methods for `BlendedGenericTransform` and `BlendedAffine`."""
22072207

22082208
def __eq__(self, other):
2209-
if isinstance(other, (BlendedAffine2D, BlendedGenericTransform)):
2210-
return (self._x == other._x) and (self._y == other._y)
2211-
elif self._x == self._y:
2212-
return self._x == other
2209+
num_transforms = len(self._transforms)
2210+
2211+
if (isinstance(other, (BlendedGenericTransform, BlendedAffine))
2212+
and num_transforms == len(other._transforms)):
2213+
return all(self._transforms[i] == other._transforms[i]
2214+
for i in range(num_transforms))
22132215
else:
22142216
return NotImplemented
22152217

22162218
def contains_branch_seperately(self, transform):
2217-
return (self._x.contains_branch(transform),
2218-
self._y.contains_branch(transform))
2219+
return tuple(branch.contains_branch(transform) for branch in self._transforms)
22192220

2220-
__str__ = _make_str_method("_x", "_y")
2221+
def __str__(self):
2222+
indent = functools.partial(textwrap.indent, prefix=" " * 4)
2223+
return (
2224+
type(self).__name__ + "("
2225+
+ ",".join([*(indent("\n" + transform.__str__())
2226+
for transform in self._transforms)])
2227+
+ ")")
22212228

22222229

22232230
class BlendedGenericTransform(_BlendedMixin, Transform):
22242231
"""
2225-
A "blended" transform uses one transform for the *x*-direction, and
2226-
another transform for the *y*-direction.
2232+
A "blended" transform uses one transform for each direction
22272233
2228-
This "generic" version can handle any given child transform in the
2229-
*x*- and *y*-directions.
2234+
This "generic" version can handle any number of given child transforms, each
2235+
handling a different axis.
22302236
"""
2231-
input_dims = 2
2232-
output_dims = 2
22332237
is_separable = True
22342238
pass_through = True
22352239

2236-
def __init__(self, x_transform, y_transform, **kwargs):
2240+
def __init__(self, *args, **kwargs):
22372241
"""
2238-
Create a new "blended" transform using *x_transform* to transform the
2239-
*x*-axis and *y_transform* to transform the *y*-axis.
2242+
Create a new "blended" transform, with the first argument providing
2243+
a transform for the *x*-axis, the second argument providing a transform
2244+
for the *y*-axis, etc.
22402245
22412246
You will generally not call this constructor directly but use the
22422247
`blended_transform_factory` function instead, which can determine
22432248
automatically which kind of blended transform to create.
22442249
"""
2250+
self.input_dims = self.output_dims = len(args)
2251+
2252+
for i in range(self.input_dims):
2253+
transform = args[i]
2254+
if transform.input_dims > 1 and transform.input_dims <= i:
2255+
raise TypeError("Invalid transform provided to"
2256+
"`BlendedGenericTransform`")
2257+
22452258
Transform.__init__(self, **kwargs)
2246-
self._x = x_transform
2247-
self._y = y_transform
2248-
self.set_children(x_transform, y_transform)
2259+
self.set_children(*args)
2260+
self._transforms = args
22492261
self._affine = None
22502262

22512263
@property
22522264
def depth(self):
2253-
return max(self._x.depth, self._y.depth)
2265+
return max(transform.depth for transform in self._transforms)
22542266

22552267
def contains_branch(self, other):
22562268
# A blended transform cannot possibly contain a branch from two
22572269
# different transforms.
22582270
return False
22592271

2260-
is_affine = property(lambda self: self._x.is_affine and self._y.is_affine)
2261-
has_inverse = property(
2262-
lambda self: self._x.has_inverse and self._y.has_inverse)
2272+
is_affine = property(lambda self: all(transform.is_affine
2273+
for transform in self._transforms))
2274+
has_inverse = property(lambda self: all(transform.has_inverse
2275+
for transform in self._transforms))
22632276

22642277
def frozen(self):
22652278
# docstring inherited
2266-
return blended_transform_factory(self._x.frozen(), self._y.frozen())
2279+
return blended_transform_factory(*(transform.frozen()
2280+
for transform in self._transforms))
22672281

22682282
@_api.rename_parameter("3.8", "points", "values")
22692283
def transform_non_affine(self, values):
22702284
# docstring inherited
2271-
if self._x.is_affine and self._y.is_affine:
2285+
if self.is_affine:
22722286
return values
2273-
x = self._x
2274-
y = self._y
22752287

2276-
if x == y and x.input_dims == 2:
2277-
return x.transform_non_affine(values)
2288+
if all(transform == self._transforms[0]
2289+
for transform in self._transforms) and self.input_dims >= 2:
2290+
return self._transforms[0].transform_non_affine(values)
22782291

2279-
if x.input_dims == 2:
2280-
x_points = x.transform_non_affine(values)[:, 0:1]
2281-
else:
2282-
x_points = x.transform_non_affine(values[:, 0])
2283-
x_points = x_points.reshape((len(x_points), 1))
2292+
all_points = []
2293+
masked = False
22842294

2285-
if y.input_dims == 2:
2286-
y_points = y.transform_non_affine(values)[:, 1:]
2287-
else:
2288-
y_points = y.transform_non_affine(values[:, 1])
2289-
y_points = y_points.reshape((len(y_points), 1))
2295+
for dim in range(self.input_dims):
2296+
transform = self._transforms[dim]
2297+
if transform.input_dims == 1:
2298+
points = transform.transform_non_affine(values[:, dim])
2299+
points = points.reshape((len(points), 1))
2300+
else:
2301+
points = transform.transform_non_affine(values)[:, dim:dim+1]
22902302

2291-
if (isinstance(x_points, np.ma.MaskedArray) or
2292-
isinstance(y_points, np.ma.MaskedArray)):
2293-
return np.ma.concatenate((x_points, y_points), 1)
2303+
masked = masked or isinstance(points, np.ma.MaskedArray)
2304+
all_points.append(points)
2305+
2306+
if masked:
2307+
return np.ma.concatenate(tuple(all_points), 1)
22942308
else:
2295-
return np.concatenate((x_points, y_points), 1)
2309+
return np.concatenate(tuple(all_points), 1)
22962310

22972311
def inverted(self):
22982312
# docstring inherited
2299-
return BlendedGenericTransform(self._x.inverted(), self._y.inverted())
2313+
return BlendedGenericTransform(*(transform.inverted()
2314+
for transform in self._transforms))
23002315

23012316
def get_affine(self):
23022317
# docstring inherited
23032318
if self._invalid or self._affine is None:
2304-
if self._x == self._y:
2305-
self._affine = self._x.get_affine()
2319+
if all(transform == self._transforms[0] for transform in self._transforms):
2320+
self._affine = self._transforms[0].get_affine()
23062321
else:
2307-
x_mtx = self._x.get_affine().get_matrix()
2308-
y_mtx = self._y.get_affine().get_matrix()
2309-
# We already know the transforms are separable, so we can skip
2310-
# setting b and c to zero.
2311-
mtx = np.array([x_mtx[0], y_mtx[1], [0.0, 0.0, 1.0]])
2312-
self._affine = Affine2D(mtx)
2322+
mtx = np.identity(self.input_dims + 1)
2323+
for i in range(self.input_dims):
2324+
transform = self._transforms[i]
2325+
if transform.output_dims > 1:
2326+
mtx[i] = transform.get_affine().get_matrix()[i]
2327+
2328+
self._affine = _affine_factory(mtx, dims=self.input_dims)
23132329
self._invalid = 0
23142330
return self._affine
23152331

23162332

2317-
class BlendedAffine2D(_BlendedMixin, Affine2DBase):
2333+
class BlendedAffine(_BlendedMixin, AffineImmutable):
23182334
"""
23192335
A "blended" transform uses one transform for the *x*-direction, and
23202336
another transform for the *y*-direction.
23212337
23222338
This version is an optimization for the case where both child
2323-
transforms are of type `Affine2DBase`.
2339+
transforms are of type `AffineImmutable`.
23242340
"""
23252341

23262342
is_separable = True
23272343

2328-
def __init__(self, x_transform, y_transform, **kwargs):
2344+
def __init__(self, *args, **kwargs):
23292345
"""
2330-
Create a new "blended" transform using *x_transform* to transform the
2331-
*x*-axis and *y_transform* to transform the *y*-axis.
2346+
Create a new "blended" transform, with the first argument providing
2347+
a transform for the *x*-axis, the second argument providing a transform
2348+
for the *y*-axis, etc.
23322349
2333-
Both *x_transform* and *y_transform* must be 2D affine transforms.
2350+
All provided transforms must be affine transforms.
23342351
23352352
You will generally not call this constructor directly but use the
23362353
`blended_transform_factory` function instead, which can determine
23372354
automatically which kind of blended transform to create.
23382355
"""
2339-
is_affine = x_transform.is_affine and y_transform.is_affine
2340-
is_separable = x_transform.is_separable and y_transform.is_separable
2341-
is_correct = is_affine and is_separable
2342-
if not is_correct:
2343-
raise ValueError("Both *x_transform* and *y_transform* must be 2D "
2344-
"affine transforms")
2345-
2356+
dims = len(args)
23462357
Transform.__init__(self, **kwargs)
2347-
self._x = x_transform
2348-
self._y = y_transform
2349-
self.set_children(x_transform, y_transform)
2358+
AffineImmutable.__init__(self, dims=dims, **kwargs)
2359+
2360+
if not all(transform.is_affine and transform.is_separable
2361+
for transform in args):
2362+
raise ValueError("Given transforms must be affine")
2363+
2364+
for i in range(self.input_dims):
2365+
transform = args[i]
2366+
if transform.input_dims > 1 and transform.input_dims <= i:
2367+
raise TypeError("Invalid transform provided to"
2368+
"`BlendedGenericTransform`")
2369+
2370+
self._transforms = args
2371+
self.set_children(*args)
23502372

2351-
Affine2DBase.__init__(self)
23522373
self._mtx = None
23532374

23542375
def get_matrix(self):
23552376
# docstring inherited
23562377
if self._invalid:
2357-
if self._x == self._y:
2358-
self._mtx = self._x.get_matrix()
2378+
if all(transform == self._transforms[0] for transform in self._transforms):
2379+
self._mtx = self._transforms[0].get_matrix()
23592380
else:
2360-
x_mtx = self._x.get_matrix()
2361-
y_mtx = self._y.get_matrix()
23622381
# We already know the transforms are separable, so we can skip
2363-
# setting b and c to zero.
2364-
self._mtx = np.array([x_mtx[0], y_mtx[1], [0.0, 0.0, 1.0]])
2382+
# setting non-diagonal values to zero.
2383+
self._mtx = np.array(
2384+
[self._transforms[i].get_affine().get_matrix()[i]
2385+
for i in range(self.input_dims)] +
2386+
[[0.0] * self.input_dims + [1.0]])
23652387
self._inverted = None
23662388
self._invalid = 0
23672389
return self._mtx
23682390

23692391

2370-
def blended_transform_factory(x_transform, y_transform):
2392+
@_api.deprecated("3.9", alternative="BlendedAffine")
2393+
class BlendedAffine2D(BlendedAffine):
2394+
pass
2395+
2396+
2397+
def blended_transform_factory(*args):
23712398
"""
23722399
Create a new "blended" transform using *x_transform* to transform
23732400
the *x*-axis and *y_transform* to transform the *y*-axis.
23742401
23752402
A faster version of the blended transform is returned for the case
23762403
where both child transforms are affine.
23772404
"""
2378-
if (isinstance(x_transform, Affine2DBase) and
2379-
isinstance(y_transform, Affine2DBase)):
2380-
return BlendedAffine2D(x_transform, y_transform)
2381-
return BlendedGenericTransform(x_transform, y_transform)
2405+
if all(isinstance(transform, AffineImmutable) for transform in args):
2406+
return BlendedAffine(*args)
2407+
return BlendedGenericTransform(*args)
23822408

23832409

23842410
class CompositeGenericTransform(Transform):
@@ -2479,8 +2505,9 @@ def get_affine(self):
24792505
if not self._b.is_affine:
24802506
return self._b.get_affine()
24812507
else:
2482-
return Affine2D(np.dot(self._b.get_affine().get_matrix(),
2483-
self._a.get_affine().get_matrix()))
2508+
return _affine_factory(np.dot(self._b.get_affine().get_matrix(),
2509+
self._a.get_affine().get_matrix()),
2510+
dims=self.input_dims)
24842511

24852512
def inverted(self):
24862513
# docstring inherited

lib/matplotlib/transforms.pyi

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -258,26 +258,25 @@ class _BlendedMixin:
258258
def contains_branch_seperately(self, transform: Transform) -> Sequence[bool]: ...
259259

260260
class BlendedGenericTransform(_BlendedMixin, Transform):
261-
input_dims: Literal[2]
262-
output_dims: Literal[2]
263261
pass_through: bool
264262
def __init__(
265-
self, x_transform: Transform, y_transform: Transform, **kwargs
263+
self, *args: Transform, **kwargs
266264
) -> None: ...
267265
@property
268266
def depth(self) -> int: ...
269267
def contains_branch(self, other: Transform) -> Literal[False]: ...
270268
@property
271269
def is_affine(self) -> bool: ...
272270

273-
class BlendedAffine2D(_BlendedMixin, Affine2DBase):
274-
def __init__(
275-
self, x_transform: Transform, y_transform: Transform, **kwargs
276-
) -> None: ...
271+
class BlendedAffine(_BlendedMixin, AffineImmutable):
272+
def __init__(self, *args: Transform, **kwargs) -> None: ...
273+
274+
class BlendedAffine2D(BlendedAffine):
275+
pass
277276

278277
def blended_transform_factory(
279-
x_transform: Transform, y_transform: Transform
280-
) -> BlendedGenericTransform | BlendedAffine2D: ...
278+
*args: Transform
279+
) -> BlendedGenericTransform | BlendedAffine: ...
281280

282281
class CompositeGenericTransform(Transform):
283282
pass_through: bool

0 commit comments

Comments
 (0)