Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b0078aa

Browse files
committed
subplots() returns AxesArray
subplots() used to return a numpy array of Axes, which has some drawbacks. The numpy array is mainly used as a 2D container structure that allows 2D indexing. Apart from that, it's not particularly well suited: - Many of the numpy functions do not work on Axes. - Some functions work, but have awkward semantics; e.g. len() gives the number of rows. - We can't add our own functionality. AxesArray introduces a facade to the underlying array to allow us to customize the API. For the beginning, the API is 100% compatible with the previous numpy array behavior, but we deprecate everything except for a few reasonable methods.
1 parent ffd3b12 commit b0078aa

File tree

2 files changed

+120
-2
lines changed

2 files changed

+120
-2
lines changed

lib/matplotlib/gridspec.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import copy
1414
import logging
1515
from numbers import Integral
16+
import warnings
1617

1718
import numpy as np
1819

@@ -309,10 +310,10 @@ def subplots(self, *, sharex=False, sharey=False, squeeze=True,
309310
if squeeze:
310311
# Discarding unneeded dimensions that equal 1. If we only have one
311312
# subplot, just return it instead of a 1-element array.
312-
return axarr.item() if axarr.size == 1 else axarr.squeeze()
313+
return axarr.item() if axarr.size == 1 else AxesArray(axarr.squeeze())
313314
else:
314315
# Returned axis array will be always 2-d, even if nrows=ncols=1.
315-
return axarr
316+
return AxesArray(axarr)
316317

317318

318319
class GridSpec(GridSpecBase):
@@ -734,3 +735,54 @@ def subgridspec(self, nrows, ncols, **kwargs):
734735
fig.add_subplot(gssub[0, i])
735736
"""
736737
return GridSpecFromSubplotSpec(nrows, ncols, self, **kwargs)
738+
739+
740+
class AxesArray:
741+
"""
742+
A container for a 1D or 2D grid arrangement of Axes.
743+
744+
This is used as the return type of ``subplots()``.
745+
746+
Formerly, ``subplots()`` returned a numpy array of Axes. For a transition period,
747+
AxesArray will act like a numpy array, but all functions and properties that
748+
are not listed explicitly below are deprecated.
749+
"""
750+
def __init__(self, array):
751+
self._array = array
752+
753+
@staticmethod
754+
def _ensure_wrapped(ax_or_axs):
755+
if isinstance(ax_or_axs, np.ndarray):
756+
return AxesArray(ax_or_axs)
757+
else:
758+
return ax_or_axs
759+
760+
def __getitem__(self, index):
761+
return self._ensure_wrapped(self._array[index])
762+
763+
@property
764+
def ndim(self):
765+
return self._array.ndim
766+
767+
@property
768+
def shape(self):
769+
return self._array.shape
770+
771+
@property
772+
def size(self):
773+
return self._array.size
774+
775+
@property
776+
def flat(self):
777+
return self._array.flat
778+
779+
def __iter__(self):
780+
return iter([self._ensure_wrapped(row) for row in self._array])
781+
782+
def __getattr__(self, item):
783+
# forward all other attributes to the underlying array
784+
# (this is a temporary measure to allow a smooth transition)
785+
attr = getattr(self._array, item)
786+
_api.warn_deprecated("3.8",
787+
message=f"Using {item!r} on AxesArray is deprecated.")
788+
return attr

lib/matplotlib/tests/test_subplots.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pytest
55

6+
import matplotlib as mpl
67
from matplotlib.axes import Axes, SubplotBase
78
import matplotlib.pyplot as plt
89
from matplotlib.testing.decorators import check_figures_equal, image_comparison
@@ -260,3 +261,68 @@ def test_old_subplot_compat():
260261
assert not isinstance(fig.add_axes(rect=[0, 0, 1, 1]), SubplotBase)
261262
with pytest.raises(TypeError):
262263
Axes(fig, [0, 0, 1, 1], rect=[0, 0, 1, 1])
264+
265+
266+
class TestAxesArray:
267+
@staticmethod
268+
def contain_same_axes(axs1, axs2):
269+
return all(ax1 is ax2 for ax1, ax2 in zip(axs1.flat, axs2.flat))
270+
271+
def test_1d(self):
272+
axs = plt.figure().subplots(1, 3)
273+
# shape and size
274+
assert axs.shape == (3,)
275+
assert axs.size == 3
276+
assert axs.ndim == 1
277+
# flat
278+
assert all(isinstance(ax, Axes) for ax in axs.flat)
279+
assert len(set(id(ax) for ax in axs.flat)) == 3
280+
# single index
281+
assert all(isinstance(axs[i], Axes) for i in range(axs.size))
282+
assert len(set(axs[i] for i in range(axs.size))) == 3
283+
# iteration
284+
assert all(ax1 is ax2 for ax1, ax2 in zip(axs, axs.flat))
285+
286+
def test_1d_no_squeeze(self):
287+
axs = plt.figure().subplots(1, 3, squeeze=False)
288+
# shape and size
289+
assert axs.shape == (1, 3)
290+
assert axs.size == 3
291+
assert axs.ndim == 2
292+
# flat
293+
assert all(isinstance(ax, Axes) for ax in axs.flat)
294+
assert len(set(id(ax) for ax in axs.flat)) == 3
295+
# 2d indexing
296+
assert axs[0, 0] is axs.flat[0]
297+
assert axs[0, 2] is axs.flat[-1]
298+
# single index
299+
axs_type = type(axs)
300+
assert type(axs[0]) is axs_type
301+
assert axs[0].shape == (3,)
302+
# iteration
303+
assert all(self.contain_same_axes(axi, axs[i]) for i, axi in enumerate(axs))
304+
305+
def test_2d(self):
306+
axs = plt.figure().subplots(2, 3)
307+
# shape and size
308+
assert axs.shape == (2, 3)
309+
assert axs.size == 6
310+
assert axs.ndim == 2
311+
# flat
312+
assert all(isinstance(ax, Axes) for ax in axs.flat)
313+
assert len(set(id(ax) for ax in axs.flat)) == 6
314+
# 2d indexing
315+
assert axs[0, 0] is axs.flat[0]
316+
assert axs[1, 2] is axs.flat[-1]
317+
# single index
318+
axs_type = type(axs)
319+
assert type(axs[0]) is axs_type
320+
assert axs[0].shape == (3,)
321+
# iteration
322+
assert all(self.contain_same_axes(axi, axs[i]) for i, axi in enumerate(axs))
323+
324+
def test_deprecated(self):
325+
axs = plt.figure().subplots(2, 3)
326+
with pytest.warns(mpl.MatplotlibDeprecationWarning,
327+
match="Using 'flatten' on AxesArray"):
328+
axs.flatten()

0 commit comments

Comments
 (0)