Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Apply unit conversion early in errorbar(). #19526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 21 additions & 44 deletions lib/matplotlib/axes/_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3281,27 +3281,19 @@ def errorbar(self, x, y, yerr=None, xerr=None,
kwargs = {k: v for k, v in kwargs.items() if v is not None}
kwargs.setdefault('zorder', 2)

self._process_unit_info([("x", x), ("y", y)], kwargs, convert=False)

# Make sure all the args are iterable; use lists not arrays to preserve
# units.
if not np.iterable(x):
x = [x]

if not np.iterable(y):
y = [y]

# Casting to object arrays preserves units.
if not isinstance(x, np.ndarray):
x = np.asarray(x, dtype=object)
if not isinstance(y, np.ndarray):
y = np.asarray(y, dtype=object)
if xerr is not None and not isinstance(xerr, np.ndarray):
xerr = np.asarray(xerr, dtype=object)
if yerr is not None and not isinstance(yerr, np.ndarray):
yerr = np.asarray(yerr, dtype=object)
x, y = np.atleast_1d(x, y) # Make sure all the args are iterable.
if len(x) != len(y):
raise ValueError("'x' and 'y' must have the same size")

if xerr is not None:
if not np.iterable(xerr):
xerr = [xerr] * len(x)

if yerr is not None:
if not np.iterable(yerr):
yerr = [yerr] * len(y)

if isinstance(errorevery, Integral):
errorevery = (0, errorevery)
if isinstance(errorevery, tuple):
Expand All @@ -3313,10 +3305,8 @@ def errorbar(self, x, y, yerr=None, xerr=None,
raise ValueError(
f'errorevery={errorevery!r} is a not a tuple of two '
f'integers')

elif isinstance(errorevery, slice):
pass

elif not isinstance(errorevery, str) and np.iterable(errorevery):
# fancy indexing
try:
Expand All @@ -3328,6 +3318,8 @@ def errorbar(self, x, y, yerr=None, xerr=None,
else:
raise ValueError(
f"errorevery={errorevery!r} is not a recognized value")
everymask = np.zeros(len(x), bool)
everymask[errorevery] = True

label = kwargs.pop("label", None)
kwargs['label'] = '_nolegend_'
Expand Down Expand Up @@ -3410,13 +3402,8 @@ def errorbar(self, x, y, yerr=None, xerr=None,
xlolims = np.broadcast_to(xlolims, len(x)).astype(bool)
xuplims = np.broadcast_to(xuplims, len(x)).astype(bool)

everymask = np.zeros(len(x), bool)
everymask[errorevery] = True

def apply_mask(arrays, mask):
# Return, for each array in *arrays*, the elements for which *mask*
# is True, without using fancy indexing.
return [[*itertools.compress(array, mask)] for array in arrays]
# Vectorized fancy-indexer.
def apply_mask(arrays, mask): return [array[mask] for array in arrays]

def extract_err(name, err, data, lolims, uplims):
"""
Expand All @@ -3437,24 +3424,14 @@ def extract_err(name, err, data, lolims, uplims):
Error is only applied on **lower** side when this is True. See
the note in the main docstring about this parameter's name.
"""
try: # Asymmetric error: pair of 1D iterables.
a, b = err
iter(a)
iter(b)
except (TypeError, ValueError):
a = b = err # Symmetric error: 1D iterable.
if np.ndim(a) > 1 or np.ndim(b) > 1:
try:
low, high = np.broadcast_to(err, (2, len(data)))
except ValueError:
raise ValueError(
f"{name}err must be a scalar or a 1D or (2, n) array-like")
# Using list comprehensions rather than arrays to preserve units.
for e in [a, b]:
if len(data) != len(e):
raise ValueError(
f"The lengths of the data ({len(data)}) and the "
f"error {len(e)} do not match")
low = [v if lo else v - e for v, e, lo in zip(data, a, lolims)]
high = [v if up else v + e for v, e, up in zip(data, b, uplims)]
return low, high
f"'{name}err' (shape: {np.shape(err)}) must be a scalar "
f"or a 1D or (2, n) array-like whose shape matches "
f"'{name}' (shape: {np.shape(data)})") from None
return data - low * ~lolims, data + high * ~uplims # low, high

if xerr is not None:
left, right = extract_err('x', xerr, x, xlolims, xuplims)
Expand Down
6 changes: 4 additions & 2 deletions lib/matplotlib/axes/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2312,7 +2312,8 @@ def _process_unit_info(self, datasets=None, kwargs=None, *, convert=True):
----------
datasets : list
List of (axis_name, dataset) pairs (where the axis name is defined
as in `._get_axis_map`.
as in `._get_axis_map`). Individual datasets can also be None
(which gets passed through).
kwargs : dict
Other parameters from which unit info (i.e., the *xunits*,
*yunits*, *zunits* (for 3D axes), *runits* and *thetaunits* (for
Expand Down Expand Up @@ -2359,7 +2360,8 @@ def _process_unit_info(self, datasets=None, kwargs=None, *, convert=True):
for dataset_axis_name, data in datasets:
if dataset_axis_name == axis_name and data is not None:
axis.update_units(data)
return [axis_map[axis_name].convert_units(data) if convert else data
return [axis_map[axis_name].convert_units(data)
if convert and data is not None else data
for axis_name, data in datasets]

def in_axes(self, mouseevent):
Expand Down
8 changes: 8 additions & 0 deletions lib/matplotlib/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,14 @@ def test_scatter_element0_masked():
fig.canvas.draw()


def test_errorbar_mixed_units():
x = np.arange(10)
y = [datetime(2020, 5, i * 2 + 1) for i in x]
fig, ax = plt.subplots()
ax.errorbar(x, y, timedelta(days=0.5))
fig.canvas.draw()


@check_figures_equal(extensions=["png"])
def test_subclass(fig_test, fig_ref):
class subdate(datetime):
Expand Down