Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 11072f2

Browse files
authored
Merge pull request #12422 from efiring/scatter_color
FIX/API: scatter with invalid data
2 parents 739b86a + a2cec14 commit 11072f2

File tree

9 files changed

+174
-47
lines changed

9 files changed

+174
-47
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
PathCollections created with `~.Axes.scatter` now keep track of invalid points
2+
``````````````````````````````````````````````````````````````````````````````
3+
4+
Previously, points with nonfinite (infinite or nan) coordinates would not be
5+
included in the offsets (as returned by `PathCollection.get_offsets`) of a
6+
`PathCollection` created by `~.Axes.scatter`, and points with nonfinite values
7+
(as specified by the *c* kwarg) would not be included in the array (as returned
8+
by `PathCollection.get_array`)
9+
10+
Such points are now included, but masked out by returning a masked array.
11+
12+
If the *plotnonfinite* kwarg to `~.Axes.scatter` is set, then points with
13+
nonfinite values are plotted using the bad color of the `PathCollection`\ 's
14+
colormap (as set by `Colormap.set_bad`).

examples/units/basic_units.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,10 @@ def get_compressed_copy(self, mask):
174174
def convert_to(self, unit):
175175
if unit == self.unit or not unit:
176176
return self
177-
new_value = self.unit.convert_value_to(self.value, unit)
177+
try:
178+
new_value = self.unit.convert_value_to(self.value, unit)
179+
except AttributeError:
180+
new_value = self
178181
return TaggedValue(new_value, unit)
179182

180183
def get_value(self):
@@ -345,7 +348,20 @@ def convert(val, unit, axis):
345348
if units.ConversionInterface.is_numlike(val):
346349
return val
347350
if np.iterable(val):
348-
return [thisval.convert_to(unit).get_value() for thisval in val]
351+
if isinstance(val, np.ma.MaskedArray):
352+
val = val.astype(float).filled(np.nan)
353+
out = np.empty(len(val))
354+
for i, thisval in enumerate(val):
355+
if np.ma.is_masked(thisval):
356+
out[i] = np.nan
357+
else:
358+
try:
359+
out[i] = thisval.convert_to(unit).get_value()
360+
except AttributeError:
361+
out[i] = thisval
362+
return out
363+
if np.ma.is_masked(val):
364+
return np.nan
349365
else:
350366
return val.convert_to(unit).get_value()
351367

examples/units/units_scatter.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@
2727
ax2.scatter(xsecs, xsecs, yunits=hertz)
2828
ax2.axis([0, 10, 0, 1])
2929

30-
ax3.scatter(xsecs, xsecs, yunits=hertz)
31-
ax3.yaxis.set_units(minutes)
32-
ax3.axis([0, 10, 0, 1])
30+
ax3.scatter(xsecs, xsecs, yunits=minutes)
31+
ax3.axis([0, 10, 0, 0.2])
3332

3433
fig.tight_layout()
3534
plt.show()

