Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 8cd22b4

Browse files
committed
Apply unit conversion early in errorbar().
This allow using normal numpy constructs rather than manually looping and broadcasting. _process_unit_info was already special-handling `data is None` in a few places; the change here only handle the (theoretical) extra case where a custom unit converter would fail to properly pass None through.
1 parent d235b02 commit 8cd22b4

File tree

3 files changed

+33
-46
lines changed

3 files changed

+33
-46
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 21 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3281,27 +3281,19 @@ def errorbar(self, x, y, yerr=None, xerr=None,
32813281
kwargs = {k: v for k, v in kwargs.items() if v is not None}
32823282
kwargs.setdefault('zorder', 2)
32833283

3284-
self._process_unit_info([("x", x), ("y", y)], kwargs, convert=False)
3285-
3286-
# Make sure all the args are iterable; use lists not arrays to preserve
3287-
# units.
3288-
if not np.iterable(x):
3289-
x = [x]
3290-
3291-
if not np.iterable(y):
3292-
y = [y]
3293-
3284+
# Casting to object arrays preserves units.
3285+
if not isinstance(x, np.ndarray):
3286+
x = np.asarray(x, dtype=object)
3287+
if not isinstance(y, np.ndarray):
3288+
y = np.asarray(y, dtype=object)
3289+
if xerr is not None and not isinstance(xerr, np.ndarray):
3290+
xerr = np.asarray(xerr, dtype=object)
3291+
if yerr is not None and not isinstance(yerr, np.ndarray):
3292+
yerr = np.asarray(yerr, dtype=object)
3293+
x, y = np.atleast_1d(x, y) # Make sure all the args are iterable.
32943294
if len(x) != len(y):
32953295
raise ValueError("'x' and 'y' must have the same size")
32963296

3297-
if xerr is not None:
3298-
if not np.iterable(xerr):
3299-
xerr = [xerr] * len(x)
3300-
3301-
if yerr is not None:
3302-
if not np.iterable(yerr):
3303-
yerr = [yerr] * len(y)
3304-
33053297
if isinstance(errorevery, Integral):
33063298
errorevery = (0, errorevery)
33073299
if isinstance(errorevery, tuple):
@@ -3313,10 +3305,8 @@ def errorbar(self, x, y, yerr=None, xerr=None,
33133305
raise ValueError(
33143306
f'errorevery={errorevery!r} is a not a tuple of two '
33153307
f'integers')
3316-
33173308
elif isinstance(errorevery, slice):
33183309
pass
3319-
33203310
elif not isinstance(errorevery, str) and np.iterable(errorevery):
33213311
# fancy indexing
33223312
try:
@@ -3328,6 +3318,8 @@ def errorbar(self, x, y, yerr=None, xerr=None,
33283318
else:
33293319
raise ValueError(
33303320
f"errorevery={errorevery!r} is not a recognized value")
3321+
everymask = np.zeros(len(x), bool)
3322+
everymask[errorevery] = True
33313323

33323324
label = kwargs.pop("label", None)
33333325
kwargs['label'] = '_nolegend_'
@@ -3410,13 +3402,8 @@ def errorbar(self, x, y, yerr=None, xerr=None,
34103402
xlolims = np.broadcast_to(xlolims, len(x)).astype(bool)
34113403
xuplims = np.broadcast_to(xuplims, len(x)).astype(bool)
34123404

3413-
everymask = np.zeros(len(x), bool)
3414-
everymask[errorevery] = True
3415-
3416-
def apply_mask(arrays, mask):
3417-
# Return, for each array in *arrays*, the elements for which *mask*
3418-
# is True, without using fancy indexing.
3419-
return [[*itertools.compress(array, mask)] for array in arrays]
3405+
# Vectorized fancy-indexer.
3406+
def apply_mask(arrays, mask): return [array[mask] for array in arrays]
34203407

34213408
def extract_err(name, err, data, lolims, uplims):
34223409
"""
@@ -3437,24 +3424,14 @@ def extract_err(name, err, data, lolims, uplims):
34373424
Error is only applied on **lower** side when this is True. See
34383425
the note in the main docstring about this parameter's name.
34393426
"""
3440-
try: # Asymmetric error: pair of 1D iterables.
3441-
a, b = err
3442-
iter(a)
3443-
iter(b)
3444-
except (TypeError, ValueError):
3445-
a = b = err # Symmetric error: 1D iterable.
3446-
if np.ndim(a) > 1 or np.ndim(b) > 1:
3427+
try:
3428+
low, high = np.broadcast_to(err, (2, len(data)))
3429+
except ValueError:
34473430
raise ValueError(
3448-
f"{name}err must be a scalar or a 1D or (2, n) array-like")
3449-
# Using list comprehensions rather than arrays to preserve units.
3450-
for e in [a, b]:
3451-
if len(data) != len(e):
3452-
raise ValueError(
3453-
f"The lengths of the data ({len(data)}) and the "
3454-
f"error {len(e)} do not match")
3455-
low = [v if lo else v - e for v, e, lo in zip(data, a, lolims)]
3456-
high = [v if up else v + e for v, e, up in zip(data, b, uplims)]
3457-
return low, high
3431+
f"'{name}err' (shape: {np.shape(err)}) must be a scalar "
3432+
f"or a 1D or (2, n) array-like whose shape matches "
3433+
f"'{name}' (shape: {np.shape(data)})") from None
3434+
return data - low * ~lolims, data + high * ~uplims # low, high
34583435

34593436
if xerr is not None:
34603437
left, right = extract_err('x', xerr, x, xlolims, xuplims)

lib/matplotlib/axes/_base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2312,7 +2312,8 @@ def _process_unit_info(self, datasets=None, kwargs=None, *, convert=True):
23122312
----------
23132313
datasets : list
23142314
List of (axis_name, dataset) pairs (where the axis name is defined
2315-
as in `._get_axis_map`.
2315+
as in `._get_axis_map`). Individual datasets can also be None
2316+
(which gets passed through).
23162317
kwargs : dict
23172318
Other parameters from which unit info (i.e., the *xunits*,
23182319
*yunits*, *zunits* (for 3D axes), *runits* and *thetaunits* (for
@@ -2359,7 +2360,8 @@ def _process_unit_info(self, datasets=None, kwargs=None, *, convert=True):
23592360
for dataset_axis_name, data in datasets:
23602361
if dataset_axis_name == axis_name and data is not None:
23612362
axis.update_units(data)
2362-
return [axis_map[axis_name].convert_units(data) if convert else data
2363+
return [axis_map[axis_name].convert_units(data)
2364+
if convert and data is not None else data
23632365
for axis_name, data in datasets]
23642366

23652367
def in_axes(self, mouseevent):

lib/matplotlib/tests/test_units.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,14 @@ def test_scatter_element0_masked():
166166
fig.canvas.draw()
167167

168168

169+
def test_errorbar_mixed_units():
170+
x = np.arange(10)
171+
y = [datetime(2020, 5, i * 2 + 1) for i in x]
172+
fig, ax = plt.subplots()
173+
ax.errorbar(x, y, timedelta(days=0.5))
174+
fig.canvas.draw()
175+
176+
169177
@check_figures_equal(extensions=["png"])
170178
def test_subclass(fig_test, fig_ref):
171179
class subdate(datetime):

0 commit comments

Comments
 (0)