Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 698f065

Browse files
committed
Apply unit conversion early in errorbar().
This allow using normal numpy constructs rather than manually looping and broadcasting. _process_unit_info was already special-handling `data is None` in a few places; the change here only handle the (theoretical) extra case where a custom unit converter would fail to properly pass None through.
1 parent d235b02 commit 698f065

File tree

2 files changed

+14
-39
lines changed

2 files changed

+14
-39
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 10 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3281,27 +3281,12 @@ def errorbar(self, x, y, yerr=None, xerr=None,
32813281
kwargs = {k: v for k, v in kwargs.items() if v is not None}
32823282
kwargs.setdefault('zorder', 2)
32833283

3284-
self._process_unit_info([("x", x), ("y", y)], kwargs, convert=False)
3285-
3286-
# Make sure all the args are iterable; use lists not arrays to preserve
3287-
# units.
3288-
if not np.iterable(x):
3289-
x = [x]
3290-
3291-
if not np.iterable(y):
3292-
y = [y]
3293-
3284+
x, y, xerr, yerr = self._process_unit_info(
3285+
[("x", x), ("y", y), ("x", xerr), ("y", yerr)], kwargs)
3286+
x, y = np.atleast_1d(x, y) # Make sure all the args are iterable.
32943287
if len(x) != len(y):
32953288
raise ValueError("'x' and 'y' must have the same size")
32963289

3297-
if xerr is not None:
3298-
if not np.iterable(xerr):
3299-
xerr = [xerr] * len(x)
3300-
3301-
if yerr is not None:
3302-
if not np.iterable(yerr):
3303-
yerr = [yerr] * len(y)
3304-
33053290
if isinstance(errorevery, Integral):
33063291
errorevery = (0, errorevery)
33073292
if isinstance(errorevery, tuple):
@@ -3313,10 +3298,8 @@ def errorbar(self, x, y, yerr=None, xerr=None,
33133298
raise ValueError(
33143299
f'errorevery={errorevery!r} is a not a tuple of two '
33153300
f'integers')
3316-
33173301
elif isinstance(errorevery, slice):
33183302
pass
3319-
33203303
elif not isinstance(errorevery, str) and np.iterable(errorevery):
33213304
# fancy indexing
33223305
try:
@@ -3437,24 +3420,14 @@ def extract_err(name, err, data, lolims, uplims):
34373420
Error is only applied on **lower** side when this is True. See
34383421
the note in the main docstring about this parameter's name.
34393422
"""
3440-
try: # Asymmetric error: pair of 1D iterables.
3441-
a, b = err
3442-
iter(a)
3443-
iter(b)
3444-
except (TypeError, ValueError):
3445-
a = b = err # Symmetric error: 1D iterable.
3446-
if np.ndim(a) > 1 or np.ndim(b) > 1:
3423+
try:
3424+
low, high = np.broadcast_to(err, (2, len(data)))
3425+
except ValueError:
34473426
raise ValueError(
3448-
f"{name}err must be a scalar or a 1D or (2, n) array-like")
3449-
# Using list comprehensions rather than arrays to preserve units.
3450-
for e in [a, b]:
3451-
if len(data) != len(e):
3452-
raise ValueError(
3453-
f"The lengths of the data ({len(data)}) and the "
3454-
f"error {len(e)} do not match")
3455-
low = [v if lo else v - e for v, e, lo in zip(data, a, lolims)]
3456-
high = [v if up else v + e for v, e, up in zip(data, b, uplims)]
3457-
return low, high
3427+
f"'{name}err' (shape: {np.shape(err)}) must be a scalar "
3428+
f"or a 1D or (2, n) array-like whose shape matches "
3429+
f"'{name}' (shape: {np.shape(data)})") from None
3430+
return data - low * ~lolims, data + high * ~uplims # low, high
34583431

34593432
if xerr is not None:
34603433
left, right = extract_err('x', xerr, x, xlolims, xuplims)

lib/matplotlib/axes/_base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2312,7 +2312,8 @@ def _process_unit_info(self, datasets=None, kwargs=None, *, convert=True):
23122312
----------
23132313
datasets : list
23142314
List of (axis_name, dataset) pairs (where the axis name is defined
2315-
as in `._get_axis_map`.
2315+
as in `._get_axis_map`). Individual datasets can also be None
2316+
(which gets passed through).
23162317
kwargs : dict
23172318
Other parameters from which unit info (i.e., the *xunits*,
23182319
*yunits*, *zunits* (for 3D axes), *runits* and *thetaunits* (for
@@ -2359,7 +2360,8 @@ def _process_unit_info(self, datasets=None, kwargs=None, *, convert=True):
23592360
for dataset_axis_name, data in datasets:
23602361
if dataset_axis_name == axis_name and data is not None:
23612362
axis.update_units(data)
2362-
return [axis_map[axis_name].convert_units(data) if convert else data
2363+
return [axis_map[axis_name].convert_units(data)
2364+
if convert and data is not None else data
23632365
for axis_name, data in datasets]
23642366

23652367
def in_axes(self, mouseevent):

0 commit comments

Comments
 (0)