Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d03b0e0

Browse files
committed
Refactor c_array checking; Lower N in color len mismatch test
1 parent e218bcf commit d03b0e0

2 files changed

Lines changed: 17 additions & 9 deletions

File tree

lib/matplotlib/axes/_axes.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3952,6 +3952,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
39523952

39533953
# np.ma.ravel yields an ndarray, not a masked array,
39543954
# unless its argument is a masked array.
3955+
xy_shape = (np.shape(x), np.shape(y))
39553956
x = np.ma.ravel(x)
39563957
y = np.ma.ravel(y)
39573958
if x.size != y.size:
@@ -3974,18 +3975,22 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
39743975
else:
39753976
try:
39763977
c_array = np.asanyarray(c, dtype=float)
3978+
if c_array.shape in xy_shape:
3979+
c = np.ma.ravel(c_array)
3980+
else:
3981+
# Wrong size; it must not be intended for mapping.
3982+
c_array = None
39773983
except ValueError:
39783984
# Failed to make a floating-point array; c must be color specs.
39793985
c_array = None
3980-
else:
3981-
if c_array.size == x.size:
3982-
c = np.ma.ravel(c_array)
3983-
elif c_array.size not in (3, 4):
3984-
# Wrong size. Not a rgb/rgba and not same size as x
3985-
raise ValueError("x and c must be the same size")
39863986

39873987
if c_array is None:
39883988
colors = c # must be acceptable as PathCollection facecolors
3989+
try:
3990+
mcolors.to_rgba_array(colors)
3991+
except ValueError:
3992+
# c not acceptable as PathCollection facecolor
3993+
raise ValueError("c not acceptable as color sequence")
39893994
else:
39903995
colors = None # use cmap, norm after collection is created
39913996

lib/matplotlib/tests/test_axes.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4801,9 +4801,12 @@ def test_fillbetween_cycle():
48014801

48024802
@cleanup
48034803
def test_color_length_mismatch():
4804-
N = 500
4805-
x, y = np.random.rand(N), np.random.rand(N)
4806-
colors = np.random.rand(N+1)
4804+
N = 5
4805+
x, y = np.arange(N), np.arange(N)
4806+
colors = np.arange(N+1)
48074807
fig, ax = plt.subplots()
48084808
with pytest.raises(ValueError):
48094809
ax.scatter(x, y, c=colors)
4810+
c_rgb = (0.5, 0.5, 0.5)
4811+
ax.scatter(x, y, c=c_rgb)
4812+
ax.scatter(x, y, c=[c_rgb] * N)

0 commit comments

Comments
 (0)