Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 81eff1f

Browse files
committed
Apply unit decorator to more functions
Add some more unit decorators Add unit decorator to mplot3d
1 parent 8e0a9cf commit 81eff1f

File tree

4 files changed

+41
-75
lines changed

4 files changed

+41
-75
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 16 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,8 @@ def axvline(self, x=0, ymin=0, ymax=1, **kwargs):
871871
self.autoscale_view(scalex=scalex, scaley=False)
872872
return l
873873

874+
@munits._accepts_units(convert_x=['xmin', 'xmax'],
875+
convert_y=['ymin', 'ymax'])
874876
@docstring.dedent_interpd
875877
def axhspan(self, ymin, ymax, xmin=0, xmax=1, **kwargs):
876878
"""
@@ -912,21 +914,15 @@ def axhspan(self, ymin, ymax, xmin=0, xmax=1, **kwargs):
912914
axvspan : Add a vertical span across the axes.
913915
"""
914916
trans = self.get_yaxis_transform(which='grid')
915-
916-
# process the unit information
917-
self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs)
918-
919-
# first we need to strip away the units
920-
xmin, xmax = self.convert_xunits([xmin, xmax])
921-
ymin, ymax = self.convert_yunits([ymin, ymax])
922-
923917
verts = (xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)
924918
p = mpatches.Polygon(verts, **kwargs)
925919
p.set_transform(trans)
926920
self.add_patch(p)
927921
self.autoscale_view(scalex=False)
928922
return p
929923

924+
@munits._accepts_units(convert_x=['xmin', 'xmax'],
925+
convert_y=['ymin', 'ymax'])
930926
def axvspan(self, xmin, xmax, ymin=0, ymax=1, **kwargs):
931927
"""
932928
Add a vertical span (rectangle) across the axes.
@@ -977,21 +973,14 @@ def axvspan(self, xmin, xmax, ymin=0, ymax=1, **kwargs):
977973
978974
"""
979975
trans = self.get_xaxis_transform(which='grid')
980-
981-
# process the unit information
982-
self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs)
983-
984-
# first we need to strip away the units
985-
xmin, xmax = self.convert_xunits([xmin, xmax])
986-
ymin, ymax = self.convert_yunits([ymin, ymax])
987-
988976
verts = [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)]
989977
p = mpatches.Polygon(verts, **kwargs)
990978
p.set_transform(trans)
991979
self.add_patch(p)
992980
self.autoscale_view(scaley=False)
993981
return p
994982

983+
@munits._accepts_units(convert_x=['xmin', 'xmax'], convert_y=['y'])
995984
@_preprocess_data(replace_names=["y", "xmin", "xmax", "colors"],
996985
label_namer="y")
997986
def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',
@@ -1027,14 +1016,6 @@ def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',
10271016
vlines : vertical lines
10281017
axhline: horizontal line across the axes
10291018
"""
1030-
1031-
# We do the conversion first since not all unitized data is uniform
1032-
# process the unit information
1033-
self._process_unit_info([xmin, xmax], y, kwargs=kwargs)
1034-
y = self.convert_yunits(y)
1035-
xmin = self.convert_xunits(xmin)
1036-
xmax = self.convert_xunits(xmax)
1037-
10381019
if not np.iterable(y):
10391020
y = [y]
10401021
if not np.iterable(xmin):
@@ -1068,6 +1049,7 @@ def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',
10681049

10691050
return lines
10701051

1052+
@munits._accepts_units(convert_x=['x'], convert_y=['ymin', 'ymax'])
10711053
@_preprocess_data(replace_names=["x", "ymin", "ymax", "colors"],
10721054
label_namer="x")
10731055
def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',
@@ -1105,14 +1087,6 @@ def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',
11051087
hlines : horizontal lines
11061088
axvline: vertical line across the axes
11071089
"""
1108-
1109-
self._process_unit_info(xdata=x, ydata=[ymin, ymax], kwargs=kwargs)
1110-
1111-
# We do the conversion first since not all unitized data is uniform
1112-
x = self.convert_xunits(x)
1113-
ymin = self.convert_yunits(ymin)
1114-
ymax = self.convert_yunits(ymax)
1115-
11161090
if not np.iterable(x):
11171091
x = [x]
11181092
if not np.iterable(ymin):
@@ -1145,6 +1119,8 @@ def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',
11451119

11461120
return lines
11471121

