diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index c56e4c6b7039..f389e87801fd 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -2198,9 +2198,12 @@ def plot_surface(self, X, Y, Z, *, norm=None, vmin=None, vmin, vmax : float, optional Bounds for the normalization. - shade : bool, default: True - Whether to shade the facecolors. Shading is always disabled when - *cmap* is specified. + shade : bool or "auto", default: "auto" + Whether to shade the facecolors. "auto" will shade only if the facecolor + is uniform, i.e. neither *cmap* nor *facecolors* is given. + + Furthermore, shading is generally not compatible with colormapping + and ``shade=True, cmap=...`` will raise an error. lightsource : `~matplotlib.colors.LightSource`, optional The lightsource to use when *shade* is True. @@ -2251,9 +2254,15 @@ def plot_surface(self, X, Y, Z, *, norm=None, vmin=None, fcolors = kwargs.pop('facecolors', None) cmap = kwargs.get('cmap', None) - shade = kwargs.pop('shade', cmap is None) - if shade is None: - raise ValueError("shade cannot be None.") + shade = kwargs.pop('shade', 'auto') + if shade == "auto": + shade = cmap is None and fcolors is None + # Remove the None check as it doesn't seem to be needed + + # Raise error if shade=True and cmap is provided as documented + if shade is True and cmap is not None: + raise ValueError("Shading is not compatible with colormapping. " + "Set shade=False or do not provide a cmap.") colset = [] # the sampled facecolor if (rows - 1) % rstride == 0 and \ diff --git a/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py b/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py index e6d11f793b46..220ea875ff71 100644 --- a/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py +++ b/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py @@ -703,7 +703,7 @@ def test_surface3d_masked(): z = np.ma.masked_less(matrix, 0) norm = mcolors.Normalize(vmax=z.max(), vmin=z.min()) colors = mpl.colormaps["plasma"](norm(z)) - ax.plot_surface(x, y, z, facecolors=colors) + ax.plot_surface(x, y, z, facecolors=colors, shade=True) ax.view_init(30, -80, 0) @@ -2689,3 +2689,87 @@ def test_ndarray_color_kwargs_value_error(): ax = fig.add_subplot(111, projection='3d') ax.scatter(1, 0, 0, color=np.array([0, 0, 0, 1])) fig.canvas.draw() + + +@check_figures_equal() +def test_plot_surface_shade_auto_with_facecolors(fig_test, fig_ref): + """Test that plot_surface with facecolors uses shade=False by default.""" + X = np.linspace(0, 1, 5) + Y = np.linspace(0, 1, 5) + X_mesh, Y_mesh = np.meshgrid(X, Y) + Z = X_mesh + Y_mesh + colors = cm.viridis(X_mesh) + + # Test with facecolors (should have shade=False by default) + ax_test = fig_test.add_subplot(projection='3d') + ax_test.plot_surface(X_mesh, Y_mesh, Z, facecolors=colors) + + # Reference with explicit shade=False + ax_ref = fig_ref.add_subplot(projection='3d') + ax_ref.plot_surface(X_mesh, Y_mesh, Z, facecolors=colors, shade=False) + + +@check_figures_equal() +def test_plot_surface_shade_auto_without_facecolors(fig_test, fig_ref): + """Test that plot_surface without facecolors uses shade=True by default.""" + X = np.linspace(0, 1, 5) + Y = np.linspace(0, 1, 5) + X_mesh, Y_mesh = np.meshgrid(X, Y) + Z = X_mesh + Y_mesh + + # Test without facecolors (should have shade=True by default) + ax_test = fig_test.add_subplot(projection='3d') + ax_test.plot_surface(X_mesh, Y_mesh, Z) + + # Reference with explicit shade=True + ax_ref = fig_ref.add_subplot(projection='3d') + ax_ref.plot_surface(X_mesh, Y_mesh, Z, shade=True) + + +@check_figures_equal() +def test_plot_surface_shade_auto_with_cmap(fig_test, fig_ref): + """Test that plot_surface with cmap uses shade=False by default.""" + X = np.linspace(0, 1, 5) + Y = np.linspace(0, 1, 5) + X_mesh, Y_mesh = np.meshgrid(X, Y) + Z = X_mesh + Y_mesh + + # Test with cmap (should have shade=False by default) + ax_test = fig_test.add_subplot(projection='3d') + ax_test.plot_surface(X_mesh, Y_mesh, Z, cmap=cm.viridis) + + # Reference with explicit shade=False + ax_ref = fig_ref.add_subplot(projection='3d') + ax_ref.plot_surface(X_mesh, Y_mesh, Z, cmap=cm.viridis, shade=False) + + +@check_figures_equal() +def test_plot_surface_shade_override_with_facecolors(fig_test, fig_ref): + """Test that explicit shade parameter overrides auto behavior with facecolors.""" + X = np.linspace(0, 1, 5) + Y = np.linspace(0, 1, 5) + X_mesh, Y_mesh = np.meshgrid(X, Y) + Z = X_mesh + Y_mesh + colors = cm.viridis(X_mesh) + + # Test with explicit shade=True (overrides auto behavior) + ax_test = fig_test.add_subplot(projection='3d') + ax_test.plot_surface(X_mesh, Y_mesh, Z, facecolors=colors, shade=True) + + # Reference with explicit shade=True + ax_ref = fig_ref.add_subplot(projection='3d') + ax_ref.plot_surface(X_mesh, Y_mesh, Z, facecolors=colors, shade=True) + + +def test_plot_surface_shade_with_cmap_raises(): + """Test that shade=True with cmap raises an error.""" + X = np.linspace(0, 1, 5) + Y = np.linspace(0, 1, 5) + X_mesh, Y_mesh = np.meshgrid(X, Y) + Z = X_mesh + Y_mesh + + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + + with pytest.raises(ValueError, match="Shading is not compatible with colormapping"): + ax.plot_surface(X_mesh, Y_mesh, Z, cmap=cm.viridis, shade=True)