Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f4cca16

Browse files
committed
adjustments according to code-review
1 parent fb51972 commit f4cca16

5 files changed

Lines changed: 16275 additions & 16647 deletions

File tree

examples/mplot3d/quiver3d_demo.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
fig = plt.figure()
66
ax = fig.gca(projection='3d')
77

8-
x, y, z = np.meshgrid(np.arange(-0.8, 1, 0.2), np.arange(-0.8, 1, 0.2), np.arange(-0.8, 1, 0.8))
8+
x, y, z = np.meshgrid(np.arange(-0.8, 1, 0.2),
9+
np.arange(-0.8, 1, 0.2),
10+
np.arange(-0.8, 1, 0.8))
911

1012
u = np.sin(np.pi * x) * np.cos(np.pi * y) * np.cos(np.pi * z)
1113
v = -np.cos(np.pi * x) * np.sin(np.pi * y) * np.cos(np.pi * z)
12-
w = np.sqrt(2.0 / 3.0) * np.cos(np.pi * x) * np.cos(np.pi * y) * np.sin(np.pi * z)
14+
w = np.sqrt(2.0 / 3.0) * np.cos(np.pi * x) * np.cos(np.pi * y) * \
15+
np.sin(np.pi * z)
1316

1417
ax.quiver(x, y, z, u, v, w, length=0.1)
1518

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 79 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@
3434
from . import art3d
3535
from . import proj3d
3636
from . import axis3d
37-
from mpl_toolkits.mplot3d.art3d import Line3DCollection
38-
3937

