From a83af6a5a89077c3fbd38e8ff476f3a08184d43a Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Sun, 12 Nov 2023 13:50:13 +0100 Subject: [PATCH] Move axisartist towards untransposed transforms. While Matplotlib normally represents lists of (x, y) coordinates as (N, 2) arrays and transforms (which we'll call "trans") have shape signature (N, 2) -> (N, 2), axisartist uses the opposite convention of using (2, N) arrays (or size-2 tuples of 1D arrays) and transforms (which it typically calls "transform_xy"). Change that and go back to Matplotlib's standard represenation in some of axisartist's internal representations for consistency. Also replace some uses of (x1, y1, x2, y2) quadruplets by single Bbox objects, which avoid having to keep track of the order of the points (is it x1, y1, x2, y2 or x1, x2, y1, y2?). - Add a `_find_transformed_bbox(trans, bbox)` API to ExtremeFinderSimple and its subclasses, replacing `__call__(transform_xy, x1, y1, x2, y2)`. (I intentionally did not overload `__call__`'s signature yet nor did I deprecate it for now; we can consider doing that later.) - Deprecate `GridFinder.{,inv_}transform_xy`, which implement the transposed transform API. - Switch `grid_info["extremes"]` from quadruplet representation to Bbox. - Switch `grid_info["lon"]["lines"]` and likewise for "lat" from list-of-size-1-lists-of-pairs-of-1D-arrays to list-of-(N, 2)-arrays. - Switch `grid_info["line_xy"]` from pair-of-1D-arrays to a (N, 2) array. - Let `_get_grid_info` and `_get_raw_grid_lines` take a Bbox as (last) argument instead of 4 coordinates. Note that I intentionally mostly didn't touch (transpose) public-facing APIs for now, this may happen later. --- .../deprecations/27551-AL.rst | 4 + lib/mpl_toolkits/axisartist/angle_helper.py | 21 ++--- lib/mpl_toolkits/axisartist/axislines.py | 6 +- lib/mpl_toolkits/axisartist/floating_axes.py | 39 ++++----- lib/mpl_toolkits/axisartist/grid_finder.py | 84 +++++++++---------- .../axisartist/grid_helper_curvelinear.py | 47 +++++------ 6 files changed, 98 insertions(+), 103 deletions(-) create mode 100644 doc/api/next_api_changes/deprecations/27551-AL.rst diff --git a/doc/api/next_api_changes/deprecations/27551-AL.rst b/doc/api/next_api_changes/deprecations/27551-AL.rst new file mode 100644 index 000000000000..4811f542bd5f --- /dev/null +++ b/doc/api/next_api_changes/deprecations/27551-AL.rst @@ -0,0 +1,4 @@ +``GridFinder.transform_xy`` and ``GridFinder.inv_transform_xy`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +... are deprecated. Directly use the standard transform returned by +`.GridFinder.get_transform` instead. diff --git a/lib/mpl_toolkits/axisartist/angle_helper.py b/lib/mpl_toolkits/axisartist/angle_helper.py index 1786cd70bcdb..56b461e4a1d3 100644 --- a/lib/mpl_toolkits/axisartist/angle_helper.py +++ b/lib/mpl_toolkits/axisartist/angle_helper.py @@ -1,6 +1,7 @@ import numpy as np import math +from matplotlib.transforms import Bbox from mpl_toolkits.axisartist.grid_finder import ExtremeFinderSimple @@ -347,11 +348,12 @@ def __init__(self, nx, ny, self.lon_minmax = lon_minmax self.lat_minmax = lat_minmax - def __call__(self, transform_xy, x1, y1, x2, y2): + def _find_transformed_bbox(self, trans, bbox): # docstring inherited - x, y = np.meshgrid( - np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny)) - lon, lat = transform_xy(np.ravel(x), np.ravel(y)) + grid = np.reshape(np.meshgrid(np.linspace(bbox.x0, bbox.x1, self.nx), + np.linspace(bbox.y0, bbox.y1, self.ny)), + (2, -1)).T + lon, lat = trans.transform(grid).T # iron out jumps, but algorithm should be improved. # This is just naive way of doing and my fail for some cases. @@ -367,11 +369,10 @@ def __call__(self, transform_xy, x1, y1, x2, y2): lat0 = np.nanmin(lat) lat -= 360. * ((lat - lat0) > 180.) - lon_min, lon_max = np.nanmin(lon), np.nanmax(lon) - lat_min, lat_max = np.nanmin(lat), np.nanmax(lat) - - lon_min, lon_max, lat_min, lat_max = \ - self._add_pad(lon_min, lon_max, lat_min, lat_max) + tbbox = Bbox.null() + tbbox.update_from_data_xy(np.column_stack([lon, lat])) + tbbox = tbbox.expanded(1 + 2 / self.nx, 1 + 2 / self.ny) + lon_min, lat_min, lon_max, lat_max = tbbox.extents # check cycle if self.lon_cycle: @@ -391,4 +392,4 @@ def __call__(self, transform_xy, x1, y1, x2, y2): max0 = self.lat_minmax[1] lat_max = min(max0, lat_max) - return lon_min, lon_max, lat_min, lat_max + return Bbox.from_extents(lon_min, lat_min, lon_max, lat_max) diff --git a/lib/mpl_toolkits/axisartist/axislines.py b/lib/mpl_toolkits/axisartist/axislines.py index 8d06cb236269..c0379f11b8d4 100644 --- a/lib/mpl_toolkits/axisartist/axislines.py +++ b/lib/mpl_toolkits/axisartist/axislines.py @@ -45,6 +45,8 @@ from matplotlib import _api import matplotlib.axes as maxes from matplotlib.path import Path +from matplotlib.transforms import Bbox + from mpl_toolkits.axes_grid1 import mpl_axes from .axisline_style import AxislineStyle # noqa from .axis_artist import AxisArtist, GridlinesCollection @@ -285,10 +287,10 @@ def update_lim(self, axes): x1, x2 = axes.get_xlim() y1, y2 = axes.get_ylim() if self._old_limits != (x1, x2, y1, y2): - self._update_grid(x1, y1, x2, y2) + self._update_grid(Bbox.from_extents(x1, y1, x2, y2)) self._old_limits = (x1, x2, y1, y2) - def _update_grid(self, x1, y1, x2, y2): + def _update_grid(self, bbox): """Cache relevant computations when the axes limits have changed.""" def get_gridlines(self, which, axis): diff --git a/lib/mpl_toolkits/axisartist/floating_axes.py b/lib/mpl_toolkits/axisartist/floating_axes.py index 74e4c941879b..1893955046e6 100644 --- a/lib/mpl_toolkits/axisartist/floating_axes.py +++ b/lib/mpl_toolkits/axisartist/floating_axes.py @@ -13,9 +13,8 @@ from matplotlib import _api, cbook import matplotlib.patches as mpatches from matplotlib.path import Path - +from matplotlib.transforms import Bbox from mpl_toolkits.axes_grid1.parasite_axes import host_axes_class_factory - from . import axislines, grid_helper_curvelinear from .axis_artist import AxisArtist from .grid_finder import ExtremeFinderSimple @@ -109,8 +108,7 @@ def get_line(self, axes): right=("lon_lines0", 1), bottom=("lat_lines0", 0), top=("lat_lines0", 1))[self._side] - xx, yy = self._grid_info[k][v] - return Path(np.column_stack([xx, yy])) + return Path(self._grid_info[k][v]) class ExtremeFinderFixed(ExtremeFinderSimple): @@ -125,11 +123,12 @@ def __init__(self, extremes): extremes : (float, float, float, float) The bounding box that this helper always returns. """ - self._extremes = extremes + x0, x1, y0, y1 = extremes + self._tbbox = Bbox.from_extents(x0, y0, x1, y1) - def __call__(self, transform_xy, x1, y1, x2, y2): + def _find_transformed_bbox(self, trans, bbox): # docstring inherited - return self._extremes + return self._tbbox class GridHelperCurveLinear(grid_helper_curvelinear.GridHelperCurveLinear): @@ -177,25 +176,22 @@ def new_fixed_axis( # axis.get_helper().set_extremes(*self._extremes[2:]) # return axis - def _update_grid(self, x1, y1, x2, y2): + def _update_grid(self, bbox): if self._grid_info is None: self._grid_info = dict() grid_info = self._grid_info grid_finder = self.grid_finder - extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy, - x1, y1, x2, y2) + tbbox = grid_finder.extreme_finder._find_transformed_bbox( + grid_finder.get_transform().inverted(), bbox) - lon_min, lon_max = sorted(extremes[:2]) - lat_min, lat_max = sorted(extremes[2:]) - grid_info["extremes"] = lon_min, lon_max, lat_min, lat_max # extremes + lon_min, lat_min, lon_max, lat_max = tbbox.extents + grid_info["extremes"] = tbbox - lon_levs, lon_n, lon_factor = \ - grid_finder.grid_locator1(lon_min, lon_max) + lon_levs, lon_n, lon_factor = grid_finder.grid_locator1(lon_min, lon_max) lon_levs = np.asarray(lon_levs) - lat_levs, lat_n, lat_factor = \ - grid_finder.grid_locator2(lat_min, lat_max) + lat_levs, lat_n, lat_factor = grid_finder.grid_locator2(lat_min, lat_max) lat_levs = np.asarray(lat_levs) grid_info["lon_info"] = lon_levs, lon_n, lon_factor @@ -212,14 +208,13 @@ def _update_grid(self, x1, y1, x2, y2): lon_lines, lat_lines = grid_finder._get_raw_grid_lines( lon_values[(lon_min < lon_values) & (lon_values < lon_max)], lat_values[(lat_min < lat_values) & (lat_values < lat_max)], - lon_min, lon_max, lat_min, lat_max) + tbbox) grid_info["lon_lines"] = lon_lines grid_info["lat_lines"] = lat_lines lon_lines, lat_lines = grid_finder._get_raw_grid_lines( - # lon_min, lon_max, lat_min, lat_max) - extremes[:2], extremes[2:], *extremes) + tbbox.intervalx, tbbox.intervaly, tbbox) grid_info["lon_lines0"] = lon_lines grid_info["lat_lines0"] = lat_lines @@ -227,9 +222,9 @@ def _update_grid(self, x1, y1, x2, y2): def get_gridlines(self, which="major", axis="both"): grid_lines = [] if axis in ["both", "x"]: - grid_lines.extend(self._grid_info["lon_lines"]) + grid_lines.extend(map(np.transpose, self._grid_info["lon_lines"])) if axis in ["both", "y"]: - grid_lines.extend(self._grid_info["lat_lines"]) + grid_lines.extend(map(np.transpose, self._grid_info["lat_lines"])) return grid_lines diff --git a/lib/mpl_toolkits/axisartist/grid_finder.py b/lib/mpl_toolkits/axisartist/grid_finder.py index ff67aa6e8720..b2495b8d883f 100644 --- a/lib/mpl_toolkits/axisartist/grid_finder.py +++ b/lib/mpl_toolkits/axisartist/grid_finder.py @@ -77,20 +77,29 @@ def __call__(self, transform_xy, x1, y1, x2, y2): extremal coordinates; then adding some padding to take into account the finite sampling. - As each sampling step covers a relative range of *1/nx* or *1/ny*, + As each sampling step covers a relative range of ``1/nx`` or ``1/ny``, the padding is computed by expanding the span covered by the extremal coordinates by these fractions. """ - x, y = np.meshgrid( - np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny)) - xt, yt = transform_xy(np.ravel(x), np.ravel(y)) - return self._add_pad(xt.min(), xt.max(), yt.min(), yt.max()) + tbbox = self._find_transformed_bbox( + _User2DTransform(transform_xy, None), Bbox.from_extents(x1, y1, x2, y2)) + return tbbox.x0, tbbox.x1, tbbox.y0, tbbox.y1 - def _add_pad(self, x_min, x_max, y_min, y_max): - """Perform the padding mentioned in `__call__`.""" - dx = (x_max - x_min) / self.nx - dy = (y_max - y_min) / self.ny - return x_min - dx, x_max + dx, y_min - dy, y_max + dy + def _find_transformed_bbox(self, trans, bbox): + """ + Compute an approximation of the bounding box obtained by applying + *trans* to *bbox*. + + See ``__call__`` for details; this method performs similar + calculations, but using a different representation of the arguments and + return value. + """ + grid = np.reshape(np.meshgrid(np.linspace(bbox.x0, bbox.x1, self.nx), + np.linspace(bbox.y0, bbox.y1, self.ny)), + (2, -1)).T + tbbox = Bbox.null() + tbbox.update_from_data_xy(trans.transform(grid)) + return tbbox.expanded(1 + 2 / self.nx, 1 + 2 / self.ny) class _User2DTransform(Transform): @@ -170,42 +179,31 @@ def get_grid_info(self, x1, y1, x2, y2): rough number of grids in each direction. """ - extremes = self.extreme_finder(self.inv_transform_xy, x1, y1, x2, y2) - - # min & max rage of lat (or lon) for each grid line will be drawn. - # i.e., gridline of lon=0 will be drawn from lat_min to lat_max. + bbox = Bbox.from_extents(x1, y1, x2, y2) + tbbox = self.extreme_finder._find_transformed_bbox( + self.get_transform().inverted(), bbox) - lon_min, lon_max, lat_min, lat_max = extremes - lon_levs, lon_n, lon_factor = self.grid_locator1(lon_min, lon_max) - lon_levs = np.asarray(lon_levs) - lat_levs, lat_n, lat_factor = self.grid_locator2(lat_min, lat_max) - lat_levs = np.asarray(lat_levs) + lon_levs, lon_n, lon_factor = self.grid_locator1(*tbbox.intervalx) + lat_levs, lat_n, lat_factor = self.grid_locator2(*tbbox.intervaly) - lon_values = lon_levs[:lon_n] / lon_factor - lat_values = lat_levs[:lat_n] / lat_factor + lon_values = np.asarray(lon_levs[:lon_n]) / lon_factor + lat_values = np.asarray(lat_levs[:lat_n]) / lat_factor - lon_lines, lat_lines = self._get_raw_grid_lines(lon_values, - lat_values, - lon_min, lon_max, - lat_min, lat_max) + lon_lines, lat_lines = self._get_raw_grid_lines(lon_values, lat_values, tbbox) - bb = Bbox.from_extents(x1, y1, x2, y2).expanded(1 + 2e-10, 1 + 2e-10) - - grid_info = { - "extremes": extremes, - # "lon", "lat", filled below. - } + bbox_expanded = bbox.expanded(1 + 2e-10, 1 + 2e-10) + grid_info = {"extremes": tbbox} # "lon", "lat" keys filled below. for idx, lon_or_lat, levs, factor, values, lines in [ (1, "lon", lon_levs, lon_factor, lon_values, lon_lines), (2, "lat", lat_levs, lat_factor, lat_values, lat_lines), ]: grid_info[lon_or_lat] = gi = { - "lines": [[l] for l in lines], + "lines": lines, "ticks": {"left": [], "right": [], "bottom": [], "top": []}, } - for (lx, ly), v, level in zip(lines, values, levs): - all_crossings = _find_line_box_crossings(np.column_stack([lx, ly]), bb) + for xys, v, level in zip(lines, values, levs): + all_crossings = _find_line_box_crossings(xys, bbox_expanded) for side, crossings in zip( ["left", "right", "bottom", "top"], all_crossings): for crossing in crossings: @@ -218,18 +216,14 @@ def get_grid_info(self, x1, y1, x2, y2): return grid_info - def _get_raw_grid_lines(self, - lon_values, lat_values, - lon_min, lon_max, lat_min, lat_max): - - lons_i = np.linspace(lon_min, lon_max, 100) # for interpolation - lats_i = np.linspace(lat_min, lat_max, 100) - - lon_lines = [self.transform_xy(np.full_like(lats_i, lon), lats_i) + def _get_raw_grid_lines(self, lon_values, lat_values, bbox): + trans = self.get_transform() + lons = np.linspace(bbox.x0, bbox.x1, 100) # for interpolation + lats = np.linspace(bbox.y0, bbox.y1, 100) + lon_lines = [trans.transform(np.column_stack([np.full_like(lats, lon), lats])) for lon in lon_values] - lat_lines = [self.transform_xy(lons_i, np.full_like(lons_i, lat)) + lat_lines = [trans.transform(np.column_stack([lons, np.full_like(lons, lat)])) for lat in lat_values] - return lon_lines, lat_lines def set_transform(self, aux_trans): @@ -246,9 +240,11 @@ def get_transform(self): update_transform = set_transform # backcompat alias. + @_api.deprecated("3.11", alternative="grid_finder.get_transform()") def transform_xy(self, x, y): return self._aux_transform.transform(np.column_stack([x, y])).T + @_api.deprecated("3.11", alternative="grid_finder.get_transform().inverted()") def inv_transform_xy(self, x, y): return self._aux_transform.inverted().transform( np.column_stack([x, y])).T diff --git a/lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py b/lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py index a7eb9d5cfe21..508e7a9bc506 100644 --- a/lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py +++ b/lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py @@ -9,7 +9,7 @@ import matplotlib as mpl from matplotlib import _api from matplotlib.path import Path -from matplotlib.transforms import Affine2D, IdentityTransform +from matplotlib.transforms import Affine2D, Bbox, IdentityTransform from .axislines import ( _FixedAxisArtistHelperBase, _FloatingAxisArtistHelperBase, GridHelperBase) from .axis_artist import AxisArtist @@ -115,10 +115,10 @@ def update_lim(self, axes): x1, x2 = axes.get_xlim() y1, y2 = axes.get_ylim() grid_finder = self.grid_helper.grid_finder - extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy, - x1, y1, x2, y2) + tbbox = grid_finder.extreme_finder._find_transformed_bbox( + grid_finder.get_transform().inverted(), Bbox.from_extents(x1, y1, x2, y2)) - lon_min, lon_max, lat_min, lat_max = extremes + lon_min, lat_min, lon_max, lat_max = tbbox.extents e_min, e_max = self._extremes # ranges of other coordinates if self.nth_coord == 0: lat_min = max(e_min, lat_min) @@ -127,29 +127,29 @@ def update_lim(self, axes): lon_min = max(e_min, lon_min) lon_max = min(e_max, lon_max) - lon_levs, lon_n, lon_factor = \ - grid_finder.grid_locator1(lon_min, lon_max) - lat_levs, lat_n, lat_factor = \ - grid_finder.grid_locator2(lat_min, lat_max) + lon_levs, lon_n, lon_factor = grid_finder.grid_locator1(lon_min, lon_max) + lat_levs, lat_n, lat_factor = grid_finder.grid_locator2(lat_min, lat_max) if self.nth_coord == 0: - xx0 = np.full(self._line_num_points, self.value) - yy0 = np.linspace(lat_min, lat_max, self._line_num_points) - xx, yy = grid_finder.transform_xy(xx0, yy0) + xys = grid_finder.get_transform().transform(np.column_stack([ + np.full(self._line_num_points, self.value), + np.linspace(lat_min, lat_max, self._line_num_points), + ])) elif self.nth_coord == 1: - xx0 = np.linspace(lon_min, lon_max, self._line_num_points) - yy0 = np.full(self._line_num_points, self.value) - xx, yy = grid_finder.transform_xy(xx0, yy0) + xys = grid_finder.get_transform().transform(np.column_stack([ + np.linspace(lon_min, lon_max, self._line_num_points), + np.full(self._line_num_points, self.value), + ])) self._grid_info = { - "extremes": (lon_min, lon_max, lat_min, lat_max), + "extremes": Bbox.from_extents(lon_min, lat_min, lon_max, lat_max), "lon_info": (lon_levs, lon_n, np.asarray(lon_factor)), "lat_info": (lat_levs, lat_n, np.asarray(lat_factor)), "lon_labels": grid_finder._format_ticks( 1, "bottom", lon_factor, lon_levs), "lat_labels": grid_finder._format_ticks( 2, "bottom", lat_factor, lat_levs), - "line_xy": (xx, yy), + "line_xy": xys, } def get_axislabel_transform(self, axes): @@ -160,7 +160,7 @@ def trf_xy(x, y): trf = self.grid_helper.grid_finder.get_transform() + axes.transData return trf.transform([x, y]).T - xmin, xmax, ymin, ymax = self._grid_info["extremes"] + xmin, ymin, xmax, ymax = self._grid_info["extremes"].extents if self.nth_coord == 0: xx0 = self.value yy0 = (ymin + ymax) / 2 @@ -232,8 +232,7 @@ def get_line_transform(self, axes): def get_line(self, axes): self.update_lim(axes) - x, y = self._grid_info["line_xy"] - return Path(np.column_stack([x, y])) + return Path(self._grid_info["line_xy"]) class GridHelperCurveLinear(GridHelperBase): @@ -303,17 +302,15 @@ def new_floating_axis(self, nth_coord, value, axes=None, axis_direction="bottom" # axisline.minor_ticklabels.set_visible(False) return axisline - def _update_grid(self, x1, y1, x2, y2): - self._grid_info = self.grid_finder.get_grid_info(x1, y1, x2, y2) + def _update_grid(self, bbox): + self._grid_info = self.grid_finder.get_grid_info(*bbox.extents) def get_gridlines(self, which="major", axis="both"): grid_lines = [] if axis in ["both", "x"]: - for gl in self._grid_info["lon"]["lines"]: - grid_lines.extend(gl) + grid_lines.extend([gl.T for gl in self._grid_info["lon"]["lines"]]) if axis in ["both", "y"]: - for gl in self._grid_info["lat"]["lines"]: - grid_lines.extend(gl) + grid_lines.extend([gl.T for gl in self._grid_info["lat"]["lines"]]) return grid_lines @_api.deprecated("3.9")