Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 88e0d84

Browse files
committed
Simplify axes_grid.Grid/axes_grid.ImageGrid construction.
Store axes in a 2D, None-initialized numpy array (`axes_array`) rather than nested lists (converting everything to nested lists with tolist() at the end for backcompat); this allows one to get rid of self._refax/ self._column_refax/self._row_refax and instead just use `axes_array[0, 0]`/`axes_array[0, col]`/`axes_array[row, 0]` as axes for sharing (these default correctly to None on the first iteration). Minor cleanups: use check_in_list; factor out common kwargs when constructing self._divider; directly construct self.cbar_axes as a list comprehension. Note that the Grid and AxesGrid constructors are almost duplicate of one another and could be deduplicated later.
1 parent f7e7e46 commit 88e0d84

1 file changed

Lines changed: 37 additions & 116 deletions

File tree

lib/mpl_toolkits/axes_grid1/axes_grid.py

Lines changed: 37 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from numbers import Number
22

3+
import numpy as np
4+
35
import matplotlib as mpl
46
from matplotlib import cbook
57
import matplotlib.ticker as ticker
@@ -183,75 +185,37 @@ def __init__(self, fig,
183185

184186
self._init_axes_pad(axes_pad)
185187

186-
if direction not in ["column", "row"]:
187-
raise Exception("")
188-
188+
cbook._check_in_list(["column", "row"], direction=direction)
189189
self._direction = direction
190190

191191
if axes_class is None:
192192
axes_class = self._defaultAxesClass
193193

194-
self.axes_all = []
195-
self.axes_column = [[] for _ in range(self._ncols)]
196-
self.axes_row = [[] for _ in range(self._nrows)]
197-
198-
h = []
199-
v = []
200-
if isinstance(rect, (str, Number)):
201-
self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v,
202-
aspect=False)
203-
elif isinstance(rect, SubplotSpec):
204-
self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v,
205-
aspect=False)
194+
kw = dict(horizontal=[], vertical=[], aspect=False)
195+
if isinstance(rect, (str, Number, SubplotSpec)):
196+
self._divider = SubplotDivider(fig, rect, **kw)
206197
elif len(rect) == 3:
207-
kw = dict(horizontal=h, vertical=v, aspect=False)
208198
self._divider = SubplotDivider(fig, *rect, **kw)
209199
elif len(rect) == 4:
210-
self._divider = Divider(fig, rect, horizontal=h, vertical=v,
211-
aspect=False)
200+
self._divider = Divider(fig, rect, **kw)
212201
else:
213202
raise Exception("")
214203

215204
rect = self._divider.get_position()
216205

217-
# reference axes
218-
self._column_refax = [None for _ in range(self._ncols)]
219-
self._row_refax = [None for _ in range(self._nrows)]
220-
self._refax = None
221-
206+
axes_array = np.full((self._nrows, self._ncols), None, dtype=object)
222207
for i in range(self.ngrids):
223-
224208
col, row = self._get_col_row(i)
225-
226209
if share_all:
227-
sharex = self._refax
228-
sharey = self._refax
210+
sharex = sharey = axes_array[0, 0]
229211
else:
230-
if share_x:
231-
sharex = self._column_refax[col]
232-
else:
233-
sharex = None
234-
235-
if share_y:
236-
sharey = self._row_refax[row]
237-
else:
238-
sharey = None
239-
240-
ax = axes_class(fig, rect, sharex=sharex, sharey=sharey)
241-
242-
if share_all:
243-
if self._refax is None:
244-
self._refax = ax
245-
else:
246-
if sharex is None:
247-
self._column_refax[col] = ax
248-
if sharey is None:
249-
self._row_refax[row] = ax
250-
251-
self.axes_all.append(ax)
252-
self.axes_column[col].append(ax)
253-
self.axes_row[row].append(ax)
254-
212+
sharex = axes_array[0, col] if share_x else None
213+
sharey = axes_array[row, 0] if share_y else None
214+
axes_array[row, col] = axes_class(
215+
fig, rect, sharex=sharex, sharey=sharey)
216+
self.axes_all = axes_array.ravel().tolist()
217+
self.axes_column = axes_array.T.tolist()
218+
self.axes_row = axes_array.tolist()
255219
self.axes_llc = self.axes_column[0][-1]
256220

257221
self._update_locators()
@@ -272,27 +236,19 @@ def _init_axes_pad(self, axes_pad):
272236
def _update_locators(self):
273237

