@@ -333,6 +333,7 @@ def __call__(self, axes, renderer):
333
333
renderer )
334
334
335
335
336
+ from matplotlib .gridspec import SubplotSpec , GridSpec
336
337
337
338
class SubplotDivider (Divider ):
338
339
"""
@@ -357,28 +358,48 @@ def __init__(self, fig, *args, **kwargs):
357
358
self .figure = fig
358
359
359
360
if len (args )== 1 :
360
- s = str (args [0 ])
361
- if len (s ) != 3 :
362
- raise ValueError ('Argument to subplot must be a 3 digits long' )
363
- rows , cols , num = map (int , s )
361
+ if isinstance (args [0 ], SubplotSpec ):
362
+ self ._subplotspec = args [0 ]
363
+ else :
364
+ try :
365
+ s = str (int (args [0 ]))
366
+ rows , cols , num = map (int , s )
367
+ except ValueError :
368
+ raise ValueError (
369
+ 'Single argument to subplot must be a 3-digit integer' )
370
+ self ._subplotspec = GridSpec (rows , cols )[num - 1 ]
371
+ # num - 1 for converting from MATLAB to python indexing
364
372
elif len (args )== 3 :
365
373
rows , cols , num = args
374
+ rows = int (rows )
375
+ cols = int (cols )
376
+ if isinstance (num , tuple ) and len (num ) == 2 :
377
+ num = [int (n ) for n in num ]
378
+ self ._subplotspec = GridSpec (rows , cols )[num [0 ]- 1 :num [1 ]]
379
+ else :
380
+ self ._subplotspec = GridSpec (rows , cols )[int (num )- 1 ]
381
+ # num - 1 for converting from MATLAB to python indexing
366
382
else :
367
- raise ValueError ( 'Illegal argument to subplot' )
383
+ raise ValueError ('Illegal argument(s) to subplot: %s' % (args ,))
384
+
368
385
386
+ # total = rows*cols
387
+ # num -= 1 # convert from matlab to python indexing
388
+ # # ie num in range(0,total)
389
+ # if num >= total:
390
+ # raise ValueError( 'Subplot number exceeds total subplots')
391
+ # self._rows = rows
392
+ # self._cols = cols
393
+ # self._num = num
369
394
370
- total = rows * cols
371
- num -= 1 # convert from matlab to python indexing
372
- # ie num in range(0,total)
373
- if num >= total :
374
- raise ValueError ( 'Subplot number exceeds total subplots' )
375
- self ._rows = rows
376
- self ._cols = cols
377
- self ._num = num
395
+ # self.update_params()
378
396
397
+
398
+ # sets self.fixbox
379
399
self .update_params ()
380
400
381
401
pos = self .figbox .bounds
402
+
382
403
horizontal = kwargs .pop ("horizontal" , [])
383
404
vertical = kwargs .pop ("vertical" , [])
384
405
aspect = kwargs .pop ("aspect" , None )
@@ -393,40 +414,67 @@ def __init__(self, fig, *args, **kwargs):
393
414
394
415
def get_position (self ):
395
416
"return the bounds of the subplot box"
396
- self .update_params ()
417
+
418
+ self .update_params () # update self.figbox
397
419
return self .figbox .bounds
398
420
399
421
422
+ # def update_params(self):
423
+ # 'update the subplot position from fig.subplotpars'
424
+
425
+ # rows = self._rows
426
+ # cols = self._cols
427
+ # num = self._num
428
+
429
+ # pars = self.figure.subplotpars
430
+ # left = pars.left
431
+ # right = pars.right
432
+ # bottom = pars.bottom
433
+ # top = pars.top
434
+ # wspace = pars.wspace
435
+ # hspace = pars.hspace
436
+ # totWidth = right-left
437
+ # totHeight = top-bottom
438
+
439
+ # figH = totHeight/(rows + hspace*(rows-1))
440
+ # sepH = hspace*figH
441
+
442
+ # figW = totWidth/(cols + wspace*(cols-1))
443
+ # sepW = wspace*figW
444
+
445
+ # rowNum, colNum = divmod(num, cols)
446
+
447
+ # figBottom = top - (rowNum+1)*figH - rowNum*sepH
448
+ # figLeft = left + colNum*(figW + sepW)
449
+
450
+ # self.figbox = mtransforms.Bbox.from_bounds(figLeft, figBottom,
451
+ # figW, figH)
452
+
400
453
def update_params (self ):
401
454
'update the subplot position from fig.subplotpars'
402
455
403
- rows = self ._rows
404
- cols = self ._cols
405
- num = self ._num
456
+ self .figbox = self .get_subplotspec ().get_position (self .figure )
406
457
407
- pars = self .figure .subplotpars
408
- left = pars .left
409
- right = pars .right
410
- bottom = pars .bottom
411
- top = pars .top
412
- wspace = pars .wspace
413
- hspace = pars .hspace
414
- totWidth = right - left
415
- totHeight = top - bottom
458
+ def get_geometry (self ):
459
+ 'get the subplot geometry, eg 2,2,3'
460
+ rows , cols , num1 , num2 = self .get_subplotspec ().get_geometry ()
461
+ return rows , cols , num1 + 1 # for compatibility
416
462
417
- figH = totHeight / (rows + hspace * (rows - 1 ))
418
- sepH = hspace * figH
419
-
420
- figW = totWidth / (cols + wspace * (cols - 1 ))
421
- sepW = wspace * figW
463
+ # COVERAGE NOTE: Never used internally or from examples
464
+ def change_geometry (self , numrows , numcols , num ):
465
+ 'change subplot geometry, eg. from 1,1,1 to 2,2,3'
466
+ self ._subplotspec = GridSpec (numrows , numcols )[num - 1 ]
467
+ self .update_params ()
468
+ self .set_position (self .figbox )
422
469
423
- rowNum , colNum = divmod (num , cols )
470
+ def get_subplotspec (self ):
471
+ 'get the SubplotSpec instance'
472
+ return self ._subplotspec
424
473
425
- figBottom = top - (rowNum + 1 )* figH - rowNum * sepH
426
- figLeft = left + colNum * (figW + sepW )
474
+ def set_subplotspec (self , subplotspec ):
475
+ 'set the SubplotSpec instance'
476
+ self ._subplotspec = subplotspec
427
477
428
- self .figbox = mtransforms .Bbox .from_bounds (figLeft , figBottom ,
429
- figW , figH )
430
478
431
479
432
480
class AxesDivider (Divider ):
0 commit comments