Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 82adc45

Browse files
authored
Merge pull request #27551 from anntzer/aaut
Move axisartist towards untransposed transforms (operating on (N, 2) arrays instead of (2, N) arrays).
2 parents 05fc1b3 + a83af6a commit 82adc45

File tree

6 files changed

+98
-103
lines changed

6 files changed

+98
-103
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
``GridFinder.transform_xy`` and ``GridFinder.inv_transform_xy``
2+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3+
... are deprecated. Directly use the standard transform returned by
4+
`.GridFinder.get_transform` instead.

lib/mpl_toolkits/axisartist/angle_helper.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import math
33

4+
from matplotlib.transforms import Bbox
45
from mpl_toolkits.axisartist.grid_finder import ExtremeFinderSimple
56

67

@@ -347,11 +348,12 @@ def __init__(self, nx, ny,
347348
self.lon_minmax = lon_minmax
348349
self.lat_minmax = lat_minmax
349350

350-
def __call__(self, transform_xy, x1, y1, x2, y2):
351+
def _find_transformed_bbox(self, trans, bbox):
351352
# docstring inherited
352-
x, y = np.meshgrid(
353-
np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny))
354-
lon, lat = transform_xy(np.ravel(x), np.ravel(y))
353+
grid = np.reshape(np.meshgrid(np.linspace(bbox.x0, bbox.x1, self.nx),
354+
np.linspace(bbox.y0, bbox.y1, self.ny)),
355+
(2, -1)).T
356+
lon, lat = trans.transform(grid).T
355357

356358
# iron out jumps, but algorithm should be improved.
357359
# This is just naive way of doing and my fail for some cases.
@@ -367,11 +369,10 @@ def __call__(self, transform_xy, x1, y1, x2, y2):
367369
lat0 = np.nanmin(lat)
368370
lat -= 360. * ((lat - lat0) > 180.)
369371

370-
lon_min, lon_max = np.nanmin(lon), np.nanmax(lon)
371-
lat_min, lat_max = np.nanmin(lat), np.nanmax(lat)
372-
373-
lon_min, lon_max, lat_min, lat_max = \
374-
self._add_pad(lon_min, lon_max, lat_min, lat_max)
372+
tbbox = Bbox.null()
373+
tbbox.update_from_data_xy(np.column_stack([lon, lat]))
374+
tbbox = tbbox.expanded(1 + 2 / self.nx, 1 + 2 / self.ny)
375+
lon_min, lat_min, lon_max, lat_max = tbbox.extents
375376

376377
# check cycle
377378
if self.lon_cycle:
@@ -391,4 +392,4 @@ def __call__(self, transform_xy, x1, y1, x2, y2):
391392
max0 = self.lat_minmax[1]
392393
lat_max = min(max0, lat_max)
393394

394-
return lon_min, lon_max, lat_min, lat_max
395+
return Bbox.from_extents(lon_min, lat_min, lon_max, lat_max)

lib/mpl_toolkits/axisartist/axislines.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
from matplotlib import _api
4646
import matplotlib.axes as maxes
4747
from matplotlib.path import Path
48+
from matplotlib.transforms import Bbox
49+
4850
from mpl_toolkits.axes_grid1 import mpl_axes
4951
from .axisline_style import AxislineStyle # noqa
5052
from .axis_artist import AxisArtist, GridlinesCollection
@@ -285,10 +287,10 @@ def update_lim(self, axes):
285287
x1, x2 = axes.get_xlim()
286288
y1, y2 = axes.get_ylim()
287289
if self._old_limits != (x1, x2, y1, y2):
288-
self._update_grid(x1, y1, x2, y2)
290+
self._update_grid(Bbox.from_extents(x1, y1, x2, y2))
289291
self._old_limits = (x1, x2, y1, y2)
290292

291-
def _update_grid(self, x1, y1, x2, y2):
293+
def _update_grid(self, bbox):
292294
"""Cache relevant computations when the axes limits have changed."""
293295

294296
def get_gridlines(self, which, axis):