lib/matplotlib/axes/_axes.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4180,7 +4180,7 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
41804180
label_namer="y")
41814181
def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
41824182
vmin=None, vmax=None, alpha=None, linewidths=None,
4183-
verts=None, edgecolors=None,
4183+
verts=None, edgecolors=None, *, plotnonfinite=False,
41844184
**kwargs):
41854185
"""
41864186
A scatter plot of *y* vs *x* with varying marker size and/or color.
@@ -4257,6 +4257,10 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
42574257
For non-filled markers, the *edgecolors* kwarg is ignored and
42584258
forced to 'face' internally.
42594259
4260+
plotnonfinite : boolean, optional, default: False
4261+
Set to plot points with nonfinite *c*, in conjunction with
4262+
`~matplotlib.colors.Colormap.set_bad`.
4263+
42604264
Returns
42614265
-------
42624266
paths : `~matplotlib.collections.PathCollection`
@@ -4310,11 +4314,14 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
43104314
c, edgecolors, kwargs, xshape, yshape,
43114315
get_next_color_func=self._get_patches_for_fill.get_next_color)
43124316

4313-
# `delete_masked_points` only modifies arguments of the same length as
4314-
# `x`.
4315-
x, y, s, c, colors, edgecolors, linewidths =\
4316-
cbook.delete_masked_points(
4317-
x, y, s, c, colors, edgecolors, linewidths)
4317+
if plotnonfinite and colors is None:
4318+
c = np.ma.masked_invalid(c)
4319+
x, y, s, edgecolors, linewidths = \
4320+
cbook._combine_masks(x, y, s, edgecolors, linewidths)
4321+
else:
4322+
x, y, s, c, colors, edgecolors, linewidths = \
4323+
cbook._combine_masks(
4324+
x, y, s, c, colors, edgecolors, linewidths)
43184325

43194326
scales = s # Renamed for readability below.
43204327

@@ -4340,7 +4347,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
43404347
edgecolors = 'face'
43414348
linewidths = rcParams['lines.linewidth']
43424349

4343-
offsets = np.column_stack([x, y])
4350+
offsets = np.ma.column_stack([x, y])
43444351

43454352
collection = mcoll.PathCollection(
43464353
(path,), scales,
@@ -4358,7 +4365,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
43584365
if norm is not None and not isinstance(norm, mcolors.Normalize):
43594366
raise ValueError(
43604367
"'norm' must be an instance of 'mcolors.Normalize'")
4361-
collection.set_array(np.asarray(c))
4368+
collection.set_array(c)
43624369
collection.set_cmap(cmap)
43634370
collection.set_norm(norm)
43644371

lib/matplotlib/cbook/__init__.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,66 @@ def delete_masked_points(*args):
10811081
return margs
10821082

10831083

1084+
def _combine_masks(*args):
1085+
"""
1086+
Find all masked and/or non-finite points in a set of arguments,
1087+
and return the arguments as masked arrays with a common mask.
1088+
1089+
Arguments can be in any of 5 categories:
1090+
1091+
1) 1-D masked arrays
1092+
2) 1-D ndarrays
1093+
3) ndarrays with more than one dimension
1094+
4) other non-string iterables
1095+
5) anything else
1096+
1097+
The first argument must be in one of the first four categories;
1098+
any argument with a length differing from that of the first
1099+
argument (and hence anything in category 5) then will be
1100+
passed through unchanged.
1101+
1102+
Masks are obtained from all arguments of the correct length
1103+
in categories 1, 2, and 4; a point is bad if masked in a masked
1104+
array or if it is a nan or inf. No attempt is made to
1105+
extract a mask from categories 2 and 4 if :meth:`np.isfinite`
1106+
does not yield a Boolean array. Category 3 is included to
1107+
support RGB or RGBA ndarrays, which are assumed to have only
1108+
valid values and which are passed through unchanged.
1109+
1110+
All input arguments that are not passed unchanged are returned
1111+
as masked arrays if any masked points are found, otherwise as
1112+
ndarrays.
1113+
1114+
"""
1115+
if not len(args):
1116+
return ()
1117+
if is_scalar_or_string(args[0]):
1118+
raise ValueError("First argument must be a sequence")
1119+
nrecs = len(args[0])
1120+
margs = [] # Output args; some may be modified.
1121+
seqlist = [False] * len(args) # Flags: True if output will be masked.
1122+
masks = [] # List of masks.
1123+
for i, x in enumerate(args):
1124+
if is_scalar_or_string(x) or len(x) != nrecs:
1125+
margs.append(x) # Leave it unmodified.
1126+
else:
1127+
if isinstance(x, np.ma.MaskedArray) and x.ndim > 1:
1128+
raise ValueError("Masked arrays must be 1-D")
1129+
x = np.asanyarray(x)
1130+
if x.ndim == 1:
1131+
x = safe_masked_invalid(x)
1132+
seqlist[i] = True
1133+
if np.ma.is_masked(x):
1134+
masks.append(np.ma.getmaskarray(x))
1135+
margs.append(x) # Possibly modified.
1136+
if len(masks):
1137+
mask = np.logical_or.reduce(masks)
1138+
for i, x in enumerate(margs):
1139+
if seqlist[i]:
1140+
margs[i] = np.ma.array(x, mask=mask)
1141+
return margs
1142+
1143+
10841144
def boxplot_stats(X, whis=1.5, bootstrap=None, labels=None,
10851145
autorange=False):
10861146
"""

