Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7a15095

Browse files
committed
vectorizing some trisurf functions for performance improvement
1 parent eaa2169 commit 7a15095

File tree

1 file changed

+43
-47
lines changed

1 file changed

+43
-47
lines changed

plotly/tools.py

Lines changed: 43 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,11 +1464,10 @@ def _find_intermediate_color(lowcolor, highcolor, intermed):
14641464
diff_1 = float(highcolor[1] - lowcolor[1])
14651465
diff_2 = float(highcolor[2] - lowcolor[2])
14661466

1467-
new_tuple = (lowcolor[0] + intermed*diff_0,
1468-
lowcolor[1] + intermed*diff_1,
1469-
lowcolor[2] + intermed*diff_2)
1470-
1471-
return new_tuple
1467+
inter_colors = np.array([lowcolor[0] + intermed * diff_0,
1468+
lowcolor[1] + intermed * diff_1,
1469+
lowcolor[2] + intermed * diff_2])
1470+
return inter_colors
14721471

14731472
@staticmethod
14741473
def _unconvert_from_RGB_255(colors):
@@ -1491,7 +1490,7 @@ def _unconvert_from_RGB_255(colors):
14911490
return un_rgb_colors
14921491

14931492
@staticmethod
1494-
def _map_z2color(zval, colormap, vmin, vmax):
1493+
def _map_z2color(zvals, colormap, vmin, vmax):
14951494
"""
14961495
Returns the color corresponding zval's place between vmin and vmax
14971496
@@ -1508,21 +1507,14 @@ def _map_z2color(zval, colormap, vmin, vmax):
15081507
"of vmax.")
15091508
# find distance t of zval from vmin to vmax where the distance
15101509
# is normalized to be between 0 and 1
1511-
t = (zval - vmin)/float((vmax - vmin))
1512-
t_color = FigureFactory._find_intermediate_color(colormap[0],
1513-
colormap[1],
1514-
t)
1515-
t_color = (t_color[0]*255.0, t_color[1]*255.0, t_color[2]*255.0)
1516-
labelled_color = 'rgb{}'.format(t_color)
1517-
1518-
return labelled_color
1519-
1520-
@staticmethod
1521-
def _tri_indices(simplices):
1522-
"""
1523-
Returns a triplet of lists containing simplex coordinates
1524-
"""
1525-
return ([triplet[c] for triplet in simplices] for c in range(3))
1510+
t = (zvals - vmin) / float((vmax - vmin))
1511+
t_colors = FigureFactory._find_intermediate_color(colormap[0],
1512+
colormap[1],
1513+
t)
1514+
t_colors = t_colors * 255.
1515+
labelled_colors = ['rgb(%s, %s, %s)' % (i, j, k)
1516+
for i, j, k in t_colors.T]
1517+
return labelled_colors
15261518

15271519
@staticmethod
15281520
def _trisurf(x, y, z, simplices, colormap=None, dist_func=None,
@@ -1539,11 +1531,11 @@ def _trisurf(x, y, z, simplices, colormap=None, dist_func=None,
15391531
points3D = np.vstack((x, y, z)).T
15401532

15411533
# vertices of the surface triangles
1542-
tri_vertices = list(map(lambda index: points3D[index], simplices))
1534+
tri_vertices = points3D[simplices]
15431535

15441536
if not dist_func:
15451537
# mean values of z-coordinates of triangle vertices
1546-
mean_dists = [np.mean(tri[:, 2]) for tri in tri_vertices]
1538+
mean_dists = tri_vertices[:, :, 2].mean(-1)
15471539
else:
15481540
# apply user inputted function to calculate
15491541
# custom coloring for triangle vertices
@@ -1559,38 +1551,43 @@ def _trisurf(x, y, z, simplices, colormap=None, dist_func=None,
15591551

15601552
min_mean_dists = np.min(mean_dists)
15611553
max_mean_dists = np.max(mean_dists)
1562-
facecolor = ([FigureFactory._map_z2color(zz, colormap, min_mean_dists,
1563-
max_mean_dists) for zz in mean_dists])
1564-
ii, jj, kk = FigureFactory._tri_indices(simplices)
1554+
facecolor = FigureFactory._map_z2color(mean_dists, colormap,
1555+
min_mean_dists, max_mean_dists)
1556+
ii, jj, kk = zip(*simplices)
15651557

15661558
triangles = graph_objs.Mesh3d(x=x, y=y, z=z, facecolor=facecolor,
15671559
i=ii, j=jj, k=kk, name='')
15681560

1569-
if plot_edges is None: # the triangle sides are not plotted
1561+
if plot_edges is not True: # the triangle sides are not plotted
15701562
return graph_objs.Data([triangles])
15711563

15721564
# define the lists x_edge, y_edge and z_edge, of x, y, resp z
15731565
# coordinates of edge end points for each triangle
15741566
# None separates data corresponding to two consecutive triangles
1575-
lists_coord = ([[[T[k % 3][c] for k in range(4)]+[None]
1576-
for T in tri_vertices] for c in range(3)])
1577-
if x_edge is None:
1578-
x_edge = []
1579-
for array in lists_coord[0]:
1580-
for item in array:
1581-
x_edge.append(item)
1582-
1583-
if y_edge is None:
1584-
y_edge = []
1585-
for array in lists_coord[1]:
1586-
for item in array:
1587-
y_edge.append(item)
1588-
1589-
if z_edge is None:
1590-
z_edge = []
1591-
for array in lists_coord[2]:
1592-
for item in array:
1593-
z_edge.append(item)
1567+
is_none = [ii is None for ii in [x_edge, y_edge, z_edge]]
1568+
if any(is_none):
1569+
if not all(is_none):
1570+
raise ValueError('If any (x_edge, y_edge, z_edge) is None,'
1571+
' all must be None')
1572+
else:
1573+
x_edge = []
1574+
y_edge = []
1575+
z_edge = []
1576+
1577+
# Pull indices we care about, then add a None column to separate tris
1578+
ixs_triangles = [0, 1, 2, 0]
1579+
pull_edges = tri_vertices[:, ixs_triangles, :]
1580+
x_edge_pull = np.hstack([pull_edges[:, :, 0],
1581+
np.tile(None, [pull_edges.shape[0], 1])])
1582+
y_edge_pull = np.hstack([pull_edges[:, :, 1],
1583+
np.tile(None, [pull_edges.shape[0], 1])])
1584+
z_edge_pull = np.hstack([pull_edges[:, :, 2],
1585+
np.tile(None, [pull_edges.shape[0], 1])])
1586+
1587+
# Now unravel the edges into a 1-d vector for plotting
1588+
x_edge = np.hstack([x_edge, x_edge_pull.reshape([1, -1])[0]])
1589+
y_edge = np.hstack([y_edge, y_edge_pull.reshape([1, -1])[0]])
1590+
z_edge = np.hstack([z_edge, z_edge_pull.reshape([1, -1])[0]])
15941591

15951592
# define the lines for plotting
15961593
lines = graph_objs.Scatter3d(
@@ -5621,4 +5618,3 @@ def make_table_annotations(self):
56215618
font=dict(color=font_color),
56225619
showarrow=False))
56235620
return annotations
5624-

0 commit comments

Comments
 (0)