Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ba43765

Browse files
authored
Merge pull request #18531 from anntzer/units
Unit handling improvements
2 parents 1e19aa8 + 5182c6f commit ba43765

File tree

4 files changed

+107
-155
lines changed

4 files changed

+107
-155
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 48 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -708,10 +708,8 @@ def axhline(self, y=0, xmin=0, xmax=1, **kwargs):
708708
"argument; axhline generates its own transform.")
709709
ymin, ymax = self.get_ybound()
710710

711-
# We need to strip away the units for comparison with
712-
# non-unitized bounds
713-
self._process_unit_info(ydata=y, kwargs=kwargs)
714-
yy = self.convert_yunits(y)
711+
# Strip away the units for comparison with non-unitized bounds.
712+
yy, = self._process_unit_info([("y", y)], kwargs)
715713
scaley = (yy < ymin) or (yy > ymax)
716714

717715
trans = self.get_yaxis_transform(which='grid')
@@ -777,10 +775,8 @@ def axvline(self, x=0, ymin=0, ymax=1, **kwargs):
777775
"argument; axvline generates its own transform.")
778776
xmin, xmax = self.get_xbound()
779777

780-
# We need to strip away the units for comparison with
781-
# non-unitized bounds
782-
self._process_unit_info(xdata=x, kwargs=kwargs)
783-
xx = self.convert_xunits(x)
778+
# Strip away the units for comparison with non-unitized bounds.
779+
xx, = self._process_unit_info([("x", x)], kwargs)
784780
scalex = (xx < xmin) or (xx > xmax)
785781

786782
trans = self.get_xaxis_transform(which='grid')
@@ -917,19 +913,13 @@ def axhspan(self, ymin, ymax, xmin=0, xmax=1, **kwargs):
917913
--------
918914
axvspan : Add a vertical span across the axes.
919915
"""
916+
# Strip units away.
920917
self._check_no_units([xmin, xmax], ['xmin', 'xmax'])
921-
trans = self.get_yaxis_transform(which='grid')
922-
923-
# process the unit information
924-
self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs)
925-
926-
# first we need to strip away the units
927-
xmin, xmax = self.convert_xunits([xmin, xmax])
928-
ymin, ymax = self.convert_yunits([ymin, ymax])
918+
(ymin, ymax), = self._process_unit_info([("y", [ymin, ymax])], kwargs)
929919

930920
verts = (xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)
931921
p = mpatches.Polygon(verts, **kwargs)
932-
p.set_transform(trans)
922+
p.set_transform(self.get_yaxis_transform(which="grid"))
933923
self.add_patch(p)
934924
self._request_autoscale_view(scalex=False)
935925
return p
@@ -978,19 +968,13 @@ def axvspan(self, xmin, xmax, ymin=0, ymax=1, **kwargs):
978968
>>> axvspan(1.25, 1.55, facecolor='g', alpha=0.5)
979969
980970
"""
971+
# Strip units away.
981972
self._check_no_units([ymin, ymax], ['ymin', 'ymax'])
982-
trans = self.get_xaxis_transform(which='grid')
983-
984-
# process the unit information
985-
self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs)
986-
987-
# first we need to strip away the units
988-
xmin, xmax = self.convert_xunits([xmin, xmax])
989-
ymin, ymax = self.convert_yunits([ymin, ymax])
973+
(xmin, xmax), = self._process_unit_info([("x", [xmin, xmax])], kwargs)
990974

