From 8cd22b4d5840b08cd42f793f5a30e23595aa44f2 Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Mon, 21 Sep 2020 15:38:39 +0200 Subject: [PATCH] Apply unit conversion early in errorbar(). This allow using normal numpy constructs rather than manually looping and broadcasting. _process_unit_info was already special-handling `data is None` in a few places; the change here only handle the (theoretical) extra case where a custom unit converter would fail to properly pass None through. --- lib/matplotlib/axes/_axes.py | 65 ++++++++++-------------------- lib/matplotlib/axes/_base.py | 6 ++- lib/matplotlib/tests/test_units.py | 8 ++++ 3 files changed, 33 insertions(+), 46 deletions(-) diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py index 1b0c881245af..859295118c88 100644 --- a/lib/matplotlib/axes/_axes.py +++ b/lib/matplotlib/axes/_axes.py @@ -3281,27 +3281,19 @@ def errorbar(self, x, y, yerr=None, xerr=None, kwargs = {k: v for k, v in kwargs.items() if v is not None} kwargs.setdefault('zorder', 2) - self._process_unit_info([("x", x), ("y", y)], kwargs, convert=False) - - # Make sure all the args are iterable; use lists not arrays to preserve - # units. - if not np.iterable(x): - x = [x] - - if not np.iterable(y): - y = [y] - + # Casting to object arrays preserves units. + if not isinstance(x, np.ndarray): + x = np.asarray(x, dtype=object) + if not isinstance(y, np.ndarray): + y = np.asarray(y, dtype=object) + if xerr is not None and not isinstance(xerr, np.ndarray): + xerr = np.asarray(xerr, dtype=object) + if yerr is not None and not isinstance(yerr, np.ndarray): + yerr = np.asarray(yerr, dtype=object) + x, y = np.atleast_1d(x, y) # Make sure all the args are iterable. if len(x) != len(y): raise ValueError("'x' and 'y' must have the same size") - if xerr is not None: - if not np.iterable(xerr): - xerr = [xerr] * len(x) - - if yerr is not None: - if not np.iterable(yerr): - yerr = [yerr] * len(y) - if isinstance(errorevery, Integral): errorevery = (0, errorevery) if isinstance(errorevery, tuple): @@ -3313,10 +3305,8 @@ def errorbar(self, x, y, yerr=None, xerr=None, raise ValueError( f'errorevery={errorevery!r} is a not a tuple of two ' f'integers') - elif isinstance(errorevery, slice): pass - elif not isinstance(errorevery, str) and np.iterable(errorevery): # fancy indexing try: @@ -3328,6 +3318,8 @@ def errorbar(self, x, y, yerr=None, xerr=None, else: raise ValueError( f"errorevery={errorevery!r} is not a recognized value") + everymask = np.zeros(len(x), bool) + everymask[errorevery] = True label = kwargs.pop("label", None) kwargs['label'] = '_nolegend_' @@ -3410,13 +3402,8 @@ def errorbar(self, x, y, yerr=None, xerr=None, xlolims = np.broadcast_to(xlolims, len(x)).astype(bool) xuplims = np.broadcast_to(xuplims, len(x)).astype(bool) - everymask = np.zeros(len(x), bool) - everymask[errorevery] = True - - def apply_mask(arrays, mask): - # Return, for each array in *arrays*, the elements for which *mask* - # is True, without using fancy indexing. - return [[*itertools.compress(array, mask)] for array in arrays] + # Vectorized fancy-indexer. + def apply_mask(arrays, mask): return [array[mask] for array in arrays] def extract_err(name, err, data, lolims, uplims): """ @@ -3437,24 +3424,14 @@ def extract_err(name, err, data, lolims, uplims): Error is only applied on **lower** side when this is True. See the note in the main docstring about this parameter's name. """ - try: # Asymmetric error: pair of 1D iterables. - a, b = err - iter(a) - iter(b) - except (TypeError, ValueError): - a = b = err # Symmetric error: 1D iterable. - if np.ndim(a) > 1 or np.ndim(b) > 1: + try: + low, high = np.broadcast_to(err, (2, len(data))) + except ValueError: raise ValueError( - f"{name}err must be a scalar or a 1D or (2, n) array-like") - # Using list comprehensions rather than arrays to preserve units. - for e in [a, b]: - if len(data) != len(e): - raise ValueError( - f"The lengths of the data ({len(data)}) and the " - f"error {len(e)} do not match") - low = [v if lo else v - e for v, e, lo in zip(data, a, lolims)] - high = [v if up else v + e for v, e, up in zip(data, b, uplims)] - return low, high + f"'{name}err' (shape: {np.shape(err)}) must be a scalar " + f"or a 1D or (2, n) array-like whose shape matches " + f"'{name}' (shape: {np.shape(data)})") from None + return data - low * ~lolims, data + high * ~uplims # low, high if xerr is not None: left, right = extract_err('x', xerr, x, xlolims, xuplims) diff --git a/lib/matplotlib/axes/_base.py b/lib/matplotlib/axes/_base.py index b85275bce970..2d589de84d5f 100644 --- a/lib/matplotlib/axes/_base.py +++ b/lib/matplotlib/axes/_base.py @@ -2312,7 +2312,8 @@ def _process_unit_info(self, datasets=None, kwargs=None, *, convert=True): ---------- datasets : list List of (axis_name, dataset) pairs (where the axis name is defined - as in `._get_axis_map`. + as in `._get_axis_map`). Individual datasets can also be None + (which gets passed through). kwargs : dict Other parameters from which unit info (i.e., the *xunits*, *yunits*, *zunits* (for 3D axes), *runits* and *thetaunits* (for @@ -2359,7 +2360,8 @@ def _process_unit_info(self, datasets=None, kwargs=None, *, convert=True): for dataset_axis_name, data in datasets: if dataset_axis_name == axis_name and data is not None: axis.update_units(data) - return [axis_map[axis_name].convert_units(data) if convert else data + return [axis_map[axis_name].convert_units(data) + if convert and data is not None else data for axis_name, data in datasets] def in_axes(self, mouseevent): diff --git a/lib/matplotlib/tests/test_units.py b/lib/matplotlib/tests/test_units.py index 3f40a99a2f5a..b0c5998c5338 100644 --- a/lib/matplotlib/tests/test_units.py +++ b/lib/matplotlib/tests/test_units.py @@ -166,6 +166,14 @@ def test_scatter_element0_masked(): fig.canvas.draw() +def test_errorbar_mixed_units(): + x = np.arange(10) + y = [datetime(2020, 5, i * 2 + 1) for i in x] + fig, ax = plt.subplots() + ax.errorbar(x, y, timedelta(days=0.5)) + fig.canvas.draw() + + @check_figures_equal(extensions=["png"]) def test_subclass(fig_test, fig_ref): class subdate(datetime):