Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0bcd2af

Browse files
committed
Move axisartist towards untransposed transforms.
While Matplotlib normally represents lists of (x, y) coordinates as (N, 2) arrays and transforms (which we'll call "trans") have shape signature (N, 2) -> (N, 2), axisartist uses the opposite convention of using (2, N) arrays (or size-2 tuples of 1D arrays) and transforms (which it typically calls "transform_xy"). Change that and go back to Matplotlib's standard represenation in some of axisartist's internal representations for consistency. Also replace some uses of (x1, y1, x2, y2) quadruplets by single Bbox objects, which avoid having to keep track of the order of the points (is it x1, y1, x2, y2 or x1, x2, y1, y2?). - Add a `_find_transformed_bbox(trans, bbox)` API to ExtremeFinderSimple and its subclasses, replacing `__call__(transform_xy, x1, y1, x2, y2)`. (I intentionally did not overload `__call__`'s signature yet nor did I deprecate it for now; we can consider doing that later.) - Deprecate `GridFinder.{,inv_}transform_xy`, which implement the transposed transform API. - Switch `grid_info["extremes"]` from quadruplet representation to Bbox. - Switch `grid_info["lon"]["lines"]` and likewise for "lat" from list-of-size-1-lists-of-pairs-of-1D-arrays to list-of-(N, 2)-arrays. - Switch `grid_info["line_xy"]` from pair-of-1D-arrays to a (N, 2) array. - Let `_get_raw_grid_lines` take a Bbox as last argument instead of 4 coordinates.
1 parent 8f296db commit 0bcd2af

File tree

5 files changed

+79
-84
lines changed

5 files changed

+79
-84
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
``GridFinder.transform_xy`` and ``GridFinder.inv_transform_xy``
2+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3+
... are deprecated. Directly use the standard transform returned by
4+
`.GridFinder.get_transform` instead.

lib/mpl_toolkits/axisartist/angle_helper.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import math
33

4+
from matplotlib.transforms import Bbox
45
from mpl_toolkits.axisartist.grid_finder import ExtremeFinderSimple
56

67

@@ -347,11 +348,12 @@ def __init__(self, nx, ny,
347348
self.lon_minmax = lon_minmax
348349
self.lat_minmax = lat_minmax
349350

350-
def __call__(self, transform_xy, x1, y1, x2, y2):
351+
def _find_transformed_bbox(self, trans, bbox):
351352
# docstring inherited
352-
x, y = np.meshgrid(
353-
np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny))
354-
lon, lat = transform_xy(np.ravel(x), np.ravel(y))
353+
grid = np.reshape(np.meshgrid(np.linspace(bbox.x0, bbox.x1, self.nx),
354+
np.linspace(bbox.y0, bbox.y1, self.ny)),
355+
(2, -1)).T
356+
lon, lat = trans.transform(grid).T
355357

356358
# iron out jumps, but algorithm should be improved.
357359
# This is just naive way of doing and my fail for some cases.
@@ -367,11 +369,10 @@ def __call__(self, transform_xy, x1, y1, x2, y2):
367369
lat0 = np.nanmin(lat)
368370
lat -= 360. * ((lat - lat0) > 180.)
369371

370-
lon_min, lon_max = np.nanmin(lon), np.nanmax(lon)
371-
lat_min, lat_max = np.nanmin(lat), np.nanmax(lat)
372-
373-
lon_min, lon_max, lat_min, lat_max = \
374-
self._add_pad(lon_min, lon_max, lat_min, lat_max)
372+
tbbox = Bbox.null()
373+
tbbox.update_from_data_xy(np.column_stack([lon, lat]))
374+
tbbox = tbbox.expanded(1 + 2 / self.nx, 1 + 2 / self.ny)
375+
lon_min, lat_min, lon_max, lat_max = tbbox.extents
375376

376377
# check cycle
377378
if self.lon_cycle:
@@ -391,4 +392,4 @@ def __call__(self, transform_xy, x1, y1, x2, y2):
391392
max0 = self.lat_minmax[1]
392393
lat_max = min(max0, lat_max)
393394

