From 7b510447aa47bca80f727e98bbb0a3bdf0ee39f0 Mon Sep 17 00:00:00 2001 From: Oscar Gustafsson Date: Thu, 24 Feb 2022 23:35:50 +0100 Subject: [PATCH 1/2] Improve pandas and xarray conversion --- lib/matplotlib/axes/_axes.py | 4 ++-- lib/matplotlib/cbook/__init__.py | 33 +++++++++++++++++++----------- lib/matplotlib/dates.py | 5 ++--- lib/matplotlib/testing/conftest.py | 7 +++++++ lib/matplotlib/tests/conftest.py | 2 +- lib/matplotlib/tests/test_cbook.py | 25 +++++++++++++++++++++- lib/matplotlib/units.py | 5 +++-- 7 files changed, 60 insertions(+), 21 deletions(-) diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py index 67779775b8a5..6c1c3285eec5 100644 --- a/lib/matplotlib/axes/_axes.py +++ b/lib/matplotlib/axes/_axes.py @@ -7907,8 +7907,8 @@ def violinplot(self, dataset, positions=None, vert=True, widths=0.5, """ def _kde_method(X, coords): - if hasattr(X, 'values'): # support pandas.Series - X = X.values + # Unpack in case of e.g. Pandas or xarray object + X = cbook._unpack_to_numpy(X) # fallback gracefully if the vector contains only one value if np.all(X[0] == X): return (X[0] == coords).astype(float) diff --git a/lib/matplotlib/cbook/__init__.py b/lib/matplotlib/cbook/__init__.py index 47ff088a1cd4..228df1bf637c 100644 --- a/lib/matplotlib/cbook/__init__.py +++ b/lib/matplotlib/cbook/__init__.py @@ -1311,9 +1311,8 @@ def _to_unmasked_float_array(x): def _check_1d(x): """Convert scalars to 1D arrays; pass-through arrays as is.""" - if hasattr(x, 'to_numpy'): - # if we are given an object that creates a numpy, we should use it... - x = x.to_numpy() + # Unpack in case of e.g. Pandas or xarray object + x = _unpack_to_numpy(x) if not hasattr(x, 'shape') or len(x.shape) < 1: return np.atleast_1d(x) else: @@ -1332,15 +1331,8 @@ def _reshape_2D(X, name): *name* is used to generate the error message for invalid inputs. """ - # unpack if we have a values or to_numpy method. - try: - X = X.to_numpy() - except AttributeError: - try: - if isinstance(X.values, np.ndarray): - X = X.values - except AttributeError: - pass + # Unpack in case of e.g. Pandas or xarray object + X = _unpack_to_numpy(X) # Iterate over columns for ndarrays. if isinstance(X, np.ndarray): @@ -2231,3 +2223,20 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class): factory = _make_class_factory(mixin_class, fmt, attr_name) cls = factory(base_class) return cls.__new__(cls) + + +def _unpack_to_numpy(x): + """Internal helper to extract data from e.g. pandas and xarray objects.""" + if isinstance(x, np.ndarray): + # If numpy, return directly + return x + if hasattr(x, 'to_numpy'): + # Assume that any function to_numpy() do actually return a numpy array + return x.to_numpy() + if hasattr(x, 'values'): + xtmp = x.values + # For example a dict has a 'values' attribute, but it is not a property + # so in this case we do not want to return a function + if isinstance(xtmp, np.ndarray): + return xtmp + return x diff --git a/lib/matplotlib/dates.py b/lib/matplotlib/dates.py index c6e327327e08..a6a95ad46b81 100644 --- a/lib/matplotlib/dates.py +++ b/lib/matplotlib/dates.py @@ -437,9 +437,8 @@ def date2num(d): The Gregorian calendar is assumed; this is not universal practice. For details see the module docstring. """ - if hasattr(d, "values"): - # this unpacks pandas series or dataframes... - d = d.values + # Unpack in case of e.g. Pandas or xarray object + d = cbook._unpack_to_numpy(d) # make an iterable, but save state to unpack later: iterable = np.iterable(d) diff --git a/lib/matplotlib/testing/conftest.py b/lib/matplotlib/testing/conftest.py index 996bfbefef80..01e60fea05e4 100644 --- a/lib/matplotlib/testing/conftest.py +++ b/lib/matplotlib/testing/conftest.py @@ -125,3 +125,10 @@ def pd(): except ImportError: pass return pd + + +@pytest.fixture +def xr(): + """Fixture to import xarray.""" + xr = pytest.importorskip('xarray') + return xr diff --git a/lib/matplotlib/tests/conftest.py b/lib/matplotlib/tests/conftest.py index f051470f777c..06c6d150f31b 100644 --- a/lib/matplotlib/tests/conftest.py +++ b/lib/matplotlib/tests/conftest.py @@ -1,3 +1,3 @@ from matplotlib.testing.conftest import (mpl_test_settings, pytest_configure, pytest_unconfigure, - pd) + pd, xr) diff --git a/lib/matplotlib/tests/test_cbook.py b/lib/matplotlib/tests/test_cbook.py index dcb855f73ce7..65956ba9ceea 100644 --- a/lib/matplotlib/tests/test_cbook.py +++ b/lib/matplotlib/tests/test_cbook.py @@ -680,14 +680,37 @@ def test_reshape2d_pandas(pd): for x, xnew in zip(X.T, Xnew): np.testing.assert_array_equal(x, xnew) + +def test_reshape2d_xarray(xr): + # separate to allow the rest of the tests to run if no xarray... X = np.arange(30).reshape(10, 3) - x = pd.DataFrame(X, columns=["a", "b", "c"]) + x = xr.DataArray(X, dims=["x", "y"]) Xnew = cbook._reshape_2D(x, 'x') # Need to check each row because _reshape_2D returns a list of arrays: for x, xnew in zip(X.T, Xnew): np.testing.assert_array_equal(x, xnew) +def test_index_of_pandas(pd): + # separate to allow the rest of the tests to run if no pandas... + X = np.arange(30).reshape(10, 3) + x = pd.DataFrame(X, columns=["a", "b", "c"]) + Idx, Xnew = cbook.index_of(x) + np.testing.assert_array_equal(X, Xnew) + IdxRef = np.arange(10) + np.testing.assert_array_equal(Idx, IdxRef) + + +def test_index_of_xarray(xr): + # separate to allow the rest of the tests to run if no xarray... + X = np.arange(30).reshape(10, 3) + x = xr.DataArray(X, dims=["x", "y"]) + Idx, Xnew = cbook.index_of(x) + np.testing.assert_array_equal(X, Xnew) + IdxRef = np.arange(10) + np.testing.assert_array_equal(Idx, IdxRef) + + def test_contiguous_regions(): a, b, c = 3, 4, 5 # Starts and ends with True diff --git a/lib/matplotlib/units.py b/lib/matplotlib/units.py index f0a0072abf67..910509f20310 100644 --- a/lib/matplotlib/units.py +++ b/lib/matplotlib/units.py @@ -180,8 +180,9 @@ class Registry(dict): def get_converter(self, x): """Get the converter interface instance for *x*, or None.""" - if hasattr(x, "values"): - x = x.values # Unpack pandas Series and DataFrames. + # Unpack in case of e.g. Pandas or xarray object + x = cbook._unpack_to_numpy(x) + if isinstance(x, np.ndarray): # In case x in a masked array, access the underlying data (only its # type matters). If x is a regular ndarray, getdata() just returns From 56af810bdf9e0b49af1ae7afce08f48e6f5eff21 Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Wed, 13 Apr 2022 16:16:52 -0400 Subject: [PATCH 2/2] CI: add xarray to extra dependencies --- requirements/testing/extra.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/testing/extra.txt b/requirements/testing/extra.txt index e5bb5d86a667..248e86b53e27 100644 --- a/requirements/testing/extra.txt +++ b/requirements/testing/extra.txt @@ -7,3 +7,4 @@ pandas!=0.25.0 pikepdf pytz pywin32; sys.platform == 'win32' +xarray