Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 38cb765

Browse files
authored
Merge pull request #15639 from anntzer/axes_grid-init
Simplify axes_grid.Grid/axes_grid.ImageGrid construction.
2 parents 66e4fb2 + 88e0d84 commit 38cb765

File tree

1 file changed

+37
-116
lines changed

1 file changed

+37
-116
lines changed

lib/mpl_toolkits/axes_grid1/axes_grid.py

Lines changed: 37 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from numbers import Number
22

3+
import numpy as np
4+
35
import matplotlib as mpl
46
from matplotlib import cbook
57
import matplotlib.ticker as ticker
@@ -156,75 +158,37 @@ def __init__(self, fig,
156158

157159
self._init_axes_pad(axes_pad)
158160

159-
if direction not in ["column", "row"]:
160-
raise Exception("")
161-
161+
cbook._check_in_list(["column", "row"], direction=direction)
162162
self._direction = direction
163163

164164
if axes_class is None:
165165
axes_class = self._defaultAxesClass
166166

167-
self.axes_all = []
168-
self.axes_column = [[] for _ in range(self._ncols)]
169-
self.axes_row = [[] for _ in range(self._nrows)]
170-
171-
h = []
172-
v = []
173-
if isinstance(rect, (str, Number)):
174-
self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v,
175-
aspect=False)
176-
elif isinstance(rect, SubplotSpec):
177-
self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v,
178-
aspect=False)
167+
kw = dict(horizontal=[], vertical=[], aspect=False)
168+
if isinstance(rect, (str, Number, SubplotSpec)):
169+
self._divider = SubplotDivider(fig, rect, **kw)
179170
elif len(rect) == 3:
180-
kw = dict(horizontal=h, vertical=v, aspect=False)
181171
self._divider = SubplotDivider(fig, *rect, **kw)
182172
elif len(rect) == 4:
183-
self._divider = Divider(fig, rect, horizontal=h, vertical=v,
184-
aspect=False)
173+
self._divider = Divider(fig, rect, **kw)
185174
else:
186175
raise Exception("")
187176

188177
rect = self._divider.get_position()
189178

190-
# reference axes
191-
self._column_refax = [None for _ in range(self._ncols)]
192-
self._row_refax = [None for _ in range(self._nrows)]
193-
self._refax = None
194-
179+
axes_array = np.full((self._nrows, self._ncols), None, dtype=object)
195180
for i in range(self.ngrids):
196-
197181
col, row = self._get_col_row(i)
198-
199182
if share_all:
200-
sharex = self._refax
201-
sharey = self._refax
183+
sharex = sharey = axes_array[0, 0]
202184
else:
203-
if share_x:
204-
sharex = self._column_refax[col]
205-
else:
206-
sharex = None
207-
208-
if share_y:
209-
sharey = self._row_refax[row]
210-
else:
211-
sharey = None
212-
213-
ax = axes_class(fig, rect, sharex=sharex, sharey=sharey)
214-
215-
if share_all:
216-
if self._refax is None:
217-
self._refax = ax
218-
else:
219-
if sharex is None:
220-
self._column_refax[col] = ax
221-
if sharey is None:
222-
self._row_refax[row] = ax
223-
224-
self.axes_all.append(ax)
225-
self.axes_column[col].append(ax)
226-
self.axes_row[row].append(ax)
227-
185+
sharex = axes_array[0, col] if share_x else None
186+
sharey = axes_array[row, 0] if share_y else None
187+
axes_array[row, col] = axes_class(
188+
fig, rect, sharex=sharex, sharey=sharey)
189+
self.axes_all = axes_array.ravel().tolist()
190+
self.axes_column = axes_array.T.tolist()
191+
self.axes_row = axes_array.tolist()
228192
self.axes_llc = self.axes_column[0][-1]
229193

230194
self._update_locators()
@@ -245,27 +209,19 @@ def _init_axes_pad(self, axes_pad):
245209
def _update_locators(self):
246210

247211
h = []
248-
249212
h_ax_pos = []
250-
251-
for _ in self._column_refax:
252-
#if h: h.append(Size.Fixed(self._axes_pad))
213+
for _ in range(self._ncols):
253214
if h:
254215
h.append(self._horiz_pad_size)
255-
256216
h_ax_pos.append(len(h))
257-
258217
sz = Size.Scaled(1)
259218
h.append(sz)
260219

