Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 6e36226

Browse files
committed
Merge pull request #3735 from perimosocordiae/patch-1
ENH: add pivot kwarg to 3d quiver plot
2 parents a8e028d + b0902ec commit 6e36226

File tree

5 files changed

+102
-96
lines changed

5 files changed

+102
-96
lines changed

doc/users/whats_new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ Support for legend for PolyCollection and stackplot
4848
Added a `legend_handler` for :class:`~matplotlib.collections.PolyCollection` as well as a `labels` argument to
4949
:func:`~matplotlib.axes.Axes.stackplot`.
5050

51+
Support for alternate pivots in mplot3d quiver plot
52+
``````````````````````````````````````````````````
53+
Added a :code:`pivot` kwarg to :func:`~mpl_toolkits.mplot3d.Axes3D.quiver`
54+
that controls the pivot point around which the quiver line rotates. This also
55+
determines the placement of the arrow head along the quiver line.
5156

5257
.. _whats-new-1-4:
5358

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 68 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -2439,10 +2439,11 @@ def quiver(self, *args, **kwargs):
24392439
Arguments:
24402440
24412441
*X*, *Y*, *Z*:
2442-
The x, y and z coordinates of the arrow locations
2442+
The x, y and z coordinates of the arrow locations (default is
2443+
tip of arrow; see *pivot* kwarg)
24432444
24442445
*U*, *V*, *W*:
2445-
The direction vector that the arrow is pointing
2446+
The x, y and z components of the arrow vectors
24462447
24472448
The arguments could be array-like or scalars, so long as they
24482449
they can be broadcast together. The arguments can also be
@@ -2459,60 +2460,40 @@ def quiver(self, *args, **kwargs):
24592460
The ratio of the arrow head with respect to the quiver,
24602461
default to 0.3
24612462
2463+
*pivot*: [ 'tail' | 'middle' | 'tip' ]
2464+
The part of the arrow that is at the grid point; the arrow
2465+
rotates about this point, hence the name *pivot*.
2466+
24622467
Any additional keyword arguments are delegated to
24632468
:class:`~matplotlib.collections.LineCollection`
24642469
24652470
"""
2466-
def calc_arrow(u, v, w, angle=15):
2471+
def calc_arrow(uvw, angle=15):
24672472
"""
2468-
To calculate the arrow head. (u, v, w) should be unit vector.
2473+
To calculate the arrow head. uvw should be a unit vector.
24692474
"""
2470-
2471-
# this part figures out the axis of rotation to use
2472-
2473-
# use unit vector perpendicular to (u,v,w) when |w|=1, by default
2474-
x, y, z = 0, 1, 0
2475-
2476-
# get the norm
2477-
norm = math.sqrt(v**2 + u**2)
2478-
# normalize it if it is safe
2475+
# get unit direction vector perpendicular to (u,v,w)
2476+
norm = np.linalg.norm(uvw[:2])
24792477
if norm > 0:
2480-
# get unit direction vector perpendicular to (u,v,w)
2481-
x, y = v/norm, -u/norm
2482-
2483-
# this function takes an angle, and rotates the (u,v,w)
2484-
# angle degrees around (x,y,z)
2485-
def rotatefunction(angle):
2486-
ra = math.radians(angle)
2487-
c = math.cos(ra)
2488-
s = math.sin(ra)
2489-
2490-
# construct the rotation matrix
2491-
R = np.matrix([[c+(x**2)*(1-c), x*y*(1-c)-z*s, x*z*(1-c)+y*s],
2492-
[y*x*(1-c)+z*s, c+(y**2)*(1-c), y*z*(1-c)-x*s],
2493-
[z*x*(1-c)-y*s, z*y*(1-c)+x*s, c+(z**2)*(1-c)]])
2494-
2495-
# construct the column vector for (u,v,w)
2496-
line = np.matrix([[u],[v],[w]])
2497-
2498-
# use numpy to multiply them to get the rotated vector
2499-
rotatedline = R*line
2478+
x = uvw[1] / norm
2479+
y = -uvw[0] / norm
2480+
else:
2481+
x, y = 0, 1
25002482

2501-
# return the rotated (u,v,w) from the computed matrix
2502-
return (rotatedline[0,0], rotatedline[1,0], rotatedline[2,0])
2483+
# compute the two arrowhead direction unit vectors
2484+
ra = math.radians(angle)
2485+
c = math.cos(ra)
2486+
s = math.sin(ra)
25032487

2504-
# compute and return the two arrowhead direction unit vectors
2505-
return rotatefunction(angle), rotatefunction(-angle)
2488+
# construct the rotation matrices
2489+
Rpos = np.array([[c+(x**2)*(1-c), x*y*(1-c), y*s],
2490+
[y*x*(1-c), c+(y**2)*(1-c), -x*s],
2491+
[-y*s, x*s, c]])
2492+
# opposite rotation negates everything but the diagonal
2493+
Rneg = Rpos * (np.eye(3)*2 - 1)
25062494

2507-
def point_vector_to_line(point, vector, length):
2508-
"""
2509-
use a point and vector to generate lines
2510-
"""
2511-
lines = []
2512-
for var in np.linspace(0, length, num=2):
2513-
lines.append(list(zip(*(point - var * vector))))
2514-
lines = np.array(lines).swapaxes(0, 1)
2515-
return lines.tolist()
2495+
# multiply them to get the rotated vector
2496+
return Rpos.dot(uvw), Rneg.dot(uvw)
25162497

25172498
had_data = self.has_data()
25182499

@@ -2521,6 +2502,8 @@ def point_vector_to_line(point, vector, length):
25212502
length = kwargs.pop('length', 1)
25222503
# arrow length ratio to the shaft length
25232504
arrow_length_ratio = kwargs.pop('arrow_length_ratio', 0.3)
2505+
# pivot point
2506+
pivot = kwargs.pop('pivot', 'tip')
25242507

25252508
# handle args
25262509
argi = 6
@@ -2561,60 +2544,49 @@ def point_vector_to_line(point, vector, length):
25612544
# must all in same shape
25622545
assert len(set([k.shape for k in input_args])) == 1
25632546

2564-
xs, ys, zs, us, vs, ws = input_args[:argi]
2565-
lines = []
2566-
2567-
# for each arrow
2568-
for i in range(xs.shape[0]):
2569-
# calulate body
2570-
x = xs[i]
2571-
y = ys[i]
2572-
z = zs[i]
2573-
u = us[i]
2574-
v = vs[i]
2575-
w = ws[i]
2576-
2577-
# (u,v,w) expected to be normalized, recursive to fix A=0 scenario.
2578-
if u == 0 and v == 0 and w == 0:
2579-
# Just don't make a quiver for such a case.
2580-
continue
2581-
2582-
# normalize
2583-
norm = math.sqrt(u ** 2 + v ** 2 + w ** 2)
2584-
u /= norm
2585-
v /= norm
2586-
w /= norm
2587-
2588-
# draw main line
2589-
t = np.linspace(0, length, num=20)
2590-
lx = x - t * u
2591-
ly = y - t * v
2592-
lz = z - t * w
2593-
line = list(zip(lx, ly, lz))
2594-
lines.append(line)
2595-
2596-
d1, d2 = calc_arrow(u, v, w)
2597-
ua1, va1, wa1 = d1[0], d1[1], d1[2]
2598-
ua2, va2, wa2 = d2[0], d2[1], d2[2]
2599-
2600-
# TODO: num should probably get parameterized
2601-
t = np.linspace(0, length * arrow_length_ratio, num=20)
2602-
la1x = x - t * ua1
2603-
la1y = y - t * va1
2604-
la1z = z - t * wa1
2605-
la2x = x - t * ua2
2606-
la2y = y - t * va2
2607-
la2z = z - t * wa2
2608-
2609-
line = list(zip(la1x, la1y, la1z))
2610-
lines.append(line)
2611-
line = list(zip(la2x, la2y, la2z))
2612-
lines.append(line)
2547+
# TODO: num should probably get parameterized
2548+
shaft_dt = np.linspace(0, length, num=20)
2549+
arrow_dt = shaft_dt * arrow_length_ratio
2550+
2551+
if pivot == 'tail':
2552+
shaft_dt -= length
2553+
elif pivot == 'middle':
2554+
shaft_dt -= length/2.
2555+
elif pivot != 'tip':
2556+
raise ValueError('Invalid pivot argument: ' + str(pivot))
2557+
2558+
XYZ = np.column_stack(input_args[:3])
2559+
UVW = np.column_stack(input_args[3:argi]).astype(float)
2560+
2561+
# Normalize rows of UVW
2562+
# Note: with numpy 1.9+, could use np.linalg.norm(UVW, axis=1)
2563+
norm = np.sqrt(np.sum(UVW**2, axis=1))
2564+
2565+
# If any row of UVW is all zeros, don't make a quiver for it
2566+
mask = norm > 1e-10
2567+
XYZ = XYZ[mask]
2568+
UVW = UVW[mask] / norm[mask].reshape((-1, 1))
2569+
2570+
if len(XYZ) > 0:
2571+
# compute the shaft lines all at once with an outer product
2572+
shafts = (XYZ - np.multiply.outer(shaft_dt, UVW)).swapaxes(0, 1)
2573+
# compute head direction vectors, n heads by 2 sides by 3 dimensions
2574+
head_dirs = np.array([calc_arrow(d) for d in UVW])
2575+
# compute all head lines at once, starting from where the shaft ends
2576+
heads = shafts[:, :1] - np.multiply.outer(arrow_dt, head_dirs)
2577+
# stack left and right head lines together
2578+
heads.shape = (len(arrow_dt), -1, 3)
2579+
# transpose to get a list of lines
2580+
heads = heads.swapaxes(0, 1)
2581+
2582+
lines = list(shafts) + list(heads)
2583+
else:
2584+
lines = []
26132585

26142586
linec = art3d.Line3DCollection(lines, *args[argi:], **kwargs)
26152587
self.add_collection(linec)
26162588

2617-
self.auto_scale_xyz(xs, ys, zs, had_data)
2589+
self.auto_scale_xyz(XYZ[:, 0], XYZ[:, 1], XYZ[:, 2], had_data)
26182590

26192591
return linec
26202592

lib/mpl_toolkits/tests/test_mplot3d.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,35 @@ def test_quiver3d_masked():
203203

204204
ax.quiver(x, y, z, u, v, w, length=0.1)
205205

206+
@image_comparison(baseline_images=['quiver3d_pivot_middle'], remove_text=True,
207+
extensions=['png'])
208+
def test_quiver3d_pivot_middle():
209+
fig = plt.figure()
210+
ax = fig.gca(projection='3d')
211+
212+
x, y, z = np.ogrid[-1:0.8:10j, -1:0.8:10j, -1:0.6:3j]
213+
214+
u = np.sin(np.pi * x) * np.cos(np.pi * y) * np.cos(np.pi * z)
215+
v = -np.cos(np.pi * x) * np.sin(np.pi * y) * np.cos(np.pi * z)
216+
w = (np.sqrt(2.0 / 3.0) * np.cos(np.pi * x) * np.cos(np.pi * y) *
217+
np.sin(np.pi * z))
218+
219+
ax.quiver(x, y, z, u, v, w, length=0.1, pivot='middle')
220+
221+
@image_comparison(baseline_images=['quiver3d_pivot_tail'], remove_text=True,
222+
extensions=['png'])
223+
def test_quiver3d_pivot_tail():
224+
fig = plt.figure()
225+
ax = fig.gca(projection='3d')
226+
227+
x, y, z = np.ogrid[-1:0.8:10j, -1:0.8:10j, -1:0.6:3j]
228+
229+
u = np.sin(np.pi * x) * np.cos(np.pi * y) * np.cos(np.pi * z)
230+
v = -np.cos(np.pi * x) * np.sin(np.pi * y) * np.cos(np.pi * z)
231+
w = (np.sqrt(2.0 / 3.0) * np.cos(np.pi * x) * np.cos(np.pi * y) *
232+
np.sin(np.pi * z))
233+
234+
ax.quiver(x, y, z, u, v, w, length=0.1, pivot='tail')
206235

207236
if __name__ == '__main__':
208237
import nose

0 commit comments

Comments
 (0)