99import matplotlib .collections as mcollections
1010import matplotlib .patches as patches
1111
12+
1213__all__ = ['streamplot' ]
1314
1415
@@ -46,11 +47,16 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
4647 *minlength* : float
4748 Minimum length of streamline in axes coordinates.
4849
49- Returns *streamlines* : :class:`~matplotlib.collections.LineCollection`
50- Line collection with all streamlines as a series of line segments.
51- Currently, there is no way to differentiate between line segments
52- on different streamlines (other than manually checking that segments
53- are connected).
50+ Returns
51+ -------
52+ *stream_container* : StreamplotSet
53+ Container object with attributes
54+ lines : `matplotlib.collections.LineCollection` of streamlines
55+ arrows : collection of `matplotlib.patches.FancyArrowPatch` objects
56+ repesenting arrows half-way along stream lines.
57+ This container will probably change in the future to allow changes to
58+ the colormap, alpha, etc. for both lines and arrows, but these changes
59+ should be backward compatible.
5460 """
5561 grid = Grid (x , y )
5662 mask = StreamMask (density )
@@ -108,6 +114,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
108114 cmap = cm .get_cmap (cmap )
109115
110116 streamlines = []
117+ arrows = []
111118 for t in trajectories :
112119 tgx = np .array (t [0 ])
113120 tgy = np .array (t [1 ])
@@ -139,6 +146,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
139146 transform = transform ,
140147 ** arrow_kw )
141148 axes .add_patch (p )
149+ arrows .append (p )
142150
143151 lc = mcollections .LineCollection (streamlines ,
144152 transform = transform ,
@@ -151,7 +159,17 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
151159
152160 axes .update_datalim (((x .min (), y .min ()), (x .max (), y .max ())))
153161 axes .autoscale_view (tight = True )
154- return lc
162+
163+ ac = matplotlib .collections .PatchCollection (arrows )
164+ stream_container = StreamplotSet (lc , ac )
165+ return stream_container
166+
167+
168+ class StreamplotSet (object ):
169+
170+ def __init__ (self , lines , arrows , ** kwargs ):
171+ self .lines = lines
172+ self .arrows = arrows
155173
156174
157175# Coordinate definitions
0 commit comments