Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 4800e11

Browse files
committed
Move axisartist towards untransposed transforms.
While Matplotlib normally represents lists of (x, y) coordinates as (N, 2) arrays and transforms (which we'll call "trans") have shape signature (N, 2) -> (N, 2), axisartist uses the opposite convention of using (2, N) arrays (or size-2 tuples of 1D arrays) and transforms (which it typically calls "transform_xy"). Change that and go back to Matplotlib's standard represenation in some of axisartist's internal representations for consistency. Also replace some uses of (x1, y1, x2, y2) quadruplets by single Bbox objects, which avoid having to keep track of the order of the points (is it x1, y1, x2, y2 or x1, x2, y1, y2?). - Add a `_find_transformed_bbox(trans, bbox)` API to ExtremeFinderSimple and its subclasses, replacing `__call__(transform_xy, x1, y1, x2, y2)`. (I intentionally did not overload `__call__`'s signature yet nor did I deprecate it for now; we can consider doing that later.) - Deprecate `GridFinder.{,inv_}transform_xy`, which implement the transposed transform API. - Switch `grid_info["extremes"]` from quadruplet representation to Bbox. - Switch `grid_info["lon"]["lines"]` and likewise for "lat" from list-of-size-1-lists-of-pairs-of-1D-arrays to list-of-(N, 2)-arrays. - Switch `grid_info["line_xy"]` from pair-of-1D-arrays to a (N, 2) array. - Let `_get_grid_info` and `_get_raw_grid_lines` take a Bbox as (last) argument instead of 4 coordinates. Note that I intentionally mostly didn't touch (transpose) public-facing APIs for now, this may happen later.
1 parent f8900ea commit 4800e11

File tree

6 files changed

+90
-103
lines changed

6 files changed

+90
-103
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
``GridFinder.transform_xy`` and ``GridFinder.inv_transform_xy``
2+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3+
... are deprecated. Directly use the standard transform returned by
4+
`.GridFinder.get_transform` instead.

lib/mpl_toolkits/axisartist/angle_helper.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import math
33

4+
from matplotlib.transforms import Bbox
45
from mpl_toolkits.axisartist.grid_finder import ExtremeFinderSimple
56

67

@@ -347,11 +348,12 @@ def __init__(self, nx, ny,
347348
self.lon_minmax = lon_minmax
348349
self.lat_minmax = lat_minmax
349350

350-
def __call__(self, transform_xy, x1, y1, x2, y2):
351+
def _find_transformed_bbox(self, trans, bbox):
351352
# docstring inherited
352-
x, y = np.meshgrid(
353-
np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny))
354-
lon, lat = transform_xy(np.ravel(x), np.ravel(y))
353+
grid = np.reshape(np.meshgrid(np.linspace(bbox.x0, bbox.x1, self.nx),
354+
np.linspace(bbox.y0, bbox.y1, self.ny)),
355+
(2, -1)).T
356+
lon, lat = trans.transform(grid).T
355357

356358
# iron out jumps, but algorithm should be improved.
357359
# This is just naive way of doing and my fail for some cases.
@@ -367,11 +369,10 @@ def __call__(self, transform_xy, x1, y1, x2, y2):
367369
lat0 = np.nanmin(lat)
368370
lat -= 360. * ((lat - lat0) > 180.)
369371

370-
lon_min, lon_max = np.nanmin(lon), np.nanmax(lon)
371-
lat_min, lat_max = np.nanmin(lat), np.nanmax(lat)
372-
373-
lon_min, lon_max, lat_min, lat_max = \
374-
self._add_pad(lon_min, lon_max, lat_min, lat_max)
372+
tbbox = Bbox.null()
373+
tbbox.update_from_data_xy(np.column_stack([lon, lat]))
374+
tbbox = tbbox.expanded(1 + 2 / self.nx, 1 + 2 / self.ny)
375+
lon_min, lat_min, lon_max, lat_max = tbbox.extents
375376

376377
# check cycle
377378
if self.lon_cycle:
@@ -391,4 +392,4 @@ def __call__(self, transform_xy, x1, y1, x2, y2):
391392
max0 = self.lat_minmax[1]
392393
lat_max = min(max0, lat_max)
393394

394-
return lon_min, lon_max, lat_min, lat_max
395+
return Bbox.from_extents(lon_min, lat_min, lon_max, lat_max)

