Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit bc61047

Browse files
committed
Apply unit decorator to more functions
Add some more unit decorators Add unit decorator to mplot3d
1 parent 44f778b commit bc61047

File tree

4 files changed

+39
-73
lines changed

4 files changed

+39
-73
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 17 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ def legend(self, *args, **kwargs):
385385
def _remove_legend(self, legend):
386386
self.legend_ = None
387387

388+
@munits._accepts_units(convert_x=['x'], convert_y=['y'])
388389
def text(self, x, y, s, fontdict=None, withdash=False, **kwargs):
389390
"""
390391
Add text to the axes.
@@ -619,6 +620,8 @@ def axvline(self, x=0, ymin=0, ymax=1, **kwargs):
619620
self.autoscale_view(scalex=scalex, scaley=False)
620621
return l
621622

623+
@munits._accepts_units(convert_x=['xmin', 'xmax'],
624+
convert_y=['ymin', 'ymax'])
622625
@docstring.dedent_interpd
623626
def axhspan(self, ymin, ymax, xmin=0, xmax=1, **kwargs):
624627
"""
@@ -660,21 +663,15 @@ def axhspan(self, ymin, ymax, xmin=0, xmax=1, **kwargs):
660663
axvspan : Add a vertical span across the axes.
661664
"""
662665
trans = self.get_yaxis_transform(which='grid')
663-
664-
# process the unit information
665-
self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs)
666-
667-
# first we need to strip away the units
668-
xmin, xmax = self.convert_xunits([xmin, xmax])
669-
ymin, ymax = self.convert_yunits([ymin, ymax])
670-
671666
verts = (xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)
672667
p = mpatches.Polygon(verts, **kwargs)
673668
p.set_transform(trans)
674669
self.add_patch(p)
675670
self.autoscale_view(scalex=False)
676671
return p
677672

673+
@munits._accepts_units(convert_x=['xmin', 'xmax'],
674+
convert_y=['ymin', 'ymax'])
678675
def axvspan(self, xmin, xmax, ymin=0, ymax=1, **kwargs):
679676
"""
680677
Add a vertical span (rectangle) across the axes.
@@ -725,21 +722,14 @@ def axvspan(self, xmin, xmax, ymin=0, ymax=1, **kwargs):
725722
726723
"""
727724
trans = self.get_xaxis_transform(which='grid')
728-
729-
# process the unit information
730-
self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs)
731-
732-
# first we need to strip away the units
733-
xmin, xmax = self.convert_xunits([xmin, xmax])
734-
ymin, ymax = self.convert_yunits([ymin, ymax])
735-
736725
verts = [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)]
737726
p = mpatches.Polygon(verts, **kwargs)
738727
p.set_transform(trans)
739728
self.add_patch(p)
740729
self.autoscale_view(scaley=False)
741730
return p
742731

732+
@munits._accepts_units(convert_x=['xmin', 'xmax'], convert_y=['y'])
743733
@_preprocess_data(replace_names=["y", "xmin", "xmax", "colors"],
744734
label_namer="y")
745735
def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',
@@ -775,14 +765,6 @@ def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',
775765
vlines : vertical lines
776766
axhline: horizontal line across the axes
777767
"""
778-
779-
# We do the conversion first since not all unitized data is uniform
780-
# process the unit information
781-
self._process_unit_info([xmin, xmax], y, kwargs=kwargs)
782-
y = self.convert_yunits(y)
783-
xmin = self.convert_xunits(xmin)
784-
xmax = self.convert_xunits(xmax)
785-
786768
if not iterable(y):
787769
y = [y]
788770
if not iterable(xmin):
@@ -816,6 +798,7 @@ def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',
816798

817799
return lines
818800

801+
@munits._accepts_units(convert_x=['x'], convert_y=['ymin', 'ymax'])
819802
@_preprocess_data(replace_names=["x", "ymin", "ymax", "colors"],
820803
label_namer="x")
821804
def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',
@@ -853,14 +836,6 @@ def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',
853836
hlines : horizontal lines
854837
axvline: vertical line across the axes
855838
"""
856-
857-
self._process_unit_info(xdata=x, ydata=[ymin, ymax], kwargs=kwargs)
858-
859-
# We do the conversion first since not all unitized data is uniform
860-
x = self.convert_xunits(x)
861-
ymin = self.convert_yunits(ymin)
862-
ymax = self.convert_yunits(ymax)
863-
864839
if not iterable(x):
865840
x = [x]
866841
if not iterable(ymin):
@@ -893,6 +868,8 @@ def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',
893868

894869
return lines
895870

871+
@munits._accepts_units(convert_x=['positions'],
872+
convert_y=['lineoffsets', 'linelengths'])
896873
@_preprocess_data(replace_names=["positions", "lineoffsets",
897874
"linelengths", "linewidths",
898875
"colors", "linestyles"],
@@ -981,15 +958,6 @@ def eventplot(self, positions, orientation='horizontal', lineoffsets=1,
981958
982959
.. plot:: gallery/lines_bars_and_markers/eventplot_demo.py
983960
"""
984-
self._process_unit_info(xdata=positions,
985-
ydata=[lineoffsets, linelengths],
986-
kwargs=kwargs)
987-
988-
# We do the conversion first since not all unitized data is uniform
989-
positions = self.convert_xunits(positions)
990-
lineoffsets = self.convert_yunits(lineoffsets)
991-
linelengths = self.convert_yunits(linelengths)
992-
993961
if not iterable(positions):
994962
positions = [positions]
995963
elif any(iterable(position) for position in positions):
@@ -4719,6 +4687,7 @@ def fill(self, *args, **kwargs):
47194687
self.autoscale_view()
47204688
return patches
47214689

4690+
@munits._accepts_units(convert_x=['x'], convert_y=['y1', 'y2'])
47224691
@_preprocess_data(replace_names=["x", "y1", "y2", "where"],
47234692
label_namer=None)
47244693
@docstring.dedent_interpd
@@ -4812,14 +4781,10 @@ def fill_between(self, x, y1, y2=0, where=None, interpolate=False,
48124781
kwargs['facecolor'] = \
48134782
self._get_patches_for_fill.get_next_color()
48144783

4815-
# Handle united data, such as dates
4816-
self._process_unit_info(xdata=x, ydata=y1, kwargs=kwargs)
4817-
self._process_unit_info(ydata=y2)
4818-
48194784
# Convert the arrays so we can work with them
4820-
x = ma.masked_invalid(self.convert_xunits(x))
4821-
y1 = ma.masked_invalid(self.convert_yunits(y1))
4822-
y2 = ma.masked_invalid(self.convert_yunits(y2))
4785+
x = ma.masked_invalid(x)
4786+
y1 = ma.masked_invalid(y1)
4787+
y2 = ma.masked_invalid(y2)
48234788

48244789
for name, array in [('x', x), ('y1', y1), ('y2', y2)]:
48254790
if array.ndim > 1:
@@ -4902,6 +4867,7 @@ def get_interp_point(ind):
49024867
self.autoscale_view()
49034868
return collection
49044869

4870+
@munits._accepts_units(convert_x=['x1', 'x2'], convert_y=['y'])
49054871
@_preprocess_data(replace_names=["y", "x1", "x2", "where"],
49064872
label_namer=None)
49074873
@docstring.dedent_interpd
@@ -4995,14 +4961,10 @@ def fill_betweenx(self, y, x1, x2=0, where=None,
49954961
kwargs['facecolor'] = \
49964962
self._get_patches_for_fill.get_next_color()
49974963

4998-
# Handle united data, such as dates
4999-
self._process_unit_info(ydata=y, xdata=x1, kwargs=kwargs)
5000-
self._process_unit_info(xdata=x2)
5001-
50024964
# Convert the arrays so we can work with them
5003-
y = ma.masked_invalid(self.convert_yunits(y))
5004-
x1 = ma.masked_invalid(self.convert_xunits(x1))
5005-
x2 = ma.masked_invalid(self.convert_xunits(x2))
4965+
y = ma.masked_invalid(y)
4966+
x1 = ma.masked_invalid(x1)
4967+
x2 = ma.masked_invalid(x2)
50064968

50074969
for name, array in [('y', y), ('x1', x1), ('x2', x2)]:
50084970
if array.ndim > 1:

lib/matplotlib/axes/_base.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import matplotlib.font_manager as font_manager
2727
import matplotlib.text as mtext
2828
import matplotlib.image as mimage
29+
import matplotlib.units as munits
2930
from matplotlib.offsetbox import OffsetBox
3031
from matplotlib.artist import allow_rasterization
3132
from matplotlib.legend import Legend
@@ -2993,7 +2994,7 @@ def get_xlim(self):
29932994
"""
29942995
return tuple(self.viewLim.intervalx)
29952996

2996-
def _validate_converted_limits(self, limit, convert):
2997+
def _validate_converted_limits(self, converted_limit):
29972998
"""
29982999
Raise ValueError if converted limits are non-finite.
29993000
@@ -3004,14 +3005,13 @@ def _validate_converted_limits(self, limit, convert):
30043005
The limit value after call to convert(), or None if limit is None.
30053006
30063007
"""
3007-
if limit is not None:
3008-
converted_limit = convert(limit)
3008+
if converted_limit is not None:
30093009
if (isinstance(converted_limit, float) and
30103010
(not np.isreal(converted_limit) or
30113011
not np.isfinite(converted_limit))):
30123012
raise ValueError("Axis limits cannot be NaN or Inf")
3013-
return converted_limit
30143013

3014+
@munits._accepts_units(convert_x=['left', 'right'])
30153015
def set_xlim(self, left=None, right=None, emit=True, auto=False, **kw):
30163016
"""
30173017
Set the data limits for the x-axis
@@ -3079,9 +3079,8 @@ def set_xlim(self, left=None, right=None, emit=True, auto=False, **kw):
30793079
if right is None and iterable(left):
30803080
left, right = left
30813081

3082-
self._process_unit_info(xdata=(left, right))
3083-
left = self._validate_converted_limits(left, self.convert_xunits)
3084-
right = self._validate_converted_limits(right, self.convert_xunits)
3082+
self._validate_converted_limits(left)
3083+
self._validate_converted_limits(right)
30853084

30863085
old_left, old_right = self.get_xlim()
30873086
if left is None:
@@ -3332,6 +3331,7 @@ def get_ylim(self):
33323331
"""
33333332
return tuple(self.viewLim.intervaly)
33343333

3334+
@munits._accepts_units(convert_y=['bottom', 'top'])
33353335
def set_ylim(self, bottom=None, top=None, emit=True, auto=False, **kw):
33363336
"""
33373337
Set the data limits for the y-axis
@@ -3398,8 +3398,8 @@ def set_ylim(self, bottom=None, top=None, emit=True, auto=False, **kw):
33983398
if top is None and iterable(bottom):
33993399
bottom, top = bottom
34003400

3401-
bottom = self._validate_converted_limits(bottom, self.convert_yunits)
3402-
top = self._validate_converted_limits(top, self.convert_yunits)
3401+
self._validate_converted_limits(bottom)
3402+
self._validate_converted_limits(top)
34033403

34043404
old_bottom, old_top = self.get_ylim()
34053405

lib/matplotlib/units.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def default_units(x, axis):
5151
from matplotlib.cbook import iterable, safe_first_element
5252

5353

54-
def _accepts_units(convert_x, convert_y):
54+
def _accepts_units(convert_x=[], convert_y=[]):
5555
"""
5656
A decorator for functions and methods that accept units. The parameters
5757
indicated in *convert_x* and *convert_y* are used to update the axis
@@ -69,6 +69,7 @@ def wrapper(*args, **kwargs):
6969
axes = args[0]
7070
# Bind the incoming arguments to the function signature
7171
bound_args = inspect.signature(func).bind(*args, **kwargs)
72+
bound_args.apply_defaults()
7273
# Get the original arguments - these will be modified later
7374
arguments = bound_args.arguments
7475
# Check for data kwarg

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import matplotlib.docstring as docstring
3030
import matplotlib.scale as mscale
3131
import matplotlib.transforms as mtransforms
32+
import matplotlib.units as munits
3233
from matplotlib.axes import Axes, rcParams
3334
from matplotlib.colors import Normalize, LightSource
3435
from matplotlib.transforms import Bbox
@@ -609,6 +610,7 @@ def _determine_lims(self, xmin=None, xmax=None, *args, **kwargs):
609610
xmax += 0.05
610611
return (xmin, xmax)
611612

613+
@munits._accepts_units(convert_x=['left', 'right'])
612614
def set_xlim3d(self, left=None, right=None, emit=True, auto=False, **kw):
613615
"""
614616
Set 3D x limits.
@@ -626,9 +628,8 @@ def set_xlim3d(self, left=None, right=None, emit=True, auto=False, **kw):
626628
if right is None and cbook.iterable(left):
627629
left, right = left
628630

629-
self._process_unit_info(xdata=(left, right))
630-
left = self._validate_converted_limits(left, self.convert_xunits)
631-
right = self._validate_converted_limits(right, self.convert_xunits)
631+
self._validate_converted_limits(left)
632+
self._validate_converted_limits(right)
632633

633634
old_left, old_right = self.get_xlim()
634635
if left is None:
@@ -661,6 +662,7 @@ def set_xlim3d(self, left=None, right=None, emit=True, auto=False, **kw):
661662
return left, right
662663
set_xlim = set_xlim3d
663664

665+
@munits._accepts_units(convert_y=['bottom', 'top'])
664666
def set_ylim3d(self, bottom=None, top=None, emit=True, auto=False, **kw):
665667
"""
666668
Set 3D y limits.
@@ -678,9 +680,8 @@ def set_ylim3d(self, bottom=None, top=None, emit=True, auto=False, **kw):
678680
if top is None and cbook.iterable(bottom):
679681
bottom, top = bottom
680682

681-
self._process_unit_info(ydata=(bottom, top))
682-
bottom = self._validate_converted_limits(bottom, self.convert_yunits)
683-
top = self._validate_converted_limits(top, self.convert_yunits)
683+
self._validate_converted_limits(bottom)
684+
self._validate_converted_limits(top)
684685

685686
old_bottom, old_top = self.get_ylim()
686687
if bottom is None:
@@ -731,8 +732,10 @@ def set_zlim3d(self, bottom=None, top=None, emit=True, auto=False, **kw):
731732
bottom, top = bottom
732733

733734
self._process_unit_info(zdata=(bottom, top))
734-
bottom = self._validate_converted_limits(bottom, self.convert_zunits)
735-
top = self._validate_converted_limits(top, self.convert_zunits)
735+
bottom = self.convert_zunits(bottom)
736+
top = self.convert_zunits(top)
737+
self._validate_converted_limits(bottom)
738+
self._validate_converted_limits(top)
736739

737740
old_bottom, old_top = self.get_zlim()
738741
if bottom is None:

0 commit comments

Comments
 (0)