Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 3420565

Browse files
authored
Merge pull request #22876 from meeseeksmachine/auto-backport-of-pr-22560-on-v3.5.x
Backport PR #22560 on branch v3.5.x (Improve pandas/xarray/... conversion)
2 parents 0fe45ab + 1e23977 commit 3420565

File tree

7 files changed

+60
-21
lines changed

7 files changed

+60
-21
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7933,8 +7933,8 @@ def violinplot(self, dataset, positions=None, vert=True, widths=0.5,
79337933
"""
79347934

79357935
def _kde_method(X, coords):
7936-
if hasattr(X, 'values'): # support pandas.Series
7937-
X = X.values
7936+
# Unpack in case of e.g. Pandas or xarray object
7937+
X = cbook._unpack_to_numpy(X)
79387938
# fallback gracefully if the vector contains only one value
79397939
if np.all(X[0] == X):
79407940
return (X[0] == coords).astype(float)

lib/matplotlib/cbook/__init__.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,9 +1300,8 @@ def _to_unmasked_float_array(x):
13001300

13011301
def _check_1d(x):
13021302
"""Convert scalars to 1D arrays; pass-through arrays as is."""
1303-
if hasattr(x, 'to_numpy'):
1304-
# if we are given an object that creates a numpy, we should use it...
1305-
x = x.to_numpy()
1303+
# Unpack in case of e.g. Pandas or xarray object
1304+
x = _unpack_to_numpy(x)
13061305
if not hasattr(x, 'shape') or len(x.shape) < 1:
13071306
return np.atleast_1d(x)
13081307
else:
@@ -1321,15 +1320,8 @@ def _reshape_2D(X, name):
13211320
*name* is used to generate the error message for invalid inputs.
13221321
"""
13231322

1324-
# unpack if we have a values or to_numpy method.
1325-
try:
1326-
X = X.to_numpy()
1327-
except AttributeError:
1328-
try:
1329-
if isinstance(X.values, np.ndarray):
1330-
X = X.values
1331-
except AttributeError:
1332-
pass
1323+
# Unpack in case of e.g. Pandas or xarray object
1324+
X = _unpack_to_numpy(X)
13331325

13341326
# Iterate over columns for ndarrays.
13351327
if isinstance(X, np.ndarray):
@@ -2275,3 +2267,20 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
22752267
factory = _make_class_factory(mixin_class, fmt, attr_name)
22762268
cls = factory(base_class)
22772269
return cls.__new__(cls)
2270+
2271+
2272+
def _unpack_to_numpy(x):
2273+
"""Internal helper to extract data from e.g. pandas and xarray objects."""
2274+
if isinstance(x, np.ndarray):
2275+
# If numpy, return directly
2276+
return x
2277+
if hasattr(x, 'to_numpy'):
2278+
# Assume that any function to_numpy() do actually return a numpy array
2279+
return x.to_numpy()
2280+
if hasattr(x, 'values'):
2281+
xtmp = x.values
2282+
# For example a dict has a 'values' attribute, but it is not a property
2283+
# so in this case we do not want to return a function
2284+
if isinstance(xtmp, np.ndarray):
2285+
return xtmp
2286+
return x

lib/matplotlib/dates.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,9 +423,8 @@ def date2num(d):
423423
The Gregorian calendar is assumed; this is not universal practice.
424424
For details see the module docstring.
425425
"""
426-
if hasattr(d, "values"):
427-
# this unpacks pandas series or dataframes...
428-
d = d.values
426+
# Unpack in case of e.g. Pandas or xarray object
427+
d = cbook._unpack_to_numpy(d)
429428

430429
# make an iterable, but save state to unpack later:
431430
iterable = np.iterable(d)

lib/matplotlib/testing/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,10 @@ def pd():
125125
except ImportError:
126126
pass
127127
return pd
128+
129+
130+
@pytest.fixture
131+
def xr():
132+
"""Fixture to import xarray."""
133+
xr = pytest.importorskip('xarray')
134+
return xr

lib/matplotlib/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from matplotlib.testing.conftest import (mpl_test_settings,
22
pytest_configure, pytest_unconfigure,
3-
pd)
3+
pd, xr)

lib/matplotlib/tests/test_cbook.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,14 +668,37 @@ def test_reshape2d_pandas(pd):
668668
for x, xnew in zip(X.T, Xnew):
669669
np.testing.assert_array_equal(x, xnew)
670670

671+
672+
def test_reshape2d_xarray(xr):
673+
# separate to allow the rest of the tests to run if no xarray...
671674
X = np.arange(30).reshape(10, 3)
672-
x = pd.DataFrame(X, columns=["a", "b", "c"])
675+
x = xr.DataArray(X, dims=["x", "y"])
673676
Xnew = cbook._reshape_2D(x, 'x')
674677
# Need to check each row because _reshape_2D returns a list of arrays:
675678
for x, xnew in zip(X.T, Xnew):
676679
np.testing.assert_array_equal(x, xnew)
677680

678681

682+
def test_index_of_pandas(pd):
683+
# separate to allow the rest of the tests to run if no pandas...
684+
X = np.arange(30).reshape(10, 3)
685+
x = pd.DataFrame(X, columns=["a", "b", "c"])
686+
Idx, Xnew = cbook.index_of(x)
687+
np.testing.assert_array_equal(X, Xnew)
688+
IdxRef = np.arange(10)
689+
np.testing.assert_array_equal(Idx, IdxRef)
690+
691+
692+
def test_index_of_xarray(xr):
693+
# separate to allow the rest of the tests to run if no xarray...
694+
X = np.arange(30).reshape(10, 3)
695+
x = xr.DataArray(X, dims=["x", "y"])
696+
Idx, Xnew = cbook.index_of(x)
697+
np.testing.assert_array_equal(X, Xnew)
698+
IdxRef = np.arange(10)
699+
np.testing.assert_array_equal(Idx, IdxRef)
700+
701+
679702
def test_contiguous_regions():
680703
a, b, c = 3, 4, 5
681704
# Starts and ends with True

lib/matplotlib/units.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,9 @@ class Registry(dict):
180180

181181
def get_converter(self, x):
182182
"""Get the converter interface instance for *x*, or None."""
183-
if hasattr(x, "values"):
184-
x = x.values # Unpack pandas Series and DataFrames.
183+
# Unpack in case of e.g. Pandas or xarray object
184+
x = cbook._unpack_to_numpy(x)
185+
185186
if isinstance(x, np.ndarray):
186187
# In case x in a masked array, access the underlying data (only its
187188
# type matters). If x is a regular ndarray, getdata() just returns

0 commit comments

Comments
 (0)