Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d1bf75f

Browse files
committed
Refactor c_array checking; Lower N in color len mismatch test
1 parent e75a3c4 commit d1bf75f

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3992,6 +3992,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
39923992

39933993
# np.ma.ravel yields an ndarray, not a masked array,
39943994
# unless its argument is a masked array.
3995+
xy_shape = (np.shape(x), np.shape(y))
39953996
x = np.ma.ravel(x)
39963997
y = np.ma.ravel(y)
39973998
if x.size != y.size:
@@ -4014,18 +4015,22 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
40144015
else:
40154016
try:
40164017
c_array = np.asanyarray(c, dtype=float)
4018+
if c_array.shape in xy_shape:
4019+
c = np.ma.ravel(c_array)
4020+
else:
4021+
# Wrong size; it must not be intended for mapping.
4022+
c_array = None
40174023
except ValueError:
40184024
# Failed to make a floating-point array; c must be color specs.
40194025
c_array = None
4020-
else:
4021-
if c_array.size == x.size:
4022-
c = np.ma.ravel(c_array)
4023-
elif c_array.size not in (3, 4):
4024-
# Wrong size. Not a rgb/rgba and not same size as x
4025-
raise ValueError("x and c must be the same size")
40264026

40274027
if c_array is None:
40284028
colors = c # must be acceptable as PathCollection facecolors
4029+
try:
4030+
mcolors.to_rgba_array(colors)
4031+
except ValueError:
4032+
# c not acceptable as PathCollection facecolor
4033+
raise ValueError("c not acceptable as color sequence")
40294034
else:
40304035
colors = None # use cmap, norm after collection is created
40314036

lib/matplotlib/tests/test_axes.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4717,9 +4717,12 @@ def test_fillbetween_cycle():
47174717

47184718
@cleanup
47194719
def test_color_length_mismatch():
4720-
N = 500
4721-
x, y = np.random.rand(N), np.random.rand(N)
4722-
colors = np.random.rand(N+1)
4720+
N = 5
4721+
x, y = np.arange(N), np.arange(N)
4722+
colors = np.arange(N+1)
47234723
fig, ax = plt.subplots()
47244724
with pytest.raises(ValueError):
47254725
ax.scatter(x, y, c=colors)
4726+
c_rgb = (0.5, 0.5, 0.5)
4727+
ax.scatter(x, y, c=c_rgb)
4728+
ax.scatter(x, y, c=[c_rgb] * N)

0 commit comments

Comments
 (0)