Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit cc2517c

Browse files
committed
Apply unit conversion early in errorbar().
This allow using normal numpy constructs rather than manually looping and broadcasting. _process_unit_info was already special-handling `data is None` in a few places; the change here only handle the (theoretical) extra case where a custom unit converter would fail to properly pass None through.
1 parent d235b02 commit cc2517c

File tree

2 files changed

+18
-46
lines changed

2 files changed

+18
-46
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 14 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3281,27 +3281,12 @@ def errorbar(self, x, y, yerr=None, xerr=None,
32813281
kwargs = {k: v for k, v in kwargs.items() if v is not None}
32823282
kwargs.setdefault('zorder', 2)
32833283

3284-
self._process_unit_info([("x", x), ("y", y)], kwargs, convert=False)
3285-
3286-
# Make sure all the args are iterable; use lists not arrays to preserve
3287-
# units.
3288-
if not np.iterable(x):
3289-
x = [x]
3290-
3291-
if not np.iterable(y):
3292-
y = [y]
3293-
3284+
x, y, xerr, yerr = self._process_unit_info(
3285+
[("x", x), ("y", y), ("x", xerr), ("y", yerr)], kwargs)
3286+
x, y = np.atleast_1d(x, y) # Make sure all the args are iterable.
32943287
if len(x) != len(y):
32953288
raise ValueError("'x' and 'y' must have the same size")
32963289

3297-
if xerr is not None:
3298-
if not np.iterable(xerr):
3299-
xerr = [xerr] * len(x)
3300-
3301-
if yerr is not None:
3302-
if not np.iterable(yerr):
3303-
yerr = [yerr] * len(y)
3304-
33053290
if isinstance(errorevery, Integral):
33063291
errorevery = (0, errorevery)
33073292
if isinstance(errorevery, tuple):
@@ -3313,10 +3298,8 @@ def errorbar(self, x, y, yerr=None, xerr=None,
33133298
raise ValueError(
33143299
f'errorevery={errorevery!r} is a not a tuple of two '
33153300
f'integers')
3316-
33173301
elif isinstance(errorevery, slice):
33183302
pass
3319-
33203303
elif not isinstance(errorevery, str) and np.iterable(errorevery):
33213304
# fancy indexing
33223305
try:
@@ -3328,6 +3311,8 @@ def errorbar(self, x, y, yerr=None, xerr=None,
33283311
else:
33293312
raise ValueError(
33303313
f"errorevery={errorevery!r} is not a recognized value")
3314+
everymask = np.zeros(len(x), bool)
3315+
everymask[errorevery] = True
33313316

33323317
label = kwargs.pop("label", None)
33333318
kwargs['label'] = '_nolegend_'
@@ -3410,13 +3395,8 @@ def errorbar(self, x, y, yerr=None, xerr=None,
34103395
xlolims = np.broadcast_to(xlolims, len(x)).astype(bool)
34113396
xuplims = np.broadcast_to(xuplims, len(x)).astype(bool)
34123397

3413-
everymask = np.zeros(len(x), bool)
3414-
everymask[errorevery] = True
3415-
3416-
def apply_mask(arrays, mask):
3417-
# Return, for each array in *arrays*, the elements for which *mask*
3418-
# is True, without using fancy indexing.
3419-
return [[*itertools.compress(array, mask)] for array in arrays]
3398+
# Vectorized fancy-indexer.
3399+
def apply_mask(arrays, mask): return [array[mask] for array in arrays]
34203400

34213401
def extract_err(name, err, data, lolims, uplims):
34223402
"""
@@ -3437,24 +3417,14 @@ def extract_err(name, err, data, lolims, uplims):
34373417
Error is only applied on **lower** side when this is True. See
34383418
the note in the main docstring about this parameter's name.
34393419
"""
3440-
try: # Asymmetric error: pair of 1D iterables.
3441-
a, b = err
3442-
iter(a)
3443-
iter(b)
3444-
except (TypeError, ValueError):
3445-
a = b = err # Symmetric error: 1D iterable.
3446-
if np.ndim(a) > 1 or np.ndim(b) > 1:
3420+
try:
3421+
low, high = np.broadcast_to(err, (2, len(data)))
3422+
except ValueError:
34473423
raise ValueError(
3448-
f"{name}err must be a scalar or a 1D or (2, n) array-like")
3449-
# Using list comprehensions rather than arrays to preserve units.
3450-
for e in [a, b]:
3451-
if len(data) != len(e):
3452-
raise ValueError(
3453-
f"The lengths of the data ({len(data)}) and the "
3454-
f"error {len(e)} do not match")
3455-
low = [v if lo else v - e for v, e, lo in zip(data, a, lolims)]
3456-
high = [v if up else v + e for v, e, up in zip(data, b, uplims)]
3457-
return low, high
3424+
f"'{name}err' (shape: {np.shape(err)}) must be a scalar "
3425+
f"or a 1D or (2, n) array-like whose shape matches "
3426+
f"'{name}' (shape: {np.shape(data)})") from None
3427+
return data - low * ~lolims, data + high * ~uplims # low, high
34583428

34593429
if xerr is not None:
34603430
left, right = extract_err('x', xerr, x, xlolims, xuplims)

lib/matplotlib/axes/_base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2312,7 +2312,8 @@ def _process_unit_info(self, datasets=None, kwargs=None, *, convert=True):
23122312
----------
23132313
datasets : list
23142314
List of (axis_name, dataset) pairs (where the axis name is defined
2315-
as in `._get_axis_map`.
2315+
as in `._get_axis_map`). Individual datasets can also be None
2316+
(which gets passed through).
23162317
kwargs : dict
23172318
Other parameters from which unit info (i.e., the *xunits*,
23182319
*yunits*, *zunits* (for 3D axes), *runits* and *thetaunits* (for
@@ -2359,7 +2360,8 @@ def _process_unit_info(self, datasets=None, kwargs=None, *, convert=True):
23592360
for dataset_axis_name, data in datasets:
23602361
if dataset_axis_name == axis_name and data is not None:
23612362
axis.update_units(data)
2362-
return [axis_map[axis_name].convert_units(data) if convert else data
2363+
return [axis_map[axis_name].convert_units(data)
2364+
if convert and data is not None else data
23632365
for axis_name, data in datasets]
23642366

23652367
def in_axes(self, mouseevent):

0 commit comments

Comments
 (0)