Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 8c97270

Browse files
committed
Apply unit decorator to more functions
Add some more unit decorators Add unit decorator to mplot3d
1 parent 837cb43 commit 8c97270

File tree

4 files changed

+39
-73
lines changed

4 files changed

+39
-73
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 17 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ def legend(self, *args, **kwargs):
385385
def _remove_legend(self, legend):
386386
self.legend_ = None
387387

388+
@munits._accepts_units(convert_x=['x'], convert_y=['y'])
388389
def text(self, x, y, s, fontdict=None, withdash=False, **kwargs):
389390
"""
390391
Add text to the axes.
@@ -619,6 +620,8 @@ def axvline(self, x=0, ymin=0, ymax=1, **kwargs):
619620
self.autoscale_view(scalex=scalex, scaley=False)
620621
return l
621622

623+
@munits._accepts_units(convert_x=['xmin', 'xmax'],
624+
convert_y=['ymin', 'ymax'])
622625
@docstring.dedent_interpd
623626
def axhspan(self, ymin, ymax, xmin=0, xmax=1, **kwargs):
624627
"""
@@ -660,21 +663,15 @@ def axhspan(self, ymin, ymax, xmin=0, xmax=1, **kwargs):
660663
axvspan : Add a vertical span across the axes.
661664
"""
662665
trans = self.get_yaxis_transform(which='grid')
663-
664-
# process the unit information
665-
self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs)
666-
667-
# first we need to strip away the units
668-
xmin, xmax = self.convert_xunits([xmin, xmax])
669-
ymin, ymax = self.convert_yunits([ymin, ymax])
670-
671666
verts = (xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)
672667
p = mpatches.Polygon(verts, **kwargs)
673668
p.set_transform(trans)
674669
self.add_patch(p)
675670
self.autoscale_view(scalex=False)
676671
return p
677672

673+
@munits._accepts_units(convert_x=['xmin', 'xmax'],
674+
convert_y=['ymin', 'ymax'])
678675
def axvspan(self, xmin, xmax, ymin=0, ymax=1, **kwargs):
679676
"""
680677
Add a vertical span (rectangle) across the axes.
@@ -725,21 +722,14 @@ def axvspan(self, xmin, xmax, ymin=0, ymax=1, **kwargs):
725722
726723
"""
727724
trans = self.get_xaxis_transform(which='grid')
728-
729-
# process the unit information
730-
self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs)
731-
732-
# first we need to strip away the units
733-
xmin, xmax = self.convert_xunits([xmin, xmax])
734-
ymin, ymax = self.convert_yunits([ymin, ymax])
735-
736725
verts = [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)]
737726
p = mpatches.Polygon(verts, **kwargs)
738727
p.set_transform(trans)
739728
self.add_patch(p)
740729
self.autoscale_view(scaley=False)
741730
return p
742731

732+
@munits._accepts_units(convert_x=['xmin', 'xmax'], convert_y=['y'])
743733
@_preprocess_data(replace_names=["y", "xmin", "xmax", "colors"],
744734
label_namer="y")
745735
def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',
@@ -775,14 +765,6 @@ def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',
775765
vlines : vertical lines
776766
axhline: horizontal line across the axes
777767
"""
778-
779-
# We do the conversion first since not all unitized data is uniform
780-
# process the unit information
781-
self._process_unit_info([xmin, xmax], y, kwargs=kwargs)
782-
y = self.convert_yunits(y)
783-
xmin = self.convert_xunits(xmin)
784-
xmax = self.convert_xunits(xmax)
785-
786768
if not iterable(y):
787769
y = [y]
788770
if not iterable(xmin):
@@ -816,6 +798,7 @@ def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',
816798

817799
return lines
818800

801+
@munits._accepts_units(convert_x=['x'], convert_y=['ymin', 'ymax'])
819802
@_preprocess_data(replace_names=["x", "ymin", "ymax", "colors"],
820803
label_namer="x")
821804
def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',
@@ -853,14 +836,6 @@ def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',
853836
hlines : horizontal lines
854837
axvline: vertical line across the axes
855838
"""
856-
857-
self._process_unit_info(xdata=x, ydata=[ymin, ymax], kwargs=kwargs)
858-
859-
# We do the conversion first since not all unitized data is uniform
860-
x = self.convert_xunits(x)
861-
ymin = self.convert_yunits(ymin)
862-
ymax = self.convert_yunits(ymax)
863-
864839
if not iterable(x):
865840
x = [x]
866841
if not iterable(ymin):
@@ -893,6 +868,8 @@ def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',
893868

894869
return lines
895870

871+
@munits._accepts_units(convert_x=['positions'],
872+
convert_y=['lineoffsets', 'linelengths'])
896873
@_preprocess_data(replace_names=["positions", "lineoffsets",
897874
"linelengths", "linewidths",
898875
"colors", "linestyles"],
@@ -982,15 +959,6 @@ def eventplot(self, positions, orientation='horizontal', lineoffsets=1,
982959
983960
.. plot:: gallery/lines_bars_and_markers/eventplot_demo.py
984961
"""
985-
self._process_unit_info(xdata=positions,
986-
ydata=[lineoffsets, linelengths],
987-
kwargs=kwargs)
988-
989-
# We do the conversion first since not all unitized data is uniform
990-
positions = self.convert_xunits(positions)
991-
lineoffsets = self.convert_yunits(lineoffsets)
992-
linelengths = self.convert_yunits(linelengths)
993-
994962
if not iterable(positions):
995963
positions = [positions]
996964
elif any(iterable(position) for position in positions):
@@ -4628,6 +4596,7 @@ def fill(self, *args, **kwargs):
46284596
self.autoscale_view()
46294597
return patches
46304598

