@@ -98,39 +98,41 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=1, color='k', cmap=None,
9898 line_kw = {}
9999 arrow_kw = dict (arrowstyle = arrowstyle , mutation_scale = 10 * arrowsize )
100100
101+ if not type (linewidth ) == np .ndarray :
102+ line_kw ['linewidth' ] = linewidth
103+ arrow_kw ['linewidth' ] = linewidth
104+
105+ if not type (color ) == np .ndarray :
106+ line_kw ['color' ] = color
107+ arrow_kw ['color' ] = color
108+
101109 for t in trajectories :
102110 tgx = np .array (t [0 ])
103111 tgy = np .array (t [1 ])
104-
105- if type (linewidth ) == np .ndarray :
106- line_kw ['linewidth' ] = interpgrid (linewidth , tgx , tgy )[:- 1 ]
107- arrow_kw ['linewidth' ] = line_kw ['linewidth' ][len (tgx ) / 2 ]
108- else :
109- line_kw ['linewidth' ] = linewidth
110- arrow_kw ['linewidth' ] = linewidth
111-
112- if type (color ) == np .ndarray :
113- line_kw ['color' ] = cmap (norm (interpgrid (color , tgx , tgy )[:- 1 ]))
114- arrow_kw ['color' ] = line_kw ['color' ][len (tgx ) / 2 ]
115- else :
116- line_kw ['color' ] = color
117- arrow_kw ['color' ] = color
118-
119112 # Rescale from grid-coordinates to data-coordinates.
120113 tx = np .array (t [0 ]) * grid .dx + grid .x_origin
121114 ty = np .array (t [1 ]) * grid .dy + grid .y_origin
122115
123116 points = np .transpose ([tx , ty ]).reshape (- 1 , 1 , 2 )
124117 segments = np .concatenate ([points [:- 1 ], points [1 :]], axis = 1 )
125118
126- lc = matplotlib .collections .LineCollection (segments , ** line_kw )
127- axes .add_collection (lc )
128-
129119 ## Add arrows half way along each trajectory.
130120 s = np .cumsum (np .sqrt (np .diff (tx )** 2 + np .diff (ty )** 2 ))
131121 n = np .searchsorted (s , s [- 1 ] / 2. )
132122 arrow_tail = (tx [n ], ty [n ])
133123 arrow_head = (np .mean (tx [n :n + 2 ]), np .mean (ty [n :n + 2 ]))
124+
125+ if type (linewidth ) == np .ndarray :
126+ line_kw ['linewidth' ] = interpgrid (linewidth , tgx , tgy )[:- 1 ]
127+ arrow_kw ['linewidth' ] = line_kw ['linewidth' ][n ]
128+
129+ if type (color ) == np .ndarray :
130+ line_kw ['color' ] = cmap (norm (interpgrid (color , tgx , tgy )[:- 1 ]))
131+ arrow_kw ['color' ] = line_kw ['color' ][n ]
132+
133+ lc = matplotlib .collections .LineCollection (segments , ** line_kw )
134+ axes .add_collection (lc )
135+
134136 p = mpp .FancyArrowPatch (arrow_tail , arrow_head , ** arrow_kw )
135137 axes .add_patch (p )
136138
0 commit comments