Description
I found a way to do this and I would like to submit a PR. I'm asking for input from a developer on whether this would be worthwhile and whether the approach I'm thinking of makes sense.
More details
I'm plotting a vector field with streamplot, and no matter how large I set the density, it interrupts streamlines at (seemingly random) points. My understanding is that streamplot uses a grid and only allows one streamline to pass through each cell of the grid. So if streamlines get too close to each other, one of them must end.
What I'd really like is that if any part of a given streamline is plotted, then the whole thing is plotted. That is, streamlines should continue in both directions until they leave the plot region. You might think that you could achieve this by setting a large value of the density
argument. But it seems that density
also determines how many starting points are selected for streamlines, so this doesn't help. I'm using Matplotlib 2.0.0 but I've also tried with a clone of the current dev version.
I've found a way to accomplish what I want, by simply changing the last line of this snippet (from the StreamMask class):
def _update_trajectory(self, xm, ym):
"""Update current trajectory position in mask.
If the new position has already been filled, raise `InvalidIndexError`.
"""
if self._current_xy != (xm, ym):
if self[ym, xm] == 0:
self._traj.append((ym, xm))
self._mask[ym, xm] = 1
self._current_xy = (xm, ym)
else:
raise InvalidIndexError
to
def _update_trajectory(self, xm, ym):
"""Update current trajectory position in mask.
If the new position has already been filled, raise `InvalidIndexError`.
"""
if self._current_xy != (xm, ym):
if self[ym, xm] == 0:
self._traj.append((ym, xm))
self._mask[ym, xm] = 1
self._current_xy = (xm, ym)
else:
pass
I'd like to submit a PR that allows the user to do this. My suggestion would be:
- Have a keyword argument
no_broken_streamlines
(or similar) that defaults to False. - If set to True, then streamplot will behave as with my patch above.
To reproduce the figure at the top, use this code:
import numpy as np
import matplotlib.pyplot as plt
N = 400
h, hu = np.meshgrid(np.linspace(0.1,6,N),np.linspace(-3,3,N))
g = 1.
u = hu/h
dh = np.ones_like(h)
dhu = u - np.sqrt(g*h)
plt.streamplot(h,hu,dh,dhu,density=1.,arrowstyle='-')
plt.axis('tight');