From b0f0781b8679826aa8e1419aaec96c023c63a87e Mon Sep 17 00:00:00 2001 From: Artem Shekhovtsov Date: Tue, 4 Jul 2023 09:44:59 +0300 Subject: [PATCH] FIX: axes3d.scatter color parameter doesn't decrease in size for non-finite coordinate inputs. --- lib/mpl_toolkits/mplot3d/axes3d.py | 6 ++++- lib/mpl_toolkits/mplot3d/tests/test_axes3d.py | 26 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index 25cf17cab126..f2776f26dfc3 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -2249,7 +2249,11 @@ def scatter(self, xs, ys, zs=0, zdir='z', s=20, c=None, depthshade=True, *[np.ravel(np.ma.filled(t, np.nan)) for t in [xs, ys, zs]]) s = np.ma.ravel(s) # This doesn't have to match x, y in size. - xs, ys, zs, s, c = cbook.delete_masked_points(xs, ys, zs, s, c) + xs, ys, zs, s, c, color = cbook.delete_masked_points( + xs, ys, zs, s, c, kwargs.get('color', None) + ) + if kwargs.get('color', None): + kwargs['color'] = color # For xs and ys, 2D scatter() will do the copying. if np.may_share_memory(zs_orig, zs): # Avoid unnecessary copies. diff --git a/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py b/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py index dbc0f23876c0..140ef9413408 100644 --- a/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py +++ b/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py @@ -2226,3 +2226,29 @@ def test_mutating_input_arrays_y_and_z(fig_test, fig_ref): y = [0.0, 0.0, 0.0] z = [0.0, 0.0, 0.0] ax2.plot(x, y, z, 'o-') + + +def test_scatter_masked_color(): + """ + Test color parameter usage with non-finite coordinate arrays. + + GH#26236 + """ + + x = [np.nan, 1, 2, 1] + y = [0, np.inf, 2, 1] + z = [0, 1, -np.inf, 1] + colors = [ + [0.0, 0.0, 0.0, 1], + [0.0, 0.0, 0.0, 1], + [0.0, 0.0, 0.0, 1], + [0.0, 0.0, 0.0, 1] + ] + + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + path3d = ax.scatter(x, y, z, color=colors) + + # Assert sizes' equality + assert len(path3d.get_offsets()) ==\ + len(super(type(path3d), path3d).get_facecolors())