11from numbers import Number
22
3+ import numpy as np
4+
35import matplotlib as mpl
46from matplotlib import cbook
57import matplotlib .ticker as ticker
@@ -156,75 +158,37 @@ def __init__(self, fig,
156158
157159 self ._init_axes_pad (axes_pad )
158160
159- if direction not in ["column" , "row" ]:
160- raise Exception ("" )
161-
161+ cbook ._check_in_list (["column" , "row" ], direction = direction )
162162 self ._direction = direction
163163
164164 if axes_class is None :
165165 axes_class = self ._defaultAxesClass
166166
167- self .axes_all = []
168- self .axes_column = [[] for _ in range (self ._ncols )]
169- self .axes_row = [[] for _ in range (self ._nrows )]
170-
171- h = []
172- v = []
173- if isinstance (rect , (str , Number )):
174- self ._divider = SubplotDivider (fig , rect , horizontal = h , vertical = v ,
175- aspect = False )
176- elif isinstance (rect , SubplotSpec ):
177- self ._divider = SubplotDivider (fig , rect , horizontal = h , vertical = v ,
178- aspect = False )
167+ kw = dict (horizontal = [], vertical = [], aspect = False )
168+ if isinstance (rect , (str , Number , SubplotSpec )):
169+ self ._divider = SubplotDivider (fig , rect , ** kw )
179170 elif len (rect ) == 3 :
180- kw = dict (horizontal = h , vertical = v , aspect = False )
181171 self ._divider = SubplotDivider (fig , * rect , ** kw )
182172 elif len (rect ) == 4 :
183- self ._divider = Divider (fig , rect , horizontal = h , vertical = v ,
184- aspect = False )
173+ self ._divider = Divider (fig , rect , ** kw )
185174 else :
186175 raise Exception ("" )
187176
188177 rect = self ._divider .get_position ()
189178
190- # reference axes
191- self ._column_refax = [None for _ in range (self ._ncols )]
192- self ._row_refax = [None for _ in range (self ._nrows )]
193- self ._refax = None
194-
179+ axes_array = np .full ((self ._nrows , self ._ncols ), None , dtype = object )
195180 for i in range (self .ngrids ):
196-
197181 col , row = self ._get_col_row (i )
198-
199182 if share_all :
200- sharex = self ._refax
201- sharey = self ._refax
183+ sharex = sharey = axes_array [0 , 0 ]
202184 else :
203- if share_x :
204- sharex = self ._column_refax [col ]
205- else :
206- sharex = None
207-
208- if share_y :
209- sharey = self ._row_refax [row ]
210- else :
211- sharey = None
212-
213- ax = axes_class (fig , rect , sharex = sharex , sharey = sharey )
214-
215- if share_all :
216- if self ._refax is None :
217- self ._refax = ax
218- else :
219- if sharex is None :
220- self ._column_refax [col ] = ax
221- if sharey is None :
222- self ._row_refax [row ] = ax
223-
224- self .axes_all .append (ax )
225- self .axes_column [col ].append (ax )
226- self .axes_row [row ].append (ax )
227-
185+ sharex = axes_array [0 , col ] if share_x else None
186+ sharey = axes_array [row , 0 ] if share_y else None
187+ axes_array [row , col ] = axes_class (
188+ fig , rect , sharex = sharex , sharey = sharey )
189+ self .axes_all = axes_array .ravel ().tolist ()
190+ self .axes_column = axes_array .T .tolist ()
191+ self .axes_row = axes_array .tolist ()
228192 self .axes_llc = self .axes_column [0 ][- 1 ]
229193
230194 self ._update_locators ()
@@ -245,27 +209,19 @@ def _init_axes_pad(self, axes_pad):
245209 def _update_locators (self ):
246210
247211 h = []
248-
249212 h_ax_pos = []
250-
251- for _ in self ._column_refax :
252- #if h: h.append(Size.Fixed(self._axes_pad))
213+ for _ in range (self ._ncols ):
253214 if h :
254215 h .append (self ._horiz_pad_size )
255-
256216 h_ax_pos .append (len (h ))
257-
258217 sz = Size .Scaled (1 )
259218 h .append (sz )
260219
261220 v = []
262-
263221 v_ax_pos = []
264- for _ in self ._row_refax [::- 1 ]:
265- #if v: v.append(Size.Fixed(self._axes_pad))
222+ for _ in range (self ._nrows ):
266223 if v :
267224 v .append (self ._vert_pad_size )
268-
269225 v_ax_pos .append (len (v ))
270226 sz = Size .Scaled (1 )
271227 v .append (sz )
@@ -485,79 +441,44 @@ def __init__(self, fig,
485441
486442 self ._init_axes_pad (axes_pad )
487443
488- if direction not in ["column" , "row" ]:
489- raise Exception ("" )
490-
444+ cbook ._check_in_list (["column" , "row" ], direction = direction )
491445 self ._direction = direction
492446
493447 if axes_class is None :
494448 axes_class = self ._defaultAxesClass
495449
496- self .axes_all = []
497- self .axes_column = [[] for _ in range (self ._ncols )]
498- self .axes_row = [[] for _ in range (self ._nrows )]
499-
500- self .cbar_axes = []
501-
502- h = []
503- v = []
504- if isinstance (rect , (str , Number )):
505- self ._divider = SubplotDivider (fig , rect , horizontal = h , vertical = v ,
506- aspect = aspect )
507- elif isinstance (rect , SubplotSpec ):
508- self ._divider = SubplotDivider (fig , rect , horizontal = h , vertical = v ,
509- aspect = aspect )
450+ kw = dict (horizontal = [], vertical = [], aspect = aspect )
451+ if isinstance (rect , (str , Number , SubplotSpec )):
452+ self ._divider = SubplotDivider (fig , rect , ** kw )
510453 elif len (rect ) == 3 :
511- kw = dict (horizontal = h , vertical = v , aspect = aspect )
512454 self ._divider = SubplotDivider (fig , * rect , ** kw )
513455 elif len (rect ) == 4 :
514- self ._divider = Divider (fig , rect , horizontal = h , vertical = v ,
515- aspect = aspect )
456+ self ._divider = Divider (fig , rect , ** kw )
516457 else :
517458 raise Exception ("" )
518459
519460 rect = self ._divider .get_position ()
520461
521- # reference axes
522- self ._column_refax = [None for _ in range (self ._ncols )]
523- self ._row_refax = [None for _ in range (self ._nrows )]
524- self ._refax = None
525-
462+ axes_array = np .full ((self ._nrows , self ._ncols ), None , dtype = object )
526463 for i in range (self .ngrids ):
527-
528464 col , row = self ._get_col_row (i )
529-
530465 if share_all :
531- if self .axes_all :
532- sharex = self .axes_all [0 ]
533- sharey = self .axes_all [0 ]
534- else :
535- sharex = None
536- sharey = None
466+ sharex = sharey = axes_array [0 , 0 ]
537467 else :
538- sharex = self ._column_refax [col ]
539- sharey = self ._row_refax [row ]
540-
541- ax = axes_class (fig , rect , sharex = sharex , sharey = sharey )
542-
543- self .axes_all .append (ax )
544- self .axes_column [col ].append (ax )
545- self .axes_row [row ].append (ax )
546-
547- if share_all :
548- if self ._refax is None :
549- self ._refax = ax
550- if sharex is None :
551- self ._column_refax [col ] = ax
552- if sharey is None :
553- self ._row_refax [row ] = ax
554-
555- cax = self ._defaultCbarAxesClass (fig , rect ,
556- orientation = self ._colorbar_location )
557- self .cbar_axes .append (cax )
558-
468+ sharex = axes_array [0 , col ]
469+ sharey = axes_array [row , 0 ]
470+ axes_array [row , col ] = axes_class (
471+ fig , rect , sharex = sharex , sharey = sharey )
472+ self .axes_all = axes_array .ravel ().tolist ()
473+ self .axes_column = axes_array .T .tolist ()
474+ self .axes_row = axes_array .tolist ()
559475 self .axes_llc = self .axes_column [0 ][- 1 ]
560476
477+ self .cbar_axes = [
478+ self ._defaultCbarAxesClass (fig , rect ,
479+ orientation = self ._colorbar_location )
480+ for _ in range (self .ngrids )]
481+
561482 self ._update_locators ()
562483
563484 if add_all :
0 commit comments