lib/matplotlib/pyplot.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2835,12 +2835,13 @@ def quiverkey(Q, X, Y, U, label, **kw):
28352835
def scatter(
28362836
x, y, s=None, c=None, marker=None, cmap=None, norm=None,
28372837
vmin=None, vmax=None, alpha=None, linewidths=None, verts=None,
2838-
edgecolors=None, *, data=None, **kwargs):
2838+
edgecolors=None, *, plotnonfinite=False, data=None, **kwargs):
28392839
__ret = gca().scatter(
28402840
x, y, s=s, c=c, marker=marker, cmap=cmap, norm=norm,
28412841
vmin=vmin, vmax=vmax, alpha=alpha, linewidths=linewidths,
2842-
verts=verts, edgecolors=edgecolors, **({"data": data} if data
2843-
is not None else {}), **kwargs)
2842+
verts=verts, edgecolors=edgecolors,
2843+
plotnonfinite=plotnonfinite, **({"data": data} if data is not
2844+
None else {}), **kwargs)
28442845
sci(__ret)
28452846
return __ret
28462847

lib/matplotlib/testing/decorators.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -452,19 +452,37 @@ def decorator(func):
452452

453453
_, result_dir = map(Path, _image_directories(func))
454454

455-
@pytest.mark.parametrize("ext", extensions)
456-
def wrapper(ext):
457-
fig_test = plt.figure("test")
458-
fig_ref = plt.figure("reference")
459-
func(fig_test, fig_ref)
460-
test_image_path = str(
461-
result_dir / (func.__name__ + "." + ext))
462-
ref_image_path = str(
463-
result_dir / (func.__name__ + "-expected." + ext))
464-
fig_test.savefig(test_image_path)
465-
fig_ref.savefig(ref_image_path)
466-
_raise_on_image_difference(
467-
ref_image_path, test_image_path, tol=tol)
455+
if len(inspect.signature(func).parameters) == 2:
456+
# Free-standing function.
457+
@pytest.mark.parametrize("ext", extensions)
458+
def wrapper(ext):
459+
fig_test = plt.figure("test")
460+
fig_ref = plt.figure("reference")
461+
func(fig_test, fig_ref)
462+
test_image_path = str(
463+
result_dir / (func.__name__ + "." + ext))
464+
ref_image_path = str(
465+
result_dir / (func.__name__ + "-expected." + ext))
466+
fig_test.savefig(test_image_path)
467+
fig_ref.savefig(ref_image_path)
468+
_raise_on_image_difference(
469+
ref_image_path, test_image_path, tol=tol)
470+
471+
elif len(inspect.signature(func).parameters) == 3:
472+
# Method.
473+
@pytest.mark.parametrize("ext", extensions)
474+
def wrapper(self, ext):
475+
fig_test = plt.figure("test")
476+
fig_ref = plt.figure("reference")
477+
func(self, fig_test, fig_ref)
478+
test_image_path = str(
479+
result_dir / (func.__name__ + "." + ext))
480+
ref_image_path = str(
481+
result_dir / (func.__name__ + "-expected." + ext))
482+
fig_test.savefig(test_image_path)
483+
fig_ref.savefig(ref_image_path)
484+
_raise_on_image_difference(
485+
ref_image_path, test_image_path, tol=tol)
468486

469487
return wrapper
470488

