1414import matplotlib .lines as mlines
1515import matplotlib .ticker as ticker
1616
17- from matplotlib .gridspec import SubplotSpec , GridSpec
17+ from matplotlib .gridspec import SubplotSpec
1818
1919from .axes_divider import Size , SubplotDivider , LocatableAxes , Divider
2020
2121
22+ def _extend_axes_pad (value ):
23+ # Check whether a list/tuple/array or scalar has been passed
24+ ret = value
25+ if not hasattr (ret , "__getitem__" ):
26+ ret = (value , value )
27+ return ret
28+
2229def _tick_only (ax , bottom_on , left_on ):
2330 bottom_off = not bottom_on
2431 left_off = not left_on
@@ -200,6 +207,8 @@ def __init__(self, fig,
200207 ================ ======== =========================================
201208 direction "row" [ "row" | "column" ]
202209 axes_pad 0.02 float| pad between axes given in inches
210+ or tuple-like of floats,
211+ (horizontal padding, vertical padding)
203212 add_all True [ True | False ]
204213 share_all False [ True | False ]
205214 share_x True [ True | False ]
@@ -238,8 +247,8 @@ def __init__(self, fig,
238247 axes_class , axes_class_args = axes_class
239248
240249 self .axes_all = []
241- self .axes_column = [[] for i in range (self ._ncols )]
242- self .axes_row = [[] for i in range (self ._nrows )]
250+ self .axes_column = [[] for _ in range (self ._ncols )]
251+ self .axes_row = [[] for _ in range (self ._nrows )]
243252
244253 h = []
245254 v = []
@@ -261,8 +270,8 @@ def __init__(self, fig,
261270 rect = self ._divider .get_position ()
262271
263272 # reference axes
264- self ._column_refax = [None for i in range (self ._ncols )]
265- self ._row_refax = [None for i in range (self ._nrows )]
273+ self ._column_refax = [None for _ in range (self ._ncols )]
274+ self ._row_refax = [None for _ in range (self ._nrows )]
266275 self ._refax = None
267276
268277 for i in range (self .ngrids ):
@@ -310,19 +319,19 @@ def __init__(self, fig,
310319 self .set_label_mode (label_mode )
311320
312321 def _init_axes_pad (self , axes_pad ):
322+ axes_pad = _extend_axes_pad (axes_pad )
313323 self ._axes_pad = axes_pad
314324
315- self ._horiz_pad_size = Size .Fixed (axes_pad )
316- self ._vert_pad_size = Size .Fixed (axes_pad )
325+ self ._horiz_pad_size = Size .Fixed (axes_pad [ 0 ] )
326+ self ._vert_pad_size = Size .Fixed (axes_pad [ 1 ] )
317327
318328 def _update_locators (self ):
319329
320330 h = []
321331
322332 h_ax_pos = []
323- h_cb_pos = []
324333
325- for ax in self ._column_refax :
334+ for _ in self ._column_refax :
326335 #if h: h.append(Size.Fixed(self._axes_pad))
327336 if h :
328337 h .append (self ._horiz_pad_size )
@@ -335,8 +344,7 @@ def _update_locators(self):
335344 v = []
336345
337346 v_ax_pos = []
338- v_cb_pos = []
339- for ax in self ._row_refax [::- 1 ]:
347+ for _ in self ._row_refax [::- 1 ]:
340348 #if v: v.append(Size.Fixed(self._axes_pad))
341349 if v :
342350 v .append (self ._vert_pad_size )
@@ -362,6 +370,10 @@ def _get_col_row(self, n):
362370
363371 return col , row
364372
373+ # Good to propagate __len__ if we have __getitem__
374+ def __len__ (self ):
375+ return len (self .axes_all )
376+
365377 def __getitem__ (self , i ):
366378 return self .axes_all [i ]
367379
@@ -376,11 +388,19 @@ def set_axes_pad(self, axes_pad):
376388 "set axes_pad"
377389 self ._axes_pad = axes_pad
378390
379- self ._horiz_pad_size .fixed_size = axes_pad
380- self ._vert_pad_size .fixed_size = axes_pad
391+ # These two lines actually differ from ones in _init_axes_pad
392+ self ._horiz_pad_size .fixed_size = axes_pad [0 ]
393+ self ._vert_pad_size .fixed_size = axes_pad [1 ]
381394
382395 def get_axes_pad (self ):
383- "get axes_pad"
396+ """
397+ get axes_pad
398+
399+ Returns
400+ -------
401+ tuple
402+ Padding in inches, (horizontal pad, vertical pad)
403+ """
384404 return self ._axes_pad
385405
386406 def set_aspect (self , aspect ):
@@ -484,6 +504,8 @@ def __init__(self, fig,
484504 ================ ======== =========================================
485505 direction "row" [ "row" | "column" ]
486506 axes_pad 0.02 float| pad between axes given in inches
507+ or tuple-like of floats,
508+ (horizontal padding, vertical padding)
487509 add_all True [ True | False ]
488510 share_all False [ True | False ]
489511 aspect True [ True | False ]
@@ -510,12 +532,17 @@ def __init__(self, fig,
510532
511533 self .ngrids = ngrids
512534
535+ axes_pad = _extend_axes_pad (axes_pad )
513536 self ._axes_pad = axes_pad
514537
515538 self ._colorbar_mode = cbar_mode
516539 self ._colorbar_location = cbar_location
517540 if cbar_pad is None :
518- self ._colorbar_pad = axes_pad
541+ # horizontal or vertical arrangement?
542+ if cbar_location in ("left" , "right" ):
543+ self ._colorbar_pad = axes_pad [0 ]
544+ else :
545+ self ._colorbar_pad = axes_pad [1 ]
519546 else :
520547 self ._colorbar_pad = cbar_pad
521548
@@ -538,8 +565,8 @@ def __init__(self, fig,
538565 axes_class , axes_class_args = axes_class
539566
540567 self .axes_all = []
541- self .axes_column = [[] for i in range (self ._ncols )]
542- self .axes_row = [[] for i in range (self ._nrows )]
568+ self .axes_column = [[] for _ in range (self ._ncols )]
569+ self .axes_row = [[] for _ in range (self ._nrows )]
543570
544571 self .cbar_axes = []
545572
@@ -563,8 +590,8 @@ def __init__(self, fig,
563590 rect = self ._divider .get_position ()
564591
565592 # reference axes
566- self ._column_refax = [None for i in range (self ._ncols )]
567- self ._row_refax = [None for i in range (self ._nrows )]
593+ self ._column_refax = [None for _ in range (self ._ncols )]
594+ self ._row_refax = [None for _ in range (self ._nrows )]
568595 self ._refax = None
569596
570597 for i in range (self .ngrids ):
@@ -678,7 +705,7 @@ def _update_locators(self):
678705 v_cb_pos = []
679706 for row , ax in enumerate (self .axes_column [0 ][::- 1 ]):
680707 if v :
681- v .append (self ._horiz_pad_size ) # Size.Fixed(self._axes_pad))
708+ v .append (self ._vert_pad_size ) # Size.Fixed(self._axes_pad))
682709
683710 if ax :
684711 sz = Size .AxesY (ax , aspect = "axes" , ref_ax = self .axes_all [0 ])
@@ -786,7 +813,7 @@ def _update_locators(self):
786813 F .subplots_adjust (left = 0.15 , right = 0.9 )
787814
788815 grid = Grid (F , 111 , # similar to subplot(111)
789- nrows_ncols = (2 , 2 ),
816+ nrows_ncols = (2 , 2 ),
790817 direction = "row" ,
791818 axes_pad = 0.05 ,
792819 add_all = True ,
@@ -802,12 +829,12 @@ def _update_locators(self):
802829 F .subplots_adjust (left = 0.05 , right = 0.98 )
803830
804831 grid = ImageGrid (F , 131 , # similar to subplot(111)
805- nrows_ncols = (2 , 2 ),
806- direction = "row" ,
807- axes_pad = 0.05 ,
808- add_all = True ,
809- label_mode = "1" ,
810- )
832+ nrows_ncols = (2 , 2 ),
833+ direction = "row" ,
834+ axes_pad = 0.05 ,
835+ add_all = True ,
836+ label_mode = "1" ,
837+ )
811838
812839 Z , extent = get_demo_image ()
813840 plt .ioff ()
@@ -821,14 +848,14 @@ def _update_locators(self):
821848 plt .ion ()
822849
823850 grid = ImageGrid (F , 132 , # similar to subplot(111)
824- nrows_ncols = (2 , 2 ),
825- direction = "row" ,
826- axes_pad = 0.0 ,
827- add_all = True ,
828- share_all = True ,
829- label_mode = "1" ,
830- cbar_mode = "single" ,
831- )
851+ nrows_ncols = (2 , 2 ),
852+ direction = "row" ,
853+ axes_pad = 0.0 ,
854+ add_all = True ,
855+ share_all = True ,
856+ label_mode = "1" ,
857+ cbar_mode = "single" ,
858+ )
832859
833860 Z , extent = get_demo_image ()
834861 plt .ioff ()
@@ -844,17 +871,17 @@ def _update_locators(self):
844871 plt .ion ()
845872
846873 grid = ImageGrid (F , 133 , # similar to subplot(122)
847- nrows_ncols = (2 , 2 ),
848- direction = "row" ,
849- axes_pad = 0.1 ,
850- add_all = True ,
851- label_mode = "1" ,
852- share_all = True ,
853- cbar_location = "top" ,
854- cbar_mode = "each" ,
855- cbar_size = "7%" ,
856- cbar_pad = "2%" ,
857- )
874+ nrows_ncols = (2 , 2 ),
875+ direction = "row" ,
876+ axes_pad = 0.1 ,
877+ add_all = True ,
878+ label_mode = "1" ,
879+ share_all = True ,
880+ cbar_location = "top" ,
881+ cbar_mode = "each" ,
882+ cbar_size = "7%" ,
883+ cbar_pad = "2%" ,
884+ )
858885 plt .ioff ()
859886 for i in range (4 ):
860887 im = grid [i ].imshow (Z , extent = extent , interpolation = "nearest" )
0 commit comments