Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit dca4a97

Browse files
authored
Merge pull request #19526 from anntzer/errorbar-early-units
Apply unit conversion early in errorbar().
2 parents fb0c10d + 8cd22b4 commit dca4a97

File tree

3 files changed

+33
-46
lines changed

3 files changed

+33
-46
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 21 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3282,27 +3282,19 @@ def errorbar(self, x, y, yerr=None, xerr=None,
32823282
kwargs = {k: v for k, v in kwargs.items() if v is not None}
32833283
kwargs.setdefault('zorder', 2)
32843284

3285-
self._process_unit_info([("x", x), ("y", y)], kwargs, convert=False)
3286-
3287-
# Make sure all the args are iterable; use lists not arrays to preserve
3288-
# units.
3289-
if not np.iterable(x):
3290-
x = [x]
3291-
3292-
if not np.iterable(y):
3293-
y = [y]
3294-
3285+
# Casting to object arrays preserves units.
3286+
if not isinstance(x, np.ndarray):
3287+
x = np.asarray(x, dtype=object)
3288+
if not isinstance(y, np.ndarray):
3289+
y = np.asarray(y, dtype=object)
3290+
if xerr is not None and not isinstance(xerr, np.ndarray):
3291+
xerr = np.asarray(xerr, dtype=object)
3292+
if yerr is not None and not isinstance(yerr, np.ndarray):
3293+
yerr = np.asarray(yerr, dtype=object)
3294+
x, y = np.atleast_1d(x, y) # Make sure all the args are iterable.
32953295
if len(x) != len(y):
32963296
raise ValueError("'x' and 'y' must have the same size")
32973297

3298-
if xerr is not None:
3299-
if not np.iterable(xerr):
3300-
xerr = [xerr] * len(x)
3301-
3302-
if yerr is not None:
3303-
if not np.iterable(yerr):
3304-
yerr = [yerr] * len(y)
3305-
33063298
if isinstance(errorevery, Integral):
33073299
errorevery = (0, errorevery)
33083300
if isinstance(errorevery, tuple):
@@ -3314,10 +3306,8 @@ def errorbar(self, x, y, yerr=None, xerr=None,
33143306
raise ValueError(
33153307
f'errorevery={errorevery!r} is a not a tuple of two '
33163308
f'integers')
3317-
33183309
elif isinstance(errorevery, slice):
33193310
pass
3320-
33213311
elif not isinstance(errorevery, str) and np.iterable(errorevery):
33223312
# fancy indexing
33233313
try:
@@ -3329,6 +3319,8 @@ def errorbar(self, x, y, yerr=None, xerr=None,
33293319
else:
33303320
raise ValueError(
33313321
f"errorevery={errorevery!r} is not a recognized value")
3322+
everymask = np.zeros(len(x), bool)
3323+
everymask[errorevery] = True
33323324

33333325
label = kwargs.pop("label", None)
33343326
kwargs['label'] = '_nolegend_'
@@ -3412,13 +3404,8 @@ def errorbar(self, x, y, yerr=None, xerr=None,
34123404
xlolims = np.broadcast_to(xlolims, len(x)).astype(bool)
34133405
xuplims = np.broadcast_to(xuplims, len(x)).astype(bool)
34143406

3415-
everymask = np.zeros(len(x), bool)
3416-
everymask[errorevery] = True
3417-
3418-
def apply_mask(arrays, mask):
3419-
# Return, for each array in *arrays*, the elements for which *mask*
3420-
# is True, without using fancy indexing.
3421-
return [[*itertools.compress(array, mask)] for array in arrays]
3407+
# Vectorized fancy-indexer.
3408+
def apply_mask(arrays, mask): return [array[mask] for array in arrays]
34223409

34233410
def extract_err(name, err, data, lolims, uplims):
34243411
"""
@@ -3439,24 +3426,14 @@ def extract_err(name, err, data, lolims, uplims):
34393426
Error is only applied on **lower** side when this is True. See
34403427
the note in the main docstring about this parameter's name.
34413428
"""
3442-
try: # Asymmetric error: pair of 1D iterables.
3443-
a, b = err
3444-
iter(a)
3445-
iter(b)
3446-
except (TypeError, ValueError):
3447-
a = b = err # Symmetric error: 1D iterable.
3448-
if np.ndim(a) > 1 or np.ndim(b) > 1:
3429+
try:
3430+
low, high = np.broadcast_to(err, (2, len(data)))
3431+
except ValueError:
34493432
raise ValueError(
3450-
f"{name}err must be a scalar or a 1D or (2, n) array-like")
3451-
# Using list comprehensions rather than arrays to preserve units.
3452-
for e in [a, b]:
3453-
if len(data) != len(e):
3454-
raise ValueError(
3455-
f"The lengths of the data ({len(data)}) and the "
3456-
f"error {len(e)} do not match")
3457-
low = [v if lo else v - e for v, e, lo in zip(data, a, lolims)]
3458-
high = [v if up else v + e for v, e, up in zip(data, b, uplims)]
3459-
return low, high
3433+
f"'{name}err' (shape: {np.shape(err)}) must be a scalar "
3434+
f"or a 1D or (2, n) array-like whose shape matches "
3435+
f"'{name}' (shape: {np.shape(data)})") from None
3436+
return data - low * ~lolims, data + high * ~uplims # low, high
34603437

34613438
if xerr is not None:
34623439
left, right = extract_err('x', xerr, x, xlolims, xuplims)

lib/matplotlib/axes/_base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2479,7 +2479,8 @@ def _process_unit_info(self, datasets=None, kwargs=None, *, convert=True):
24792479
----------
24802480
datasets : list
24812481
List of (axis_name, dataset) pairs (where the axis name is defined
2482-
as in `._get_axis_map`.
2482+
as in `._get_axis_map`). Individual datasets can also be None
2483+
(which gets passed through).
24832484
kwargs : dict
24842485
Other parameters from which unit info (i.e., the *xunits*,
24852486
*yunits*, *zunits* (for 3D axes), *runits* and *thetaunits* (for
@@ -2526,7 +2527,8 @@ def _process_unit_info(self, datasets=None, kwargs=None, *, convert=True):
25262527
for dataset_axis_name, data in datasets:
25272528
if dataset_axis_name == axis_name and data is not None:
25282529
axis.update_units(data)
2529-
return [axis_map[axis_name].convert_units(data) if convert else data
2530+
return [axis_map[axis_name].convert_units(data)
2531+
if convert and data is not None else data
25302532
for axis_name, data in datasets]
25312533

25322534
def in_axes(self, mouseevent):

lib/matplotlib/tests/test_units.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,14 @@ def test_scatter_element0_masked():
166166
fig.canvas.draw()
167167

168168

169+
def test_errorbar_mixed_units():
170+
x = np.arange(10)
171+
y = [datetime(2020, 5, i * 2 + 1) for i in x]
172+
fig, ax = plt.subplots()
173+
ax.errorbar(x, y, timedelta(days=0.5))
174+
fig.canvas.draw()
175+
176+
169177
@check_figures_equal(extensions=["png"])
170178
def test_subclass(fig_test, fig_ref):
171179
class subdate(datetime):

0 commit comments

Comments
 (0)