394-
return lon_min, lon_max, lat_min, lat_max
395+
return Bbox.from_extents(lon_min, lat_min, lon_max, lat_max)

lib/mpl_toolkits/axisartist/floating_axes.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
from matplotlib import _api, cbook
1414
import matplotlib.patches as mpatches
1515
from matplotlib.path import Path
16-
16+
from matplotlib.transforms import Bbox
1717
from mpl_toolkits.axes_grid1.parasite_axes import host_axes_class_factory
18-
1918
from . import axislines, grid_helper_curvelinear
2019
from .axis_artist import AxisArtist
2120
from .grid_finder import ExtremeFinderSimple
@@ -109,8 +108,7 @@ def get_line(self, axes):
109108
right=("lon_lines0", 1),
110109
bottom=("lat_lines0", 0),
111110
top=("lat_lines0", 1))[self._side]
112-
xx, yy = self._grid_info[k][v]
113-
return Path(np.column_stack([xx, yy]))
111+
return Path(self._grid_info[k][v])
114112

115113

116114
class ExtremeFinderFixed(ExtremeFinderSimple):
@@ -125,11 +123,12 @@ def __init__(self, extremes):
125123
extremes : (float, float, float, float)
126124
The bounding box that this helper always returns.
127125
"""
128-
self._extremes = extremes
126+
x0, x1, y0, y1 = extremes
127+
self._tbbox = Bbox.from_extents(x0, y0, x1, y1)
129128

130-
def __call__(self, transform_xy, x1, y1, x2, y2):
129+
def _find_transformed_bbox(self, trans, bbox):
131130
# docstring inherited
132-
return self._extremes
131+
return self._tbbox
133132

134133

135134
class GridHelperCurveLinear(grid_helper_curvelinear.GridHelperCurveLinear):
@@ -195,18 +194,16 @@ def _update_grid(self, x1, y1, x2, y2):
195194
grid_info = self._grid_info
196195

197196
grid_finder = self.grid_finder
198-
extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy,
199-
x1, y1, x2, y2)
197+
tbbox = grid_finder.extreme_finder._find_transformed_bbox(
198+
grid_finder.get_transform().inverted(), Bbox.from_extents(x1, y1, x2, y2))
200199

201-
lon_min, lon_max = sorted(extremes[:2])
202-
lat_min, lat_max = sorted(extremes[2:])
203-
grid_info["extremes"] = lon_min, lon_max, lat_min, lat_max # extremes
200+
lon_min, lat_min = tbbox.min
201+
lon_max, lat_max = tbbox.max
202+
grid_info["extremes"] = Bbox.from_extents(lon_min, lat_min, lon_max, lat_max)
204203

205-
lon_levs, lon_n, lon_factor = \
206-
grid_finder.grid_locator1(lon_min, lon_max)
204+
lon_levs, lon_n, lon_factor = grid_finder.grid_locator1(lon_min, lon_max)
207205
lon_levs = np.asarray(lon_levs)
208-
lat_levs, lat_n, lat_factor = \
209-
grid_finder.grid_locator2(lat_min, lat_max)
206+
lat_levs, lat_n, lat_factor = grid_finder.grid_locator2(lat_min, lat_max)
210207
lat_levs = np.asarray(lat_levs)
211208

212209
grid_info["lon_info"] = lon_levs, lon_n, lon_factor
@@ -223,24 +220,23 @@ def _update_grid(self, x1, y1, x2, y2):
223220
lon_lines, lat_lines = grid_finder._get_raw_grid_lines(
224221
lon_values[(lon_min < lon_values) & (lon_values < lon_max)],
225222
lat_values[(lat_min < lat_values) & (lat_values < lat_max)],
226-
lon_min, lon_max, lat_min, lat_max)
223+
tbbox)
227224

228225
grid_info["lon_lines"] = lon_lines
229226
grid_info["lat_lines"] = lat_lines
230227

231228
lon_lines, lat_lines = grid_finder._get_raw_grid_lines(
232-
# lon_min, lon_max, lat_min, lat_max)
233-
extremes[:2], extremes[2:], *extremes)
229+
tbbox.intervalx, tbbox.intervaly, tbbox)
234230

235231
grid_info["lon_lines0"] = lon_lines
236232
grid_info["lat_lines0"] = lat_lines
237233

238234
def get_gridlines(self, which="major", axis="both"):
239235
grid_lines = []
240236
if axis in ["both", "x"]:
241-
grid_lines.extend(self._grid_info["lon_lines"])
237+
grid_lines.extend([xys.T for xys in self._grid_info["lon_lines"]])
242238
if axis in ["both", "y"]:
243-
grid_lines.extend(self._grid_info["lat_lines"])
239+
grid_lines.extend([xys.T for xys in self._grid_info["lat_lines"]])
244240
return grid_lines
245241

246242

lib/mpl_toolkits/axisartist/grid_finder.py

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -77,20 +77,21 @@ def __call__(self, transform_xy, x1, y1, x2, y2):
7777
extremal coordinates; then adding some padding to take into account the
7878
finite sampling.
7979
80-
As each sampling step covers a relative range of *1/nx* or *1/ny*,
80+
As each sampling step covers a relative range of ``1/nx`` or ``1/ny``,
8181
the padding is computed by expanding the span covered by the extremal
8282
coordinates by these fractions.
8383
"""
84-
x, y = np.meshgrid(
85-
np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny))
86-
xt, yt = transform_xy(np.ravel(x), np.ravel(y))
87-
return self._add_pad(xt.min(), xt.max(), yt.min(), yt.max())
84+
tbbox = self._find_transformed_bbox(
85+
_User2DTransform(transform_xy, None), Bbox.from_extents(x1, y1, x2, y2))
86+
return tbbox.x0, tbbox.x1, tbbox.y0, tbbox.y1
8887

