@@ -2526,44 +2526,31 @@ def quiver(self, *args,
2526
2526
Any additional keyword arguments are delegated to
2527
2527
:class:`~matplotlib.collections.LineCollection`
2528
2528
"""
2529
+
2529
2530
def calc_arrows (UVW , angle = 15 ):
2530
2531
# get unit direction vector perpendicular to (u, v, w)
2531
2532
x = UVW [:, 0 ]
2532
2533
y = UVW [:, 1 ]
2533
2534
norm = np .linalg .norm (UVW [:, :2 ], axis = 1 )
2534
2535
x_p = np .divide (y , norm , where = norm != 0 , out = np .zeros_like (x ))
2535
2536
y_p = np .divide (- x , norm , where = norm != 0 , out = np .ones_like (x ))
2536
-
2537
2537
# compute the two arrowhead direction unit vectors
2538
2538
ra = math .radians (angle )
2539
2539
c = math .cos (ra )
2540
2540
s = math .sin (ra )
2541
-
2542
- # construct the rotation matrices
2541
+ # construct the rotation matrices of shape (3, 3, n)
2543
2542
Rpos = np .array (
2544
2543
[[c + (x_p ** 2 ) * (1 - c ), x_p * y_p * (1 - c ), y_p * s ],
2545
2544
[y_p * x_p * (1 - c ), c + (y_p ** 2 ) * (1 - c ), - x_p * s ],
2546
2545
[- y_p * s , x_p * s , np .full_like (x_p , c )]])
2547
- Rpos = Rpos .transpose (2 , 0 , 1 )
2548
-
2549
2546
# opposite rotation negates all the sin terms
2550
2547
Rneg = Rpos .copy ()
2551
- Rneg [:, [0 , 1 , 2 , 2 ], [2 , 2 , 0 , 1 ]] = \
2552
- - Rneg [:, [0 , 1 , 2 , 2 ], [2 , 2 , 0 , 1 ]]
2553
-
2554
- # expand dimensions for batched matrix multiplication
2555
- UVW = np .expand_dims (UVW , axis = - 1 )
2556
-
2557
- # multiply them to get the rotated vector
2558
- Rpos_vecs = np .matmul (Rpos , UVW )
2559
- Rneg_vecs = np .matmul (Rneg , UVW )
2560
-
2561
- # transpose for concatenation
2562
- Rpos_vecs = Rpos_vecs .transpose (0 , 2 , 1 )
2563
- Rneg_vecs = Rneg_vecs .transpose (0 , 2 , 1 )
2564
-
2565
- head_dirs = np .concatenate ([Rpos_vecs , Rneg_vecs ], axis = 1 )
2566
-
2548
+ Rneg [[0 , 1 , 2 , 2 ], [2 , 2 , 0 , 1 ]] *= - 1
2549
+ # Batch n (3, 3) x (3) matrix multiplications ((3, 3, n) x (n, 3)).
2550
+ Rpos_vecs = np .einsum ("ij...,...j->...i" , Rpos , UVW )
2551
+ Rneg_vecs = np .einsum ("ij...,...j->...i" , Rneg , UVW )
2552
+ # Stack into (n, 2, 3) result.
2553
+ head_dirs = np .stack ([Rpos_vecs , Rneg_vecs ], axis = 1 )
2567
2554
return head_dirs
2568
2555
2569
2556
had_data = self .has_data ()
@@ -2630,7 +2617,7 @@ def calc_arrows(UVW, angle=15):
2630
2617
# compute all head lines at once, starting from the shaft ends
2631
2618
heads = shafts [:, :1 ] - np .multiply .outer (arrow_dt , head_dirs )
2632
2619
# stack left and right head lines together
2633
- heads . shape = ( len (arrow_dt ), - 1 , 3 )
2620
+ heads = heads . reshape (( len (arrow_dt ), - 1 , 3 ) )
2634
2621
# transpose to get a list of lines
2635
2622
heads = heads .swapaxes (0 , 1 )
2636
2623
0 commit comments