lib/mpl_toolkits/axisartist/floating_axes.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
from matplotlib import _api, cbook
1414
import matplotlib.patches as mpatches
1515
from matplotlib.path import Path
16-
16+
from matplotlib.transforms import Bbox
1717
from mpl_toolkits.axes_grid1.parasite_axes import host_axes_class_factory
18-
1918
from . import axislines, grid_helper_curvelinear
2019
from .axis_artist import AxisArtist
2120
from .grid_finder import ExtremeFinderSimple
@@ -103,8 +102,7 @@ def get_line(self, axes):
103102
right=("lon_lines0", 1),
104103
bottom=("lat_lines0", 0),
105104
top=("lat_lines0", 1))[self._side]
106-
xx, yy = self._grid_info[k][v]
107-
return Path(np.column_stack([xx, yy]))
105+
return Path(self._grid_info[k][v])
108106

109107

110108
class ExtremeFinderFixed(ExtremeFinderSimple):
@@ -119,11 +117,12 @@ def __init__(self, extremes):
119117
extremes : (float, float, float, float)
120118
The bounding box that this helper always returns.
121119
"""
122-
self._extremes = extremes
120+
x0, x1, y0, y1 = extremes
121+
self._tbbox = Bbox.from_extents(x0, y0, x1, y1)
123122

124-
def __call__(self, transform_xy, x1, y1, x2, y2):
123+
def _find_transformed_bbox(self, trans, bbox):
125124
# docstring inherited
126-
return self._extremes
125+
return self._tbbox
127126

128127

129128
class GridHelperCurveLinear(grid_helper_curvelinear.GridHelperCurveLinear):
@@ -171,25 +170,22 @@ def new_fixed_axis(
171170
# axis.get_helper().set_extremes(*self._extremes[2:])
172171
# return axis
173172

174-
def _update_grid(self, x1, y1, x2, y2):
173+
def _update_grid(self, bbox):
175174
if self._grid_info is None:
176175
self._grid_info = dict()
177176

178177
grid_info = self._grid_info
179178

180179
grid_finder = self.grid_finder
181-
extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy,
182-
x1, y1, x2, y2)
180+
tbbox = grid_finder.extreme_finder._find_transformed_bbox(
181+
grid_finder.get_transform().inverted(), bbox)
183182

184-
lon_min, lon_max = sorted(extremes[:2])
185-
lat_min, lat_max = sorted(extremes[2:])
186-
grid_info["extremes"] = lon_min, lon_max, lat_min, lat_max # extremes
183+
lon_min, lat_min, lon_max, lat_max = tbbox.extents
184+
grid_info["extremes"] = tbbox
187185

188-
lon_levs, lon_n, lon_factor = \
189-
grid_finder.grid_locator1(lon_min, lon_max)
186+
lon_levs, lon_n, lon_factor = grid_finder.grid_locator1(lon_min, lon_max)
190187
lon_levs = np.asarray(lon_levs)
191-
lat_levs, lat_n, lat_factor = \
192-
grid_finder.grid_locator2(lat_min, lat_max)
188+
lat_levs, lat_n, lat_factor = grid_finder.grid_locator2(lat_min, lat_max)
193189
lat_levs = np.asarray(lat_levs)
194190

195191
grid_info["lon_info"] = lon_levs, lon_n, lon_factor
@@ -206,24 +202,23 @@ def _update_grid(self, x1, y1, x2, y2):
206202
lon_lines, lat_lines = grid_finder._get_raw_grid_lines(
207203
lon_values[(lon_min < lon_values) & (lon_values < lon_max)],
208204
lat_values[(lat_min < lat_values) & (lat_values < lat_max)],
209-
lon_min, lon_max, lat_min, lat_max)
205+
tbbox)
210206

211207
grid_info["lon_lines"] = lon_lines
212208
grid_info["lat_lines"] = lat_lines
213209

214210
lon_lines, lat_lines = grid_finder._get_raw_grid_lines(
215-
# lon_min, lon_max, lat_min, lat_max)
216-
extremes[:2], extremes[2:], *extremes)
211+
tbbox.intervalx, tbbox.intervaly, tbbox)
217212

218213
grid_info["lon_lines0"] = lon_lines
219214
grid_info["lat_lines0"] = lat_lines
220215

221216
def get_gridlines(self, which="major", axis="both"):
222217
grid_lines = []
223218
if axis in ["both", "x"]:
224-
grid_lines.extend(self._grid_info["lon_lines"])
219+
grid_lines.extend(map(np.transpose, self._grid_info["lon_lines"]))
225220
if axis in ["both", "y"]:
226-
grid_lines.extend(self._grid_info["lat_lines"])
221+
grid_lines.extend(map(np.transpose, self._grid_info["lat_lines"]))
227222
return grid_lines
228223

229224

lib/mpl_toolkits/axisartist/grid_finder.py

Lines changed: 40 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -77,20 +77,29 @@ def __call__(self, transform_xy, x1, y1, x2, y2):
7777
extremal coordinates; then adding some padding to take into account the
7878
finite sampling.
7979
80-
As each sampling step covers a relative range of *1/nx* or *1/ny*,
80+
As each sampling step covers a relative range of ``1/nx`` or ``1/ny``,
8181
the padding is computed by expanding the span covered by the extremal
8282
coordinates by these fractions.
8383
"""
84-
x, y = np.meshgrid(
85-
np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny))
86-
xt, yt = transform_xy(np.ravel(x), np.ravel(y))
87-
return self._add_pad(xt.min(), xt.max(), yt.min(), yt.max())
84+
tbbox = self._find_transformed_bbox(
85+
_User2DTransform(transform_xy, None), Bbox.from_extents(x1, y1, x2, y2))
86+
return tbbox.x0, tbbox.x1, tbbox.y0, tbbox.y1
8887

