3434from . import art3d
3535from . import proj3d
3636from . import axis3d
37- from mpl_toolkits .mplot3d .art3d import Line3DCollection
38-
3937
4038def unit_bbox ():
4139 box = Bbox (np .array ([[0 , 0 ], [1 , 1 ]]))
@@ -2447,39 +2445,51 @@ def quiver(self, *args, **kwargs):
24472445 """
24482446 def calc_arrow (u , v , w , angle = 15 ):
24492447 """
2450- To calculate the arrow head.
2451-
2452- (u,v,w) expected to be normalized, recursive to fix A=0 scenario.
2448+ To calculate the arrow head vectors
24532449 """
2454- if u == 0 and v == 0 and w == 0 :
2455- raise ValueError ("u,v,w can't be all zero" )
2456-
2457- # angle should never greater than 20
2458- # safeguard from stack overflow
2450+ # should finish on first recursion
24592451 assert angle <= 20
24602452
2461- t = math .cos (math .radians (angle ))
2453+ norm_uvw = np .linalg .norm ([u , v , w ], ord = 2 )
2454+ if norm_uvw == 0 :
2455+ raise ValueError ("u,v,w can't be all zero" )
2456+ u , v , w = np .array ([u , v , w ]) / norm_uvw
24622457
2463- #D = 1 - w ** 2
2464- #A = t ** 2 - w ** 2
2465- #B = t**2 * D * (1-t**2)
2466- #C = D * w
2467-
2468- t2 = t ** 2
2469- w2 = w ** 2
2470- D = 1 - w2
2458+ cos_angle = math .cos (math .radians (angle ))
24712459
2472- A = t2 - w2
2473- B = t2 * D * (1 - t2 )
2474- C = D * w
2460+ # using following formula to get the two solutions
2461+ # (C + math.sqrt(B)) / A
2462+ # (C - math.sqrt(B)) / A
2463+ A = cos_angle ** 2 - w ** 2
2464+ B = cos_angle ** 2 * (1 - w ** 2 ) * (1 - cos_angle ** 2 )
2465+ C = (1 - w ** 2 ) * w
24752466
24762467 if A == 0 :
2477- x1 , x2 = calc_arrow (u , v , w , angle = angle + 0.1 )
2468+ # to prevent the case of dividing by zero
2469+ # use a small enough delta such that
2470+ # A != 0 for one of
2471+ # angle
2472+ # or
2473+ # angle + delta
2474+ return calc_arrow (u , v , w , angle = angle + 0.1 )
24782475 else :
24792476 x1 = (C + math .sqrt (B )) / A
24802477 x2 = (C - math .sqrt (B )) / A
24812478
2482- return x1 , x2
2479+ # normalize the output vectors and convert to ndarray
2480+ norm_func = lambda k : np .array ([u , v , k ]) / np .linalg .norm ([u , v , k ], ord = 2 )
2481+
2482+ return norm_func (x1 ), norm_func (x2 )
2483+
2484+ def point_vector_to_line (point , vector , length ):
2485+ """
2486+ use a point and vector to generate lines
2487+ """
2488+ lines = []
2489+ for var in np .linspace (0 , length , num = 20 ):
2490+ lines .append (list (zip (* (point - var * vector ))))
2491+ lines = np .array (lines ).swapaxes (0 , 1 )
2492+ return lines .tolist ()
24832493
24842494 had_data = self .has_data ()
24852495
@@ -2488,92 +2498,60 @@ def calc_arrow(u, v, w, angle=15):
24882498 length = kwargs .pop ('length' , 1 )
24892499 # arrow length ratio to the shaft length
24902500 arrow_length_ratio = kwargs .pop ('arrow_length_ratio' , 0.3 )
2491- # zdir
2492- zdir = kwargs .pop ('zdir' , 'z' )
24932501
24942502 # handle args
2495- if len (args ) != 6 :
2503+ if len (args ) < 6 :
24962504 ValueError ('Wrong number of arguments' )
2497- # X, Y, Z, U, V, W
2498- coords = list (map (lambda k : np .array (k ) if not isinstance (k , np .ndarray ) else k , args ))
2505+ argi = 6
2506+ # first 6 arguments are X, Y, Z, U, V, W
2507+ input_args = args [:argi ]
2508+ # if any of the args are scalar, convert into list
2509+ input_args = [[k ] if isinstance (k , (int , float )) else k for k in input_args ]
2510+ # extract the masks, if any
2511+ masks = [k .mask for k in input_args if isinstance (k , np .ma .MaskedArray )]
2512+ # broadcast to match the shape
2513+ bcast = np .broadcast_arrays (* (input_args + masks ))
2514+ input_args = bcast [:argi ]
2515+ masks = bcast [argi :]
2516+ if masks :
2517+ # combine the masks into one
2518+ mask = reduce (lambda k , k1 : k or k1 , masks )
2519+ # put mask on and compress
2520+ input_args = [np .ma .array (k , mask = mask ).compressed () for k in input_args ]
2521+ else :
2522+ input_args = [k .flatten () for k in input_args ]
24992523
2500- shapes = set ([k .shape for k in coords ])
2501- if len (shapes ) != 1 :
2502- raise ValueError ("unmatched input array shape" )
2503- elif list (shapes )[0 ] == 0 :
2504- raise ValueError ("input are constants" )
2524+ points = input_args [:3 ]
2525+ vectors = input_args [3 :]
25052526
2506- # Below assertion must be true as a safe guard
2507- assert all ([isinstance (k , np .ndarray ) for k in coords ])
2527+ # Below assertions must be true before proceed
2528+ # must all be ndarray
2529+ assert all ([isinstance (k , np .ndarray ) for k in input_args ])
2530+ # must all in same shape
2531+ assert len (set ([k .shape for k in input_args ])) == 1
25082532
2509- coords = [k .flatten () for k in coords ]
2510- xs , ys , zs , us , vs , ws = coords
2511- lines = []
2533+ # get arrow vectors
2534+ arrow_vectors = list (map (calc_arrow , * [np .nditer (k ) for k in vectors ]))
2535+ # reshape into set of vectors
2536+ arrow_vectors = np .array (arrow_vectors ).reshape ((- 1 , 3 )).swapaxes (0 , 1 )
25122537
2513- # for each arrow
2514- for i in xrange (xs .shape [0 ]):
2515- # calculate body
2516- x = xs [i ]
2517- y = ys [i ]
2518- z = zs [i ]
2519- u = us [i ]
2520- v = vs [i ]
2521- w = ws [i ]
2522- if any ([k is np .ma .masked for k in [x , y , z , u , v , w ]]):
2523- continue
2524-
2525- # normalize
2526- norm = math .sqrt (u ** 2 + v ** 2 + w ** 2 )
2527- if norm == 0 :
2528- norm = 1
2529- u /= norm
2530- v /= norm
2531- w /= norm
2532-
2533- t = np .linspace (0 , length , num = 20 )
2534- lx = x - t * u
2535- ly = y - t * v
2536- lz = z - t * w
2537- line = list (zip (lx , ly , lz ))
2538- lines .append (line )
2539-
2540- # arrow one side
2541- ua1 = u
2542- va1 = v
2543- wa1 , wa2 = calc_arrow (u , v , w )
2544-
2545- # normalize arrowhead 1
2546- norm = math .sqrt (ua1 ** 2 + va1 ** 2 + wa1 ** 2 )
2547- if norm == 0 :
2548- norm = 1
2549- ua1_ = ua1 / norm
2550- va1_ = va1 / norm
2551- wa1_ = wa1 / norm
2552-
2553- # normalize arrowhead 2
2554- norm = math .sqrt (ua1 ** 2 + va1 ** 2 + wa2 ** 2 )
2555- if norm == 0 :
2556- norm = 1
2557- ua2_ = ua1 / norm
2558- va2_ = va1 / norm
2559- wa2_ = wa2 / norm
2560-
2561- t = np .linspace (0 , length * arrow_length_ratio , num = 20 )
2562- la1x = x - t * ua1_
2563- la1y = y - t * va1_
2564- la1z = z - t * wa1_
2565- la2x = x - t * ua2_
2566- la2y = y - t * va2_
2567- la2z = z - t * wa2_
2568-
2569- line = list (zip (la1x , la1y , la1z ))
2570- lines .append (line )
2571- line = list (zip (la2x , la2y , la2z ))
2572- lines .append (line )
2573-
2574- linec = Line3DCollection (lines , * args [6 :], ** kwargs )
2538+ lines = []
2539+ # construct the main lines
2540+ lines .extend (point_vector_to_line (np .array (points ),
2541+ np .array (vectors ),
2542+ length ))
2543+ # construct the arrow heads
2544+ lines .extend (point_vector_to_line (np .array (points ).repeat (2 , axis = 1 ),
2545+ arrow_vectors ,
2546+ length * arrow_length_ratio ))
2547+
2548+ # Line3DCollection to do the heavy lifting
2549+ linec = art3d .Line3DCollection (lines , * args [argi :], ** kwargs )
25752550 self .add_collection (linec )
25762551
2552+ # scale
2553+
2554+ xs , ys , zs = list (zip (* np .array (lines ).reshape (- 1 , 3 )))
25772555 self .auto_scale_xyz (xs , ys , zs , had_data )
25782556
25792557 return linec
0 commit comments