Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 26224d9

Browse files
authored
Merge pull request #26036 from anntzer/ag
Cleanup AxesGrid
2 parents bb335a1 + f6eacd8 commit 26224d9

File tree

1 file changed

+23
-57
lines changed

1 file changed

+23
-57
lines changed

lib/mpl_toolkits/axes_grid1/axes_grid.py

Lines changed: 23 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,6 @@
1111
from .mpl_axes import Axes, SimpleAxisArtist
1212

1313

14-
def _tick_only(ax, bottom_on, left_on):
15-
bottom_off = not bottom_on
16-
left_off = not left_on
17-
if isinstance(ax.axis, MethodType):
18-
bottom = SimpleAxisArtist(ax.xaxis, 1, ax.spines["bottom"])
19-
left = SimpleAxisArtist(ax.yaxis, 1, ax.spines["left"])
20-
else:
21-
bottom = ax.axis["bottom"]
22-
left = ax.axis["left"]
23-
bottom.toggle(ticklabels=bottom_off, label=bottom_off)
24-
left.toggle(ticklabels=left_off, label=left_off)
25-
26-
2714
class CbarAxesBase:
2815
def __init__(self, *args, orientation, **kwargs):
2916
self.orientation = orientation
@@ -170,31 +157,15 @@ def __init__(self, fig,
170157
self.set_label_mode(label_mode)
171158

172159
def _init_locators(self):
173-
174-
h = []
175-
h_ax_pos = []
176-
for _ in range(self._ncols):
177-
if h:
178-
h.append(self._horiz_pad_size)
179-
h_ax_pos.append(len(h))
180-
sz = Size.Scaled(1)
181-
h.append(sz)
182-
183-
v = []
184-
v_ax_pos = []
185-
for _ in range(self._nrows):
186-
if v:
187-
v.append(self._vert_pad_size)
188-
v_ax_pos.append(len(v))
189-
sz = Size.Scaled(1)
190-
v.append(sz)
191-
160+
h = [Size.Scaled(1), self._horiz_pad_size] * (self._ncols-1) + [Size.Scaled(1)]
161+
h_indices = range(0, 2 * self._ncols, 2) # Indices of Scaled(1).
162+
v = [Size.Scaled(1), self._vert_pad_size] * (self._nrows-1) + [Size.Scaled(1)]
163+
v_indices = range(0, 2 * self._nrows, 2) # Indices of Scaled(1).
192164
for i in range(self.ngrids):
193165
col, row = self._get_col_row(i)
194166
locator = self._divider.new_locator(
195-
nx=h_ax_pos[col], ny=v_ax_pos[self._nrows - 1 - row])
167+
nx=h_indices[col], ny=v_indices[self._nrows - 1 - row])
196168
self.axes_all[i].set_axes_locator(locator)
197-
198169
self._divider.set_horizontal(h)
199170
self._divider.set_vertical(v)
200171

@@ -266,32 +237,15 @@ def set_label_mode(self, mode):
266237
- "all": All axes are labelled.
267238
- "keep": Do not do anything.
268239
"""
240+
is_last_row, is_first_col = (
241+
np.mgrid[:self._nrows, :self._ncols] == [[[self._nrows - 1]], [[0]]])
269242
if mode == "all":
270-
for ax in self.axes_all:
271-
_tick_only(ax, False, False)
243+
bottom = left = np.full((self._nrows, self._ncols), True)
272244
elif mode == "L":
273-
# left-most axes
274-
for ax in self.axes_column[0][:-1]:
275-
_tick_only(ax, bottom_on=True, left_on=False)
276-
# lower-left axes
277-
ax = self.axes_column[0][-1]
278-
_tick_only(ax, bottom_on=False, left_on=False)
279-
280-
for col in self.axes_column[1:]:
281-
# axes with no labels
282-
for ax in col[:-1]:
283-
_tick_only(ax, bottom_on=True, left_on=True)
284-
285-
# bottom
286-
ax = col[-1]
287-
_tick_only(ax, bottom_on=False, left_on=True)
288-
245+
bottom = is_last_row
246+
left = is_first_col
289247
elif mode == "1":
290-
for ax in self.axes_all:
291-
_tick_only(ax, bottom_on=True, left_on=True)
292-
293-
ax = self.axes_llc
294-
_tick_only(ax, bottom_on=False, left_on=False)
248+
bottom = left = is_last_row & is_first_col
295249
else:
296250
# Use _api.check_in_list at the top of the method when deprecation
297251
# period expires
@@ -302,6 +256,18 @@ def set_label_mode(self, mode):
302256
'since %(since)s and will become an error '
303257
'%(removal)s. To silence this warning, pass '
304258
'"keep", which gives the same behaviour.')
259+
return
260+
for i in range(self._nrows):
261+
for j in range(self._ncols):
262+
ax = self.axes_row[i][j]
263+
if isinstance(ax.axis, MethodType):
264+
bottom_axis = SimpleAxisArtist(ax.xaxis, 1, ax.spines["bottom"])
265+
left_axis = SimpleAxisArtist(ax.yaxis, 1, ax.spines["left"])
266+
else:
267+
bottom_axis = ax.axis["bottom"]
268+
left_axis = ax.axis["left"]
269+
bottom_axis.toggle(ticklabels=bottom[i, j], label=bottom[i, j])
270+
left_axis.toggle(ticklabels=left[i, j], label=left[i, j])
305271

306272
def get_divider(self):
307273
return self._divider

0 commit comments

Comments
 (0)