diff --git a/lib/matplotlib/gridspec.py b/lib/matplotlib/gridspec.py index b2f1b5d8ff2e..9cd9b81f9aca 100644 --- a/lib/matplotlib/gridspec.py +++ b/lib/matplotlib/gridspec.py @@ -309,10 +309,10 @@ def subplots(self, *, sharex=False, sharey=False, squeeze=True, if squeeze: # Discarding unneeded dimensions that equal 1. If we only have one # subplot, just return it instead of a 1-element array. - return axarr.item() if axarr.size == 1 else axarr.squeeze() + return axarr.item() if axarr.size == 1 else AxesArray(axarr.squeeze()) else: # Returned axis array will be always 2-d, even if nrows=ncols=1. - return axarr + return AxesArray(axarr) class GridSpec(GridSpecBase): @@ -736,3 +736,70 @@ def subgridspec(self, nrows, ncols, **kwargs): fig.add_subplot(gssub[0, i]) """ return GridSpecFromSubplotSpec(nrows, ncols, self, **kwargs) + + +class AxesArray: + """ + A container for a 1D or 2D grid arrangement of Axes. + + This is used as the return type of ``subplots()``. + + Formerly, ``subplots()`` returned a numpy array of Axes. For a transition period, + AxesArray will act like a numpy array, but all functions and properties that + are not listed explicitly below are deprecated. + """ + def __init__(self, array): + self._array = array + + @staticmethod + def _ensure_wrapped(ax_or_axs): + if isinstance(ax_or_axs, np.ndarray): + return AxesArray(ax_or_axs) + else: + return ax_or_axs + + def __getitem__(self, index): + return self._ensure_wrapped(self._array[index]) + + @property + def __array_struct__(self): + return self._array.__array_struct__ + + @property + def ndim(self): + return self._array.ndim + + @property + def shape(self): + return self._array.shape + + @property + def size(self): + return self._array.size + + @property + def flat(self): + return self._array.flat + + @property + def flatten(self): + """[Disouraged] Use ``axs.flat`` instead.""" + return self._array.flatten + + @property + def ravel(self): + """[Disouraged] Use ``axs.flat`` instead.""" + return self._array.ravel + + @property + def __iter__(self): + return iter([self._ensure_wrapped(row) for row in self._array]) + + def __getattr__(self, item): + # forward all other attributes to the underlying array + # (this is a temporary measure to allow a smooth transition) + attr = getattr(self._array, item) + _api.warn_deprecated("3.9", + message=f"Using {item!r} on AxesArray is deprecated.", + pending=True) + return attr diff --git a/lib/matplotlib/tests/test_subplots.py b/lib/matplotlib/tests/test_subplots.py index cf5f4b902e24..d067edd2543e 100644 --- a/lib/matplotlib/tests/test_subplots.py +++ b/lib/matplotlib/tests/test_subplots.py @@ -3,6 +3,7 @@ import numpy as np import pytest +import matplotlib as mpl from matplotlib.axes import Axes, SubplotBase import matplotlib.pyplot as plt from matplotlib.testing.decorators import check_figures_equal, image_comparison @@ -283,3 +284,80 @@ def test_old_subplot_compat(): assert not isinstance(fig.add_axes(rect=[0, 0, 1, 1]), SubplotBase) with pytest.raises(TypeError): Axes(fig, [0, 0, 1, 1], rect=[0, 0, 1, 1]) + + +class TestAxesArray: + @staticmethod + def contain_same_axes(axs1, axs2): + return all(ax1 is ax2 for ax1, ax2 in zip(axs1.flat, axs2.flat)) + + def test_1d(self): + axs = plt.figure().subplots(1, 3) + # shape and size + assert axs.shape == (3,) + assert axs.size == 3 + assert axs.ndim == 1 + # flat + assert all(isinstance(ax, Axes) for ax in axs.flat) + assert len(set(id(ax) for ax in axs.flat)) == 3 + # flatten + assert all(isinstance(ax, Axes) for ax in axs.flatten()) + assert len(set(id(ax) for ax in axs.flatten())) == 3 + # ravel + assert all(isinstance(ax, Axes) for ax in axs.ravel()) + assert len(set(id(ax) for ax in axs.ravel())) == 3 + # single index + assert all(isinstance(axs[i], Axes) for i in range(axs.size)) + assert len(set(axs[i] for i in range(axs.size))) == 3 +# iteration + assert all(ax1 is ax2 for ax1, ax2 in zip(axs, axs.flat)) + + def test_1d_no_squeeze(self): + axs = plt.figure().subplots(1, 3, squeeze=False) + # shape and size + assert axs.shape == (1, 3) + assert axs.size == 3 + assert axs.ndim == 2 + # flat + assert all(isinstance(ax, Axes) for ax in axs.flat) + assert len(set(id(ax) for ax in axs.flat)) == 3 + # 2d indexing + assert axs[0, 0] is axs.flat[0] + assert axs[0, 2] is axs.flat[-1] + # single index + axs_type = type(axs) + assert type(axs[0]) is axs_type + assert axs[0].shape == (3,) + # iteration + assert all(self.contain_same_axes(axi, axs[i]) for i, axi in enumerate(axs)) + + def test_2d(self): + axs = plt.figure().subplots(2, 3) + # shape and size + assert axs.shape == (2, 3) + assert axs.size == 6 + assert axs.ndim == 2 + # flat + assert all(isinstance(ax, Axes) for ax in axs.flat) + assert len(set(id(ax) for ax in axs.flat)) == 6 + # flatten + assert all(isinstance(ax, Axes) for ax in axs.flatten()) + assert len(set(id(ax) for ax in axs.flatten())) == 6 + # ravel + assert all(isinstance(ax, Axes) for ax in axs.ravel()) + assert len(set(id(ax) for ax in axs.ravel())) == 6 + # 2d indexing + assert axs[0, 0] is axs.flat[0] + assert axs[1, 2] is axs.flat[-1] + # single index + axs_type = type(axs) + assert type(axs[0]) is axs_type + assert axs[0].shape == (3,) + # iteration + assert all(self.contain_same_axes(axi, axs[i]) for i, axi in enumerate(axs)) + + def test_deprecated(self): + axs = plt.figure().subplots(2, 2) + with pytest.warns(PendingDeprecationWarning, + match="Using 'diagonal' on AxesArray"): + axs.diagonal()