lib/matplotlib/tests/test_axes.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,6 +1749,34 @@ def test_scatter_color(self):
17491749
with pytest.raises(ValueError):
17501750
plt.scatter([1, 2, 3], [1, 2, 3], color=[1, 2, 3])
17511751

1752+
@check_figures_equal(extensions=["png"])
1753+
def test_scatter_invalid_color(self, fig_test, fig_ref):
1754+
ax = fig_test.subplots()
1755+
cmap = plt.get_cmap("viridis", 16)
1756+
cmap.set_bad("k", 1)
1757+
# Set a nonuniform size to prevent the last call to `scatter` (plotting
1758+
# the invalid points separately in fig_ref) from using the marker
1759+
# stamping fast path, which would result in slightly offset markers.
1760+
ax.scatter(range(4), range(4),
1761+
c=[1, np.nan, 2, np.nan], s=[1, 2, 3, 4],
1762+
cmap=cmap, plotnonfinite=True)
1763+
ax = fig_ref.subplots()
1764+
cmap = plt.get_cmap("viridis", 16)
1765+
ax.scatter([0, 2], [0, 2], c=[1, 2], s=[1, 3], cmap=cmap)
1766+
ax.scatter([1, 3], [1, 3], s=[2, 4], color="k")
1767+
1768+
@check_figures_equal(extensions=["png"])
1769+
def test_scatter_no_invalid_color(self, fig_test, fig_ref):
1770+
# With plotninfinite=False we plot only 2 points.
1771+
ax = fig_test.subplots()
1772+
cmap = plt.get_cmap("viridis", 16)
1773+
cmap.set_bad("k", 1)
1774+
ax.scatter(range(4), range(4),
1775+
c=[1, np.nan, 2, np.nan], s=[1, 2, 3, 4],
1776+
cmap=cmap, plotnonfinite=False)
1777+
ax = fig_ref.subplots()
1778+
ax.scatter([0, 2], [0, 2], c=[1, 2], s=[1, 3], cmap=cmap)
1779+
17521780
# Parameters for *test_scatter_c*. NB: assuming that the
17531781
# scatter plot will have 4 elements. The tuple scheme is:
17541782
# (*c* parameter case, exception regexp key or None if no exception)
@@ -5743,21 +5771,6 @@ def test_color_length_mismatch():
57435771
ax.scatter(x, y, c=[c_rgb] * N)
57445772

57455773

5746-
def test_scatter_color_masking():
5747-
x = np.array([1, 2, 3])
5748-
y = np.array([1, np.nan, 3])
5749-
colors = np.array(['k', 'w', 'k'])
5750-
linewidths = np.array([1, 2, 3])
5751-
s = plt.scatter(x, y, color=colors, linewidths=linewidths)
5752-
5753-
facecolors = s.get_facecolors()
5754-
linecolors = s.get_edgecolors()
5755-
linewidths = s.get_linewidths()
5756-
assert_array_equal(facecolors[1], np.array([0, 0, 0, 1]))
5757-
assert_array_equal(linecolors[1], np.array([0, 0, 0, 1]))
5758-
assert linewidths[1] == 3
5759-
5760-
57615774
def test_eventplot_legend():
57625775
plt.eventplot([1.0], label='Label')
57635776
plt.legend()

lib/matplotlib/tests/test_colorbar.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,8 @@ def test_colorbar_single_scatter():
197197
# the norm scaling within the colorbar must ensure a
198198
# finite range, otherwise a zero denominator will occur in _locate.
199199
plt.figure()
200-
x = np.arange(4)
201-
y = x.copy()
202-
z = np.ma.masked_greater(np.arange(50, 54), 50)
200+
x = y = [0]
201+
z = [50]
203202
cmap = plt.get_cmap('jet', 16)
204203
cs = plt.scatter(x, y, z, c=z, cmap=cmap)
205204
plt.colorbar(cs)

0 commit comments

Comments
 (0)