274238
h = []
275-
276239
h_ax_pos = []
277-
278-
for _ in self._column_refax:
279-
#if h: h.append(Size.Fixed(self._axes_pad))
240+
for _ in range(self._ncols):
280241
if h:
281242
h.append(self._horiz_pad_size)
282-
283243
h_ax_pos.append(len(h))
284-
285244
sz = Size.Scaled(1)
286245
h.append(sz)
287246

288247
v = []
289-
290248
v_ax_pos = []
291-
for _ in self._row_refax[::-1]:
292-
#if v: v.append(Size.Fixed(self._axes_pad))
249+
for _ in range(self._nrows):
293250
if v:
294251
v.append(self._vert_pad_size)
295-
296252
v_ax_pos.append(len(v))
297253
sz = Size.Scaled(1)
298254
v.append(sz)
@@ -512,79 +468,44 @@ def __init__(self, fig,
512468

513469
self._init_axes_pad(axes_pad)
514470

515-
if direction not in ["column", "row"]:
516-
raise Exception("")
517-
471+
cbook._check_in_list(["column", "row"], direction=direction)
518472
self._direction = direction
519473

520474
if axes_class is None:
521475
axes_class = self._defaultAxesClass
522476

523-
self.axes_all = []
524-
self.axes_column = [[] for _ in range(self._ncols)]
525-
self.axes_row = [[] for _ in range(self._nrows)]
526-
527-
self.cbar_axes = []
528-
529-
h = []
530-
v = []
531-
if isinstance(rect, (str, Number)):
532-
self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v,
533-
aspect=aspect)
534-
elif isinstance(rect, SubplotSpec):
535-
self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v,
536-
aspect=aspect)
477+
kw = dict(horizontal=[], vertical=[], aspect=aspect)
478+
if isinstance(rect, (str, Number, SubplotSpec)):
479+
self._divider = SubplotDivider(fig, rect, **kw)
537480
elif len(rect) == 3:
538-
kw = dict(horizontal=h, vertical=v, aspect=aspect)
539481
self._divider = SubplotDivider(fig, *rect, **kw)
540482
elif len(rect) == 4:
541-
self._divider = Divider(fig, rect, horizontal=h, vertical=v,
542-
aspect=aspect)
483+
self._divider = Divider(fig, rect, **kw)
543484
else:
544485
raise Exception("")
545486

546487
rect = self._divider.get_position()
547488

548-
# reference axes
549-
self._column_refax = [None for _ in range(self._ncols)]
550-
self._row_refax = [None for _ in range(self._nrows)]
551-
self._refax = None
552-
489+
axes_array = np.full((self._nrows, self._ncols), None, dtype=object)
553490
for i in range(self.ngrids):
554-
555491
col, row = self._get_col_row(i)
556-
557492
if share_all:
558-
if self.axes_all:
559-
sharex = self.axes_all[0]
560-
sharey = self.axes_all[0]
561-
else:
562-
sharex = None
563-
sharey = None
493+
sharex = sharey = axes_array[0, 0]
564494
else:
565-
sharex = self._column_refax[col]
566-
sharey = self._row_refax[row]
567-
568-
ax = axes_class(fig, rect, sharex=sharex, sharey=sharey)
569-
570-
self.axes_all.append(ax)
571-
self.axes_column[col].append(ax)
572-
self.axes_row[row].append(ax)
573-
574-
if share_all:
575-
if self._refax is None:
576-
self._refax = ax
577-
if sharex is None:
578-
self._column_refax[col] = ax
579-
if sharey is None:
580-
self._row_refax[row] = ax
581-
582-
cax = self._defaultCbarAxesClass(fig, rect,
583-
orientation=self._colorbar_location)
584-
self.cbar_axes.append(cax)
585-
495+
sharex = axes_array[0, col]
496+
sharey = axes_array[row, 0]
497+
axes_array[row, col] = axes_class(
498+
fig, rect, sharex=sharex, sharey=sharey)
499+
self.axes_all = axes_array.ravel().tolist()
500+
self.axes_column = axes_array.T.tolist()
501+
self.axes_row = axes_array.tolist()
586502
self.axes_llc = self.axes_column[0][-1]
587503

504+
self.cbar_axes = [
505+
self._defaultCbarAxesClass(fig, rect,
506+
orientation=self._colorbar_location)
507+
for _ in range(self.ngrids)]
508+
588509
self._update_locators()
589510

590511
if add_all:

0 commit comments

Comments
 (0)