diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index 73385ec91256..23755d55b6a6 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -1984,47 +1984,30 @@ def plot_trisurf(self, *args, **kwargs): args = args[1:] triangles = tri.get_masked_triangles() - xt = tri.x[triangles][..., np.newaxis] - yt = tri.y[triangles][..., np.newaxis] - zt = z[triangles][..., np.newaxis] + xt = tri.x[triangles] + yt = tri.y[triangles] + zt = z[triangles] - verts = np.concatenate((xt, yt, zt), axis=2) - - # Only need these vectors to shade if there is no cmap - if cmap is None and shade: - totpts = len(verts) - v1 = np.empty((totpts, 3)) - v2 = np.empty((totpts, 3)) - # This indexes the vertex points - which_pt = 0 - - colset = [] - for i in xrange(len(verts)): - avgzsum = verts[i,0,2] + verts[i,1,2] + verts[i,2,2] - colset.append(avgzsum / 3.0) - - # Only need vectors to shade if no cmap - if cmap is None and shade: - v1[which_pt] = np.array(verts[i,0]) - np.array(verts[i,1]) - v2[which_pt] = np.array(verts[i,1]) - np.array(verts[i,2]) - which_pt += 1 - - if cmap is None and shade: - normals = np.cross(v1, v2) - else: - normals = [] + # verts = np.stack((xt, yt, zt), axis=-1) + verts = np.concatenate(( + xt[..., np.newaxis], yt[..., np.newaxis], zt[..., np.newaxis] + ), axis=-1) polyc = art3d.Poly3DCollection(verts, *args, **kwargs) if cmap: - colset = np.array(colset) - polyc.set_array(colset) + # average over the three points of each triangle + avg_z = verts[:, :, 2].mean(axis=1) + polyc.set_array(avg_z) if vmin is not None or vmax is not None: polyc.set_clim(vmin, vmax) if norm is not None: polyc.set_norm(norm) else: if shade: + v1 = verts[:, 0, :] - verts[:, 1, :] + v2 = verts[:, 1, :] - verts[:, 2, :] + normals = np.cross(v1, v2) colset = self._shade_colors(color, normals) else: colset = color diff --git a/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/trisurf3d_shaded.png b/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/trisurf3d_shaded.png new file mode 100644 index 000000000000..9faa9d915349 Binary files /dev/null and b/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/trisurf3d_shaded.png differ diff --git a/lib/mpl_toolkits/tests/test_mplot3d.py b/lib/mpl_toolkits/tests/test_mplot3d.py index c157433c752a..b948c7381627 100644 --- a/lib/mpl_toolkits/tests/test_mplot3d.py +++ b/lib/mpl_toolkits/tests/test_mplot3d.py @@ -216,6 +216,25 @@ def test_trisurf3d(): ax.plot_trisurf(x, y, z, cmap=cm.jet, linewidth=0.2) +@image_comparison(baseline_images=['trisurf3d_shaded'], remove_text=True, + tol=0.03, extensions=['png']) +def test_trisurf3d_shaded(): + n_angles = 36 + n_radii = 8 + radii = np.linspace(0.125, 1.0, n_radii) + angles = np.linspace(0, 2*np.pi, n_angles, endpoint=False) + angles = np.repeat(angles[..., np.newaxis], n_radii, axis=1) + angles[:, 1::2] += np.pi/n_angles + + x = np.append(0, (radii*np.cos(angles)).flatten()) + y = np.append(0, (radii*np.sin(angles)).flatten()) + z = np.sin(-x*y) + + fig = plt.figure() + ax = fig.gca(projection='3d') + ax.plot_trisurf(x, y, z, color=[1, 0.5, 0], linewidth=0.2) + + @image_comparison(baseline_images=['wireframe3d'], remove_text=True) def test_wireframe3d(): fig = plt.figure()