Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c4ad02b

Browse files
committed
Simplifications to quiver3d.
- No need to convert input_args to lists: broadcast_arrays will handle the broadcasting just fine on scalars, and later using np.ravel() will ensure that the result is at least 1D. - The assertions don't really add much (and are clearly always true based on the code just above). - np.linspace(x, y, 2) is a pretty obfuscated way to write np.array([x, y]).
1 parent da58453 commit c4ad02b

File tree

1 file changed

+3
-12
lines changed

1 file changed

+3
-12
lines changed

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2637,15 +2637,12 @@ def calc_arrow(uvw, angle=15):
26372637

26382638
# first 6 arguments are X, Y, Z, U, V, W
26392639
input_args = args[:argi]
2640-
# if any of the args are scalar, convert into list
2641-
input_args = [[k] if isinstance(k, (int, float)) else k
2642-
for k in input_args]
26432640

26442641
# extract the masks, if any
26452642
masks = [k.mask for k in input_args
26462643
if isinstance(k, np.ma.MaskedArray)]
26472644
# broadcast to match the shape
2648-
bcast = np.broadcast_arrays(*(input_args + masks))
2645+
bcast = np.broadcast_arrays(*input_args, *masks)
26492646
input_args = bcast[:argi]
26502647
masks = bcast[argi:]
26512648
if masks:
@@ -2655,21 +2652,15 @@ def calc_arrow(uvw, angle=15):
26552652
input_args = [np.ma.array(k, mask=mask).compressed()
26562653
for k in input_args]
26572654
else:
2658-
input_args = [k.flatten() for k in input_args]
2655+
input_args = [np.ravel(k) for k in input_args]
26592656

26602657
if any(len(v) == 0 for v in input_args):
26612658
# No quivers, so just make an empty collection and return early
26622659
linec = art3d.Line3DCollection([], *args[argi:], **kwargs)
26632660
self.add_collection(linec)
26642661
return linec
26652662

2666-
# Following assertions must be true before proceeding
2667-
# must all be ndarray
2668-
assert all(isinstance(k, np.ndarray) for k in input_args)
2669-
# must all in same shape
2670-
assert len({k.shape for k in input_args}) == 1
2671-
2672-
shaft_dt = np.linspace(0, length, num=2)
2663+
shaft_dt = np.array([0, length])
26732664
arrow_dt = shaft_dt * arrow_length_ratio
26742665

26752666
if pivot == 'tail':

0 commit comments

Comments
 (0)