Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 72819e9

Browse files
vectorizing
1 parent 18ba5d8 commit 72819e9

1 file changed

Lines changed: 48 additions & 61 deletions

File tree

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 48 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2468,46 +2468,32 @@ def quiver(self, *args, **kwargs):
24682468
:class:`~matplotlib.collections.LineCollection`
24692469
24702470
"""
2471-
def calc_arrow(u, v, w, angle=15):
2471+
def calc_arrow(uvw, angle=15):
24722472
"""
2473-
To calculate the arrow head. (u, v, w) should be unit vector.
2473+
To calculate the arrow head. uvw should be a unit vector.
24742474
"""
2475-
2476-
# this part figures out the axis of rotation to use
2477-
2478-
# use unit vector perpendicular to (u,v,w) when |w|=1, by default
2479-
x, y, z = 0, 1, 0
2480-
2481-
# get the norm
2482-
norm = math.sqrt(v**2 + u**2)
2483-
# normalize it if it is safe
2475+
# get unit direction vector perpendicular to (u,v,w)
2476+
norm = np.linalg.norm(uvw[:2])
24842477
if norm > 0:
2485-
# get unit direction vector perpendicular to (u,v,w)
2486-
x, y = v/norm, -u/norm
2487-
2488-
# this function takes an angle, and rotates the (u,v,w)
2489-
# angle degrees around (x,y,z)
2490-
def rotatefunction(angle):
2491-
ra = math.radians(angle)
2492-
c = math.cos(ra)
2493-
s = math.sin(ra)
2494-
2495-
# construct the rotation matrix
2496-
R = np.matrix([[c+(x**2)*(1-c), x*y*(1-c)-z*s, x*z*(1-c)+y*s],
2497-
[y*x*(1-c)+z*s, c+(y**2)*(1-c), y*z*(1-c)-x*s],
2498-
[z*x*(1-c)-y*s, z*y*(1-c)+x*s, c+(z**2)*(1-c)]])
2499-
2500-
# construct the column vector for (u,v,w)
2501-
line = np.matrix([[u],[v],[w]])
2478+
x = uvw[1] / norm
2479+
y = -uvw[0] / norm
2480+
else:
2481+
x, y = 0, 1
25022482

2503-
# use numpy to multiply them to get the rotated vector
2504-
rotatedline = R*line
2483+
# compute the two arrowhead direction unit vectors
2484+
ra = math.radians(angle)
2485+
c = math.cos(ra)
2486+
s = math.sin(ra)
25052487

2506-
# return the rotated (u,v,w) from the computed matrix
2507-
return (rotatedline[0,0], rotatedline[1,0], rotatedline[2,0])
2488+
# construct the rotation matrices
2489+
Rpos = np.array([[c+(x**2)*(1-c), x*y*(1-c), y*s],
2490+
[y*x*(1-c), c+(y**2)*(1-c), -x*s],
2491+
[-y*s, x*s, c]])
2492+
# opposite rotation negates everything but the diagonal
2493+
Rneg = Rpos * (np.eye(3)*2 - 1)
25082494

2509-
# compute and return the two arrowhead direction unit vectors
2510-
return rotatefunction(angle), rotatefunction(-angle)
2495+
# multiply them to get the rotated vector
2496+
return Rpos.dot(uvw), Rneg.dot(uvw)
25112497

25122498
had_data = self.has_data()
25132499

@@ -2558,9 +2544,6 @@ def rotatefunction(angle):
25582544
# must all in same shape
25592545
assert len(set([k.shape for k in input_args])) == 1
25602546

2561-
xs, ys, zs, us, vs, ws = input_args[:argi]
2562-
lines = []
2563-
25642547
# TODO: num should probably get parameterized
25652548
shaft_dt = np.linspace(0, length, num=20)
25662549
arrow_dt = shaft_dt * arrow_length_ratio
@@ -2572,34 +2555,38 @@ def rotatefunction(angle):
25722555
elif pivot != 'tip':
25732556
raise ValueError('Invalid pivot argument: ' + str(pivot))
25742557

2575-
# for each arrow
2576-
for i in range(xs.shape[0]):
2577-
# calulate body
2578-
xyz = np.array([xs[i], ys[i], zs[i]])
2579-
uvw = np.array([us[i], vs[i], ws[i]], dtype=np.float)
2580-
2581-
# (u,v,w) expected to be normalized, recursive to fix A=0 scenario.
2582-
if np.all(uvw==0):
2583-
# Just don't make a quiver for such a case.
2584-
continue
2585-
2586-
# normalize
2587-
uvw /= np.linalg.norm(uvw)
2588-
2589-
# draw main line
2590-
shaft = xyz - np.outer(shaft_dt, uvw)
2591-
2592-
# draw arrow head
2593-
d1, d2 = calc_arrow(*uvw)
2594-
arrow1 = shaft[0] - np.outer(arrow_dt, d1)
2595-
arrow2 = shaft[0] - np.outer(arrow_dt, d2)
2596-
2597-
lines.extend([shaft, arrow1, arrow2])
2558+
XYZ = np.column_stack(input_args[:3])
2559+
UVW = np.column_stack(input_args[3:argi]).astype(float)
2560+
2561+
# Normalize rows of UVW
2562+
# Note: with numpy 1.9+, could use np.linalg.norm(UVW, axis=1)
2563+
norm = np.sqrt(np.sum(UVW**2, axis=1))
2564+
2565+
# If any row of UVW is all zeros, don't make a quiver for it
2566+
mask = norm > 1e-10
2567+
XYZ = XYZ[mask]
2568+
UVW = UVW[mask] / norm[mask, np.newaxis]
2569+
2570+
if len(XYZ) > 0:
2571+
# compute the shaft lines all at once with an outer product
2572+
shafts = (XYZ - np.multiply.outer(shaft_dt, UVW)).swapaxes(0,1)
2573+
# compute head direction vectors, n heads by 2 sides by 3 dimensions
2574+
head_dirs = np.array([calc_arrow(d) for d in UVW])
2575+
# compute all head lines at once, starting from where the shaft ends
2576+
heads = shafts[:,:1] - np.multiply.outer(arrow_dt, head_dirs)
2577+
# stack left and right head lines together
2578+
heads.shape = (len(arrow_dt), -1, 3)
2579+
# transpose to get a list of lines
2580+
heads = heads.swapaxes(0,1)
2581+
2582+
lines = list(shafts) + list(heads)
2583+
else:
2584+
lines = []
25982585

25992586
linec = art3d.Line3DCollection(lines, *args[argi:], **kwargs)
26002587
self.add_collection(linec)
26012588

2602-
self.auto_scale_xyz(xs, ys, zs, had_data)
2589+
self.auto_scale_xyz(XYZ[:,0], XYZ[:,1], XYZ[:,2], had_data)
26032590

26042591
return linec
26052592

0 commit comments

Comments
 (0)