Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a786018

Browse files
committed
Apply unit decorator to more functions
Add some more unit decorators Add unit decorator to mplot3d
1 parent 50a8da5 commit a786018

File tree

4 files changed

+39
-73
lines changed

4 files changed

+39
-73
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 17 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ def legend(self, *args, **kwargs):
385385
def _remove_legend(self, legend):
386386
self.legend_ = None
387387

388+
@munits._accepts_units(convert_x=['x'], convert_y=['y'])
388389
def text(self, x, y, s, fontdict=None, withdash=False, **kwargs):
389390
"""
390391
Add text to the axes.
@@ -619,6 +620,8 @@ def axvline(self, x=0, ymin=0, ymax=1, **kwargs):
619620
self.autoscale_view(scalex=scalex, scaley=False)
620621
return l
621622

623+
@munits._accepts_units(convert_x=['xmin', 'xmax'],
624+
convert_y=['ymin', 'ymax'])
622625
@docstring.dedent_interpd
623626
def axhspan(self, ymin, ymax, xmin=0, xmax=1, **kwargs):
624627
"""
@@ -660,21 +663,15 @@ def axhspan(self, ymin, ymax, xmin=0, xmax=1, **kwargs):
660663
axvspan : Add a vertical span across the axes.
661664
"""
662665
trans = self.get_yaxis_transform(which='grid')
663-
664-
# process the unit information
665-
self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs)
666-
667-
# first we need to strip away the units
668-
xmin, xmax = self.convert_xunits([xmin, xmax])
669-
ymin, ymax = self.convert_yunits([ymin, ymax])
670-
671666
verts = (xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)
672667
p = mpatches.Polygon(verts, **kwargs)
673668
p.set_transform(trans)
674669
self.add_patch(p)
675670
self.autoscale_view(scalex=False)
676671
return p
677672

673+
@munits._accepts_units(convert_x=['xmin', 'xmax'],
674+
convert_y=['ymin', 'ymax'])
678675
def axvspan(self, xmin, xmax, ymin=0, ymax=1, **kwargs):
679676
"""
680677
Add a vertical span (rectangle) across the axes.
@@ -725,21 +722,14 @@ def axvspan(self, xmin, xmax, ymin=0, ymax=1, **kwargs):
725722
726723
"""
727724
trans = self.get_xaxis_transform(which='grid')
728-
729-
# process the unit information
730-
self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs)
731-
732-
# first we need to strip away the units
733-
xmin, xmax = self.convert_xunits([xmin, xmax])
734-
ymin, ymax = self.convert_yunits([ymin, ymax])
735-
736725
verts = [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)]
737726
p = mpatches.Polygon(verts, **kwargs)
738727
p.set_transform(trans)
739728
self.add_patch(p)
740729
self.autoscale_view(scaley=False)
741730
return p
742731

732+
@munits._accepts_units(convert_x=['xmin', 'xmax'], convert_y=['y'])
743733
@_preprocess_data(replace_names=["y", "xmin", "xmax", "colors"],
744734
label_namer="y")
745735
def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',
@@ -775,14 +765,6 @@ def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',
775765
vlines : vertical lines
776766
axhline: horizontal line across the axes
777767
"""
778-
779-
# We do the conversion first since not all unitized data is uniform
780-
# process the unit information
781-
self._process_unit_info([xmin, xmax], y, kwargs=kwargs)
782-
y = self.convert_yunits(y)
783-
xmin = self.convert_xunits(xmin)
784-
xmax = self.convert_xunits(xmax)
785-
786768
if not iterable(y):
787769
y = [y]
788770
if not iterable(xmin):
@@ -816,6 +798,7 @@ def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',
816798

817799
return lines
818800

801+
@munits._accepts_units(convert_x=['x'], convert_y=['ymin', 'ymax'])
819802
@_preprocess_data(replace_names=["x", "ymin", "ymax", "colors"],
820803
label_namer="x")
821804
def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',
@@ -853,14 +836,6 @@ def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',
853836
hlines : horizontal lines
854837
axvline: vertical line across the axes
855838
"""
856-
857-
self._process_unit_info(xdata=x, ydata=[ymin, ymax], kwargs=kwargs)
858-
859-
# We do the conversion first since not all unitized data is uniform
860-
x = self.convert_xunits(x)
861-
ymin = self.convert_yunits(ymin)
862-
ymax = self.convert_yunits(ymax)
863-
864839
if not iterable(x):
865840
x = [x]
866841
if not iterable(ymin):
@@ -893,6 +868,8 @@ def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',
893868

894869
return lines
895870

871+
@munits._accepts_units(convert_x=['positions'],
872+
convert_y=['lineoffsets', 'linelengths'])
896873
@_preprocess_data(replace_names=["positions", "lineoffsets",
897874
"linelengths", "linewidths",
898875
"colors", "linestyles"],
@@ -982,15 +959,6 @@ def eventplot(self, positions, orientation='horizontal', lineoffsets=1,
982959
983960
.. plot:: gallery/lines_bars_and_markers/eventplot_demo.py
984961
"""
985-
self._process_unit_info(xdata=positions,
986-
ydata=[lineoffsets, linelengths],
987-
kwargs=kwargs)
988-
989-
# We do the conversion first since not all unitized data is uniform
990-
positions = self.convert_xunits(positions)
991-
lineoffsets = self.convert_yunits(lineoffsets)
992-
linelengths = self.convert_yunits(linelengths)
993-
994962
if not iterable(positions):
995963
positions = [positions]
996964
elif any(iterable(position) for position in positions):
@@ -4628,6 +4596,7 @@ def fill(self, *args, **kwargs):
46284596
self.autoscale_view()
46294597
return patches
46304598

