Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 494b12b

Browse files
committed
Refactor c_array checking; Lower N in color len mismatch test
1 parent 546372d commit 494b12b

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3949,6 +3949,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
39493949

39503950
# np.ma.ravel yields an ndarray, not a masked array,
39513951
# unless its argument is a masked array.
3952+
xy_shape = (np.shape(x), np.shape(y))
39523953
x = np.ma.ravel(x)
39533954
y = np.ma.ravel(y)
39543955
if x.size != y.size:
@@ -3971,18 +3972,22 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
39713972
else:
39723973
try:
39733974
c_array = np.asanyarray(c, dtype=float)
3975+
if c_array.shape in xy_shape:
3976+
c = np.ma.ravel(c_array)
3977+
else:
3978+
# Wrong size; it must not be intended for mapping.
3979+
c_array = None
39743980
except ValueError:
39753981
# Failed to make a floating-point array; c must be color specs.
39763982
c_array = None
3977-
else:
3978-
if c_array.size == x.size:
3979-
c = np.ma.ravel(c_array)
3980-
elif c_array.size not in (3, 4):
3981-
# Wrong size. Not a rgb/rgba and not same size as x
3982-
raise ValueError("x and c must be the same size")
39833983

39843984
if c_array is None:
39853985
colors = c # must be acceptable as PathCollection facecolors
3986+
try:
3987+
mcolors.to_rgba_array(colors)
3988+
except ValueError:
3989+
# c not acceptable as PathCollection facecolor
3990+
raise ValueError("c not acceptable as color sequence")
39863991
else:
39873992
colors = None # use cmap, norm after collection is created
39883993

lib/matplotlib/tests/test_axes.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4717,9 +4717,12 @@ def test_fillbetween_cycle():
47174717

47184718
@cleanup
47194719
def test_color_length_mismatch():
4720-
N = 500
4721-
x, y = np.random.rand(N), np.random.rand(N)
4722-
colors = np.random.rand(N+1)
4720+
N = 5
4721+
x, y = np.arange(N), np.arange(N)
4722+
colors = np.arange(N+1)
47234723
fig, ax = plt.subplots()
47244724
with pytest.raises(ValueError):
47254725
ax.scatter(x, y, c=colors)
4726+
c_rgb = (0.5, 0.5, 0.5)
4727+
ax.scatter(x, y, c=c_rgb)
4728+
ax.scatter(x, y, c=[c_rgb] * N)

0 commit comments

Comments
 (0)