89-
def _add_pad(self, x_min, x_max, y_min, y_max):
90-
"""Perform the padding mentioned in `__call__`."""
91-
dx = (x_max - x_min) / self.nx
92-
dy = (y_max - y_min) / self.ny
93-
return x_min - dx, x_max + dx, y_min - dy, y_max + dy
88+
def _find_transformed_bbox(self, trans, bbox):
89+
"""
90+
Compute an approximation of the bounding box obtained by applying
91+
*trans* to *bbox*.
92+
93+
See ``__call__`` for details; this method performs similar
94+
calculations, but using a different representation of the arguments and
95+
return value.
96+
"""
97+
grid = np.reshape(np.meshgrid(np.linspace(bbox.x0, bbox.x1, self.nx),
98+
np.linspace(bbox.y0, bbox.y1, self.ny)),
99+
(2, -1)).T
100+
tbbox = Bbox.null()
101+
tbbox.update_from_data_xy(trans.transform(grid))
102+
return tbbox.expanded(1 + 2 / self.nx, 1 + 2 / self.ny)
94103

95104

96105
class _User2DTransform(Transform):
@@ -170,42 +179,31 @@ def get_grid_info(self, x1, y1, x2, y2):
170179
rough number of grids in each direction.
171180
"""
172181

173-
extremes = self.extreme_finder(self.inv_transform_xy, x1, y1, x2, y2)
174-
175-
# min & max rage of lat (or lon) for each grid line will be drawn.
176-
# i.e., gridline of lon=0 will be drawn from lat_min to lat_max.
182+
bbox = Bbox.from_extents(x1, y1, x2, y2)
183+
tbbox = self.extreme_finder._find_transformed_bbox(
184+
self.get_transform().inverted(), bbox)
177185

178-
lon_min, lon_max, lat_min, lat_max = extremes
179-
lon_levs, lon_n, lon_factor = self.grid_locator1(lon_min, lon_max)
180-
lon_levs = np.asarray(lon_levs)
181-
lat_levs, lat_n, lat_factor = self.grid_locator2(lat_min, lat_max)
182-
lat_levs = np.asarray(lat_levs)
186+
lon_levs, lon_n, lon_factor = self.grid_locator1(*tbbox.intervalx)
187+
lat_levs, lat_n, lat_factor = self.grid_locator2(*tbbox.intervaly)
183188

184-
lon_values = lon_levs[:lon_n] / lon_factor
185-
lat_values = lat_levs[:lat_n] / lat_factor
189+
lon_values = np.asarray(lon_levs[:lon_n]) / lon_factor
190+
lat_values = np.asarray(lat_levs[:lat_n]) / lat_factor
186191

187-
lon_lines, lat_lines = self._get_raw_grid_lines(lon_values,
188-
lat_values,
189-
lon_min, lon_max,
190-
lat_min, lat_max)
192+
lon_lines, lat_lines = self._get_raw_grid_lines(lon_values, lat_values, tbbox)
191193

192-
bb = Bbox.from_extents(x1, y1, x2, y2).expanded(1 + 2e-10, 1 + 2e-10)
193-
194-
grid_info = {
195-
"extremes": extremes,
196-
# "lon", "lat", filled below.
197-
}
194+
bbox_expanded = bbox.expanded(1 + 2e-10, 1 + 2e-10)
195+
grid_info = {"extremes": tbbox} # "lon", "lat" keys filled below.
198196

199197
for idx, lon_or_lat, levs, factor, values, lines in [
200198
(1, "lon", lon_levs, lon_factor, lon_values, lon_lines),
201199
(2, "lat", lat_levs, lat_factor, lat_values, lat_lines),
202200
]:
203201
grid_info[lon_or_lat] = gi = {
204-
"lines": [[l] for l in lines],
202+
"lines": lines,
205203
"ticks": {"left": [], "right": [], "bottom": [], "top": []},
206204
}
207-
for (lx, ly), v, level in zip(lines, values, levs):
208-
all_crossings = _find_line_box_crossings(np.column_stack([lx, ly]), bb)
205+
for xys, v, level in zip(lines, values, levs):
206+
all_crossings = _find_line_box_crossings(xys, bbox_expanded)
209207
for side, crossings in zip(
210208
["left", "right", "bottom", "top"], all_crossings):
211209
for crossing in crossings:
@@ -218,18 +216,14 @@ def get_grid_info(self, x1, y1, x2, y2):
218216

219217
return grid_info
220218

221-
def _get_raw_grid_lines(self,
222-
lon_values, lat_values,
223-
lon_min, lon_max, lat_min, lat_max):
224-
225-
lons_i = np.linspace(lon_min, lon_max, 100) # for interpolation
226-
lats_i = np.linspace(lat_min, lat_max, 100)
227-
228-
lon_lines = [self.transform_xy(np.full_like(lats_i, lon), lats_i)
219+
def _get_raw_grid_lines(self, lon_values, lat_values, bbox):
220+
trans = self.get_transform()
221+
lons = np.linspace(bbox.x0, bbox.x1, 100) # for interpolation
222+
lats = np.linspace(bbox.y0, bbox.y1, 100)
223+
lon_lines = [trans.transform(np.column_stack([np.full_like(lats, lon), lats]))
229224
for lon in lon_values]
230-
lat_lines = [self.transform_xy(lons_i, np.full_like(lons_i, lat))
225+
lat_lines = [trans.transform(np.column_stack([lons, np.full_like(lons, lat)]))
231226
for lat in lat_values]
232-
233227
return lon_lines, lat_lines
234228

235229
def set_transform(self, aux_trans):
@@ -246,9 +240,11 @@ def get_transform(self):
246240

247241
update_transform = set_transform # backcompat alias.
248242

243+
@_api.deprecated("3.11", alternative="grid_finder.get_transform()")
249244
def transform_xy(self, x, y):
250245
return self._aux_transform.transform(np.column_stack([x, y])).T
251246

247+
@_api.deprecated("3.11", alternative="grid_finder.get_transform().inverted()")
252248
def inv_transform_xy(self, x, y):
253249
return self._aux_transform.inverted().transform(
254250
np.column_stack([x, y])).T

0 commit comments

Comments
 (0)