261220
v = []
262-
263221
v_ax_pos = []
264-
for _ in self._row_refax[::-1]:
265-
#if v: v.append(Size.Fixed(self._axes_pad))
222+
for _ in range(self._nrows):
266223
if v:
267224
v.append(self._vert_pad_size)
268-
269225
v_ax_pos.append(len(v))
270226
sz = Size.Scaled(1)
271227
v.append(sz)
@@ -485,79 +441,44 @@ def __init__(self, fig,
485441

486442
self._init_axes_pad(axes_pad)
487443

488-
if direction not in ["column", "row"]:
489-
raise Exception("")
490-
444+
cbook._check_in_list(["column", "row"], direction=direction)
491445
self._direction = direction
492446

493447
if axes_class is None:
494448
axes_class = self._defaultAxesClass
495449

496-
self.axes_all = []
497-
self.axes_column = [[] for _ in range(self._ncols)]
498-
self.axes_row = [[] for _ in range(self._nrows)]
499-
500-
self.cbar_axes = []
501-
502-
h = []
503-
v = []
504-
if isinstance(rect, (str, Number)):
505-
self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v,
506-
aspect=aspect)
507-
elif isinstance(rect, SubplotSpec):
508-
self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v,
509-
aspect=aspect)
450+
kw = dict(horizontal=[], vertical=[], aspect=aspect)
451+
if isinstance(rect, (str, Number, SubplotSpec)):
452+
self._divider = SubplotDivider(fig, rect, **kw)
510453
elif len(rect) == 3:
511-
kw = dict(horizontal=h, vertical=v, aspect=aspect)
512454
self._divider = SubplotDivider(fig, *rect, **kw)
513455
elif len(rect) == 4:
514-
self._divider = Divider(fig, rect, horizontal=h, vertical=v,
515-
aspect=aspect)
456+
self._divider = Divider(fig, rect, **kw)
516457
else:
517458
raise Exception("")
518459

519460
rect = self._divider.get_position()
520461

521-
# reference axes
522-
self._column_refax = [None for _ in range(self._ncols)]
523-
self._row_refax = [None for _ in range(self._nrows)]
524-
self._refax = None
525-
462+
axes_array = np.full((self._nrows, self._ncols), None, dtype=object)
526463
for i in range(self.ngrids):
527-
528464
col, row = self._get_col_row(i)
529-
530465
if share_all:
531-
if self.axes_all:
532-
sharex = self.axes_all[0]
533-
sharey = self.axes_all[0]
534-
else:
535-
sharex = None
536-
sharey = None
466+
sharex = sharey = axes_array[0, 0]
537467
else:
538-
sharex = self._column_refax[col]
539-
sharey = self._row_refax[row]
540-
541-
ax = axes_class(fig, rect, sharex=sharex, sharey=sharey)
542-
543-
self.axes_all.append(ax)
544-
self.axes_column[col].append(ax)
545-
self.axes_row[row].append(ax)
546-
547-
if share_all:
548-
if self._refax is None:
549-
self._refax = ax
550-
if sharex is None:
551-
self._column_refax[col] = ax
552-
if sharey is None:
553-
self._row_refax[row] = ax
554-
555-
cax = self._defaultCbarAxesClass(fig, rect,
556-
orientation=self._colorbar_location)
557-
self.cbar_axes.append(cax)
558-
468+
sharex = axes_array[0, col]
469+
sharey = axes_array[row, 0]
470+
axes_array[row, col] = axes_class(
471+
fig, rect, sharex=sharex, sharey=sharey)
472+
self.axes_all = axes_array.ravel().tolist()
473+
self.axes_column = axes_array.T.tolist()
474+
self.axes_row = axes_array.tolist()
559475
self.axes_llc = self.axes_column[0][-1]
560476

477+
self.cbar_axes = [
478+
self._defaultCbarAxesClass(fig, rect,
479+
orientation=self._colorbar_location)
480+
for _ in range(self.ngrids)]
481+
561482
self._update_locators()
562483

563484
if add_all:

0 commit comments

Comments
 (0)