@@ -2439,10 +2439,11 @@ def quiver(self, *args, **kwargs):
24392439 Arguments:
24402440
24412441 *X*, *Y*, *Z*:
2442- The x, y and z coordinates of the arrow locations
2442+ The x, y and z coordinates of the arrow locations (default is
2443+ tip of arrow; see *pivot* kwarg)
24432444
24442445 *U*, *V*, *W*:
2445- The direction vector that the arrow is pointing
2446+ The x, y and z components of the arrow vectors
24462447
24472448 The arguments could be array-like or scalars, so long as they
24482449 they can be broadcast together. The arguments can also be
@@ -2459,60 +2460,40 @@ def quiver(self, *args, **kwargs):
24592460 The ratio of the arrow head with respect to the quiver,
24602461 default to 0.3
24612462
2463+ *pivot*: [ 'tail' | 'middle' | 'tip' ]
2464+ The part of the arrow that is at the grid point; the arrow
2465+ rotates about this point, hence the name *pivot*.
2466+
24622467 Any additional keyword arguments are delegated to
24632468 :class:`~matplotlib.collections.LineCollection`
24642469
24652470 """
2466- def calc_arrow (u , v , w , angle = 15 ):
2471+ def calc_arrow (uvw , angle = 15 ):
24672472 """
2468- To calculate the arrow head. (u, v, w) should be unit vector.
2473+ To calculate the arrow head. uvw should be a unit vector.
24692474 """
2470-
2471- # this part figures out the axis of rotation to use
2472-
2473- # use unit vector perpendicular to (u,v,w) when |w|=1, by default
2474- x , y , z = 0 , 1 , 0
2475-
2476- # get the norm
2477- norm = math .sqrt (v ** 2 + u ** 2 )
2478- # normalize it if it is safe
2475+ # get unit direction vector perpendicular to (u,v,w)
2476+ norm = np .linalg .norm (uvw [:2 ])
24792477 if norm > 0 :
2480- # get unit direction vector perpendicular to (u,v,w)
2481- x , y = v / norm , - u / norm
2482-
2483- # this function takes an angle, and rotates the (u,v,w)
2484- # angle degrees around (x,y,z)
2485- def rotatefunction (angle ):
2486- ra = math .radians (angle )
2487- c = math .cos (ra )
2488- s = math .sin (ra )
2489-
2490- # construct the rotation matrix
2491- R = np .matrix ([[c + (x ** 2 )* (1 - c ), x * y * (1 - c )- z * s , x * z * (1 - c )+ y * s ],
2492- [y * x * (1 - c )+ z * s , c + (y ** 2 )* (1 - c ), y * z * (1 - c )- x * s ],
2493- [z * x * (1 - c )- y * s , z * y * (1 - c )+ x * s , c + (z ** 2 )* (1 - c )]])
2494-
2495- # construct the column vector for (u,v,w)
2496- line = np .matrix ([[u ],[v ],[w ]])
2497-
2498- # use numpy to multiply them to get the rotated vector
2499- rotatedline = R * line
2478+ x = uvw [1 ] / norm
2479+ y = - uvw [0 ] / norm
2480+ else :
2481+ x , y = 0 , 1
25002482
2501- # return the rotated (u,v,w) from the computed matrix
2502- return (rotatedline [0 ,0 ], rotatedline [1 ,0 ], rotatedline [2 ,0 ])
2483+ # compute the two arrowhead direction unit vectors
2484+ ra = math .radians (angle )
2485+ c = math .cos (ra )
2486+ s = math .sin (ra )
25032487
2504- # compute and return the two arrowhead direction unit vectors
2505- return rotatefunction (angle ), rotatefunction (- angle )
2488+ # construct the rotation matrices
2489+ Rpos = np .array ([[c + (x ** 2 )* (1 - c ), x * y * (1 - c ), y * s ],
2490+ [y * x * (1 - c ), c + (y ** 2 )* (1 - c ), - x * s ],
2491+ [- y * s , x * s , c ]])
2492+ # opposite rotation negates everything but the diagonal
2493+ Rneg = Rpos * (np .eye (3 )* 2 - 1 )
25062494
2507- def point_vector_to_line (point , vector , length ):
2508- """
2509- use a point and vector to generate lines
2510- """
2511- lines = []
2512- for var in np .linspace (0 , length , num = 2 ):
2513- lines .append (list (zip (* (point - var * vector ))))
2514- lines = np .array (lines ).swapaxes (0 , 1 )
2515- return lines .tolist ()
2495+ # multiply them to get the rotated vector
2496+ return Rpos .dot (uvw ), Rneg .dot (uvw )
25162497
25172498 had_data = self .has_data ()
25182499
@@ -2521,6 +2502,8 @@ def point_vector_to_line(point, vector, length):
25212502 length = kwargs .pop ('length' , 1 )
25222503 # arrow length ratio to the shaft length
25232504 arrow_length_ratio = kwargs .pop ('arrow_length_ratio' , 0.3 )
2505+ # pivot point
2506+ pivot = kwargs .pop ('pivot' , 'tip' )
25242507
25252508 # handle args
25262509 argi = 6
@@ -2561,60 +2544,49 @@ def point_vector_to_line(point, vector, length):
25612544 # must all in same shape
25622545 assert len (set ([k .shape for k in input_args ])) == 1
25632546
2564- xs , ys , zs , us , vs , ws = input_args [:argi ]
2565- lines = []
2566-
2567- # for each arrow
2568- for i in range (xs .shape [0 ]):
2569- # calulate body
2570- x = xs [i ]
2571- y = ys [i ]
2572- z = zs [i ]
2573- u = us [i ]
2574- v = vs [i ]
2575- w = ws [i ]
2576-
2577- # (u,v,w) expected to be normalized, recursive to fix A=0 scenario.
2578- if u == 0 and v == 0 and w == 0 :
2579- # Just don't make a quiver for such a case.
2580- continue
2581-
2582- # normalize
2583- norm = math .sqrt (u ** 2 + v ** 2 + w ** 2 )
2584- u /= norm
2585- v /= norm
2586- w /= norm
2587-
2588- # draw main line
2589- t = np .linspace (0 , length , num = 20 )
2590- lx = x - t * u
2591- ly = y - t * v
2592- lz = z - t * w
2593- line = list (zip (lx , ly , lz ))
2594- lines .append (line )
2595-
2596- d1 , d2 = calc_arrow (u , v , w )
2597- ua1 , va1 , wa1 = d1 [0 ], d1 [1 ], d1 [2 ]
2598- ua2 , va2 , wa2 = d2 [0 ], d2 [1 ], d2 [2 ]
2599-
2600- # TODO: num should probably get parameterized
2601- t = np .linspace (0 , length * arrow_length_ratio , num = 20 )
2602- la1x = x - t * ua1
2603- la1y = y - t * va1
2604- la1z = z - t * wa1
2605- la2x = x - t * ua2
2606- la2y = y - t * va2
2607- la2z = z - t * wa2
2608-
2609- line = list (zip (la1x , la1y , la1z ))
2610- lines .append (line )
2611- line = list (zip (la2x , la2y , la2z ))
2612- lines .append (line )
2547+ # TODO: num should probably get parameterized
2548+ shaft_dt = np .linspace (0 , length , num = 20 )
2549+ arrow_dt = shaft_dt * arrow_length_ratio
2550+
2551+ if pivot == 'tail' :
2552+ shaft_dt -= length
2553+ elif pivot == 'middle' :
2554+ shaft_dt -= length / 2.
2555+ elif pivot != 'tip' :
2556+ raise ValueError ('Invalid pivot argument: ' + str (pivot ))
2557+
2558+ XYZ = np .column_stack (input_args [:3 ])
2559+ UVW = np .column_stack (input_args [3 :argi ]).astype (float )
2560+
2561+ # Normalize rows of UVW
2562+ # Note: with numpy 1.9+, could use np.linalg.norm(UVW, axis=1)
2563+ norm = np .sqrt (np .sum (UVW ** 2 , axis = 1 ))
2564+
2565+ # If any row of UVW is all zeros, don't make a quiver for it
2566+ mask = norm > 1e-10
2567+ XYZ = XYZ [mask ]
2568+ UVW = UVW [mask ] / norm [mask ].reshape ((- 1 , 1 ))
2569+
2570+ if len (XYZ ) > 0 :
2571+ # compute the shaft lines all at once with an outer product
2572+ shafts = (XYZ - np .multiply .outer (shaft_dt , UVW )).swapaxes (0 , 1 )
2573+ # compute head direction vectors, n heads by 2 sides by 3 dimensions
2574+ head_dirs = np .array ([calc_arrow (d ) for d in UVW ])
2575+ # compute all head lines at once, starting from where the shaft ends
2576+ heads = shafts [:, :1 ] - np .multiply .outer (arrow_dt , head_dirs )
2577+ # stack left and right head lines together
2578+ heads .shape = (len (arrow_dt ), - 1 , 3 )
2579+ # transpose to get a list of lines
2580+ heads = heads .swapaxes (0 , 1 )
2581+
2582+ lines = list (shafts ) + list (heads )
2583+ else :
2584+ lines = []
26132585
26142586 linec = art3d .Line3DCollection (lines , * args [argi :], ** kwargs )
26152587 self .add_collection (linec )
26162588
2617- self .auto_scale_xyz (xs , ys , zs , had_data )
2589+ self .auto_scale_xyz (XYZ [:, 0 ], XYZ [:, 1 ], XYZ [:, 2 ] , had_data )
26182590
26192591 return linec
26202592
0 commit comments