1122+
@munits._accepts_units(convert_x=['positions'],
1123+
convert_y=['lineoffsets', 'linelengths'])
11481124
@_preprocess_data(replace_names=["positions", "lineoffsets",
11491125
"linelengths", "linewidths",
11501126
"colors", "linestyles"],
@@ -1234,15 +1210,6 @@ def eventplot(self, positions, orientation='horizontal', lineoffsets=1,
12341210
12351211
.. plot:: gallery/lines_bars_and_markers/eventplot_demo.py
12361212
"""
1237-
self._process_unit_info(xdata=positions,
1238-
ydata=[lineoffsets, linelengths],
1239-
kwargs=kwargs)
1240-
1241-
# We do the conversion first since not all unitized data is uniform
1242-
positions = self.convert_xunits(positions)
1243-
lineoffsets = self.convert_yunits(lineoffsets)
1244-
linelengths = self.convert_yunits(linelengths)
1245-
12461213
if not np.iterable(positions):
12471214
positions = [positions]
12481215
elif any(np.iterable(position) for position in positions):
@@ -4919,6 +4886,7 @@ def fill(self, *args, **kwargs):
49194886
self.autoscale_view()
49204887
return patches
49214888

4889+
@munits._accepts_units(convert_x=['x'], convert_y=['y1', 'y2'])
49224890
@_preprocess_data(replace_names=["x", "y1", "y2", "where"],
49234891
label_namer=None)
49244892
@docstring.dedent_interpd
@@ -5012,14 +4980,10 @@ def fill_between(self, x, y1, y2=0, where=None, interpolate=False,
50124980
kwargs['facecolor'] = \
50134981
self._get_patches_for_fill.get_next_color()
50144982

5015-
# Handle united data, such as dates
5016-
self._process_unit_info(xdata=x, ydata=y1, kwargs=kwargs)
5017-
self._process_unit_info(ydata=y2)
5018-
50194983
# Convert the arrays so we can work with them
5020-
x = ma.masked_invalid(self.convert_xunits(x))
5021-
y1 = ma.masked_invalid(self.convert_yunits(y1))
5022-
y2 = ma.masked_invalid(self.convert_yunits(y2))
4984+
x = ma.masked_invalid(x)
4985+
y1 = ma.masked_invalid(y1)
4986+
y2 = ma.masked_invalid(y2)
50234987

50244988
for name, array in [('x', x), ('y1', y1), ('y2', y2)]:
50254989
if array.ndim > 1:
@@ -5102,6 +5066,7 @@ def get_interp_point(ind):
51025066
self.autoscale_view()
51035067
return collection
51045068

5069+
@munits._accepts_units(convert_x=['x1', 'x2'], convert_y=['y'])
51055070
@_preprocess_data(replace_names=["y", "x1", "x2", "where"],
51065071
label_namer=None)
51075072
@docstring.dedent_interpd
@@ -5195,14 +5160,10 @@ def fill_betweenx(self, y, x1, x2=0, where=None,
51955160
kwargs['facecolor'] = \
51965161
self._get_patches_for_fill.get_next_color()
51975162

5198-
# Handle united data, such as dates
5199-
self._process_unit_info(ydata=y, xdata=x1, kwargs=kwargs)
5200-
self._process_unit_info(xdata=x2)
5201-
52025163
# Convert the arrays so we can work with them
5203-
y = ma.masked_invalid(self.convert_yunits(y))
5204-
x1 = ma.masked_invalid(self.convert_xunits(x1))
5205-
x2 = ma.masked_invalid(self.convert_xunits(x2))
5164+
y = ma.masked_invalid(y)
5165+
x1 = ma.masked_invalid(x1)
5166+
x2 = ma.masked_invalid(x2)
52065167

52075168
for name, array in [('y', y), ('x1', x1), ('x2', x2)]:
52085169
if array.ndim > 1:

lib/matplotlib/axes/_base.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import matplotlib.font_manager as font_manager
2828
import matplotlib.text as mtext
2929
import matplotlib.image as mimage
30+
import matplotlib.units as munits
3031

3132
from matplotlib.rcsetup import cycler, validate_axisbelow
3233

@@ -3041,19 +3042,19 @@ def get_xlim(self):
30413042
"""
30423043
return tuple(self.viewLim.intervalx)
30433044

3044-
def _validate_converted_limits(self, limit, convert):
3045+
def _validate_converted_limits(self, converted_limit):
30453046
"""
30463047
Raise ValueError if converted limits are non-finite.
30473048
30483049
Note that this function also accepts None as a limit argument.
30493050
"""
3050-
if limit is not None:
3051-
converted_limit = convert(limit)
3052-
if (isinstance(converted_limit, Real)
3053-
and not np.isfinite(converted_limit)):
3051+
if converted_limit is not None:
3052+
if (isinstance(converted_limit, float) and
3053+
(not np.isreal(converted_limit) or
3054+
not np.isfinite(converted_limit))):
30543055
raise ValueError("Axis limits cannot be NaN or Inf")
3055-
return converted_limit
30563056

3057+
@munits._accepts_units(convert_x=['left', 'right'])
30573058
def set_xlim(self, left=None, right=None, emit=True, auto=False,
30583059
*, xmin=None, xmax=None):
30593060
"""
@@ -3131,9 +3132,8 @@ def set_xlim(self, left=None, right=None, emit=True, auto=False,
31313132
raise TypeError('Cannot pass both `xmax` and `right`')
31323133
right = xmax
31333134

3134-
self._process_unit_info(xdata=(left, right))
3135-
left = self._validate_converted_limits(left, self.convert_xunits)
3136-
right = self._validate_converted_limits(right, self.convert_xunits)
3135+
self._validate_converted_limits(left)
3136+
self._validate_converted_limits(right)
31373137

31383138
old_left, old_right = self.get_xlim()
31393139
if left is None:
@@ -3388,6 +3388,7 @@ def get_ylim(self):
33883388
"""
33893389
return tuple(self.viewLim.intervaly)
33903390

3391+
@munits._accepts_units(convert_y=['bottom', 'top'])
33913392
def set_ylim(self, bottom=None, top=None, emit=True, auto=False,
33923393
*, ymin=None, ymax=None):
33933394
"""
@@ -3464,8 +3465,8 @@ def set_ylim(self, bottom=None, top=None, emit=True, auto=False,
34643465
raise TypeError('Cannot pass both `ymax` and `top`')
34653466
top = ymax
34663467

3467-
bottom = self._validate_converted_limits(bottom, self.convert_yunits)
3468-
top = self._validate_converted_limits(top, self.convert_yunits)
3468+
self._validate_converted_limits(bottom)
3469+
self._validate_converted_limits(top)
34693470

34703471
old_bottom, old_top = self.get_ylim()
34713472

lib/matplotlib/units.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def default_units(x, axis):
5151
from matplotlib import cbook
5252

5353

54-
def _accepts_units(convert_x, convert_y):
54+
def _accepts_units(convert_x=[], convert_y=[]):
5555
"""
5656
A decorator for functions and methods that accept units. The parameters
5757
indicated in *convert_x* and *convert_y* are used to update the axis
@@ -69,6 +69,7 @@ def wrapper(*args, **kwargs):
6969
axes = args[0]
7070
# Bind the incoming arguments to the function signature
7171
bound_args = inspect.signature(func).bind(*args, **kwargs)
72+
bound_args.apply_defaults()
7273
# Get the original arguments - these will be modified later
7374
arguments = bound_args.arguments
7475
# Check for data kwarg

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import matplotlib.projections as proj
2626
import matplotlib.scale as mscale
2727
import matplotlib.transforms as mtransforms
28+
import matplotlib.units as munits
2829
from matplotlib.axes import Axes, rcParams
2930
from matplotlib.colors import Normalize, LightSource
3031
from matplotlib.transforms import Bbox
@@ -595,6 +596,7 @@ def _determine_lims(self, xmin=None, xmax=None, *args, **kwargs):
595596
xmax += 0.05
596597
return (xmin, xmax)
597598

599+
@munits._accepts_units(convert_x=['left', 'right'])
598600
def set_xlim3d(self, left=None, right=None, emit=True, auto=False,
599601
*, xmin=None, xmax=None):
600602
"""
@@ -618,9 +620,8 @@ def set_xlim3d(self, left=None, right=None, emit=True, auto=False,
618620
raise TypeError('Cannot pass both `xmax` and `right`')
619621
right = xmax
620622

621-
self._process_unit_info(xdata=(left, right))
622-
left = self._validate_converted_limits(left, self.convert_xunits)
623-
right = self._validate_converted_limits(right, self.convert_xunits)
623+
self._validate_converted_limits(left)
624+
self._validate_converted_limits(right)
624625

625626
old_left, old_right = self.get_xlim()
626627
if left is None:
@@ -653,6 +654,7 @@ def set_xlim3d(self, left=None, right=None, emit=True, auto=False,
653654
return left, right
654655
set_xlim = set_xlim3d
655656

657+
@munits._accepts_units(convert_y=['bottom', 'top'])
656658
def set_ylim3d(self, bottom=None, top=None, emit=True, auto=False,
657659
*, ymin=None, ymax=None):
658660
"""
@@ -676,9 +678,8 @@ def set_ylim3d(self, bottom=None, top=None, emit=True, auto=False,
676678
raise TypeError('Cannot pass both `ymax` and `top`')
677679
top = ymax
678680

679-
self._process_unit_info(ydata=(bottom, top))
680-
bottom = self._validate_converted_limits(bottom, self.convert_yunits)
681-
top = self._validate_converted_limits(top, self.convert_yunits)
681+
self._validate_converted_limits(bottom)
682+
self._validate_converted_limits(top)
682683

683684
old_bottom, old_top = self.get_ylim()
684685
if bottom is None:
@@ -735,8 +736,10 @@ def set_zlim3d(self, bottom=None, top=None, emit=True, auto=False,
735736
top = zmax
736737

737738
self._process_unit_info(zdata=(bottom, top))
738-
bottom = self._validate_converted_limits(bottom, self.convert_zunits)
739-
top = self._validate_converted_limits(top, self.convert_zunits)
739+
bottom = self.convert_zunits(bottom)
740+
top = self.convert_zunits(top)
741+
self._validate_converted_limits(bottom)
742+
self._validate_converted_limits(top)
740743

741744
old_bottom, old_top = self.get_zlim()
742745
if bottom is None:

0 commit comments

Comments
 (0)