lib/mpl_toolkits/axisartist/axislines.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
from matplotlib import _api
4646
import matplotlib.axes as maxes
4747
from matplotlib.path import Path
48+
from matplotlib.transforms import Bbox
49+
4850
from mpl_toolkits.axes_grid1 import mpl_axes
4951
from .axisline_style import AxislineStyle # noqa
5052
from .axis_artist import AxisArtist, GridlinesCollection
@@ -285,10 +287,10 @@ def update_lim(self, axes):
285287
x1, x2 = axes.get_xlim()
286288
y1, y2 = axes.get_ylim()
287289
if self._old_limits != (x1, x2, y1, y2):
288-
self._update_grid(x1, y1, x2, y2)
290+
self._update_grid(Bbox.from_extents(x1, y1, x2, y2))
289291
self._old_limits = (x1, x2, y1, y2)
290292

291-
def _update_grid(self, x1, y1, x2, y2):
293+
def _update_grid(self, bbox):
292294
"""Cache relevant computations when the axes limits have changed."""
293295

294296
def get_gridlines(self, which, axis):

lib/mpl_toolkits/axisartist/floating_axes.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
from matplotlib import _api, cbook
1414
import matplotlib.patches as mpatches
1515
from matplotlib.path import Path
16-
16+
from matplotlib.transforms import Bbox
1717
from mpl_toolkits.axes_grid1.parasite_axes import host_axes_class_factory
18-
1918
from . import axislines, grid_helper_curvelinear
2019
from .axis_artist import AxisArtist
2120
from .grid_finder import ExtremeFinderSimple
@@ -109,8 +108,7 @@ def get_line(self, axes):
109108
right=("lon_lines0", 1),
110109
bottom=("lat_lines0", 0),
111110
top=("lat_lines0", 1))[self._side]
112-
xx, yy = self._grid_info[k][v]
113-
return Path(np.column_stack([xx, yy]))
111+
return Path(self._grid_info[k][v])
114112

115113

