Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 86e1518

Browse files
committed
MAINT: Simplify logic in plot_surface
Previously: * "cell" perimeters were clumsily calculated with duplicates, which were then removed at runtime * code to calculate normals was spread into multiple places * average z was calculated even if not used * repeated conversion between stride and count was done Should have no visible behavior changes
1 parent 024d423 commit 86e1518

File tree

1 file changed

+31
-38
lines changed

1 file changed

+31
-38
lines changed

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,64 +1671,57 @@ def plot_surface(self, X, Y, Z, *args, **kwargs):
16711671
fcolors = self._shade_colors_lightsource(Z, cmap, lightsource)
16721672

16731673
polys = []
1674-
# Only need these vectors to shade if there is no cmap
1675-
if cmap is None and shade :
1676-
totpts = int(np.ceil(float(rows - 1) / rstride) *
1677-
np.ceil(float(cols - 1) / cstride))
1678-
v1 = np.empty((totpts, 3))
1679-
v2 = np.empty((totpts, 3))
1680-
# This indexes the vertex points
1681-
which_pt = 0
1674+
1675+
# evenly spaced, and including both endpoints
1676+
row_inds = list(xrange(0, rows-1, rstride)) + [rows-1]
1677+
col_inds = list(xrange(0, cols-1, cstride)) + [cols-1]
16821678

16831679

1684-
#colset contains the data for coloring: either average z or the facecolor
1680+
#colset contains the sampled facecolors
16851681
colset = []
1686-
for rs in xrange(0, rows-1, rstride):
1687-
for cs in xrange(0, cols-1, cstride):
1682+
for rs, rs_next in zip(row_inds[:-1], row_inds[1:]):
1683+
for cs, cs_next in zip(col_inds[:-1], col_inds[1:]):
16881684
ps = []
16891685
for a in (X, Y, Z):
1690-
ztop = a[rs,cs:min(cols, cs+cstride+1)]
1691-
zleft = a[rs+1:min(rows, rs+rstride+1),
1692-
min(cols-1, cs+cstride)]
1693-
zbase = a[min(rows-1, rs+rstride), cs:min(cols, cs+cstride+1):][::-1]
1694-
zright = a[rs:min(rows-1, rs+rstride):, cs][::-1]
1686+
# the edges of the projected quadrilateral
1687+
# note we use pythons half-open ranges to avoid repeating
1688+
# the corners
1689+
ztop = a[rs, cs:cs_next ]
1690+
zleft = a[rs:rs_next, cs_next ]
1691+
zbase = a[rs_next, cs_next:cs:-1]
1692+
zright = a[rs_next:rs:-1, cs ]
16951693
z = np.concatenate((ztop, zleft, zbase, zright))
16961694
ps.append(z)
16971695

1698-
# The construction leaves the array with duplicate points, which
1699-
# are removed here.
1700-
ps = list(zip(*ps))
1701-
lastp = np.array([])
1702-
ps2 = [ps[0]] + [ps[i] for i in xrange(1, len(ps)) if ps[i] != ps[i-1]]
1703-
avgzsum = sum(p[2] for p in ps2)
1704-
polys.append(ps2)
1696+
# ps = np.stack(ps, axis=-1)
1697+
ps = np.array(ps).T
1698+
polys.append(ps)
17051699

17061700
if fcolors is not None:
17071701
colset.append(fcolors[rs][cs])
1708-
else:
1709-
colset.append(avgzsum / len(ps2))
1710-
1711-
# Only need vectors to shade if no cmap
1712-
if cmap is None and shade:
1713-
i1, i2, i3 = 0, int(len(ps2)/3), int(2*len(ps2)/3)
1714-
v1[which_pt] = np.array(ps2[i1]) - np.array(ps2[i2])
1715-
v2[which_pt] = np.array(ps2[i2]) - np.array(ps2[i3])
1716-
which_pt += 1
1717-
if cmap is None and shade:
1718-
normals = np.cross(v1, v2)
1719-
else :
1720-
normals = []
17211702

17221703
polyc = art3d.Poly3DCollection(polys, *args, **kwargs)
17231704

1705+
if shade and cmap is None:
1706+
v1 = np.empty((len(polys), 3))
1707+
v2 = np.empty((len(polys), 3))
1708+
for poly_i, ps in enumerate(polys):
1709+
# pick three points around the polygon to find the normal at
1710+
# hard to vectorize, since len(ps) is different at the edges
1711+
i1, i2, i3 = 0, int(len(ps)/3), int(2*len(ps)/3)
1712+
v1[poly_i] = ps[i1] - ps[i2]
1713+
v2[poly_i] = ps[i2] - ps[i3]
1714+
normals = np.cross(v1, v2)
1715+
17241716
if fcolors is not None:
17251717
if shade:
17261718
colset = self._shade_colors(colset, normals)
17271719
polyc.set_facecolors(colset)
17281720
polyc.set_edgecolors(colset)
17291721
elif cmap:
1730-
colset = np.array(colset)
1731-
polyc.set_array(colset)
1722+
# doesn't vectorize because polys is jagged
1723+
avg_z = np.array([ps[:,2].mean() for ps in polys])
1724+
polyc.set_array(avg_z)
17321725
if vmin is not None or vmax is not None:
17331726
polyc.set_clim(vmin, vmax)
17341727
if norm is not None:

0 commit comments

Comments
 (0)