Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b884fc8

Browse files
committed
Fix selection of arrow's linewidth and color
1 parent 72abc5d commit b884fc8

2 files changed

Lines changed: 21 additions & 19 deletions

File tree

examples/pylab_examples/streamplot_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
ax1.streamplot(X, Y, U, V)
1414

1515
lw = 5*speed/speed.max()
16-
ax2.streamplot(X, Y, U, V, density=0.6, color=U, linewidth=lw)
16+
ax2.streamplot(X, Y, U, V, density=0.6, color='k', linewidth=lw)
1717

1818
plt.show()
1919

lib/matplotlib/streamplot.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -98,39 +98,41 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=1, color='k', cmap=None,
9898
line_kw = {}
9999
arrow_kw = dict(arrowstyle=arrowstyle, mutation_scale=10*arrowsize)
100100

101+
if not type(linewidth) == np.ndarray:
102+
line_kw['linewidth'] = linewidth
103+
arrow_kw['linewidth'] = linewidth
104+
105+
if not type(color) == np.ndarray:
106+
line_kw['color'] = color
107+
arrow_kw['color'] = color
108+
101109
for t in trajectories:
102110
tgx = np.array(t[0])
103111
tgy = np.array(t[1])
104-
105-
if type(linewidth) == np.ndarray:
106-
line_kw['linewidth'] = interpgrid(linewidth, tgx, tgy)[:-1]
107-
arrow_kw['linewidth'] = line_kw['linewidth'][len(tgx) / 2]
108-
else:
109-
line_kw['linewidth'] = linewidth
110-
arrow_kw['linewidth'] = linewidth
111-
112-
if type(color) == np.ndarray:
113-
line_kw['color'] = cmap(norm(interpgrid(color, tgx, tgy)[:-1]))
114-
arrow_kw['color'] = line_kw['color'][len(tgx) / 2]
115-
else:
116-
line_kw['color'] = color
117-
arrow_kw['color'] = color
118-
119112
# Rescale from grid-coordinates to data-coordinates.
120113
tx = np.array(t[0]) * grid.dx + grid.x_origin
121114
ty = np.array(t[1]) * grid.dy + grid.y_origin
122115

123116
points = np.transpose([tx, ty]).reshape(-1, 1, 2)
124117
segments = np.concatenate([points[:-1], points[1:]], axis=1)
125118

126-
lc = matplotlib.collections.LineCollection(segments, **line_kw)
127-
axes.add_collection(lc)
128-
129119
## Add arrows half way along each trajectory.
130120
s = np.cumsum(np.sqrt(np.diff(tx)**2 + np.diff(ty)**2))
131121
n = np.searchsorted(s, s[-1] / 2.)
132122
arrow_tail = (tx[n], ty[n])
133123
arrow_head = (np.mean(tx[n:n+2]), np.mean(ty[n:n+2]))
124+
125+
if type(linewidth) == np.ndarray:
126+
line_kw['linewidth'] = interpgrid(linewidth, tgx, tgy)[:-1]
127+
arrow_kw['linewidth'] = line_kw['linewidth'][n]
128+
129+
if type(color) == np.ndarray:
130+
line_kw['color'] = cmap(norm(interpgrid(color, tgx, tgy)[:-1]))
131+
arrow_kw['color'] = line_kw['color'][n]
132+
133+
lc = matplotlib.collections.LineCollection(segments, **line_kw)
134+
axes.add_collection(lc)
135+
134136
p = mpp.FancyArrowPatch(arrow_tail, arrow_head, **arrow_kw)
135137
axes.add_patch(p)
136138

0 commit comments

Comments
 (0)