991975
verts = [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)]
992976
p = mpatches.Polygon(verts, **kwargs)
993-
p.set_transform(trans)
977+
p.set_transform(self.get_xaxis_transform(which="grid"))
994978
self.add_patch(p)
995979
self._request_autoscale_view(scaley=False)
996980
return p
@@ -1032,11 +1016,8 @@ def hlines(self, y, xmin, xmax, colors=None, linestyles='solid',
10321016
"""
10331017

10341018
# We do the conversion first since not all unitized data is uniform
1035-
# process the unit information
1036-
self._process_unit_info([xmin, xmax], y, kwargs=kwargs)
1037-
y = self.convert_yunits(y)
1038-
xmin = self.convert_xunits(xmin)
1039-
xmax = self.convert_xunits(xmax)
1019+
xmin, xmax, y = self._process_unit_info(
1020+
[("x", xmin), ("x", xmax), ("y", y)], kwargs)
10401021

10411022
if not np.iterable(y):
10421023
y = [y]
@@ -1111,12 +1092,9 @@ def vlines(self, x, ymin, ymax, colors=None, linestyles='solid',
11111092
axvline: vertical line across the axes
11121093
"""
11131094

1114-
self._process_unit_info(xdata=x, ydata=[ymin, ymax], kwargs=kwargs)
1115-
11161095
# We do the conversion first since not all unitized data is uniform
1117-
x = self.convert_xunits(x)
1118-
ymin = self.convert_yunits(ymin)
1119-
ymax = self.convert_yunits(ymax)
1096+
x, ymin, ymax = self._process_unit_info(
1097+
[("x", x), ("y", ymin), ("y", ymax)], kwargs)
11201098

11211099
if not np.iterable(x):
11221100
x = [x]
@@ -1254,14 +1232,9 @@ def eventplot(self, positions, orientation='horizontal', lineoffsets=1,
12541232
--------
12551233
.. plot:: gallery/lines_bars_and_markers/eventplot_demo.py
12561234
"""
1257-
self._process_unit_info(xdata=positions,
1258-
ydata=[lineoffsets, linelengths],
1259-
kwargs=kwargs)
1260-
12611235
# We do the conversion first since not all unitized data is uniform
1262-
positions = self.convert_xunits(positions)
1263-
lineoffsets = self.convert_yunits(lineoffsets)
1264-
linelengths = self.convert_yunits(linelengths)
1236+
positions, lineoffsets, linelengths = self._process_unit_info(
1237+
[("x", positions), ("y", lineoffsets), ("y", linelengths)], kwargs)
12651238

12661239
if not np.iterable(positions):
12671240
positions = [positions]
@@ -2283,11 +2256,13 @@ def bar(self, x, height, width=0.8, bottom=None, *, align="center",
22832256
x = 0
22842257

22852258
if orientation == 'vertical':
2286-
self._process_unit_info(xdata=x, ydata=height, kwargs=kwargs)
2259+
self._process_unit_info(
2260+
[("x", x), ("y", height)], kwargs, convert=False)
22872261
if log:
22882262
self.set_yscale('log', nonpositive='clip')
22892263
elif orientation == 'horizontal':
2290-
self._process_unit_info(xdata=width, ydata=y, kwargs=kwargs)
2264+
self._process_unit_info(
2265+
[("x", width), ("y", y)], kwargs, convert=False)
22912266
if log:
22922267
self.set_xscale('log', nonpositive='clip')
22932268

@@ -2567,9 +2542,8 @@ def broken_barh(self, xranges, yrange, **kwargs):
25672542
ydata = cbook.safe_first_element(yrange)
25682543
else:
25692544
ydata = None
2570-
self._process_unit_info(xdata=xdata,
2571-
ydata=ydata,
2572-
kwargs=kwargs)
2545+
self._process_unit_info(
2546+
[("x", xdata), ("y", ydata)], kwargs, convert=False)
25732547
xranges_conv = []
25742548
for xr in xranges:
25752549
if len(xr) != 2:
@@ -2689,13 +2663,9 @@ def stem(self, *args, linefmt=None, markerfmt=None, basefmt=None, bottom=0,
26892663
locs, heads, *args = args
26902664

26912665
if orientation == 'vertical':
2692-
self._process_unit_info(xdata=locs, ydata=heads)
2693-
locs = self.convert_xunits(locs)
2694-
heads = self.convert_yunits(heads)
2666+
locs, heads = self._process_unit_info([("x", locs), ("y", heads)])
26952667
else:
2696-
self._process_unit_info(xdata=heads, ydata=locs)
2697-
heads = self.convert_xunits(heads)
2698-
locs = self.convert_yunits(locs)
2668+
heads, locs = self._process_unit_info([("x", heads), ("y", locs)])
26992669

27002670
# defaults for formats
27012671
if linefmt is None:
@@ -3179,7 +3149,7 @@ def errorbar(self, x, y, yerr=None, xerr=None,
31793149
if int(offset) != offset:
31803150
raise ValueError("errorevery's starting index must be an integer")
31813151

3182-
self._process_unit_info(xdata=x, ydata=y, kwargs=kwargs)
3152+
self._process_unit_info([("x", x), ("y", y)], kwargs, convert=False)
31833153

31843154
# Make sure all the args are iterable; use lists not arrays to preserve
31853155
# units.
@@ -4346,9 +4316,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
43464316
"""
43474317
# Process **kwargs to handle aliases, conflicts with explicit kwargs:
43484318

4349-
self._process_unit_info(xdata=x, ydata=y, kwargs=kwargs)
4350-
x = self.convert_xunits(x)
4351-
y = self.convert_yunits(y)
4319+
x, y = self._process_unit_info([("x", x), ("y", y)], kwargs)
43524320

43534321
# np.ma.ravel yields an ndarray, not a masked array,
43544322
# unless its argument is a masked array.
@@ -4577,7 +4545,7 @@ def reduce_C_function(C: array) -> float
45774545
%(PolyCollection)s
45784546
45794547
"""
4580-
self._process_unit_info(xdata=x, ydata=y, kwargs=kwargs)
4548+
self._process_unit_info([("x", x), ("y", y)], kwargs, convert=False)
45814549

45824550
x, y, C = cbook.delete_masked_points(x, y, C)
45834551

@@ -4926,9 +4894,7 @@ def quiverkey(self, Q, X, Y, U, label, **kw):
49264894
def _quiver_units(self, args, kw):
49274895
if len(args) > 3:
49284896
x, y = args[0:2]
4929-
self._process_unit_info(xdata=x, ydata=y, kwargs=kw)
4930-
x = self.convert_xunits(x)
4931-
y = self.convert_yunits(y)
4897+
x, y = self._process_unit_info([("x", x), ("y", y)], kw)
49324898
return (x, y) + args[2:]
49334899
return args
49344900

@@ -5114,17 +5080,9 @@ def _fill_between_x_or_y(
51145080
self._get_patches_for_fill.get_next_color()
51155081

51165082
# Handle united data, such as dates
5117-
self._process_unit_info(
5118-
**{f"{ind_dir}data": ind, f"{dep_dir}data": dep1}, kwargs=kwargs)
5119-
self._process_unit_info(
5120-
**{f"{dep_dir}data": dep2})
5121-
5122-
# Convert the arrays so we can work with them
5123-
ind = ma.masked_invalid(getattr(self, f"convert_{ind_dir}units")(ind))
5124-
dep1 = ma.masked_invalid(
5125-
getattr(self, f"convert_{dep_dir}units")(dep1))
5126-
dep2 = ma.masked_invalid(
5127-
getattr(self, f"convert_{dep_dir}units")(dep2))
5083+
ind, dep1, dep2 = map(
5084+
ma.masked_invalid, self._process_unit_info(
5085+
[(ind_dir, ind), (dep_dir, dep1), (dep_dir, dep2)], kwargs))
51285086

51295087
for name, array in [
51305088
(ind_dir, ind), (f"{dep_dir}1", dep1), (f"{dep_dir}2", dep2)]:
@@ -5739,9 +5697,7 @@ def pcolor(self, *args, shading=None, alpha=None, norm=None, cmap=None,
57395697
Ny, Nx = X.shape
57405698

57415699
# unit conversion allows e.g. datetime objects as axis values
5742-
self._process_unit_info(xdata=X, ydata=Y, kwargs=kwargs)
5743-
X = self.convert_xunits(X)
5744-
Y = self.convert_yunits(Y)
5700+
X, Y = self._process_unit_info([("x", X), ("y", Y)], kwargs)
57455701

57465702
# convert to MA, if necessary.
57475703
C = ma.asarray(C)
@@ -6016,9 +5972,7 @@ def pcolormesh(self, *args, alpha=None, norm=None, cmap=None, vmin=None,
60165972
X = X.ravel()
60175973
Y = Y.ravel()
60185974
# unit conversion allows e.g. datetime objects as axis values
6019-
self._process_unit_info(xdata=X, ydata=Y, kwargs=kwargs)
6020-
X = self.convert_xunits(X)
6021-
Y = self.convert_yunits(Y)
5975+
X, Y = self._process_unit_info([("x", X), ("y", Y)], kwargs)
60225976

60235977
# convert to one dimensional arrays
60245978
C = C.ravel()
@@ -6497,16 +6451,23 @@ def hist(self, x, bins=None, range=None, density=False, weights=None,
64976451
x = cbook._reshape_2D(x, 'x')
64986452
nx = len(x) # number of datasets
64996453

6500-
# Process unit information
6501-
# Unit conversion is done individually on each dataset
6502-
self._process_unit_info(xdata=x[0], kwargs=kwargs)
6503-
x = [self.convert_xunits(xi) for xi in x]
6454+
# Process unit information. _process_unit_info sets the unit and
6455+
# converts the first dataset; then we convert each following dataset
6456+
# one at a time.
6457+
if orientation == "vertical":
6458+
convert_units = self.convert_xunits
6459+
x = [*self._process_unit_info([("x", x[0])], kwargs),
6460+
*map(convert_units, x[1:])]
6461+
else: # horizontal
6462+
convert_units = self.convert_yunits
6463+
x = [*self._process_unit_info([("y", x[0])], kwargs),
6464+
*map(convert_units, x[1:])]
65046465

65056466
if bin_range is not None:
6506-
bin_range = self.convert_xunits(bin_range)
6467+
bin_range = convert_units(bin_range)
65076468

65086469
if not cbook.is_scalar_or_string(bins):
6509-
bins = self.convert_xunits(bins)
6470+
bins = convert_units(bins)
65106471

65116472
# We need to do to 'weights' what was done to 'x'
65126473
if weights is not None:
@@ -6787,9 +6748,8 @@ def stairs(self, values, edges=None, *,
67876748
if edges is None:
67886749
edges = np.arange(len(values) + 1)
67896750

6790-
self._process_unit_info(xdata=edges, ydata=values, kwargs=kwargs)
6791-
edges = self.convert_xunits(edges)
6792-
values = self.convert_yunits(values)
6751+
edges, values = self._process_unit_info(
6752+
[("x", edges), ("y", values)], kwargs)
67936753

67946754
patch = mpatches.StepPatch(values,
67956755
edges,

0 commit comments

Comments
 (0)