@@ -2439,10 +2439,11 @@ def quiver(self, *args, **kwargs):
2439
2439
Arguments:
2440
2440
2441
2441
*X*, *Y*, *Z*:
2442
- The x, y and z coordinates of the arrow locations
2442
+ The x, y and z coordinates of the arrow locations (default is
2443
+ tip of arrow; see *pivot* kwarg)
2443
2444
2444
2445
*U*, *V*, *W*:
2445
- The direction vector that the arrow is pointing
2446
+ The x, y and z components of the arrow vectors
2446
2447
2447
2448
The arguments could be array-like or scalars, so long as they
2448
2449
they can be broadcast together. The arguments can also be
@@ -2459,60 +2460,40 @@ def quiver(self, *args, **kwargs):
2459
2460
The ratio of the arrow head with respect to the quiver,
2460
2461
default to 0.3
2461
2462
2463
+ *pivot*: [ 'tail' | 'middle' | 'tip' ]
2464
+ The part of the arrow that is at the grid point; the arrow
2465
+ rotates about this point, hence the name *pivot*.
2466
+
2462
2467
Any additional keyword arguments are delegated to
2463
2468
:class:`~matplotlib.collections.LineCollection`
2464
2469
2465
2470
"""
2466
- def calc_arrow (u , v , w , angle = 15 ):
2471
+ def calc_arrow (uvw , angle = 15 ):
2467
2472
"""
2468
- To calculate the arrow head. (u, v, w) should be unit vector.
2473
+ To calculate the arrow head. uvw should be a unit vector.
2469
2474
"""
2470
-
2471
- # this part figures out the axis of rotation to use
2472
-
2473
- # use unit vector perpendicular to (u,v,w) when |w|=1, by default
2474
- x , y , z = 0 , 1 , 0
2475
-
2476
- # get the norm
2477
- norm = math .sqrt (v ** 2 + u ** 2 )
2478
- # normalize it if it is safe
2475
+ # get unit direction vector perpendicular to (u,v,w)
2476
+ norm = np .linalg .norm (uvw [:2 ])
2479
2477
if norm > 0 :
2480
- # get unit direction vector perpendicular to (u,v,w)
2481
- x , y = v / norm , - u / norm
2482
-
2483
- # this function takes an angle, and rotates the (u,v,w)
2484
- # angle degrees around (x,y,z)
2485
- def rotatefunction (angle ):
2486
- ra = math .radians (angle )
2487
- c = math .cos (ra )
2488
- s = math .sin (ra )
2489
-
2490
- # construct the rotation matrix
2491
- R = np .matrix ([[c + (x ** 2 )* (1 - c ), x * y * (1 - c )- z * s , x * z * (1 - c )+ y * s ],
2492
- [y * x * (1 - c )+ z * s , c + (y ** 2 )* (1 - c ), y * z * (1 - c )- x * s ],
2493
- [z * x * (1 - c )- y * s , z * y * (1 - c )+ x * s , c + (z ** 2 )* (1 - c )]])
2494
-
2495
- # construct the column vector for (u,v,w)
2496
- line = np .matrix ([[u ],[v ],[w ]])
2497
-
2498
- # use numpy to multiply them to get the rotated vector
2499
- rotatedline = R * line
2478
+ x = uvw [1 ] / norm
2479
+ y = - uvw [0 ] / norm
2480
+ else :
2481
+ x , y = 0 , 1
2500
2482
2501
- # return the rotated (u,v,w) from the computed matrix
2502
- return (rotatedline [0 ,0 ], rotatedline [1 ,0 ], rotatedline [2 ,0 ])
2483
+ # compute the two arrowhead direction unit vectors
2484
+ ra = math .radians (angle )
2485
+ c = math .cos (ra )
2486
+ s = math .sin (ra )
2503
2487
2504
- # compute and return the two arrowhead direction unit vectors
2505
- return rotatefunction (angle ), rotatefunction (- angle )
2488
+ # construct the rotation matrices
2489
+ Rpos = np .array ([[c + (x ** 2 )* (1 - c ), x * y * (1 - c ), y * s ],
2490
+ [y * x * (1 - c ), c + (y ** 2 )* (1 - c ), - x * s ],
2491
+ [- y * s , x * s , c ]])
2492
+ # opposite rotation negates everything but the diagonal
2493
+ Rneg = Rpos * (np .eye (3 )* 2 - 1 )
2506
2494
2507
- def point_vector_to_line (point , vector , length ):
2508
- """
2509
- use a point and vector to generate lines
2510
- """
2511
- lines = []
2512
- for var in np .linspace (0 , length , num = 2 ):
2513
- lines .append (list (zip (* (point - var * vector ))))
2514
- lines = np .array (lines ).swapaxes (0 , 1 )
2515
- return lines .tolist ()
2495
+ # multiply them to get the rotated vector
2496
+ return Rpos .dot (uvw ), Rneg .dot (uvw )
2516
2497
2517
2498
had_data = self .has_data ()
2518
2499
@@ -2521,6 +2502,8 @@ def point_vector_to_line(point, vector, length):
2521
2502
length = kwargs .pop ('length' , 1 )
2522
2503
# arrow length ratio to the shaft length
2523
2504
arrow_length_ratio = kwargs .pop ('arrow_length_ratio' , 0.3 )
2505
+ # pivot point
2506
+ pivot = kwargs .pop ('pivot' , 'tip' )
2524
2507
2525
2508
# handle args
2526
2509
argi = 6
@@ -2561,60 +2544,49 @@ def point_vector_to_line(point, vector, length):
2561
2544
# must all in same shape
2562
2545
assert len (set ([k .shape for k in input_args ])) == 1
2563
2546
2564
- xs , ys , zs , us , vs , ws = input_args [:argi ]
2565
- lines = []
2566
-
2567
- # for each arrow
2568
- for i in range (xs .shape [0 ]):
2569
- # calulate body
2570
- x = xs [i ]
2571
- y = ys [i ]
2572
- z = zs [i ]
2573
- u = us [i ]
2574
- v = vs [i ]
2575
- w = ws [i ]
2576
-
2577
- # (u,v,w) expected to be normalized, recursive to fix A=0 scenario.
2578
- if u == 0 and v == 0 and w == 0 :
2579
- # Just don't make a quiver for such a case.
2580
- continue
2581
-
2582
- # normalize
2583
- norm = math .sqrt (u ** 2 + v ** 2 + w ** 2 )
2584
- u /= norm
2585
- v /= norm
2586
- w /= norm
2587
-
2588
- # draw main line
2589
- t = np .linspace (0 , length , num = 20 )
2590
- lx = x - t * u
2591
- ly = y - t * v
2592
- lz = z - t * w
2593
- line = list (zip (lx , ly , lz ))
2594
- lines .append (line )
2595
-
2596
- d1 , d2 = calc_arrow (u , v , w )
2597
- ua1 , va1 , wa1 = d1 [0 ], d1 [1 ], d1 [2 ]
2598
- ua2 , va2 , wa2 = d2 [0 ], d2 [1 ], d2 [2 ]
2599
-
2600
- # TODO: num should probably get parameterized
2601
- t = np .linspace (0 , length * arrow_length_ratio , num = 20 )
2602
- la1x = x - t * ua1
2603
- la1y = y - t * va1
2604
- la1z = z - t * wa1
2605
- la2x = x - t * ua2
2606
- la2y = y - t * va2
2607
- la2z = z - t * wa2
2608
-
2609
- line = list (zip (la1x , la1y , la1z ))
2610
- lines .append (line )
2611
- line = list (zip (la2x , la2y , la2z ))
2612
- lines .append (line )
2547
+ # TODO: num should probably get parameterized
2548
+ shaft_dt = np .linspace (0 , length , num = 20 )
2549
+ arrow_dt = shaft_dt * arrow_length_ratio
2550
+
2551
+ if pivot == 'tail' :
2552
+ shaft_dt -= length
2553
+ elif pivot == 'middle' :
2554
+ shaft_dt -= length / 2.
2555
+ elif pivot != 'tip' :
2556
+ raise ValueError ('Invalid pivot argument: ' + str (pivot ))
2557
+
2558
+ XYZ = np .column_stack (input_args [:3 ])
2559
+ UVW = np .column_stack (input_args [3 :argi ]).astype (float )
2560
+
2561
+ # Normalize rows of UVW
2562
+ # Note: with numpy 1.9+, could use np.linalg.norm(UVW, axis=1)
2563
+ norm = np .sqrt (np .sum (UVW ** 2 , axis = 1 ))
2564
+
2565
+ # If any row of UVW is all zeros, don't make a quiver for it
2566
+ mask = norm > 1e-10
2567
+ XYZ = XYZ [mask ]
2568
+ UVW = UVW [mask ] / norm [mask ].reshape ((- 1 , 1 ))
2569
+
2570
+ if len (XYZ ) > 0 :
2571
+ # compute the shaft lines all at once with an outer product
2572
+ shafts = (XYZ - np .multiply .outer (shaft_dt , UVW )).swapaxes (0 , 1 )
2573
+ # compute head direction vectors, n heads by 2 sides by 3 dimensions
2574
+ head_dirs = np .array ([calc_arrow (d ) for d in UVW ])
2575
+ # compute all head lines at once, starting from where the shaft ends
2576
+ heads = shafts [:, :1 ] - np .multiply .outer (arrow_dt , head_dirs )
2577
+ # stack left and right head lines together
2578
+ heads .shape = (len (arrow_dt ), - 1 , 3 )
2579
+ # transpose to get a list of lines
2580
+ heads = heads .swapaxes (0 , 1 )
2581
+
2582
+ lines = list (shafts ) + list (heads )
2583
+ else :
2584
+ lines = []
2613
2585
2614
2586
linec = art3d .Line3DCollection (lines , * args [argi :], ** kwargs )
2615
2587
self .add_collection (linec )
2616
2588
2617
- self .auto_scale_xyz (xs , ys , zs , had_data )
2589
+ self .auto_scale_xyz (XYZ [:, 0 ], XYZ [:, 1 ], XYZ [:, 2 ] , had_data )
2618
2590
2619
2591
return linec
2620
2592
0 commit comments