116114
class ExtremeFinderFixed(ExtremeFinderSimple):
@@ -125,11 +123,12 @@ def __init__(self, extremes):
125123
extremes : (float, float, float, float)
126124
The bounding box that this helper always returns.
127125
"""
128-
self._extremes = extremes
126+
x0, x1, y0, y1 = extremes
127+
self._tbbox = Bbox.from_extents(x0, y0, x1, y1)
129128

130-
def __call__(self, transform_xy, x1, y1, x2, y2):
129+
def _find_transformed_bbox(self, trans, bbox):
131130
# docstring inherited
132-
return self._extremes
131+
return self._tbbox
133132

134133

135134
class GridHelperCurveLinear(grid_helper_curvelinear.GridHelperCurveLinear):
@@ -177,25 +176,22 @@ def new_fixed_axis(
177176
# axis.get_helper().set_extremes(*self._extremes[2:])
178177
# return axis
179178

180-
def _update_grid(self, x1, y1, x2, y2):
179+
def _update_grid(self, bbox):
181180
if self._grid_info is None:
182181
self._grid_info = dict()
183182

184183
grid_info = self._grid_info
185184

186185
grid_finder = self.grid_finder
187-
extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy,
188-
x1, y1, x2, y2)
186+
tbbox = grid_finder.extreme_finder._find_transformed_bbox(
187+
grid_finder.get_transform().inverted(), bbox)
189188

190-
lon_min, lon_max = sorted(extremes[:2])
191-
lat_min, lat_max = sorted(extremes[2:])
192-
grid_info["extremes"] = lon_min, lon_max, lat_min, lat_max # extremes
189+
lon_min, lat_min, lon_max, lat_max = tbbox.extents
190+
grid_info["extremes"] = tbbox
193191

194-
lon_levs, lon_n, lon_factor = \
195-
grid_finder.grid_locator1(lon_min, lon_max)
192+
lon_levs, lon_n, lon_factor = grid_finder.grid_locator1(lon_min, lon_max)
196193
lon_levs = np.asarray(lon_levs)
197-
lat_levs, lat_n, lat_factor = \
198-
grid_finder.grid_locator2(lat_min, lat_max)
194+
lat_levs, lat_n, lat_factor = grid_finder.grid_locator2(lat_min, lat_max)
199195
lat_levs = np.asarray(lat_levs)
200196

201197
grid_info["lon_info"] = lon_levs, lon_n, lon_factor
@@ -212,24 +208,23 @@ def _update_grid(self, x1, y1, x2, y2):
212208
lon_lines, lat_lines = grid_finder._get_raw_grid_lines(
213209
lon_values[(lon_min < lon_values) & (lon_values < lon_max)],
214210
lat_values[(lat_min < lat_values) & (lat_values < lat_max)],
215-
lon_min, lon_max, lat_min, lat_max)
211+
tbbox)
216212

217213
grid_info["lon_lines"] = lon_lines
218214
grid_info["lat_lines"] = lat_lines
219215

220216
lon_lines, lat_lines = grid_finder._get_raw_grid_lines(
221-
# lon_min, lon_max, lat_min, lat_max)
222-
extremes[:2], extremes[2:], *extremes)
217+
tbbox.intervalx, tbbox.intervaly, tbbox)
223218

224219
grid_info["lon_lines0"] = lon_lines
225220
grid_info["lat_lines0"] = lat_lines
226221

227222
def get_gridlines(self, which="major", axis="both"):
228223
grid_lines = []
229224
if axis in ["both", "x"]:
230-
grid_lines.extend(self._grid_info["lon_lines"])
225+
grid_lines.extend(map(np.transpose, self._grid_info["lon_lines"]))
231226
if axis in ["both", "y"]:
232-
grid_lines.extend(self._grid_info["lat_lines"])
227+
grid_lines.extend(map(np.transpose, self._grid_info["lat_lines"]))
233228
return grid_lines
234229

235230

lib/mpl_toolkits/axisartist/grid_finder.py

Lines changed: 32 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -77,20 +77,21 @@ def __call__(self, transform_xy, x1, y1, x2, y2):
7777
extremal coordinates; then adding some padding to take into account the
7878
finite sampling.
7979
80-
As each sampling step covers a relative range of *1/nx* or *1/ny*,
80+
As each sampling step covers a relative range of ``1/nx`` or ``1/ny``,
8181
the padding is computed by expanding the span covered by the extremal
8282
coordinates by these fractions.
8383
"""
84-
x, y = np.meshgrid(
85-
np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny))
86-
xt, yt = transform_xy(np.ravel(x), np.ravel(y))
87-
return self._add_pad(xt.min(), xt.max(), yt.min(), yt.max())
84+
tbbox = self._find_transformed_bbox(
85+
_User2DTransform(transform_xy, None), Bbox.from_extents(x1, y1, x2, y2))
86+
return tbbox.x0, tbbox.x1, tbbox.y0, tbbox.y1
8887

89-
def _add_pad(self, x_min, x_max, y_min, y_max):
90-
"""Perform the padding mentioned in `__call__`."""
91-
dx = (x_max - x_min) / self.nx
92-
dy = (y_max - y_min) / self.ny
93-
return x_min - dx, x_max + dx, y_min - dy, y_max + dy
88+
def _find_transformed_bbox(self, trans, bbox):
89+
grid = np.reshape(np.meshgrid(np.linspace(bbox.x0, bbox.x1, self.nx),
90+
np.linspace(bbox.y0, bbox.y1, self.ny)),
91+
(2, -1)).T
92+
tbbox = Bbox.null()
93+
tbbox.update_from_data_xy(trans.transform(grid))
94+
return tbbox.expanded(1 + 2 / self.nx, 1 + 2 / self.ny)
9495

9596

