Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 500c9bd

Browse files
committed
Add a helper to generate xy coordinates for AxisArtistHelper.
AxisArtistHelper can generate either x or y ticks/gridlines depending on the value of self.nth_coord. The implementation often requires generating e.g. shape (2,) arrays such that the nth_coord column is set to a tick position, and the 1-nth_coord column has is set to 0. This is currently done using constructs like ``verts = [0, 0]; verts[self.nth_coord] = value`` where the mutation doesn't really help legibility. Instead, introduce a ``_to_xy`` helper that allows writing ``to_xy(variable=x, fixed=0)``.
1 parent a6da11e commit 500c9bd

File tree

2 files changed

+40
-33
lines changed

2 files changed

+40
-33
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
``passthru_pt``
2+
~~~~~~~~~~~~~~~
3+
This attribute of ``AxisArtistHelper``\s is deprecated.

lib/mpl_toolkits/axisartist/axislines.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -108,41 +108,53 @@ def update_lim(self, axes):
108108
delta2 = _api.deprecated("3.6")(
109109
property(lambda self: 0.00001, lambda self, value: None))
110110

111+
def _to_xy(self, values, const):
112+
"""
113+
Create a (len(values), 2)-shape array representing (x, y) pairs.
114+
115+
*values* go into the coordinate determined by ``self.nth_coord``.
116+
The other coordinate is filled with the constant *const*.
117+
118+
Example::
119+
120+
>>> self.nth_coord = 0
121+
>>> self._to_xy([1, 2, 3], const=0)
122+
array([[1, 0],
123+
[2, 0],
124+
[3, 0]])
125+
"""
126+
if self.nth_coord == 0:
127+
return np.stack(np.broadcast_arrays(values, const), axis=-1)
128+
elif self.nth_coord == 1:
129+
return np.stack(np.broadcast_arrays(const, values), axis=-1)
130+
else:
131+
raise ValueError("Unxpected nth_coord")
132+
111133
class Fixed(_Base):
112134
"""Helper class for a fixed (in the axes coordinate) axis."""
113135

136+
# deprecated with passthru_pt
114137
_default_passthru_pt = dict(left=(0, 0),
115138
right=(1, 0),
116139
bottom=(0, 0),
117140
top=(0, 1))
141+
passthru_pt = _api.deprecated("3.6")(property(
142+
lambda self: self._default_passthru_pt[self._loc]))
118143

119144
def __init__(self, loc, nth_coord=None):
120145
"""
121146
nth_coord = along which coordinate value varies
122147
in 2D, nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
123148
"""
124-
_api.check_in_list(["left", "right", "bottom", "top"], loc=loc)
125149
self._loc = loc
126-
127-
if nth_coord is None:
128-
if loc in ["left", "right"]:
129-
nth_coord = 1
130-
elif loc in ["bottom", "top"]:
131-
nth_coord = 0
132-
133-
self.nth_coord = nth_coord
134-
150+
self._pos = _api.check_getitem(
151+
{"bottom": 0, "top": 1, "left": 0, "right": 1}, loc=loc)
152+
self.nth_coord = (
153+
nth_coord if nth_coord is not None else
154+
{"bottom": 0, "top": 0, "left": 1, "right": 1}[loc])
135155
super().__init__()
136-
137-
self.passthru_pt = self._default_passthru_pt[loc]
138-
139-
_verts = np.array([[0., 0.],
140-
[1., 1.]])
141-
fixed_coord = 1 - nth_coord
142-
_verts[:, fixed_coord] = self.passthru_pt[fixed_coord]
143-
144156
# axis line in transAxes
145-
self._path = Path(_verts)
157+
self._path = Path(self._to_xy((0, 1), const=self._pos))
146158

147159
def get_nth_coord(self):
148160
return self.nth_coord
@@ -225,8 +237,7 @@ def get_tick_iterators(self, axes):
225237

226238
def _f(locs, labels):
227239
for x, l in zip(locs, labels):
228-
c = list(self.passthru_pt) # copy
229-
c[self.nth_coord] = x
240+
c = self._to_xy(x, const=self._pos)
230241
# check if the tick point is inside axes
231242
c2 = tick_to_axes.transform(c)
232243
if mpl.transforms._interval_contains_close(
@@ -243,15 +254,10 @@ def __init__(self, axes, nth_coord,
243254
self.axis = [axes.xaxis, axes.yaxis][self.nth_coord]
244255

245256
def get_line(self, axes):
246-
_verts = np.array([[0., 0.],
247-
[1., 1.]])
248-
249257
fixed_coord = 1 - self.nth_coord
250258
data_to_axes = axes.transData - axes.transAxes
251259
p = data_to_axes.transform([self._value, self._value])
252-
_verts[:, fixed_coord] = p[fixed_coord]
253-
254-
return Path(_verts)
260+
return Path(self._to_xy((0, 1), const=p[fixed_coord]))
255261

256262
def get_line_transform(self, axes):
257263
return axes.transAxes
@@ -266,13 +272,12 @@ def get_axislabel_pos_angle(self, axes):
266272
get_label_transform() returns a transform of (transAxes+offset)
267273
"""
268274
angle = [0, 90][self.nth_coord]
269-
_verts = [0.5, 0.5]
270275
fixed_coord = 1 - self.nth_coord
271276
data_to_axes = axes.transData - axes.transAxes
272277
p = data_to_axes.transform([self._value, self._value])
273-
_verts[fixed_coord] = p[fixed_coord]
274-
if 0 <= _verts[fixed_coord] <= 1:
275-
return _verts, angle
278+
verts = self._to_xy(0.5, const=p[fixed_coord])
279+
if 0 <= verts[fixed_coord] <= 1:
280+
return verts, angle
276281
else:
277282
return None, None
278283

@@ -298,8 +303,7 @@ def get_tick_iterators(self, axes):
298303

299304
def _f(locs, labels):
300305
for x, l in zip(locs, labels):
301-
c = [self._value, self._value]
302-
c[self.nth_coord] = x
306+
c = self._to_xy(x, const=self._value)
303307
c1, c2 = data_to_axes.transform(c)
304308
if 0 <= c1 <= 1 and 0 <= c2 <= 1:
305309
yield c, angle_normal, angle_tangent, l

0 commit comments

Comments
 (0)