@@ -2636,15 +2636,12 @@ def calc_arrow(uvw, angle=15):
2636
2636
2637
2637
# first 6 arguments are X, Y, Z, U, V, W
2638
2638
input_args = args [:argi ]
2639
- # if any of the args are scalar, convert into list
2640
- input_args = [[k ] if isinstance (k , (int , float )) else k
2641
- for k in input_args ]
2642
2639
2643
2640
# extract the masks, if any
2644
2641
masks = [k .mask for k in input_args
2645
2642
if isinstance (k , np .ma .MaskedArray )]
2646
2643
# broadcast to match the shape
2647
- bcast = np .broadcast_arrays (* ( input_args + masks ) )
2644
+ bcast = np .broadcast_arrays (* input_args , * masks )
2648
2645
input_args = bcast [:argi ]
2649
2646
masks = bcast [argi :]
2650
2647
if masks :
@@ -2654,21 +2651,15 @@ def calc_arrow(uvw, angle=15):
2654
2651
input_args = [np .ma .array (k , mask = mask ).compressed ()
2655
2652
for k in input_args ]
2656
2653
else :
2657
- input_args = [k . flatten ( ) for k in input_args ]
2654
+ input_args = [np . ravel ( k ) for k in input_args ]
2658
2655
2659
2656
if any (len (v ) == 0 for v in input_args ):
2660
2657
# No quivers, so just make an empty collection and return early
2661
2658
linec = art3d .Line3DCollection ([], * args [argi :], ** kwargs )
2662
2659
self .add_collection (linec )
2663
2660
return linec
2664
2661
2665
- # Following assertions must be true before proceeding
2666
- # must all be ndarray
2667
- assert all (isinstance (k , np .ndarray ) for k in input_args )
2668
- # must all in same shape
2669
- assert len ({k .shape for k in input_args }) == 1
2670
-
2671
- shaft_dt = np .linspace (0 , length , num = 2 )
2662
+ shaft_dt = np .array ([0 , length ])
2672
2663
arrow_dt = shaft_dt * arrow_length_ratio
2673
2664
2674
2665
if pivot == 'tail' :
0 commit comments