Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 5cd1494

Browse files
committed
Alternative approach; addresses #10381
1 parent c0b5013 commit 5cd1494

File tree

3 files changed

+93
-21
lines changed

3 files changed

+93
-21
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4240,17 +4240,25 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
42404240
else:
42414241
colors = None # use cmap, norm after collection is created
42424242

4243+
# if plotinvalid and colors == None:
4244+
# # Do full color mapping; don't remove invalid c entries.
4245+
# ind = np.arange(len(c))
4246+
# x, y, s, ind, colors, edgecolors, linewidths =\
4247+
# cbook.delete_masked_points(
4248+
# x, y, s, ind, colors, edgecolors, linewidths)
4249+
# c = np.ma.masked_invalid(c[ind])
4250+
# else:
4251+
# x, y, s, c, colors, edgecolors, linewidths =\
4252+
# cbook.delete_masked_points(
4253+
# x, y, s, c, colors, edgecolors, linewidths)
4254+
42434255
if plotinvalid and colors == None:
4244-
# Do full color mapping; don't remove invalid c entries.
4245-
ind = np.arange(len(c))
4246-
x, y, s, ind, colors, edgecolors, linewidths =\
4247-
cbook.delete_masked_points(
4248-
x, y, s, ind, colors, edgecolors, linewidths)
4249-
c = np.ma.masked_invalid(c[ind])
4256+
c = np.ma.masked_invalid(c)
4257+
x, y, s, colors, edgecolors, linewidths =\
4258+
cbook.combine_masks(x, y, s, colors, edgecolors, linewidths)
42504259
else:
42514260
x, y, s, c, colors, edgecolors, linewidths =\
4252-
cbook.delete_masked_points(
4253-
x, y, s, c, colors, edgecolors, linewidths)
4261+
cbook.combine_masks(x, y, s, c, colors, edgecolors, linewidths)
42544262

42554263
scales = s # Renamed for readability below.
42564264

lib/matplotlib/cbook/__init__.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,6 +1044,69 @@ def delete_masked_points(*args):
10441044
return margs
10451045

10461046

1047+
def combine_masks(*args):
1048+
"""
1049+
Find all masked and/or non-finite points in a set of arguments,
1050+
and return the arguments as masked arrays with a common mask.
1051+
1052+
Arguments can be in any of 5 categories:
1053+
1054+
1) 1-D masked arrays
1055+
2) 1-D ndarrays
1056+
3) ndarrays with more than one dimension
1057+
4) other non-string iterables
1058+
5) anything else
1059+
1060+
The first argument must be in one of the first four categories;
1061+
any argument with a length differing from that of the first
1062+
argument (and hence anything in category 5) then will be
1063+
passed through unchanged.
1064+
1065+
Masks are obtained from all arguments of the correct length
1066+
in categories 1, 2, and 4; a point is bad if masked in a masked
1067+
array or if it is a nan or inf. No attempt is made to
1068+
extract a mask from categories 2, 3, and 4 if :meth:`np.isfinite`
1069+
does not yield a Boolean array.
1070+
1071+
All input arguments that are not passed unchanged are returned
1072+
as masked arrays if any masked points are found, otherwise as
1073+
ndarrays.
1074+
1075+
"""
1076+
if not len(args):
1077+
return ()
1078+
if is_scalar_or_string(args[0]):
1079+
raise ValueError("First argument must be a sequence")
1080+
nrecs = len(args[0])
1081+
margs = []
1082+
seqlist = [False] * len(args)
1083+
for i, x in enumerate(args):
1084+
if not isinstance(x, str) and np.iterable(x) and len(x) == nrecs:
1085+
if isinstance(x, np.ma.MaskedArray):
1086+
if x.ndim > 1:
1087+
raise ValueError("Masked arrays must be 1-D")
1088+
x = np.asanyarray(x)
1089+
if x.ndim == 1 and x.dtype.kind == 'f':
1090+
x = np.ma.masked_invalid(x)
1091+
seqlist[i] = True
1092+
margs.append(x)
1093+
masks = [] # list of masks that are True where bad
1094+
for i, x in enumerate(margs):
1095+
if seqlist[i]:
1096+
if x.ndim > 1:
1097+
continue # Don't try to get nan locations unless 1-D.
1098+
if np.ma.is_masked(x):
1099+
masks.append(np.ma.getmaskarray(x))
1100+
if len(masks):
1101+
mask = np.logical_or.reduce(masks)
1102+
if mask.any():
1103+
for i, x in enumerate(margs):
1104+
if seqlist[i]:
1105+
margs[i] = np.ma.array(x)
1106+
margs[i][mask] = np.ma.masked
1107+
return margs
1108+
1109+
10471110
def boxplot_stats(X, whis=1.5, bootstrap=None, labels=None,
10481111
autorange=False):
10491112
"""

lib/matplotlib/tests/test_axes.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5648,19 +5648,20 @@ def test_color_length_mismatch():
56485648
ax.scatter(x, y, c=[c_rgb] * N)
56495649

56505650

5651-
def test_scatter_color_masking():
5652-
x = np.array([1, 2, 3])
5653-
y = np.array([1, np.nan, 3])
5654-
colors = np.array(['k', 'w', 'k'])
5655-
linewidths = np.array([1, 2, 3])
5656-
s = plt.scatter(x, y, color=colors, linewidths=linewidths)
5657-
5658-
facecolors = s.get_facecolors()
5659-
linecolors = s.get_edgecolors()
5660-
linewidths = s.get_linewidths()
5661-
assert_array_equal(facecolors[1], np.array([0, 0, 0, 1]))
5662-
assert_array_equal(linecolors[1], np.array([0, 0, 0, 1]))
5663-
assert linewidths[1] == 3
5651+
# The following test is based on the old behavior of deleting bad points.
5652+
# def test_scatter_color_masking():
5653+
# x = np.array([1, 2, 3])
5654+
# y = np.array([1, np.nan, 3])
5655+
# colors = np.array(['k', 'w', 'k'])
5656+
# linewidths = np.array([1, 2, 3])
5657+
# s = plt.scatter(x, y, color=colors, linewidths=linewidths)
5658+
#
5659+
# facecolors = s.get_facecolors()
5660+
# linecolors = s.get_edgecolors()
5661+
# linewidths = s.get_linewidths()
5662+
# assert_array_equal(facecolors[1], np.array([0, 0, 0, 1]))
5663+
# assert_array_equal(linecolors[1], np.array([0, 0, 0, 1]))
5664+
# assert linewidths[1] == 3
56645665

56655666

56665667
def test_eventplot_legend():

0 commit comments

Comments
 (0)