diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py index 1c6d333e61c9..4af7c9de8573 100644 --- a/lib/matplotlib/axes/_axes.py +++ b/lib/matplotlib/axes/_axes.py @@ -33,6 +33,7 @@ import matplotlib.ticker as mticker import matplotlib.transforms as mtransforms import matplotlib.tri as mtri +import matplotlib.units as munits from matplotlib.container import BarContainer, ErrorbarContainer, StemContainer from matplotlib.axes._base import _AxesBase, _process_plot_format @@ -636,6 +637,7 @@ def indicate_inset_zoom(self, inset_ax, **kwargs): return rectpatch, connects + @munits._accepts_units(convert_x=['x'], convert_y=['y']) def text(self, x, y, s, fontdict=None, withdash=False, **kwargs): """ Add text to the axes. @@ -731,6 +733,7 @@ def annotate(self, text, xy, *args, **kwargs): annotate.__doc__ = mtext.Annotation.__init__.__doc__ #### Lines and spans + @munits._accepts_units(convert_y=['y']) @docstring.dedent_interpd def axhline(self, y=0, xmin=0, xmax=1, **kwargs): """ @@ -786,14 +789,9 @@ def axhline(self, y=0, xmin=0, xmax=1, **kwargs): if "transform" in kwargs: raise ValueError( "'transform' is not allowed as a kwarg;" - + "axhline generates its own transform.") + "axhline generates its own transform.") ymin, ymax = self.get_ybound() - - # We need to strip away the units for comparison with - # non-unitized bounds - self._process_unit_info(ydata=y, kwargs=kwargs) - yy = self.convert_yunits(y) - scaley = (yy < ymin) or (yy > ymax) + scaley = (y < ymin) or (y > ymax) trans = self.get_yaxis_transform(which='grid') l = mlines.Line2D([xmin, xmax], [y, y], transform=trans, **kwargs) @@ -801,6 +799,7 @@ def axhline(self, y=0, xmin=0, xmax=1, **kwargs): self.autoscale_view(scalex=False, scaley=scaley) return l + @munits._accepts_units(convert_x=['x']) @docstring.dedent_interpd def axvline(self, x=0, ymin=0, ymax=1, **kwargs): """ @@ -855,14 +854,9 @@ def axvline(self, x=0, ymin=0, ymax=1, **kwargs): if "transform" in kwargs: raise ValueError( "'transform' is not allowed as a kwarg;" - + "axvline generates its own transform.") + "axvline generates its own transform.") xmin, xmax = self.get_xbound() - - # We need to strip away the units for comparison with - # non-unitized bounds - self._process_unit_info(xdata=x, kwargs=kwargs) - xx = self.convert_xunits(x) - scalex = (xx < xmin) or (xx > xmax) + scalex = (x < xmin) or (x > xmax) trans = self.get_xaxis_transform(which='grid') l = mlines.Line2D([x, x], [ymin, ymax], transform=trans, **kwargs) @@ -870,6 +864,8 @@ def axvline(self, x=0, ymin=0, ymax=1, **kwargs): self.autoscale_view(scalex=scalex, scaley=False) return l + @munits._accepts_units(convert_x=['xmin', 'xmax'], + convert_y=['ymin', 'ymax']) @docstring.dedent_interpd def axhspan(self, ymin, ymax, xmin=0, xmax=1, **kwargs): """ @@ -911,14 +907,6 @@ def axhspan(self, ymin, ymax, xmin=0, xmax=1, **kwargs): axvspan : Add a vertical span across the axes. """ trans = self.get_yaxis_transform(which='grid') - - # process the unit information - self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs) - - # first we need to strip away the units - xmin, xmax = self.convert_xunits([xmin, xmax]) - ymin, ymax = self.convert_yunits([ymin, ymax]) - verts = (xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin) p = mpatches.Polygon(verts, **kwargs) p.set_transform(trans) @@ -926,6 +914,8 @@ def axhspan(self, ymin, ymax, xmin=0, xmax=1, **kwargs): self.autoscale_view(scalex=False) return p + @munits._accepts_units(convert_x=['xmin', 'xmax'], + convert_y=['ymin', 'ymax']) def axvspan(self, xmin, xmax, ymin=0, ymax=1, **kwargs): """ Add a vertical span (rectangle) across the axes. @@ -976,14 +966,6 @@ def axvspan(self, xmin, xmax, ymin=0, ymax=1, **kwargs): """ trans = self.get_xaxis_transform(which='grid') - - # process the unit information - self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs) - - # first we need to strip away the units - xmin, xmax = self.convert_xunits([xmin, xmax]) - ymin, ymax = self.convert_yunits([ymin, ymax]) - verts = [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)] p = mpatches.Polygon(verts, **kwargs) p.set_transform(trans) @@ -991,6 +973,7 @@ def axvspan(self, xmin, xmax, ymin=0, ymax=1, **kwargs): self.autoscale_view(scaley=False) return p + @munits._accepts_units(convert_x=['xmin', 'xmax'], convert_y=['y']) @_preprocess_data(replace_names=["y", "xmin", "xmax", "colors"], label_namer="y") def hlines(self, y, xmin, xmax, colors='k', linestyles='solid', @@ -1026,14 +1009,6 @@ def hlines(self, y, xmin, xmax, colors='k', linestyles='solid', vlines : vertical lines axhline: horizontal line across the axes """ - - # We do the conversion first since not all unitized data is uniform - # process the unit information - self._process_unit_info([xmin, xmax], y, kwargs=kwargs) - y = self.convert_yunits(y) - xmin = self.convert_xunits(xmin) - xmax = self.convert_xunits(xmax) - if not np.iterable(y): y = [y] if not np.iterable(xmin): @@ -1067,6 +1042,7 @@ def hlines(self, y, xmin, xmax, colors='k', linestyles='solid', return lines + @munits._accepts_units(convert_x=['x'], convert_y=['ymin', 'ymax']) @_preprocess_data(replace_names=["x", "ymin", "ymax", "colors"], label_namer="x") def vlines(self, x, ymin, ymax, colors='k', linestyles='solid', @@ -1104,14 +1080,6 @@ def vlines(self, x, ymin, ymax, colors='k', linestyles='solid', hlines : horizontal lines axvline: vertical line across the axes """ - - self._process_unit_info(xdata=x, ydata=[ymin, ymax], kwargs=kwargs) - - # We do the conversion first since not all unitized data is uniform - x = self.convert_xunits(x) - ymin = self.convert_yunits(ymin) - ymax = self.convert_yunits(ymax) - if not np.iterable(x): x = [x] if not np.iterable(ymin): @@ -1144,6 +1112,8 @@ def vlines(self, x, ymin, ymax, colors='k', linestyles='solid', return lines + @munits._accepts_units(convert_x=['positions'], + convert_y=['lineoffsets', 'linelengths']) @_preprocess_data(replace_names=["positions", "lineoffsets", "linelengths", "linewidths", "colors", "linestyles"], @@ -1233,15 +1203,6 @@ def eventplot(self, positions, orientation='horizontal', lineoffsets=1, .. plot:: gallery/lines_bars_and_markers/eventplot_demo.py """ - self._process_unit_info(xdata=positions, - ydata=[lineoffsets, linelengths], - kwargs=kwargs) - - # We do the conversion first since not all unitized data is uniform - positions = self.convert_xunits(positions) - lineoffsets = self.convert_yunits(lineoffsets) - linelengths = self.convert_yunits(linelengths) - if not np.iterable(positions): positions = [positions] elif any(np.iterable(position) for position in positions): @@ -1984,6 +1945,7 @@ def xcorr(self, x, y, normed=True, detrend=mlab.detrend_none, #### Specialized plotting + @munits._accepts_units(convert_x=['x'], convert_y=['y']) @_preprocess_data(replace_names=["x", "y"], label_namer="y") def step(self, x, y, *args, where='pre', **kwargs): """ @@ -2453,6 +2415,7 @@ def barh(self, y, width, height=0.8, left=None, *, align="center", align=align, **kwargs) return patches + @munits._accepts_units(convert_x=['xranges'], convert_y=['yrange']) @_preprocess_data(label_namer=None) @docstring.dedent_interpd def broken_barh(self, xranges, yrange, **kwargs): @@ -2512,11 +2475,6 @@ def broken_barh(self, xranges, yrange, **kwargs): ydata = cbook.safe_first_element(yrange) else: ydata = None - self._process_unit_info(xdata=xdata, - ydata=ydata, - kwargs=kwargs) - xranges = self.convert_xunits(xranges) - yrange = self.convert_yunits(yrange) col = mcoll.BrokenBarHCollection(xranges, yrange, **kwargs) self.add_collection(col, autolim=True) @@ -4006,6 +3964,7 @@ def dopatch(xs, ys, **kwargs): return dict(whiskers=whiskers, caps=caps, boxes=boxes, medians=medians, fliers=fliers, means=means) + @munits._accepts_units(convert_x=['x'], convert_y=['y']) @_preprocess_data(replace_names=["x", "y", "s", "linewidths", "edgecolors", "c", "facecolor", "facecolors", "color"], @@ -4149,10 +4108,6 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None, if edgecolors is None and not rcParams['_internal.classic_mode']: edgecolors = 'face' - self._process_unit_info(xdata=x, ydata=y, kwargs=kwargs) - x = self.convert_xunits(x) - y = self.convert_yunits(y) - # np.ma.ravel yields an ndarray, not a masked array, # unless its argument is a masked array. xy_shape = (np.shape(x), np.shape(y)) @@ -4303,6 +4258,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None, return collection + @munits._accepts_units(convert_x=['x'], convert_y=['y']) @_preprocess_data(replace_names=["x", "y"], label_namer="y") @docstring.dedent_interpd def hexbin(self, x, y, C=None, gridsize=100, bins=None, @@ -4431,8 +4387,6 @@ def hexbin(self, x, y, C=None, gridsize=100, bins=None, %(Collection)s """ - self._process_unit_info(xdata=x, ydata=y, kwargs=kwargs) - x, y, C = cbook.delete_masked_points(x, y, C) # Set the size of the hexagon grid @@ -4921,6 +4875,7 @@ def fill(self, *args, **kwargs): self.autoscale_view() return patches + @munits._accepts_units(convert_x=['x'], convert_y=['y1', 'y2']) @_preprocess_data(replace_names=["x", "y1", "y2", "where"], label_namer=None) @docstring.dedent_interpd @@ -5014,14 +4969,10 @@ def fill_between(self, x, y1, y2=0, where=None, interpolate=False, kwargs['facecolor'] = \ self._get_patches_for_fill.get_next_color() - # Handle united data, such as dates - self._process_unit_info(xdata=x, ydata=y1, kwargs=kwargs) - self._process_unit_info(ydata=y2) - # Convert the arrays so we can work with them - x = ma.masked_invalid(self.convert_xunits(x)) - y1 = ma.masked_invalid(self.convert_yunits(y1)) - y2 = ma.masked_invalid(self.convert_yunits(y2)) + x = ma.masked_invalid(x) + y1 = ma.masked_invalid(y1) + y2 = ma.masked_invalid(y2) for name, array in [('x', x), ('y1', y1), ('y2', y2)]: if array.ndim > 1: @@ -5104,6 +5055,7 @@ def get_interp_point(ind): self.autoscale_view() return collection + @munits._accepts_units(convert_x=['x1', 'x2'], convert_y=['y']) @_preprocess_data(replace_names=["y", "x1", "x2", "where"], label_namer=None) @docstring.dedent_interpd @@ -5197,14 +5149,10 @@ def fill_betweenx(self, y, x1, x2=0, where=None, kwargs['facecolor'] = \ self._get_patches_for_fill.get_next_color() - # Handle united data, such as dates - self._process_unit_info(ydata=y, xdata=x1, kwargs=kwargs) - self._process_unit_info(xdata=x2) - # Convert the arrays so we can work with them - y = ma.masked_invalid(self.convert_yunits(y)) - x1 = ma.masked_invalid(self.convert_xunits(x1)) - x2 = ma.masked_invalid(self.convert_xunits(x2)) + y = ma.masked_invalid(y) + x1 = ma.masked_invalid(x1) + x2 = ma.masked_invalid(x2) for name, array in [('y', y), ('x1', x1), ('x2', x2)]: if array.ndim > 1: diff --git a/lib/matplotlib/axes/_base.py b/lib/matplotlib/axes/_base.py index b732c6e5779d..e4bedea66587 100644 --- a/lib/matplotlib/axes/_base.py +++ b/lib/matplotlib/axes/_base.py @@ -27,6 +27,7 @@ import matplotlib.font_manager as font_manager import matplotlib.text as mtext import matplotlib.image as mimage +import matplotlib.units as munits from matplotlib.rcsetup import cycler, validate_axisbelow @@ -3041,24 +3042,19 @@ def get_xlim(self): """ return tuple(self.viewLim.intervalx) - def _validate_converted_limits(self, limit, convert): + def _validate_converted_limits(self, converted_limit): """ Raise ValueError if converted limits are non-finite. Note that this function also accepts None as a limit argument. - - Returns - ------- - The limit value after call to convert(), or None if limit is None. - """ - if limit is not None: - converted_limit = convert(limit) - if (isinstance(converted_limit, Real) - and not np.isfinite(converted_limit)): + if converted_limit is not None: + if (isinstance(converted_limit, float) and + (not np.isreal(converted_limit) or + not np.isfinite(converted_limit))): raise ValueError("Axis limits cannot be NaN or Inf") - return converted_limit + @munits._accepts_units(convert_x=['left', 'right']) def set_xlim(self, left=None, right=None, emit=True, auto=False, *, xmin=None, xmax=None): """ @@ -3136,9 +3132,8 @@ def set_xlim(self, left=None, right=None, emit=True, auto=False, raise TypeError('Cannot pass both `xmax` and `right`') right = xmax - self._process_unit_info(xdata=(left, right)) - left = self._validate_converted_limits(left, self.convert_xunits) - right = self._validate_converted_limits(right, self.convert_xunits) + self._validate_converted_limits(left) + self._validate_converted_limits(right) old_left, old_right = self.get_xlim() if left is None: @@ -3393,6 +3388,7 @@ def get_ylim(self): """ return tuple(self.viewLim.intervaly) + @munits._accepts_units(convert_y=['bottom', 'top']) def set_ylim(self, bottom=None, top=None, emit=True, auto=False, *, ymin=None, ymax=None): """ @@ -3469,8 +3465,8 @@ def set_ylim(self, bottom=None, top=None, emit=True, auto=False, raise TypeError('Cannot pass both `ymax` and `top`') top = ymax - bottom = self._validate_converted_limits(bottom, self.convert_yunits) - top = self._validate_converted_limits(top, self.convert_yunits) + self._validate_converted_limits(bottom) + self._validate_converted_limits(top) old_bottom, old_top = self.get_ylim() diff --git a/lib/matplotlib/units.py b/lib/matplotlib/units.py index 12d7a7e2617e..ff5cf3a6dca9 100644 --- a/lib/matplotlib/units.py +++ b/lib/matplotlib/units.py @@ -43,12 +43,86 @@ def default_units(x, axis): """ from numbers import Number +import inspect +import functools import numpy as np from matplotlib import cbook +def _accepts_units(convert_x=[], convert_y=[]): + """ + A decorator for functions and methods that accept units. The parameters + indicated in *convert_x* and *convert_y* are used to update the axis + unit information, are converted, and then handed on to the decorated + function. + + The first argument of the decorated function must be an Axes. + + Parameters + ---------- + convert_x, convert_y : list + A list of integers or strings, indicating the arguments to be converted + """ + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + axes = args[0] + # Bind the incoming arguments to the function signature + bound_args = inspect.signature(func).bind(*args, **kwargs) + bound_args.apply_defaults() + # Get the original arguments - these will be modified later + arguments = bound_args.arguments + # Check for data kwarg + has_data = arguments.get('data') is not None + if has_data: + data = arguments['data'] + + # Helper method to process unit info, and convert *original_data* + def _process_info(original_data, axis): + if original_data is None: + return + if axis == 'x': + axes._process_unit_info(xdata=original_data, kwargs=kwargs) + converted_data = axes.convert_xunits(original_data) + elif axis == 'y': + axes._process_unit_info(ydata=original_data, kwargs=kwargs) + converted_data = axes.convert_yunits(original_data) + return converted_data + + # Loop through each argument to be converted, update the axis + # unit info, convert argument, and replace in *arguments* with + # converted values + for arg in convert_x: + if has_data and arguments[arg] in data: + data_arg = arguments[arg] + data[data_arg] = _process_info(data[data_arg], 'x') + else: + arguments[arg] = _process_info(arguments[arg], 'x') + + for arg in convert_y: + if has_data and arguments[arg] in data: + data_arg = arguments[arg] + data[data_arg] = _process_info(data[data_arg], 'y') + else: + arguments[arg] = _process_info(arguments[arg], 'y') + + if has_data: + arguments['data'] = data + # Update the arguments with converted values + bound_args.arguments = arguments + + # Give updated values to the original function + args = bound_args.args + kwargs = bound_args.kwargs + kwargs.pop('xunits', None) + kwargs.pop('yunits', None) + return func(*args, **kwargs) + return wrapper + return decorator + + class AxisInfo(object): """ Information to support default axis labeling, tick labeling, and diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index 6c8f1c9e0278..447087ae82a5 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -25,6 +25,7 @@ import matplotlib.projections as proj import matplotlib.scale as mscale import matplotlib.transforms as mtransforms +import matplotlib.units as munits from matplotlib.axes import Axes, rcParams from matplotlib.colors import Normalize, LightSource from matplotlib.transforms import Bbox @@ -595,6 +596,7 @@ def _determine_lims(self, xmin=None, xmax=None, *args, **kwargs): xmax += 0.05 return (xmin, xmax) + @munits._accepts_units(convert_x=['left', 'right']) def set_xlim3d(self, left=None, right=None, emit=True, auto=False, *, xmin=None, xmax=None): """ @@ -618,9 +620,8 @@ def set_xlim3d(self, left=None, right=None, emit=True, auto=False, raise TypeError('Cannot pass both `xmax` and `right`') right = xmax - self._process_unit_info(xdata=(left, right)) - left = self._validate_converted_limits(left, self.convert_xunits) - right = self._validate_converted_limits(right, self.convert_xunits) + self._validate_converted_limits(left) + self._validate_converted_limits(right) old_left, old_right = self.get_xlim() if left is None: @@ -653,6 +654,7 @@ def set_xlim3d(self, left=None, right=None, emit=True, auto=False, return left, right set_xlim = set_xlim3d + @munits._accepts_units(convert_y=['bottom', 'top']) def set_ylim3d(self, bottom=None, top=None, emit=True, auto=False, *, ymin=None, ymax=None): """ @@ -676,9 +678,8 @@ def set_ylim3d(self, bottom=None, top=None, emit=True, auto=False, raise TypeError('Cannot pass both `ymax` and `top`') top = ymax - self._process_unit_info(ydata=(bottom, top)) - bottom = self._validate_converted_limits(bottom, self.convert_yunits) - top = self._validate_converted_limits(top, self.convert_yunits) + self._validate_converted_limits(bottom) + self._validate_converted_limits(top) old_bottom, old_top = self.get_ylim() if bottom is None: @@ -735,8 +736,10 @@ def set_zlim3d(self, bottom=None, top=None, emit=True, auto=False, top = zmax self._process_unit_info(zdata=(bottom, top)) - bottom = self._validate_converted_limits(bottom, self.convert_zunits) - top = self._validate_converted_limits(top, self.convert_zunits) + bottom = self.convert_zunits(bottom) + top = self.convert_zunits(top) + self._validate_converted_limits(bottom) + self._validate_converted_limits(top) old_bottom, old_top = self.get_zlim() if bottom is None: