Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 69d5ef4

Browse files
committed
Merge pull request matplotlib#3300 from WeatherGod/mplot3d/quiver_test_fix
Quiver3d fixes
2 parents 2094ee0 + 1d9af2d commit 69d5ef4

File tree

2 files changed

+16
-21
lines changed

2 files changed

+16
-21
lines changed

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2523,9 +2523,11 @@ def point_vector_to_line(point, vector, length):
25232523
arrow_length_ratio = kwargs.pop('arrow_length_ratio', 0.3)
25242524

25252525
# handle args
2526-
if len(args) < 6:
2527-
ValueError('Wrong number of arguments')
25282526
argi = 6
2527+
if len(args) < argi:
2528+
ValueError('Wrong number of arguments. Expected %d got %d' %
2529+
(argi, len(args)))
2530+
25292531
# first 6 arguments are X, Y, Z, U, V, W
25302532
input_args = args[:argi]
25312533
# if any of the args are scalar, convert into list
@@ -2549,24 +2551,17 @@ def point_vector_to_line(point, vector, length):
25492551

25502552
if any(len(v) == 0 for v in input_args):
25512553
# No quivers, so just make an empty collection and return early
2552-
linec = art3d.Line3DCollection([], *args[6:], **kwargs)
2554+
linec = art3d.Line3DCollection([], *args[argi:], **kwargs)
25532555
self.add_collection(linec)
25542556
return linec
25552557

2556-
points = input_args[:3]
2557-
vectors = input_args[3:]
2558-
2559-
# Below assertions must be true before proceed
2558+
# Following assertions must be true before proceeding
25602559
# must all be ndarray
25612560
assert all(isinstance(k, np.ndarray) for k in input_args)
25622561
# must all in same shape
25632562
assert len(set([k.shape for k in input_args])) == 1
25642563

2565-
# X, Y, Z, U, V, W
2566-
coords = (np.array(k) if not isinstance(k, np.ndarray) else k
2567-
for k in args)
2568-
coords = [k.flatten() for k in coords]
2569-
xs, ys, zs, us, vs, ws = coords
2564+
xs, ys, zs, us, vs, ws = input_args[:argi]
25702565
lines = []
25712566

25722567
# for each arrow
@@ -2578,12 +2573,11 @@ def point_vector_to_line(point, vector, length):
25782573
u = us[i]
25792574
v = vs[i]
25802575
w = ws[i]
2581-
if any(k is np.ma.masked for k in [x, y, z, u, v, w]):
2582-
continue
25832576

25842577
# (u,v,w) expected to be normalized, recursive to fix A=0 scenario.
25852578
if u == 0 and v == 0 and w == 0:
2586-
raise ValueError("u,v,w can't be all zero")
2579+
# Just don't make a quiver for such a case.
2580+
continue
25872581

25882582
# normalize
25892583
norm = math.sqrt(u ** 2 + v ** 2 + w ** 2)
@@ -2603,6 +2597,7 @@ def point_vector_to_line(point, vector, length):
26032597
ua1, va1, wa1 = d1[0], d1[1], d1[2]
26042598
ua2, va2, wa2 = d2[0], d2[1], d2[2]
26052599

2600+
# TODO: num should probably get parameterized
26062601
t = np.linspace(0, length * arrow_length_ratio, num=20)
26072602
la1x = x - t * ua1
26082603
la1y = y - t * va1
@@ -2616,7 +2611,7 @@ def point_vector_to_line(point, vector, length):
26162611
line = list(zip(la2x, la2y, la2z))
26172612
lines.append(line)
26182613

2619-
linec = art3d.Line3DCollection(lines, *args[6:], **kwargs)
2614+
linec = art3d.Line3DCollection(lines, *args[argi:], **kwargs)
26202615
self.add_collection(linec)
26212616

26222617
self.auto_scale_xyz(xs, ys, zs, had_data)

lib/mpl_toolkits/tests/test_mplot3d.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,7 @@ def test_quiver3d():
162162
fig = plt.figure()
163163
ax = fig.gca(projection='3d')
164164

165-
x, y, z = np.meshgrid(np.arange(-1, 1, 0.2), np.arange(-1, 1, 0.2),
166-
np.arange(-1, 1, 0.8))
165+
x, y, z = np.ogrid[-1:0.8:10j, -1:0.8:10j, -1:0.6:3j]
167166

168167
u = np.sin(np.pi * x) * np.cos(np.pi * y) * np.cos(np.pi * z)
169168
v = -np.cos(np.pi * x) * np.sin(np.pi * y) * np.cos(np.pi * z)
@@ -177,7 +176,7 @@ def test_quiver3d_empty():
177176
fig = plt.figure()
178177
ax = fig.gca(projection='3d')
179178

180-
x, y, z = np.meshgrid([], [], [])
179+
x, y, z = np.ogrid[-1:0.8:0j, -1:0.8:0j, -1:0.6:0j]
181180

182181
u = np.sin(np.pi * x) * np.cos(np.pi * y) * np.cos(np.pi * z)
183182
v = -np.cos(np.pi * x) * np.sin(np.pi * y) * np.cos(np.pi * z)
@@ -191,8 +190,9 @@ def test_quiver3d_masked():
191190
fig = plt.figure()
192191
ax = fig.gca(projection='3d')
193192

194-
x, y, z = np.meshgrid(np.arange(-1, 1, 0.2), np.arange(-1, 1, 0.2),
195-
np.arange(-1, 1, 0.8))
193+
# Using mgrid here instead of ogrid because masked_where doesn't
194+
# seem to like broadcasting very much...
195+
x, y, z = np.mgrid[-1:0.8:10j, -1:0.8:10j, -1:0.6:3j]
196196

197197
u = np.sin(np.pi * x) * np.cos(np.pi * y) * np.cos(np.pi * z)
198198
v = -np.cos(np.pi * x) * np.sin(np.pi * y) * np.cos(np.pi * z)

0 commit comments

Comments
 (0)