Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e3a34c5

Browse files
authored
Merge pull request #9049 from dopplershift/fix-8908
BUG: Fix weird behavior with mask and units (Fixes #8908)
2 parents fe0095e + 0525c6b commit e3a34c5

File tree

8 files changed

+105
-53
lines changed

8 files changed

+105
-53
lines changed

lib/matplotlib/axes/_base.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -556,12 +556,12 @@ def __init__(self, fig, rect,
556556
self.update(kwargs)
557557

558558
if self.xaxis is not None:
559-
self._xcid = self.xaxis.callbacks.connect('units finalize',
560-
self.relim)
559+
self._xcid = self.xaxis.callbacks.connect(
560+
'units finalize', lambda: self._on_units_changed(scalex=True))
561561

562562
if self.yaxis is not None:
563-
self._ycid = self.yaxis.callbacks.connect('units finalize',
564-
self.relim)
563+
self._ycid = self.yaxis.callbacks.connect(
564+
'units finalize', lambda: self._on_units_changed(scaley=True))
565565

566566
self.tick_params(
567567
top=rcParams['xtick.top'] and rcParams['xtick.minor.top'],
@@ -1891,6 +1891,15 @@ def add_container(self, container):
18911891
container.set_remove_method(lambda h: self.containers.remove(h))
18921892
return container
18931893

1894+
def _on_units_changed(self, scalex=False, scaley=False):
1895+
"""
1896+
Callback for processing changes to axis units.
1897+
1898+
Currently forces updates of data limits and view limits.
1899+
"""
1900+
self.relim()
1901+
self.autoscale_view(scalex=scalex, scaley=scaley)
1902+
18941903
def relim(self, visible_only=False):
18951904
"""
18961905
Recompute the data limits based on current artists. If you want to

lib/matplotlib/cbook/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1997,6 +1997,17 @@ def is_math_text(s):
19971997
return even_dollars
19981998

19991999

2000+
def _to_unmasked_float_array(x):
2001+
"""
2002+
Convert a sequence to a float array; if input was a masked array, masked
2003+
values are converted to nans.
2004+
"""
2005+
if hasattr(x, 'mask'):
2006+
return np.ma.asarray(x, float).filled(np.nan)
2007+
else:
2008+
return np.asarray(x, float)
2009+
2010+
20002011
def _check_1d(x):
20012012
'''
20022013
Converts a sequence of less than 1 dimension, to an array of 1
@@ -2283,7 +2294,7 @@ def index_of(y):
22832294
try:
22842295
return y.index.values, y.values
22852296
except AttributeError:
2286-
y = np.atleast_1d(y)
2297+
y = _check_1d(y)
22872298
return np.arange(y.shape[0], dtype=float), y
22882299

22892300

lib/matplotlib/lines.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from . import artist, colors as mcolors, docstring, rcParams
1717
from .artist import Artist, allow_rasterization
1818
from .cbook import (
19-
iterable, is_numlike, ls_mapper, ls_mapper_r, STEP_LOOKUP_MAP)
19+
_to_unmasked_float_array, iterable, is_numlike, ls_mapper, ls_mapper_r,
20+
STEP_LOOKUP_MAP)
2021
from .markers import MarkerStyle
2122
from .path import Path
2223
from .transforms import Bbox, TransformedPath, IdentityTransform
@@ -648,20 +649,12 @@ def recache_always(self):
648649
def recache(self, always=False):
649650
if always or self._invalidx:
650651
xconv = self.convert_xunits(self._xorig)
651-
if isinstance(self._xorig, np.ma.MaskedArray):
652-
x = np.ma.asarray(xconv, float).filled(np.nan)
653-
else:
654-
x = np.asarray(xconv, float)
655-
x = x.ravel()
652+
x = _to_unmasked_float_array(xconv).ravel()
656653
else:
657654
x = self._x
658655
if always or self._invalidy:
659656
yconv = self.convert_yunits(self._yorig)
660-
if isinstance(self._yorig, np.ma.MaskedArray):
661-
y = np.ma.asarray(yconv, float).filled(np.nan)
662-
else:
663-
y = np.asarray(yconv, float)
664-
y = y.ravel()
657+
y = _to_unmasked_float_array(yconv).ravel()
665658
else:
666659
y = self._y
667660

lib/matplotlib/path.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
import numpy as np
2424

2525
from . import _path, rcParams
26-
from .cbook import simple_linear_interpolation, maxdict
26+
from .cbook import (_to_unmasked_float_array, simple_linear_interpolation,
27+
maxdict)
2728

2829

2930
class Path(object):
@@ -132,11 +133,7 @@ def __init__(self, vertices, codes=None, _interpolation_steps=1,
132133
Makes the path behave in an immutable way and sets the vertices
133134
and codes as read-only arrays.
134135
"""
135-
if isinstance(vertices, np.ma.MaskedArray):
136-
vertices = vertices.astype(float).filled(np.nan)
137-
else:
138-
vertices = np.asarray(vertices, float)
139-
136+
vertices = _to_unmasked_float_array(vertices)
140137
if (vertices.ndim != 2) or (vertices.shape[1] != 2):
141138
msg = "'vertices' must be a 2D list or array with shape Nx2"
142139
raise ValueError(msg)
@@ -188,11 +185,7 @@ def _fast_from_codes_and_verts(cls, verts, codes, internals=None):
188185
"""
189186
internals = internals or {}
190187
pth = cls.__new__(cls)
191-
if isinstance(verts, np.ma.MaskedArray):
192-
verts = verts.astype(float).filled(np.nan)
193-
else:
194-
verts = np.asarray(verts, float)
195-
pth._vertices = verts
188+
pth._vertices = _to_unmasked_float_array(verts)
196189
pth._codes = codes
197190
pth._readonly = internals.pop('readonly', False)
198191
pth.should_simplify = internals.pop('should_simplify', True)

lib/matplotlib/tests/test_units.py

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from matplotlib.cbook import iterable
12
import matplotlib.pyplot as plt
3+
from matplotlib.testing.decorators import image_comparison
24
import matplotlib.units as munits
35
import numpy as np
46

@@ -9,49 +11,84 @@
911
from mock import MagicMock
1012

1113

12-
# Tests that the conversion machinery works properly for classes that
13-
# work as a facade over numpy arrays (like pint)
14-
def test_numpy_facade():
15-
# Basic class that wraps numpy array and has units
16-
class Quantity(object):
17-
def __init__(self, data, units):
18-
self.magnitude = data
19-
self.units = units
14+
# Basic class that wraps numpy array and has units
15+
class Quantity(object):
16+
def __init__(self, data, units):
17+
self.magnitude = data
18+
self.units = units
19+
20+
def to(self, new_units):
21+
factors = {('hours', 'seconds'): 3600, ('minutes', 'hours'): 1 / 60,
22+
('minutes', 'seconds'): 60, ('feet', 'miles'): 1 / 5280.,
23+
('feet', 'inches'): 12, ('miles', 'inches'): 12 * 5280}
24+
if self.units != new_units:
25+
mult = factors[self.units, new_units]
26+
return Quantity(mult * self.magnitude, new_units)
27+
else:
28+
return Quantity(self.magnitude, self.units)
29+
30+
def __getattr__(self, attr):
31+
return getattr(self.magnitude, attr)
2032

21-
def to(self, new_units):
22-
return Quantity(self.magnitude, new_units)
33+
def __getitem__(self, item):
34+
return Quantity(self.magnitude[item], self.units)
2335

24-
def __getattr__(self, attr):
25-
return getattr(self.magnitude, attr)
36+
def __array__(self):
37+
return np.asarray(self.magnitude)
2638

27-
def __getitem__(self, item):
28-
return self.magnitude[item]
2939

40+
# Tests that the conversion machinery works properly for classes that
41+
# work as a facade over numpy arrays (like pint)
42+
@image_comparison(baseline_images=['plot_pint'],
43+
extensions=['png'], remove_text=False, style='mpl20')
44+
def test_numpy_facade():
3045
# Create an instance of the conversion interface and
3146
# mock so we can check methods called
3247
qc = munits.ConversionInterface()
3348

3449
def convert(value, unit, axis):
3550
if hasattr(value, 'units'):
36-
return value.to(unit)
51+
return value.to(unit).magnitude
52+
elif iterable(value):
53+
try:
54+
return [v.to(unit).magnitude for v in value]
55+
except AttributeError:
56+
return [Quantity(v, axis.get_units()).to(unit).magnitude
57+
for v in value]
3758
else:
3859
return Quantity(value, axis.get_units()).to(unit).magnitude
3960

4061
qc.convert = MagicMock(side_effect=convert)
41-
qc.axisinfo = MagicMock(return_value=None)
62+
qc.axisinfo = MagicMock(side_effect=lambda u, a: munits.AxisInfo(label=u))
4263
qc.default_units = MagicMock(side_effect=lambda x, a: x.units)
4364

4465
# Register the class
4566
munits.registry[Quantity] = qc
4667

4768
# Simple test
48-
t = Quantity(np.linspace(0, 10), 'sec')
49-
d = Quantity(30 * np.linspace(0, 10), 'm/s')
69+
y = Quantity(np.linspace(0, 30), 'miles')
70+
x = Quantity(np.linspace(0, 5), 'hours')
5071

51-
fig, ax = plt.subplots(1, 1)
52-
l, = plt.plot(t, d)
53-
ax.yaxis.set_units('inch')
72+
fig, ax = plt.subplots()
73+
fig.subplots_adjust(left=0.15) # Make space for label
74+
ax.plot(x, y, 'tab:blue')
75+
ax.axhline(Quantity(26400, 'feet'), color='tab:red')
76+
ax.axvline(Quantity(120, 'minutes'), color='tab:green')
77+
ax.yaxis.set_units('inches')
78+
ax.xaxis.set_units('seconds')
5479

5580
assert qc.convert.called
5681
assert qc.axisinfo.called
5782
assert qc.default_units.called
83+
84+
85+
# Tests gh-8908
86+
@image_comparison(baseline_images=['plot_masked_units'],
87+
extensions=['png'], remove_text=True, style='mpl20')
88+
def test_plot_masked_units():
89+
data = np.linspace(-5, 5)
90+
data_masked = np.ma.array(data, mask=(data > -2) & (data < 2))
91+
data_masked_units = Quantity(data_masked, 'meters')
92+
93+
fig, ax = plt.subplots()
94+
ax.plot(data_masked_units)

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,13 @@ def __init__(self, fig, rect=None, *args, **kwargs):
111111
# func used to format z -- fall back on major formatters
112112
self.fmt_zdata = None
113113

114-
if zscale is not None :
114+
if zscale is not None:
115115
self.set_zscale(zscale)
116116

117-
if self.zaxis is not None :
118-
self._zcid = self.zaxis.callbacks.connect('units finalize',
119-
self.relim)
120-
else :
117+
if self.zaxis is not None:
118+
self._zcid = self.zaxis.callbacks.connect(
119+
'units finalize', lambda: self._on_units_changed(scalez=True))
120+
else:
121121
self._zcid = None
122122

123123
self._ready = 1
@@ -308,6 +308,15 @@ def get_axis_position(self):
308308
zhigh = tc[0][2] > tc[2][2]
309309
return xhigh, yhigh, zhigh
310310

311+
def _on_units_changed(self, scalex=False, scaley=False, scalez=False):
312+
"""
313+
Callback for processing changes to axis units.
314+
315+
Currently forces updates of data limits and view limits.
316+
"""
317+
self.relim()
318+
self.autoscale_view(scalex=scalex, scaley=scaley, scalez=scalez)
319+
311320
def update_datalim(self, xys, **kwargs):
312321
pass
313322

0 commit comments

Comments
 (0)