4038
def unit_bbox():
4139
box = Bbox(np.array([[0, 0], [1, 1]]))
@@ -2447,39 +2445,51 @@ def quiver(self, *args, **kwargs):
24472445
"""
24482446
def calc_arrow(u, v, w, angle=15):
24492447
"""
2450-
To calculate the arrow head.
2451-
2452-
(u,v,w) expected to be normalized, recursive to fix A=0 scenario.
2448+
To calculate the arrow head vectors
24532449
"""
2454-
if u == 0 and v == 0 and w == 0:
2455-
raise ValueError("u,v,w can't be all zero")
2456-
2457-
# angle should never greater than 20
2458-
# safeguard from stack overflow
2450+
# should finish on first recursion
24592451
assert angle <= 20
24602452

2461-
t = math.cos(math.radians(angle))
2453+
norm_uvw = np.linalg.norm([u, v, w], ord=2)
2454+
if norm_uvw == 0:
2455+
raise ValueError("u,v,w can't be all zero")
2456+
u, v, w = np.array([u, v, w]) / norm_uvw
24622457

2463-
#D = 1 - w ** 2
2464-
#A = t ** 2 - w ** 2
2465-
#B = t**2 * D * (1-t**2)
2466-
#C = D * w
2467-
2468-
t2 = t**2
2469-
w2 = w**2
2470-
D = 1 - w2
2458+
cos_angle = math.cos(math.radians(angle))
24712459

2472-
A = t2 - w2
2473-
B = t2 * D * (1 - t2)
2474-
C = D * w
2460+
# using following formula to get the two solutions
2461+
# (C + math.sqrt(B)) / A
2462+
# (C - math.sqrt(B)) / A
2463+
A = cos_angle ** 2 - w ** 2
2464+
B = cos_angle ** 2 * (1 - w ** 2) * (1 - cos_angle ** 2)
2465+
C = (1 - w ** 2) * w
24752466

24762467
if A == 0:
2477-
x1, x2 = calc_arrow(u, v, w, angle=angle + 0.1)
2468+
# to prevent the case of dividing by zero
2469+
# use a small enough delta such that
2470+
# A != 0 for one of
2471+
# angle
2472+
# or
2473+
# angle + delta
2474+
return calc_arrow(u, v, w, angle=angle + 0.1)
24782475
else:
24792476
x1 = (C + math.sqrt(B)) / A
24802477
x2 = (C - math.sqrt(B)) / A
24812478

2482-
return x1, x2
2479+
# normalize the output vectors and convert to ndarray
2480+
norm_func = lambda k: np.array([u, v, k]) / np.linalg.norm([u, v, k], ord=2)
2481+
2482+
return norm_func(x1), norm_func(x2)
2483+
2484+
def point_vector_to_line(point, vector, length):
2485+
"""
2486+
use a point and vector to generate lines
2487+
"""
2488+
lines = []
2489+
for var in np.linspace(0, length, num=20):
2490+
lines.append(list(zip(*(point - var * vector))))
2491+
lines = np.array(lines).swapaxes(0, 1)
2492+
return lines.tolist()
24832493

24842494
had_data = self.has_data()
24852495

@@ -2488,92 +2498,60 @@ def calc_arrow(u, v, w, angle=15):
24882498
length = kwargs.pop('length', 1)
24892499
# arrow length ratio to the shaft length
24902500
arrow_length_ratio = kwargs.pop('arrow_length_ratio', 0.3)
2491-
# zdir
2492-
zdir = kwargs.pop('zdir', 'z')
24932501

24942502
# handle args
2495-
if len(args) != 6:
2503+
if len(args) < 6:
24962504
ValueError('Wrong number of arguments')
2497-
# X, Y, Z, U, V, W
2498-
coords = list(map(lambda k: np.array(k) if not isinstance(k, np.ndarray) else k, args))
2505+
argi = 6
2506+
# first 6 arguments are X, Y, Z, U, V, W
2507+
input_args = args[:argi]
2508+
# if any of the args are scalar, convert into list
2509+
input_args = [[k] if isinstance(k, (int, float)) else k for k in input_args]
2510+
# extract the masks, if any
2511+
masks = [k.mask for k in input_args if isinstance(k, np.ma.MaskedArray)]
2512+
# broadcast to match the shape
2513+
bcast = np.broadcast_arrays(*(input_args + masks))
2514+
input_args = bcast[:argi]
2515+
masks = bcast[argi:]
2516+
if masks:
2517+
# combine the masks into one
2518+
mask = reduce(lambda k, k1: k or k1, masks)
2519+
# put mask on and compress
2520+
input_args = [np.ma.array(k, mask=mask).compressed() for k in input_args]
2521+
else:
2522+
input_args = [k.flatten() for k in input_args]
24992523

2500-
shapes = set([k.shape for k in coords])
2501-
if len(shapes) != 1:
2502-
raise ValueError("unmatched input array shape")
2503-
elif list(shapes)[0] == 0:
2504-
raise ValueError("input are constants")
2524+
points = input_args[:3]
2525+
vectors = input_args[3:]
25052526

2506-
# Below assertion must be true as a safe guard
2507-
assert all([isinstance(k, np.ndarray) for k in coords])
2527+
# Below assertions must be true before proceed
2528+
# must all be ndarray
2529+
assert all([isinstance(k, np.ndarray) for k in input_args])
2530+
# must all in same shape
2531+
assert len(set([k.shape for k in input_args])) == 1
25082532

2509-
coords = [k.flatten() for k in coords]
2510-
xs, ys, zs, us, vs, ws = coords
2511-
lines = []
2533+
# get arrow vectors
2534+
arrow_vectors = list(map(calc_arrow, *[np.nditer(k) for k in vectors]))
2535+
# reshape into set of vectors
2536+
arrow_vectors = np.array(arrow_vectors).reshape((-1, 3)).swapaxes(0, 1)
25122537

2513-
# for each arrow
2514-
for i in xrange(xs.shape[0]):
2515-
# calculate body
2516-
x = xs[i]
2517-
y = ys[i]
2518-
z = zs[i]
2519-
u = us[i]
2520-
v = vs[i]
2521-
w = ws[i]
2522-
if any([k is np.ma.masked for k in [x, y, z, u, v, w]]):
2523-
continue
2524-
2525-
# normalize
2526-
norm = math.sqrt(u ** 2 + v ** 2 + w ** 2)
2527-
if norm == 0:
2528-
norm = 1
2529-
u /= norm
2530-
v /= norm
2531-
w /= norm
2532-
2533-
t = np.linspace(0, length, num=20)
2534-
lx = x - t * u
2535-
ly = y - t * v
2536-
lz = z - t * w
2537-
line = list(zip(lx, ly, lz))
2538-
lines.append(line)
2539-
2540-
# arrow one side
2541-
ua1 = u
2542-
va1 = v
2543-
wa1, wa2 = calc_arrow(u, v, w)
2544-
2545-
# normalize arrowhead 1
2546-
norm = math.sqrt(ua1 ** 2 + va1 ** 2 + wa1 ** 2)
2547-
if norm == 0:
2548-
norm = 1
2549-
ua1_ = ua1/norm
2550-
va1_ = va1/norm
2551-
wa1_ = wa1/norm
2552-
2553-
# normalize arrowhead 2
2554-
norm = math.sqrt(ua1 ** 2 + va1 ** 2 + wa2 ** 2)
2555-
if norm == 0:
2556-
norm = 1
2557-
ua2_ = ua1/norm
2558-
va2_ = va1/norm
2559-
wa2_ = wa2/norm
2560-
2561-
t = np.linspace(0, length * arrow_length_ratio, num=20)
2562-
la1x = x - t * ua1_
2563-
la1y = y - t * va1_
2564-
la1z = z - t * wa1_
2565-
la2x = x - t * ua2_
2566-
la2y = y - t * va2_
2567-
la2z = z - t * wa2_
2568-
2569-
line = list(zip(la1x, la1y, la1z))
2570-
lines.append(line)
2571-
line = list(zip(la2x, la2y, la2z))
2572-
lines.append(line)
2573-
2574-
linec = Line3DCollection(lines, *args[6:], **kwargs)
2538+
lines = []
2539+
# construct the main lines
2540+
lines.extend(point_vector_to_line(np.array(points),
2541+
np.array(vectors),
2542+
length))
2543+
# construct the arrow heads
2544+
lines.extend(point_vector_to_line(np.array(points).repeat(2, axis=1),
2545+
arrow_vectors,
2546+
length * arrow_length_ratio))
2547+
2548+
# Line3DCollection to do the heavy lifting
2549+
linec = art3d.Line3DCollection(lines, *args[argi:], **kwargs)
25752550
self.add_collection(linec)
25762551

2552+
# scale
2553+
2554+
xs, ys, zs = list(zip(*np.array(lines).reshape(-1, 3)))
25772555
self.auto_scale_xyz(xs, ys, zs, had_data)
25782556

25792557
return linec
-8.27 KB
Binary file not shown.
-36.2 KB
Loading

0 commit comments

Comments
 (0)