11from numbers import Number
22
3+ import numpy as np
4+
35import matplotlib as mpl
46from matplotlib import cbook
57import matplotlib .ticker as ticker
@@ -183,75 +185,37 @@ def __init__(self, fig,
183185
184186 self ._init_axes_pad (axes_pad )
185187
186- if direction not in ["column" , "row" ]:
187- raise Exception ("" )
188-
188+ cbook ._check_in_list (["column" , "row" ], direction = direction )
189189 self ._direction = direction
190190
191191 if axes_class is None :
192192 axes_class = self ._defaultAxesClass
193193
194- self .axes_all = []
195- self .axes_column = [[] for _ in range (self ._ncols )]
196- self .axes_row = [[] for _ in range (self ._nrows )]
197-
198- h = []
199- v = []
200- if isinstance (rect , (str , Number )):
201- self ._divider = SubplotDivider (fig , rect , horizontal = h , vertical = v ,
202- aspect = False )
203- elif isinstance (rect , SubplotSpec ):
204- self ._divider = SubplotDivider (fig , rect , horizontal = h , vertical = v ,
205- aspect = False )
194+ kw = dict (horizontal = [], vertical = [], aspect = False )
195+ if isinstance (rect , (str , Number , SubplotSpec )):
196+ self ._divider = SubplotDivider (fig , rect , ** kw )
206197 elif len (rect ) == 3 :
207- kw = dict (horizontal = h , vertical = v , aspect = False )
208198 self ._divider = SubplotDivider (fig , * rect , ** kw )
209199 elif len (rect ) == 4 :
210- self ._divider = Divider (fig , rect , horizontal = h , vertical = v ,
211- aspect = False )
200+ self ._divider = Divider (fig , rect , ** kw )
212201 else :
213202 raise Exception ("" )
214203
215204 rect = self ._divider .get_position ()
216205
217- # reference axes
218- self ._column_refax = [None for _ in range (self ._ncols )]
219- self ._row_refax = [None for _ in range (self ._nrows )]
220- self ._refax = None
221-
206+ axes_array = np .full ((self ._nrows , self ._ncols ), None , dtype = object )
222207 for i in range (self .ngrids ):
223-
224208 col , row = self ._get_col_row (i )
225-
226209 if share_all :
227- sharex = self ._refax
228- sharey = self ._refax
210+ sharex = sharey = axes_array [0 , 0 ]
229211 else :
230- if share_x :
231- sharex = self ._column_refax [col ]
232- else :
233- sharex = None
234-
235- if share_y :
236- sharey = self ._row_refax [row ]
237- else :
238- sharey = None
239-
240- ax = axes_class (fig , rect , sharex = sharex , sharey = sharey )
241-
242- if share_all :
243- if self ._refax is None :
244- self ._refax = ax
245- else :
246- if sharex is None :
247- self ._column_refax [col ] = ax
248- if sharey is None :
249- self ._row_refax [row ] = ax
250-
251- self .axes_all .append (ax )
252- self .axes_column [col ].append (ax )
253- self .axes_row [row ].append (ax )
254-
212+ sharex = axes_array [0 , col ] if share_x else None
213+ sharey = axes_array [row , 0 ] if share_y else None
214+ axes_array [row , col ] = axes_class (
215+ fig , rect , sharex = sharex , sharey = sharey )
216+ self .axes_all = axes_array .ravel ().tolist ()
217+ self .axes_column = axes_array .T .tolist ()
218+ self .axes_row = axes_array .tolist ()
255219 self .axes_llc = self .axes_column [0 ][- 1 ]
256220
257221 self ._update_locators ()
@@ -272,27 +236,19 @@ def _init_axes_pad(self, axes_pad):
272236 def _update_locators (self ):
273237
274238 h = []
275-
276239 h_ax_pos = []
277-
278- for _ in self ._column_refax :
279- #if h: h.append(Size.Fixed(self._axes_pad))
240+ for _ in range (self ._ncols ):
280241 if h :
281242 h .append (self ._horiz_pad_size )
282-
283243 h_ax_pos .append (len (h ))
284-
285244 sz = Size .Scaled (1 )
286245 h .append (sz )
287246
288247 v = []
289-
290248 v_ax_pos = []
291- for _ in self ._row_refax [::- 1 ]:
292- #if v: v.append(Size.Fixed(self._axes_pad))
249+ for _ in range (self ._nrows ):
293250 if v :
294251 v .append (self ._vert_pad_size )
295-
296252 v_ax_pos .append (len (v ))
297253 sz = Size .Scaled (1 )
298254 v .append (sz )
@@ -512,79 +468,44 @@ def __init__(self, fig,
512468
513469 self ._init_axes_pad (axes_pad )
514470
515- if direction not in ["column" , "row" ]:
516- raise Exception ("" )
517-
471+ cbook ._check_in_list (["column" , "row" ], direction = direction )
518472 self ._direction = direction
519473
520474 if axes_class is None :
521475 axes_class = self ._defaultAxesClass
522476
523- self .axes_all = []
524- self .axes_column = [[] for _ in range (self ._ncols )]
525- self .axes_row = [[] for _ in range (self ._nrows )]
526-
527- self .cbar_axes = []
528-
529- h = []
530- v = []
531- if isinstance (rect , (str , Number )):
532- self ._divider = SubplotDivider (fig , rect , horizontal = h , vertical = v ,
533- aspect = aspect )
534- elif isinstance (rect , SubplotSpec ):
535- self ._divider = SubplotDivider (fig , rect , horizontal = h , vertical = v ,
536- aspect = aspect )
477+ kw = dict (horizontal = [], vertical = [], aspect = aspect )
478+ if isinstance (rect , (str , Number , SubplotSpec )):
479+ self ._divider = SubplotDivider (fig , rect , ** kw )
537480 elif len (rect ) == 3 :
538- kw = dict (horizontal = h , vertical = v , aspect = aspect )
539481 self ._divider = SubplotDivider (fig , * rect , ** kw )
540482 elif len (rect ) == 4 :
541- self ._divider = Divider (fig , rect , horizontal = h , vertical = v ,
542- aspect = aspect )
483+ self ._divider = Divider (fig , rect , ** kw )
543484 else :
544485 raise Exception ("" )
545486
546487 rect = self ._divider .get_position ()
547488
548- # reference axes
549- self ._column_refax = [None for _ in range (self ._ncols )]
550- self ._row_refax = [None for _ in range (self ._nrows )]
551- self ._refax = None
552-
489+ axes_array = np .full ((self ._nrows , self ._ncols ), None , dtype = object )
553490 for i in range (self .ngrids ):
554-
555491 col , row = self ._get_col_row (i )
556-
557492 if share_all :
558- if self .axes_all :
559- sharex = self .axes_all [0 ]
560- sharey = self .axes_all [0 ]
561- else :
562- sharex = None
563- sharey = None
493+ sharex = sharey = axes_array [0 , 0 ]
564494 else :
565- sharex = self ._column_refax [col ]
566- sharey = self ._row_refax [row ]
567-
568- ax = axes_class (fig , rect , sharex = sharex , sharey = sharey )
569-
570- self .axes_all .append (ax )
571- self .axes_column [col ].append (ax )
572- self .axes_row [row ].append (ax )
573-
574- if share_all :
575- if self ._refax is None :
576- self ._refax = ax
577- if sharex is None :
578- self ._column_refax [col ] = ax
579- if sharey is None :
580- self ._row_refax [row ] = ax
581-
582- cax = self ._defaultCbarAxesClass (fig , rect ,
583- orientation = self ._colorbar_location )
584- self .cbar_axes .append (cax )
585-
495+ sharex = axes_array [0 , col ]
496+ sharey = axes_array [row , 0 ]
497+ axes_array [row , col ] = axes_class (
498+ fig , rect , sharex = sharex , sharey = sharey )
499+ self .axes_all = axes_array .ravel ().tolist ()
500+ self .axes_column = axes_array .T .tolist ()
501+ self .axes_row = axes_array .tolist ()
586502 self .axes_llc = self .axes_column [0 ][- 1 ]
587503
504+ self .cbar_axes = [
505+ self ._defaultCbarAxesClass (fig , rect ,
506+ orientation = self ._colorbar_location )
507+ for _ in range (self .ngrids )]
508+
588509 self ._update_locators ()
589510
590511 if add_all :
0 commit comments