89-
def _add_pad(self, x_min, x_max, y_min, y_max):
90-
"""Perform the padding mentioned in `__call__`."""
91-
dx = (x_max - x_min) / self.nx
92-
dy = (y_max - y_min) / self.ny
93-
return x_min - dx, x_max + dx, y_min - dy, y_max + dy
88+
def _find_transformed_bbox(self, trans, bbox):
89+
grid = np.reshape(np.meshgrid(np.linspace(bbox.x0, bbox.x1, self.nx),
90+
np.linspace(bbox.y0, bbox.y1, self.ny)),
91+
(2, -1)).T
92+
tbbox = Bbox.null()
93+
tbbox.update_from_data_xy(trans.transform(grid))
94+
return tbbox.expanded(1 + 2 / self.nx, 1 + 2 / self.ny)
9495

9596

9697
class _User2DTransform(Transform):
@@ -170,12 +171,13 @@ def get_grid_info(self, x1, y1, x2, y2):
170171
rough number of grids in each direction.
171172
"""
172173

173-
extremes = self.extreme_finder(self.inv_transform_xy, x1, y1, x2, y2)
174+
tbbox = self.extreme_finder._find_transformed_bbox(
175+
self.get_transform().inverted(), Bbox.from_extents(x1, y1, x2, y2))
174176

175177
# min & max rage of lat (or lon) for each grid line will be drawn.
176178
# i.e., gridline of lon=0 will be drawn from lat_min to lat_max.
177179

178-
lon_min, lon_max, lat_min, lat_max = extremes
180+
lon_min, lat_min, lon_max, lat_max = tbbox.extents
179181
lon_levs, lon_n, lon_factor = self.grid_locator1(lon_min, lon_max)
180182
lon_levs = np.asarray(lon_levs)
181183
lat_levs, lat_n, lat_factor = self.grid_locator2(lat_min, lat_max)
@@ -184,15 +186,12 @@ def get_grid_info(self, x1, y1, x2, y2):
184186
lon_values = lon_levs[:lon_n] / lon_factor
185187
lat_values = lat_levs[:lat_n] / lat_factor
186188

187-
lon_lines, lat_lines = self._get_raw_grid_lines(lon_values,
188-
lat_values,
189-
lon_min, lon_max,
190-
lat_min, lat_max)
189+
lon_lines, lat_lines = self._get_raw_grid_lines(lon_values, lat_values, tbbox)
191190

192191
bb = Bbox.from_extents(x1, y1, x2, y2).expanded(1 + 2e-10, 1 + 2e-10)
193192

194193
grid_info = {
195-
"extremes": extremes,
194+
"extremes": tbbox,
196195
# "lon", "lat", filled below.
197196
}
198197

@@ -201,11 +200,11 @@ def get_grid_info(self, x1, y1, x2, y2):
201200
(2, "lat", lat_levs, lat_factor, lat_values, lat_lines),
202201
]:
203202
grid_info[lon_or_lat] = gi = {
204-
"lines": [[l] for l in lines],
203+
"lines": lines,
205204
"ticks": {"left": [], "right": [], "bottom": [], "top": []},
206205
}
207-
for (lx, ly), v, level in zip(lines, values, levs):
208-
all_crossings = _find_line_box_crossings(np.column_stack([lx, ly]), bb)
206+
for xys, v, level in zip(lines, values, levs):
207+
all_crossings = _find_line_box_crossings(xys, bb)
209208
for side, crossings in zip(
210209
["left", "right", "bottom", "top"], all_crossings):
211210
for crossing in crossings:
@@ -218,18 +217,14 @@ def get_grid_info(self, x1, y1, x2, y2):
218217

219218
return grid_info
220219

221-
def _get_raw_grid_lines(self,
222-
lon_values, lat_values,
223-
lon_min, lon_max, lat_min, lat_max):
224-
225-
lons_i = np.linspace(lon_min, lon_max, 100) # for interpolation
226-
lats_i = np.linspace(lat_min, lat_max, 100)
227-
228-
lon_lines = [self.transform_xy(np.full_like(lats_i, lon), lats_i)
220+
def _get_raw_grid_lines(self, lon_values, lat_values, bbox):
221+
trans = self.get_transform()
222+
lons = np.linspace(bbox.x0, bbox.x1, 100) # for interpolation
223+
lats = np.linspace(bbox.y0, bbox.y1, 100)
224+
lon_lines = [trans.transform(np.column_stack([np.full_like(lats, lon), lats]))
229225
for lon in lon_values]
230-
lat_lines = [self.transform_xy(lons_i, np.full_like(lons_i, lat))
226+
lat_lines = [trans.transform(np.column_stack([lons, np.full_like(lons, lat)]))
231227
for lat in lat_values]
232-
233228
return lon_lines, lat_lines
234229

235230
def set_transform(self, aux_trans):
@@ -246,9 +241,11 @@ def get_transform(self):
246241

247242
update_transform = set_transform # backcompat alias.
248243

244+
@_api.deprecated("3.9", alternative="grid_finder.get_transform()")
249245
def transform_xy(self, x, y):
250246
return self._aux_transform.transform(np.column_stack([x, y])).T
251247

248+
@_api.deprecated("3.9", alternative="grid_finder.get_transform().inverted()")
252249
def inv_transform_xy(self, x, y):
253250
return self._aux_transform.inverted().transform(
254251
np.column_stack([x, y])).T

lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import matplotlib as mpl
1010
from matplotlib import _api
1111
from matplotlib.path import Path
12-
from matplotlib.transforms import Affine2D, IdentityTransform
12+
from matplotlib.transforms import Affine2D, Bbox, IdentityTransform
1313
from .axislines import (
1414
_FixedAxisArtistHelperBase, _FloatingAxisArtistHelperBase, GridHelperBase)
1515
from .axis_artist import AxisArtist
@@ -115,10 +115,10 @@ def update_lim(self, axes):
115115
x1, x2 = axes.get_xlim()
116116
y1, y2 = axes.get_ylim()
117117
grid_finder = self.grid_helper.grid_finder
118-
extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy,
119-
x1, y1, x2, y2)
118+
tbbox = grid_finder.extreme_finder._find_transformed_bbox(
119+
grid_finder.get_transform().inverted(), Bbox.from_extents(x1, y1, x2, y2))
120120

121-
lon_min, lon_max, lat_min, lat_max = extremes
121+
lon_min, lat_min, lon_max, lat_max = tbbox.extents
122122
e_min, e_max = self._extremes # ranges of other coordinates
123123
if self.nth_coord == 0:
124124
lat_min = max(e_min, lat_min)
@@ -127,29 +127,29 @@ def update_lim(self, axes):
127127
lon_min = max(e_min, lon_min)
128128
lon_max = min(e_max, lon_max)
129129

130-
lon_levs, lon_n, lon_factor = \
131-
grid_finder.grid_locator1(lon_min, lon_max)
132-
lat_levs, lat_n, lat_factor = \
133-
grid_finder.grid_locator2(lat_min, lat_max)
130+
lon_levs, lon_n, lon_factor = grid_finder.grid_locator1(lon_min, lon_max)
131+
lat_levs, lat_n, lat_factor = grid_finder.grid_locator2(lat_min, lat_max)
134132

135133
if self.nth_coord == 0:
136-
xx0 = np.full(self._line_num_points, self.value)
137-
yy0 = np.linspace(lat_min, lat_max, self._line_num_points)
138-
xx, yy = grid_finder.transform_xy(xx0, yy0)
134+
xys = grid_finder.get_transform().transform(np.column_stack([
135+
np.full(self._line_num_points, self.value),
136+
np.linspace(lat_min, lat_max, self._line_num_points),
137+
]))
139138
elif self.nth_coord == 1:
140-
xx0 = np.linspace(lon_min, lon_max, self._line_num_points)
141-
yy0 = np.full(self._line_num_points, self.value)
142-
xx, yy = grid_finder.transform_xy(xx0, yy0)
139+
xys = grid_finder.get_transform().transform(np.column_stack([
140+
np.linspace(lon_min, lon_max, self._line_num_points),
141+
np.full(self._line_num_points, self.value),
142+
]))
143143

144144
self._grid_info = {
145-
"extremes": (lon_min, lon_max, lat_min, lat_max),
145+
"extremes": Bbox.from_extents(lon_min, lat_min, lon_max, lat_max),
146146
"lon_info": (lon_levs, lon_n, np.asarray(lon_factor)),
147147
"lat_info": (lat_levs, lat_n, np.asarray(lat_factor)),
148148
"lon_labels": grid_finder._format_ticks(
149149
1, "bottom", lon_factor, lon_levs),
150150
"lat_labels": grid_finder._format_ticks(
151151
2, "bottom", lat_factor, lat_levs),
152-
"line_xy": (xx, yy),
152+
"line_xy": xys,
153153
}
154154

155155
def get_axislabel_transform(self, axes):
@@ -160,7 +160,7 @@ def trf_xy(x, y):
160160
trf = self.grid_helper.grid_finder.get_transform() + axes.transData
161161
return trf.transform([x, y]).T
162162

163-
xmin, xmax, ymin, ymax = self._grid_info["extremes"]
163+
xmin, ymin, xmax, ymax = self._grid_info["extremes"].extents
164164
if self.nth_coord == 0:
165165
xx0 = self.value
166166
yy0 = (ymin + ymax) / 2
@@ -232,8 +232,7 @@ def get_line_transform(self, axes):
232232

233233
def get_line(self, axes):
234234
self.update_lim(axes)
235-
x, y = self._grid_info["line_xy"]
236-
return Path(np.column_stack([x, y]))
235+
return Path(self._grid_info["line_xy"])
237236

238237

239238
class GridHelperCurveLinear(GridHelperBase):
@@ -309,11 +308,9 @@ def _update_grid(self, x1, y1, x2, y2):
309308
def get_gridlines(self, which="major", axis="both"):
310309
grid_lines = []
311310
if axis in ["both", "x"]:
312-
for gl in self._grid_info["lon"]["lines"]:
313-
grid_lines.extend(gl)
311+
grid_lines.extend([gl.T for gl in self._grid_info["lon"]["lines"]])
314312
if axis in ["both", "y"]:
315-
for gl in self._grid_info["lat"]["lines"]:
316-
grid_lines.extend(gl)
313+
grid_lines.extend([gl.T for gl in self._grid_info["lat"]["lines"]])
317314
return grid_lines
318315

319316
@_api.deprecated("3.9")

0 commit comments

Comments
 (0)