4599+
@munits._accepts_units(convert_x=['x'], convert_y=['y1', 'y2'])
46314600
@_preprocess_data(replace_names=["x", "y1", "y2", "where"],
46324601
label_namer=None)
46334602
@docstring.dedent_interpd
@@ -4721,14 +4690,10 @@ def fill_between(self, x, y1, y2=0, where=None, interpolate=False,
47214690
kwargs['facecolor'] = \
47224691
self._get_patches_for_fill.get_next_color()
47234692

4724-
# Handle united data, such as dates
4725-
self._process_unit_info(xdata=x, ydata=y1, kwargs=kwargs)
4726-
self._process_unit_info(ydata=y2)
4727-
47284693
# Convert the arrays so we can work with them
4729-
x = ma.masked_invalid(self.convert_xunits(x))
4730-
y1 = ma.masked_invalid(self.convert_yunits(y1))
4731-
y2 = ma.masked_invalid(self.convert_yunits(y2))
4694+
x = ma.masked_invalid(x)
4695+
y1 = ma.masked_invalid(y1)
4696+
y2 = ma.masked_invalid(y2)
47324697

47334698
for name, array in [('x', x), ('y1', y1), ('y2', y2)]:
47344699
if array.ndim > 1:
@@ -4811,6 +4776,7 @@ def get_interp_point(ind):
48114776
self.autoscale_view()
48124777
return collection
48134778

4779+
@munits._accepts_units(convert_x=['x1', 'x2'], convert_y=['y'])
48144780
@_preprocess_data(replace_names=["y", "x1", "x2", "where"],
48154781
label_namer=None)
48164782
@docstring.dedent_interpd
@@ -4904,14 +4870,10 @@ def fill_betweenx(self, y, x1, x2=0, where=None,
49044870
kwargs['facecolor'] = \
49054871
self._get_patches_for_fill.get_next_color()
49064872

4907-
# Handle united data, such as dates
4908-
self._process_unit_info(ydata=y, xdata=x1, kwargs=kwargs)
4909-
self._process_unit_info(xdata=x2)
4910-
49114873
# Convert the arrays so we can work with them
4912-
y = ma.masked_invalid(self.convert_yunits(y))
4913-
x1 = ma.masked_invalid(self.convert_xunits(x1))
4914-
x2 = ma.masked_invalid(self.convert_xunits(x2))
4874+
y = ma.masked_invalid(y)
4875+
x1 = ma.masked_invalid(x1)
4876+
x2 = ma.masked_invalid(x2)
49154877

49164878
for name, array in [('y', y), ('x1', x1), ('x2', x2)]:
49174879
if array.ndim > 1:

lib/matplotlib/axes/_base.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import matplotlib.font_manager as font_manager
2727
import matplotlib.text as mtext
2828
import matplotlib.image as mimage
29+
import matplotlib.units as munits
2930
from matplotlib.offsetbox import OffsetBox
3031
from matplotlib.artist import allow_rasterization
3132
from matplotlib.legend import Legend
@@ -3007,20 +3008,19 @@ def get_xlim(self):
30073008
"""
30083009
return tuple(self.viewLim.intervalx)
30093010

3010-
def _validate_converted_limits(self, limit, convert):
3011+
def _validate_converted_limits(self, converted_limit):
30113012
"""
30123013
Raise ValueError if converted limits are non-finite.
30133014
30143015
Note that this function also accepts None as a limit argument.
30153016
"""
3016-
if limit is not None:
3017-
converted_limit = convert(limit)
3017+
if converted_limit is not None:
30183018
if (isinstance(converted_limit, float) and
30193019
(not np.isreal(converted_limit) or
30203020
not np.isfinite(converted_limit))):
30213021
raise ValueError("Axis limits cannot be NaN or Inf")
3022-
return converted_limit
30233022

3023+
@munits._accepts_units(convert_x=['left', 'right'])
30243024
def set_xlim(self, left=None, right=None, emit=True, auto=False, **kw):
30253025
"""
30263026
Set the data limits for the x-axis
@@ -3088,9 +3088,8 @@ def set_xlim(self, left=None, right=None, emit=True, auto=False, **kw):
30883088
if right is None and iterable(left):
30893089
left, right = left
30903090

3091-
self._process_unit_info(xdata=(left, right))
3092-
left = self._validate_converted_limits(left, self.convert_xunits)
3093-
right = self._validate_converted_limits(right, self.convert_xunits)
3091+
self._validate_converted_limits(left)
3092+
self._validate_converted_limits(right)
30943093

30953094
old_left, old_right = self.get_xlim()
30963095
if left is None:
@@ -3351,6 +3350,7 @@ def get_ylim(self):
33513350
"""
33523351
return tuple(self.viewLim.intervaly)
33533352

3353+
@munits._accepts_units(convert_y=['bottom', 'top'])
33543354
def set_ylim(self, bottom=None, top=None, emit=True, auto=False, **kw):
33553355
"""
33563356
Set the data limits for the y-axis
@@ -3417,8 +3417,8 @@ def set_ylim(self, bottom=None, top=None, emit=True, auto=False, **kw):
34173417
if top is None and iterable(bottom):
34183418
bottom, top = bottom
34193419

3420-
bottom = self._validate_converted_limits(bottom, self.convert_yunits)
3421-
top = self._validate_converted_limits(top, self.convert_yunits)
3420+
self._validate_converted_limits(bottom)
3421+
self._validate_converted_limits(top)
34223422

34233423
old_bottom, old_top = self.get_ylim()
34243424

lib/matplotlib/units.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def default_units(x, axis):
5151
from matplotlib.cbook import iterable, safe_first_element
5252

5353

54-
def _accepts_units(convert_x, convert_y):
54+
def _accepts_units(convert_x=[], convert_y=[]):
5555
"""
5656
A decorator for functions and methods that accept units. The parameters
5757
indicated in *convert_x* and *convert_y* are used to update the axis
@@ -69,6 +69,7 @@ def wrapper(*args, **kwargs):
6969
axes = args[0]
7070
# Bind the incoming arguments to the function signature
7171
bound_args = inspect.signature(func).bind(*args, **kwargs)
72+
bound_args.apply_defaults()
7273
# Get the original arguments - these will be modified later
7374
arguments = bound_args.arguments
7475
# Check for data kwarg

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import matplotlib.docstring as docstring
2727
import matplotlib.scale as mscale
2828
import matplotlib.transforms as mtransforms
29+
import matplotlib.units as munits
2930
from matplotlib.axes import Axes, rcParams
3031
from matplotlib.colors import Normalize, LightSource
3132
from matplotlib.transforms import Bbox
@@ -603,6 +604,7 @@ def _determine_lims(self, xmin=None, xmax=None, *args, **kwargs):
603604
xmax += 0.05
604605
return (xmin, xmax)
605606

607+
@munits._accepts_units(convert_x=['left', 'right'])
606608
def set_xlim3d(self, left=None, right=None, emit=True, auto=False, **kw):
607609
"""
608610
Set 3D x limits.
@@ -620,9 +622,8 @@ def set_xlim3d(self, left=None, right=None, emit=True, auto=False, **kw):
620622
if right is None and cbook.iterable(left):
621623
left, right = left
622624

623-
self._process_unit_info(xdata=(left, right))
624-
left = self._validate_converted_limits(left, self.convert_xunits)
625-
right = self._validate_converted_limits(right, self.convert_xunits)
625+
self._validate_converted_limits(left)
626+
self._validate_converted_limits(right)
626627

627628
old_left, old_right = self.get_xlim()
628629
if left is None:
@@ -655,6 +656,7 @@ def set_xlim3d(self, left=None, right=None, emit=True, auto=False, **kw):
655656
return left, right
656657
set_xlim = set_xlim3d
657658

659+
@munits._accepts_units(convert_y=['bottom', 'top'])
658660
def set_ylim3d(self, bottom=None, top=None, emit=True, auto=False, **kw):
659661
"""
660662
Set 3D y limits.
@@ -672,9 +674,8 @@ def set_ylim3d(self, bottom=None, top=None, emit=True, auto=False, **kw):
672674
if top is None and cbook.iterable(bottom):
673675
bottom, top = bottom
674676

675-
self._process_unit_info(ydata=(bottom, top))
676-
bottom = self._validate_converted_limits(bottom, self.convert_yunits)
677-
top = self._validate_converted_limits(top, self.convert_yunits)
677+
self._validate_converted_limits(bottom)
678+
self._validate_converted_limits(top)
678679

679680
old_bottom, old_top = self.get_ylim()
680681
if bottom is None:
@@ -725,8 +726,10 @@ def set_zlim3d(self, bottom=None, top=None, emit=True, auto=False, **kw):
725726
bottom, top = bottom
726727

727728
self._process_unit_info(zdata=(bottom, top))
728-
bottom = self._validate_converted_limits(bottom, self.convert_zunits)
729-
top = self._validate_converted_limits(top, self.convert_zunits)
729+
bottom = self.convert_zunits(bottom)
730+
top = self.convert_zunits(top)
731+
self._validate_converted_limits(bottom)
732+
self._validate_converted_limits(top)
730733

731734
old_bottom, old_top = self.get_zlim()
732735
if bottom is None:

0 commit comments

Comments
 (0)