@@ -333,6 +333,7 @@ def __call__(self, axes, renderer):
333333 renderer )
334334
335335
336+ from matplotlib .gridspec import SubplotSpec , GridSpec
336337
337338class SubplotDivider (Divider ):
338339 """
@@ -357,28 +358,48 @@ def __init__(self, fig, *args, **kwargs):
357358 self .figure = fig
358359
359360 if len (args )== 1 :
360- s = str (args [0 ])
361- if len (s ) != 3 :
362- raise ValueError ('Argument to subplot must be a 3 digits long' )
363- rows , cols , num = map (int , s )
361+ if isinstance (args [0 ], SubplotSpec ):
362+ self ._subplotspec = args [0 ]
363+ else :
364+ try :
365+ s = str (int (args [0 ]))
366+ rows , cols , num = map (int , s )
367+ except ValueError :
368+ raise ValueError (
369+ 'Single argument to subplot must be a 3-digit integer' )
370+ self ._subplotspec = GridSpec (rows , cols )[num - 1 ]
371+ # num - 1 for converting from MATLAB to python indexing
364372 elif len (args )== 3 :
365373 rows , cols , num = args
374+ rows = int (rows )
375+ cols = int (cols )
376+ if isinstance (num , tuple ) and len (num ) == 2 :
377+ num = [int (n ) for n in num ]
378+ self ._subplotspec = GridSpec (rows , cols )[num [0 ]- 1 :num [1 ]]
379+ else :
380+ self ._subplotspec = GridSpec (rows , cols )[int (num )- 1 ]
381+ # num - 1 for converting from MATLAB to python indexing
366382 else :
367- raise ValueError ( 'Illegal argument to subplot' )
383+ raise ValueError ('Illegal argument(s) to subplot: %s' % (args ,))
384+
368385
386+ # total = rows*cols
387+ # num -= 1 # convert from matlab to python indexing
388+ # # ie num in range(0,total)
389+ # if num >= total:
390+ # raise ValueError( 'Subplot number exceeds total subplots')
391+ # self._rows = rows
392+ # self._cols = cols
393+ # self._num = num
369394
370- total = rows * cols
371- num -= 1 # convert from matlab to python indexing
372- # ie num in range(0,total)
373- if num >= total :
374- raise ValueError ( 'Subplot number exceeds total subplots' )
375- self ._rows = rows
376- self ._cols = cols
377- self ._num = num
395+ # self.update_params()
378396
397+
398+ # sets self.fixbox
379399 self .update_params ()
380400
381401 pos = self .figbox .bounds
402+
382403 horizontal = kwargs .pop ("horizontal" , [])
383404 vertical = kwargs .pop ("vertical" , [])
384405 aspect = kwargs .pop ("aspect" , None )
@@ -393,40 +414,67 @@ def __init__(self, fig, *args, **kwargs):
393414
394415 def get_position (self ):
395416 "return the bounds of the subplot box"
396- self .update_params ()
417+
418+ self .update_params () # update self.figbox
397419 return self .figbox .bounds
398420
399421
422+ # def update_params(self):
423+ # 'update the subplot position from fig.subplotpars'
424+
425+ # rows = self._rows
426+ # cols = self._cols
427+ # num = self._num
428+
429+ # pars = self.figure.subplotpars
430+ # left = pars.left
431+ # right = pars.right
432+ # bottom = pars.bottom
433+ # top = pars.top
434+ # wspace = pars.wspace
435+ # hspace = pars.hspace
436+ # totWidth = right-left
437+ # totHeight = top-bottom
438+
439+ # figH = totHeight/(rows + hspace*(rows-1))
440+ # sepH = hspace*figH
441+
442+ # figW = totWidth/(cols + wspace*(cols-1))
443+ # sepW = wspace*figW
444+
445+ # rowNum, colNum = divmod(num, cols)
446+
447+ # figBottom = top - (rowNum+1)*figH - rowNum*sepH
448+ # figLeft = left + colNum*(figW + sepW)
449+
450+ # self.figbox = mtransforms.Bbox.from_bounds(figLeft, figBottom,
451+ # figW, figH)
452+
400453 def update_params (self ):
401454 'update the subplot position from fig.subplotpars'
402455
403- rows = self ._rows
404- cols = self ._cols
405- num = self ._num
456+ self .figbox = self .get_subplotspec ().get_position (self .figure )
406457
407- pars = self .figure .subplotpars
408- left = pars .left
409- right = pars .right
410- bottom = pars .bottom
411- top = pars .top
412- wspace = pars .wspace
413- hspace = pars .hspace
414- totWidth = right - left
415- totHeight = top - bottom
458+ def get_geometry (self ):
459+ 'get the subplot geometry, eg 2,2,3'
460+ rows , cols , num1 , num2 = self .get_subplotspec ().get_geometry ()
461+ return rows , cols , num1 + 1 # for compatibility
416462
417- figH = totHeight / (rows + hspace * (rows - 1 ))
418- sepH = hspace * figH
419-
420- figW = totWidth / (cols + wspace * (cols - 1 ))
421- sepW = wspace * figW
463+ # COVERAGE NOTE: Never used internally or from examples
464+ def change_geometry (self , numrows , numcols , num ):
465+ 'change subplot geometry, eg. from 1,1,1 to 2,2,3'
466+ self ._subplotspec = GridSpec (numrows , numcols )[num - 1 ]
467+ self .update_params ()
468+ self .set_position (self .figbox )
422469
423- rowNum , colNum = divmod (num , cols )
470+ def get_subplotspec (self ):
471+ 'get the SubplotSpec instance'
472+ return self ._subplotspec
424473
425- figBottom = top - (rowNum + 1 )* figH - rowNum * sepH
426- figLeft = left + colNum * (figW + sepW )
474+ def set_subplotspec (self , subplotspec ):
475+ 'set the SubplotSpec instance'
476+ self ._subplotspec = subplotspec
427477
428- self .figbox = mtransforms .Bbox .from_bounds (figLeft , figBottom ,
429- figW , figH )
430478
431479
432480class AxesDivider (Divider ):
0 commit comments