diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index 6dc6e0361529..0d630c36ae2c 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -1422,17 +1422,20 @@ def tight_layout(self, renderer=None, pad=1.08, h_pad=None, w_pad=None, rect=Non labels) will fit into. Default is (0, 0, 1, 1). """ - from tight_layout import get_renderer, get_tight_layout_figure + from tight_layout import (get_renderer, get_tight_layout_figure, + get_subplotspec_list) - subplot_axes = [ax for ax in self.axes if isinstance(ax, SubplotBase)] - if len(subplot_axes) < len(self.axes): - warnings.warn("tight_layout can only process Axes that descend " - "from SubplotBase; results might be incorrect.") + subplotspec_list = get_subplotspec_list(self.axes) + if None in subplotspec_list: + warnings.warn("This figure includes Axes that are not " + "compatible with tight_layout, so its " + "results might be incorrect.") if renderer is None: renderer = get_renderer(self) - kwargs = get_tight_layout_figure(self, subplot_axes, renderer, + kwargs = get_tight_layout_figure(self, self.axes, subplotspec_list, + renderer, pad=pad, h_pad=h_pad, w_pad=w_pad, rect=rect) diff --git a/lib/matplotlib/tight_layout.py b/lib/matplotlib/tight_layout.py index 0398a8c1a283..e249a632f80f 100644 --- a/lib/matplotlib/tight_layout.py +++ b/lib/matplotlib/tight_layout.py @@ -209,7 +209,33 @@ def get_renderer(fig): return renderer -def get_tight_layout_figure(fig, axes_list, renderer, +def get_subplotspec_list(axes_list): + """ + Return a list of subplotspec from the given list of axes. For an + instance of axes that does not support subplotspec, None is + inserted in the list. + + """ + subplotspec_list = [] + for ax in axes_list: + axes_or_locator = ax.get_axes_locator() + if axes_or_locator is None: + axes_or_locator = ax + + if hasattr(axes_or_locator, "get_subplotspec"): + subplotspec = axes_or_locator.get_subplotspec() + subplotspec = subplotspec.get_topmost_subplotspec() + if subplotspec.get_gridspec().locally_modified_subplot_params(): + subplotspec = None + else: + subplotspec = None + + subplotspec_list.append(subplotspec) + + return subplotspec_list + + +def get_tight_layout_figure(fig, axes_list, subplotspec_list, renderer, pad=1.08, h_pad=None, w_pad=None, rect=None): """ Return subplot parameters for tight-layouted-figure with specified @@ -221,6 +247,9 @@ def get_tight_layout_figure(fig, axes_list, renderer, *axes_list* : a list of axes + *subplotspec_list* : a list of subplotspec associated with each + axes in axes_list + *renderer* : renderer instance *pad* : float @@ -238,27 +267,20 @@ def get_tight_layout_figure(fig, axes_list, renderer, """ - subplotspec_list = [] subplot_list = [] nrows_list = [] ncols_list = [] ax_bbox_list = [] - subplot_dict = {} # for axes_grid1, multiple axes can share - # same subplot_interface. Thus we need to - # join them together. + subplot_dict = {} # multiple axes can share + # same subplot_interface (e.g, axes_grid1). Thus + # we need to join them together. - for ax in axes_list: - locator = ax.get_axes_locator() - if hasattr(locator, "get_subplotspec"): - subplotspec = locator.get_subplotspec().get_topmost_subplotspec() - elif hasattr(ax, "get_subplotspec"): - subplotspec = ax.get_subplotspec().get_topmost_subplotspec() - else: - continue + subplotspec_list2 = [] - if (subplotspec is None) or \ - subplotspec.get_gridspec().locally_modified_subplot_params(): + for ax, subplotspec in zip(axes_list, + subplotspec_list): + if subplotspec is None: continue subplots = subplot_dict.setdefault(subplotspec, []) @@ -267,7 +289,7 @@ def get_tight_layout_figure(fig, axes_list, renderer, myrows, mycols, _, _ = subplotspec.get_geometry() nrows_list.append(myrows) ncols_list.append(mycols) - subplotspec_list.append(subplotspec) + subplotspec_list2.append(subplotspec) subplot_list.append(subplots) ax_bbox_list.append(subplotspec.get_position(fig)) @@ -277,7 +299,7 @@ def get_tight_layout_figure(fig, axes_list, renderer, max_ncols = max(ncols_list) num1num2_list = [] - for subplotspec in subplotspec_list: + for subplotspec in subplotspec_list2: rows, cols, num1, num2 = subplotspec.get_geometry() div_row, mod_row = divmod(max_nrows, rows) div_col, mod_col = divmod(max_ncols, cols)