@@ -2468,46 +2468,32 @@ def quiver(self, *args, **kwargs):
24682468 :class:`~matplotlib.collections.LineCollection`
24692469
24702470 """
2471- def calc_arrow (u , v , w , angle = 15 ):
2471+ def calc_arrow (uvw , angle = 15 ):
24722472 """
2473- To calculate the arrow head. (u, v, w) should be unit vector.
2473+ To calculate the arrow head. uvw should be a unit vector.
24742474 """
2475-
2476- # this part figures out the axis of rotation to use
2477-
2478- # use unit vector perpendicular to (u,v,w) when |w|=1, by default
2479- x , y , z = 0 , 1 , 0
2480-
2481- # get the norm
2482- norm = math .sqrt (v ** 2 + u ** 2 )
2483- # normalize it if it is safe
2475+ # get unit direction vector perpendicular to (u,v,w)
2476+ norm = np .linalg .norm (uvw [:2 ])
24842477 if norm > 0 :
2485- # get unit direction vector perpendicular to (u,v,w)
2486- x , y = v / norm , - u / norm
2487-
2488- # this function takes an angle, and rotates the (u,v,w)
2489- # angle degrees around (x,y,z)
2490- def rotatefunction (angle ):
2491- ra = math .radians (angle )
2492- c = math .cos (ra )
2493- s = math .sin (ra )
2494-
2495- # construct the rotation matrix
2496- R = np .matrix ([[c + (x ** 2 )* (1 - c ), x * y * (1 - c )- z * s , x * z * (1 - c )+ y * s ],
2497- [y * x * (1 - c )+ z * s , c + (y ** 2 )* (1 - c ), y * z * (1 - c )- x * s ],
2498- [z * x * (1 - c )- y * s , z * y * (1 - c )+ x * s , c + (z ** 2 )* (1 - c )]])
2499-
2500- # construct the column vector for (u,v,w)
2501- line = np .matrix ([[u ],[v ],[w ]])
2478+ x = uvw [1 ] / norm
2479+ y = - uvw [0 ] / norm
2480+ else :
2481+ x , y = 0 , 1
25022482
2503- # use numpy to multiply them to get the rotated vector
2504- rotatedline = R * line
2483+ # compute the two arrowhead direction unit vectors
2484+ ra = math .radians (angle )
2485+ c = math .cos (ra )
2486+ s = math .sin (ra )
25052487
2506- # return the rotated (u,v,w) from the computed matrix
2507- return (rotatedline [0 ,0 ], rotatedline [1 ,0 ], rotatedline [2 ,0 ])
2488+ # construct the rotation matrices
2489+ Rpos = np .array ([[c + (x ** 2 )* (1 - c ), x * y * (1 - c ), y * s ],
2490+ [y * x * (1 - c ), c + (y ** 2 )* (1 - c ), - x * s ],
2491+ [- y * s , x * s , c ]])
2492+ # opposite rotation negates everything but the diagonal
2493+ Rneg = Rpos * (np .eye (3 )* 2 - 1 )
25082494
2509- # compute and return the two arrowhead direction unit vectors
2510- return rotatefunction ( angle ), rotatefunction ( - angle )
2495+ # multiply them to get the rotated vector
2496+ return Rpos . dot ( uvw ), Rneg . dot ( uvw )
25112497
25122498 had_data = self .has_data ()
25132499
@@ -2558,9 +2544,6 @@ def rotatefunction(angle):
25582544 # must all in same shape
25592545 assert len (set ([k .shape for k in input_args ])) == 1
25602546
2561- xs , ys , zs , us , vs , ws = input_args [:argi ]
2562- lines = []
2563-
25642547 # TODO: num should probably get parameterized
25652548 shaft_dt = np .linspace (0 , length , num = 20 )
25662549 arrow_dt = shaft_dt * arrow_length_ratio
@@ -2572,34 +2555,38 @@ def rotatefunction(angle):
25722555 elif pivot != 'tip' :
25732556 raise ValueError ('Invalid pivot argument: ' + str (pivot ))
25742557
2575- # for each arrow
2576- for i in range (xs .shape [0 ]):
2577- # calulate body
2578- xyz = np .array ([xs [i ], ys [i ], zs [i ]])
2579- uvw = np .array ([us [i ], vs [i ], ws [i ]], dtype = np .float )
2580-
2581- # (u,v,w) expected to be normalized, recursive to fix A=0 scenario.
2582- if np .all (uvw == 0 ):
2583- # Just don't make a quiver for such a case.
2584- continue
2585-
2586- # normalize
2587- uvw /= np .linalg .norm (uvw )
2588-
2589- # draw main line
2590- shaft = xyz - np .outer (shaft_dt , uvw )
2591-
2592- # draw arrow head
2593- d1 , d2 = calc_arrow (* uvw )
2594- arrow1 = shaft [0 ] - np .outer (arrow_dt , d1 )
2595- arrow2 = shaft [0 ] - np .outer (arrow_dt , d2 )
2596-
2597- lines .extend ([shaft , arrow1 , arrow2 ])
2558+ XYZ = np .column_stack (input_args [:3 ])
2559+ UVW = np .column_stack (input_args [3 :argi ]).astype (float )
2560+
2561+ # Normalize rows of UVW
2562+ # Note: with numpy 1.9+, could use np.linalg.norm(UVW, axis=1)
2563+ norm = np .sqrt (np .sum (UVW ** 2 , axis = 1 ))
2564+
2565+ # If any row of UVW is all zeros, don't make a quiver for it
2566+ mask = norm > 1e-10
2567+ XYZ = XYZ [mask ]
2568+ UVW = UVW [mask ] / norm [mask , np .newaxis ]
2569+
2570+ if len (XYZ ) > 0 :
2571+ # compute the shaft lines all at once with an outer product
2572+ shafts = (XYZ - np .multiply .outer (shaft_dt , UVW )).swapaxes (0 ,1 )
2573+ # compute head direction vectors, n heads by 2 sides by 3 dimensions
2574+ head_dirs = np .array ([calc_arrow (d ) for d in UVW ])
2575+ # compute all head lines at once, starting from where the shaft ends
2576+ heads = shafts [:,:1 ] - np .multiply .outer (arrow_dt , head_dirs )
2577+ # stack left and right head lines together
2578+ heads .shape = (len (arrow_dt ), - 1 , 3 )
2579+ # transpose to get a list of lines
2580+ heads = heads .swapaxes (0 ,1 )
2581+
2582+ lines = list (shafts ) + list (heads )
2583+ else :
2584+ lines = []
25982585
25992586 linec = art3d .Line3DCollection (lines , * args [argi :], ** kwargs )
26002587 self .add_collection (linec )
26012588
2602- self .auto_scale_xyz (xs , ys , zs , had_data )
2589+ self .auto_scale_xyz (XYZ [:, 0 ], XYZ [:, 1 ], XYZ [:, 2 ] , had_data )
26032590
26042591 return linec
26052592
0 commit comments