9697
class _User2DTransform(Transform):
@@ -170,42 +171,31 @@ def get_grid_info(self, x1, y1, x2, y2):
170171
rough number of grids in each direction.
171172
"""
172173

173-
extremes = self.extreme_finder(self.inv_transform_xy, x1, y1, x2, y2)
174+
bbox = Bbox.from_extents(x1, y1, x2, y2)
175+
tbbox = self.extreme_finder._find_transformed_bbox(
176+
self.get_transform().inverted(), bbox)
174177

175-
# min & max rage of lat (or lon) for each grid line will be drawn.
176-
# i.e., gridline of lon=0 will be drawn from lat_min to lat_max.
178+
lon_levs, lon_n, lon_factor = self.grid_locator1(*tbbox.intervalx)
179+
lat_levs, lat_n, lat_factor = self.grid_locator2(*tbbox.intervaly)
177180

178-
lon_min, lon_max, lat_min, lat_max = extremes
179-
lon_levs, lon_n, lon_factor = self.grid_locator1(lon_min, lon_max)
180-
lon_levs = np.asarray(lon_levs)
181-
lat_levs, lat_n, lat_factor = self.grid_locator2(lat_min, lat_max)
182-
lat_levs = np.asarray(lat_levs)
181+
lon_values = np.asarray(lon_levs[:lon_n]) / lon_factor
182+
lat_values = np.asarray(lat_levs[:lat_n]) / lat_factor
183183

184-
lon_values = lon_levs[:lon_n] / lon_factor
185-
lat_values = lat_levs[:lat_n] / lat_factor
184+
lon_lines, lat_lines = self._get_raw_grid_lines(lon_values, lat_values, tbbox)
186185

187-
lon_lines, lat_lines = self._get_raw_grid_lines(lon_values,
188-
lat_values,
189-
lon_min, lon_max,
190-
lat_min, lat_max)
191-
192-
bb = Bbox.from_extents(x1, y1, x2, y2).expanded(1 + 2e-10, 1 + 2e-10)
193-
194-
grid_info = {
195-
"extremes": extremes,
196-
# "lon", "lat", filled below.
197-
}
186+
bbox_expanded = bbox.expanded(1 + 2e-10, 1 + 2e-10)
187+
grid_info = {"extremes": tbbox} # "lon", "lat" keys filled below.
198188

199189
for idx, lon_or_lat, levs, factor, values, lines in [
200190
(1, "lon", lon_levs, lon_factor, lon_values, lon_lines),
201191
(2, "lat", lat_levs, lat_factor, lat_values, lat_lines),
202192
]:
203193
grid_info[lon_or_lat] = gi = {
204-
"lines": [[l] for l in lines],
194+
"lines": lines,
205195
"ticks": {"left": [], "right": [], "bottom": [], "top": []},
206196
}
207-
for (lx, ly), v, level in zip(lines, values, levs):
208-
all_crossings = _find_line_box_crossings(np.column_stack([lx, ly]), bb)
197+
for xys, v, level in zip(lines, values, levs):
198+
all_crossings = _find_line_box_crossings(xys, bbox_expanded)
209199
for side, crossings in zip(
210200
["left", "right", "bottom", "top"], all_crossings):
211201
for crossing in crossings:
@@ -218,18 +208,14 @@ def get_grid_info(self, x1, y1, x2, y2):
218208

219209
return grid_info
220210

221-
def _get_raw_grid_lines(self,
222-
lon_values, lat_values,
223-
lon_min, lon_max, lat_min, lat_max):
224-
225-
lons_i = np.linspace(lon_min, lon_max, 100) # for interpolation
226-
lats_i = np.linspace(lat_min, lat_max, 100)
227-
228-
lon_lines = [self.transform_xy(np.full_like(lats_i, lon), lats_i)
211+
def _get_raw_grid_lines(self, lon_values, lat_values, bbox):
212+
trans = self.get_transform()
213+
lons = np.linspace(bbox.x0, bbox.x1, 100) # for interpolation
214+
lats = np.linspace(bbox.y0, bbox.y1, 100)
215+
lon_lines = [trans.transform(np.column_stack([np.full_like(lats, lon), lats]))
229216
for lon in lon_values]
230-
lat_lines = [self.transform_xy(lons_i, np.full_like(lons_i, lat))
217+
lat_lines = [trans.transform(np.column_stack([lons, np.full_like(lons, lat)]))
231218
for lat in lat_values]
232-
233219
return lon_lines, lat_lines
234220

235221
def set_transform(self, aux_trans):
@@ -246,9 +232,11 @@ def get_transform(self):
246232

247233
update_transform = set_transform # backcompat alias.
248234

235+
@_api.deprecated("3.11", alternative="grid_finder.get_transform()")
249236
def transform_xy(self, x, y):
250237
return self._aux_transform.transform(np.column_stack([x, y])).T
251238

239+
@_api.deprecated("3.11", alternative="grid_finder.get_transform().inverted()")
252240
def inv_transform_xy(self, x, y):
253241
return self._aux_transform.inverted().transform(
254242
np.column_stack([x, y])).T

0 commit comments

Comments
 (0)