Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f484fb4

Browse files
committed
Further shorten quiver3d computation...
... by using einsum instead of juggling axis order. This has essentially no effect on performance.
1 parent 30282cf commit f484fb4

File tree

1 file changed

+9
-22
lines changed

1 file changed

+9
-22
lines changed

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2526,44 +2526,31 @@ def quiver(self, *args,
25262526
Any additional keyword arguments are delegated to
25272527
:class:`~matplotlib.collections.LineCollection`
25282528
"""
2529+
25292530
def calc_arrows(UVW, angle=15):
25302531
# get unit direction vector perpendicular to (u, v, w)
25312532
x = UVW[:, 0]
25322533
y = UVW[:, 1]
25332534
norm = np.linalg.norm(UVW[:, :2], axis=1)
25342535
x_p = np.divide(y, norm, where=norm != 0, out=np.zeros_like(x))
25352536
y_p = np.divide(-x, norm, where=norm != 0, out=np.ones_like(x))
2536-
25372537
# compute the two arrowhead direction unit vectors
25382538
ra = math.radians(angle)
25392539
c = math.cos(ra)
25402540
s = math.sin(ra)
2541-
2542-
# construct the rotation matrices
2541+
# construct the rotation matrices of shape (3, 3, n)
25432542
Rpos = np.array(
25442543
[[c + (x_p ** 2) * (1 - c), x_p * y_p * (1 - c), y_p * s],
25452544
[y_p * x_p * (1 - c), c + (y_p ** 2) * (1 - c), -x_p * s],
25462545
[-y_p * s, x_p * s, np.full_like(x_p, c)]])
2547-
Rpos = Rpos.transpose(2, 0, 1)
2548-
25492546
# opposite rotation negates all the sin terms
25502547
Rneg = Rpos.copy()
2551-
Rneg[:, [0, 1, 2, 2], [2, 2, 0, 1]] = \
2552-
-Rneg[:, [0, 1, 2, 2], [2, 2, 0, 1]]
2553-
2554-
# expand dimensions for batched matrix multiplication
2555-
UVW = np.expand_dims(UVW, axis=-1)
2556-
2557-
# multiply them to get the rotated vector
2558-
Rpos_vecs = np.matmul(Rpos, UVW)
2559-
Rneg_vecs = np.matmul(Rneg, UVW)
2560-
2561-
# transpose for concatenation
2562-
Rpos_vecs = Rpos_vecs.transpose(0, 2, 1)
2563-
Rneg_vecs = Rneg_vecs.transpose(0, 2, 1)
2564-
2565-
head_dirs = np.concatenate([Rpos_vecs, Rneg_vecs], axis=1)
2566-
2548+
Rneg[[0, 1, 2, 2], [2, 2, 0, 1]] *= -1
2549+
# Batch n (3, 3) x (3) matrix multiplications ((3, 3, n) x (n, 3)).
2550+
Rpos_vecs = np.einsum("ij...,...j->...i", Rpos, UVW)
2551+
Rneg_vecs = np.einsum("ij...,...j->...i", Rneg, UVW)
2552+
# Stack into (n, 2, 3) result.
2553+
head_dirs = np.stack([Rpos_vecs, Rneg_vecs], axis=1)
25672554
return head_dirs
25682555

25692556
had_data = self.has_data()
@@ -2630,7 +2617,7 @@ def calc_arrows(UVW, angle=15):
26302617
# compute all head lines at once, starting from the shaft ends
26312618
heads = shafts[:, :1] - np.multiply.outer(arrow_dt, head_dirs)
26322619
# stack left and right head lines together
2633-
heads.shape = (len(arrow_dt), -1, 3)
2620+
heads = heads.reshape((len(arrow_dt), -1, 3))
26342621
# transpose to get a list of lines
26352622
heads = heads.swapaxes(0, 1)
26362623

0 commit comments

Comments
 (0)