Thanks to visit codestin.com
Credit goes to github.com

Skip to content

MAINT: Use vectorization in plot_trisurf, simplifying greatly #9991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 13 additions & 30 deletions lib/mpl_toolkits/mplot3d/axes3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1984,47 +1984,30 @@ def plot_trisurf(self, *args, **kwargs):
args = args[1:]

triangles = tri.get_masked_triangles()
xt = tri.x[triangles][..., np.newaxis]
yt = tri.y[triangles][..., np.newaxis]
zt = z[triangles][..., np.newaxis]
xt = tri.x[triangles]
yt = tri.y[triangles]
zt = z[triangles]

verts = np.concatenate((xt, yt, zt), axis=2)

# Only need these vectors to shade if there is no cmap
if cmap is None and shade:
totpts = len(verts)
v1 = np.empty((totpts, 3))
v2 = np.empty((totpts, 3))
# This indexes the vertex points
which_pt = 0

colset = []
for i in xrange(len(verts)):
avgzsum = verts[i,0,2] + verts[i,1,2] + verts[i,2,2]
colset.append(avgzsum / 3.0)

# Only need vectors to shade if no cmap
if cmap is None and shade:
v1[which_pt] = np.array(verts[i,0]) - np.array(verts[i,1])
v2[which_pt] = np.array(verts[i,1]) - np.array(verts[i,2])
which_pt += 1

if cmap is None and shade:
normals = np.cross(v1, v2)
else:
normals = []
# verts = np.stack((xt, yt, zt), axis=-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are they 1D or 2D? np.vstack or np.dstack could also be used in those cases.

Copy link
Contributor Author

@eric-wieser eric-wieser Dec 14, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this was a new change since you last reviewed. Both vstack and dstack have undesirable semantics of guessing the array .ndim, so using concatenate is more precise.

This isn't the only place that np.stack is mentioned in a comment above a np.concatenate((a[..., None], ...)), so might be worth backporting to the very simple

def stack(arrays, axis):
    return np.concatenate([arr[..., np.newaxis] for arr in arrays], axis)

- (oops, not correct)

Not something I want to involve in this PR though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just stick the definition in https://github.com/matplotlib/matplotlib/blob/master/lib/matplotlib/cbook/_backports.py if you change your mind.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, and to answer your question, 2D - (N, edges_in_triangle)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do have a cbook._backports; not sure how strong the need is for it without having done any looking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reckon there're a substantial number of cases - I can put together a PR once this is merged

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking again, a stack_last function is all that ever seems to be needed, implemented as the above - so that wouldn't belong in _backports. Suggestions of where to put such a helper?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cbook would be the place.
But given that we'd probably just be using stack(..., -1) "if it was available", I'd just backport stack.
Let us know if you want to do this in this PR or a separate one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely a separate one

verts = np.concatenate((
xt[..., np.newaxis], yt[..., np.newaxis], zt[..., np.newaxis]
), axis=-1)

polyc = art3d.Poly3DCollection(verts, *args, **kwargs)

if cmap:
colset = np.array(colset)
polyc.set_array(colset)
# average over the three points of each triangle
avg_z = verts[:, :, 2].mean(axis=1)
polyc.set_array(avg_z)
if vmin is not None or vmax is not None:
polyc.set_clim(vmin, vmax)
if norm is not None:
polyc.set_norm(norm)
else:
if shade:
v1 = verts[:, 0, :] - verts[:, 1, :]
v2 = verts[:, 1, :] - verts[:, 2, :]
normals = np.cross(v1, v2)
colset = self._shade_colors(color, normals)
else:
colset = color
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 19 additions & 0 deletions lib/mpl_toolkits/tests/test_mplot3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,25 @@ def test_trisurf3d():
ax.plot_trisurf(x, y, z, cmap=cm.jet, linewidth=0.2)


@image_comparison(baseline_images=['trisurf3d_shaded'], remove_text=True,
tol=0.03, extensions=['png'])
def test_trisurf3d_shaded():
n_angles = 36
n_radii = 8
radii = np.linspace(0.125, 1.0, n_radii)
angles = np.linspace(0, 2*np.pi, n_angles, endpoint=False)
angles = np.repeat(angles[..., np.newaxis], n_radii, axis=1)
angles[:, 1::2] += np.pi/n_angles

x = np.append(0, (radii*np.cos(angles)).flatten())
y = np.append(0, (radii*np.sin(angles)).flatten())
z = np.sin(-x*y)

fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot_trisurf(x, y, z, color=[1, 0.5, 0], linewidth=0.2)


@image_comparison(baseline_images=['wireframe3d'], remove_text=True)
def test_wireframe3d():
fig = plt.figure()
Expand Down