4599+
@munits._accepts_units(convert_x=['x'], convert_y=['y1', 'y2'])
46314600
@_preprocess_data(replace_names=["x", "y1", "y2", "where"],
46324601
label_namer=None)
46334602
@docstring.dedent_interpd
@@ -4721,14 +4690,10 @@ def fill_between(self, x, y1, y2=0, where=None, interpolate=False,
47214690
kwargs['facecolor'] = \
47224691
self._get_patches_for_fill.get_next_color()
47234692

4724-
# Handle united data, such as dates
4725-
self._process_unit_info(xdata=x, ydata=y1, kwargs=kwargs)
4726-
self._process_unit_info(ydata=y2)
4727-
47284693
# Convert the arrays so we can work with them
4729-
x = ma.masked_invalid(self.convert_xunits(x))
4730-
y1 = ma.masked_invalid(self.convert_yunits(y1))
4731-
y2 = ma.masked_invalid(self.convert_yunits(y2))
4694+
x = ma.masked_invalid(x)
4695+
y1 = ma.masked_invalid(y1)
4696+
y2 = ma.masked_invalid(y2)
47324697

47334698
for name, array in [('x', x), ('y1', y1), ('y2', y2)]:
47344699
if array.ndim > 1:
@@ -4811,6 +4776,7 @@ def get_interp_point(ind):
48114776
self.autoscale_view()
48124777
return collection
48134778

4779+
@munits._accepts_units(convert_x=['x1', 'x2'], convert_y=['y'])
48144780
@_preprocess_data(replace_names=["y", "x1", "x2", "where"],
48154781
label_namer=None)
48164782
@docstring.dedent_interpd
@@ -4904,14 +4870,10 @@ def fill_betweenx(self, y, x1, x2=0, where=None,
49044870
kwargs['facecolor'] = \
49054871
self._get_patches_for_fill.get_next_color()
49064872

4907-
# Handle united data, such as dates
4908-
self._process_unit_info(ydata=y, xdata=x1, kwargs=kwargs)
4909-
self._process_unit_info(xdata=x2)
4910-
49114873
# Convert the arrays so we can work with them
4912-
y = ma.masked_invalid(self.convert_yunits(y))
4913-
x1 = ma.masked_invalid(self.convert_xunits(x1))
4914-
x2 = ma.masked_invalid(self.convert_xunits(x2))
4874+
y = ma.masked_invalid(y)
4875+
x1 = ma.masked_invalid(x1)
4876+
x2 = ma.masked_invalid(x2)
49154877

49164878
for name, array in [('y', y), ('x1', x1), ('x2', x2)]:
49174879
if array.ndim > 1:

lib/matplotlib/axes/_base.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import matplotlib.font_manager as font_manager
2727
import matplotlib.text as mtext
2828
import matplotlib.image as mimage
29+
import matplotlib.units as munits
2930
from matplotlib.offsetbox import OffsetBox
3031
from matplotlib.artist import allow_rasterization
3132
from matplotlib.legend import Legend
@@ -3007,20 +3008,19 @@ def get_xlim(self):
30073008
"""
30083009
return tuple(self.viewLim.intervalx)
30093010

3010-
def _validate_converted_limits(self, limit, convert):
3011+
def _validate_converted_limits(self, converted_limit):
30113012
"""
30123013
Raise ValueError if converted limits are non-finite.
30133014
30143015
Note that this function also accepts None as a limit argument.
30153016
"""
3016-
if limit is not None:
3017-
converted_limit = convert(limit)
3017+
if converted_limit is not None:
30183018
if (isinstance(converted_limit, float) and
30193019
(not np.isreal(converted_limit) or
30203020
not np.isfinite(converted_limit))):
30213021
raise ValueError("Axis limits cannot be NaN or Inf")
3022-
return converted_limit
30233022

3023+
@munits._accepts_units(convert_x=['left', 'right'])
30243024
def set_xlim(self, left=None, right=None, emit=True, auto=False,
30253025
*, xmin=None, xmax=None):
30263026
"""
@@ -3098,9 +3098,8 @@ def set_xlim(self, left=None, right=None, emit=True, auto=False,
30983098
raise TypeError('Cannot pass both `xmax` and `right`')
30993099
right = xmax
31003100

3101-
self._process_unit_info(xdata=(left, right))
3102-
left = self._validate_converted_limits(left, self.convert_xunits)
3103-
right = self._validate_converted_limits(right, self.convert_xunits)
3101+
self._validate_converted_limits(left)
3102+
self._validate_converted_limits(right)
31043103

31053104
old_left, old_right = self.get_xlim()
31063105
if left is None:
@@ -3361,6 +3360,7 @@ def get_ylim(self):
33613360
"""
33623361
return tuple(self.viewLim.intervaly)
33633362

3363+
@munits._accepts_units(convert_y=['bottom', 'top'])
33643364
def set_ylim(self, bottom=None, top=None, emit=True, auto=False,
33653365
*, ymin=None, ymax=None):
33663366
"""
@@ -3437,8 +3437,8 @@ def set_ylim(self, bottom=None, top=None, emit=True, auto=False,
34373437
raise TypeError('Cannot pass both `ymax` and `top`')
34383438
top = ymax
34393439

3440-
bottom = self._validate_converted_limits(bottom, self.convert_yunits)
3441-
top = self._validate_converted_limits(top, self.convert_yunits)
3440+
self._validate_converted_limits(bottom)
3441+
self._validate_converted_limits(top)
34423442

34433443
old_bottom, old_top = self.get_ylim()
34443444

lib/matplotlib/units.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def default_units(x, axis):
5151
from matplotlib.cbook import iterable, safe_first_element
5252

5353

54-
def _accepts_units(convert_x, convert_y):
54+
def _accepts_units(convert_x=[], convert_y=[]):
5555
"""
5656
A decorator for functions and methods that accept units. The parameters
5757
indicated in *convert_x* and *convert_y* are used to update the axis
@@ -69,6 +69,7 @@ def wrapper(*args, **kwargs):
6969
axes = args[0]
7070
# Bind the incoming arguments to the function signature
7171
bound_args = inspect.signature(func).bind(*args, **kwargs)
72+
bound_args.apply_defaults()
7273
# Get the original arguments - these will be modified later
7374
arguments = bound_args.arguments
7475
# Check for data kwarg

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import matplotlib.docstring as docstring
2727
import matplotlib.scale as mscale
2828
import matplotlib.transforms as mtransforms
29+
import matplotlib.units as munits
2930
from matplotlib.axes import Axes, rcParams
3031
from matplotlib.colors import Normalize, LightSource
3132
from matplotlib.transforms import Bbox
@@ -597,6 +598,7 @@ def _determine_lims(self, xmin=None, xmax=None, *args, **kwargs):
597598
xmax += 0.05
598599
return (xmin, xmax)
599600

601+
@munits._accepts_units(convert_x=['left', 'right'])
600602
def set_xlim3d(self, left=None, right=None, emit=True, auto=False,
601603
*, xmin=None, xmax=None):
602604
"""
@@ -620,9 +622,8 @@ def set_xlim3d(self, left=None, right=None, emit=True, auto=False,
620622
raise TypeError('Cannot pass both `xmax` and `right`')
621623
right = xmax
622624

623-
self._process_unit_info(xdata=(left, right))
624-
left = self._validate_converted_limits(left, self.convert_xunits)
625-
right = self._validate_converted_limits(right, self.convert_xunits)
625+
self._validate_converted_limits(left)
626+
self._validate_converted_limits(right)
626627

627628
old_left, old_right = self.get_xlim()
628629
if left is None:
@@ -655,6 +656,7 @@ def set_xlim3d(self, left=None, right=None, emit=True, auto=False,
655656
return left, right
656657
set_xlim = set_xlim3d
657658

659+
@munits._accepts_units(convert_y=['bottom', 'top'])
658660
def set_ylim3d(self, bottom=None, top=None, emit=True, auto=False,
659661
*, ymin=None, ymax=None):
660662
"""
@@ -678,9 +680,8 @@ def set_ylim3d(self, bottom=None, top=None, emit=True, auto=False,
678680
raise TypeError('Cannot pass both `ymax` and `top`')
679681
top = ymax
680682

681-
self._process_unit_info(ydata=(bottom, top))
682-
bottom = self._validate_converted_limits(bottom, self.convert_yunits)
683-
top = self._validate_converted_limits(top, self.convert_yunits)
683+
self._validate_converted_limits(bottom)
684+
self._validate_converted_limits(top)
684685

685686
old_bottom, old_top = self.get_ylim()
686687
if bottom is None:
@@ -737,8 +738,10 @@ def set_zlim3d(self, bottom=None, top=None, emit=True, auto=False,
737738
top = zmax
738739

739740
self._process_unit_info(zdata=(bottom, top))
740-
bottom = self._validate_converted_limits(bottom, self.convert_zunits)
741-
top = self._validate_converted_limits(top, self.convert_zunits)
741+
bottom = self.convert_zunits(bottom)
742+
top = self.convert_zunits(top)
743+
self._validate_converted_limits(bottom)
744+
self._validate_converted_limits(top)
742745

743746
old_bottom, old_top = self.get_zlim()
744747
if bottom is None:

0 commit comments

Comments
 (0)