Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Add Dataset.plot.streamplot() method #5003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ Plotting

Dataset.plot.scatter
Dataset.plot.quiver
Dataset.plot.streamplot

DataArray
=========
Expand Down
20 changes: 20 additions & 0 deletions doc/user-guide/plotting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,26 @@ where ``u`` and ``v`` denote the x and y direction components of the arrow vecto

``scale`` is required for faceted quiver plots. The scale determines the number of data units per arrow length unit, i.e. a smaller scale parameter makes the arrow longer.

Streamplot
~~~~~~~~~~

Visualizing vector fields is also supported with streamline plots:

.. ipython:: python
:okwarning:

@savefig ds_simple_streamplot.png
ds.isel(w=1, z=1).plot.streamplot(x="x", y="y", u="A", v="B")


where ``u`` and ``v`` denote the x and y direction components of the vectors tangent to the streamlines. Again, faceting is also possible:

.. ipython:: python
:okwarning:

@savefig ds_facet_streamplot.png
ds.plot.streamplot(x="x", y="y", u="A", v="B", col="w", row="z")

.. _plot-maps:

Maps
Expand Down
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ New Features
- Support for `dask.graph_manipulation
<https://docs.dask.org/en/latest/graph_manipulation.html>`_ (requires dask >=2021.3)
By `Guido Imperiale <https://github.com/crusaderky>`_
- Add :py:meth:`Dataset.plot.streamplot` for streamplot plots with :py:class:`Dataset`
variables (:pull:`5003`).
By `John Omotani <https://github.com/johnomotani>`_.
- Many of the arguments for the :py:attr:`DataArray.str` methods now support
providing an array-like input. In this case, the array provided to the
arguments is broadcast against the original array and applied elementwise.
Expand Down
78 changes: 72 additions & 6 deletions xarray/plot/dataset_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname):
add_colorbar = False
add_legend = False
else:
if add_guide is True and funcname != "quiver":
if add_guide is True and funcname not in ("quiver", "streamplot"):
raise ValueError("Cannot set add_guide when hue is None.")
add_legend = False
add_colorbar = False
Expand All @@ -62,11 +62,23 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname):
hue_style = "continuous"
elif hue_style != "continuous":
raise ValueError(
"hue_style must be 'continuous' or None for .plot.quiver"
"hue_style must be 'continuous' or None for .plot.quiver or "
".plot.streamplot"
)
else:
add_quiverkey = False

if (add_guide or add_guide is None) and funcname == "streamplot":
if hue:
add_colorbar = True
if not hue_style:
hue_style = "continuous"
elif hue_style != "continuous":
raise ValueError(
"hue_style must be 'continuous' or None for .plot.quiver or "
".plot.streamplot"
)

if hue_style is not None and hue_style not in ["discrete", "continuous"]:
raise ValueError("hue_style must be either None, 'discrete' or 'continuous'.")

Expand Down Expand Up @@ -186,7 +198,7 @@ def _dsplot(plotfunc):
x, y : str
Variable names for x, y axis.
u, v : str, optional
Variable names for quiver plots
Variable names for quiver or streamplot plots
hue: str, optional
Variable by which to color scattered points
hue_style: str, optional
Expand Down Expand Up @@ -338,8 +350,11 @@ def newplotfunc(
else:
cmap_params_subset = {}

if (u is not None or v is not None) and plotfunc.__name__ != "quiver":
raise ValueError("u, v are only allowed for quiver plots.")
if (u is not None or v is not None) and plotfunc.__name__ not in (
"quiver",
"streamplot",
):
raise ValueError("u, v are only allowed for quiver or streamplot plots.")

primitive = plotfunc(
ds=ds,
Expand Down Expand Up @@ -383,7 +398,7 @@ def newplotfunc(
coordinates="figure",
)

if plotfunc.__name__ == "quiver":
if plotfunc.__name__ in ("quiver", "streamplot"):
title = ds[u]._title_for_slice()
else:
title = ds[x]._title_for_slice()
Expand Down Expand Up @@ -526,3 +541,54 @@ def quiver(ds, x, y, ax, u, v, **kwargs):
kwargs.setdefault("pivot", "middle")
hdl = ax.quiver(*args, **kwargs, **cmap_params)
return hdl


@_dsplot
def streamplot(ds, x, y, ax, u, v, **kwargs):
""" Quiver plot with Dataset variables."""
import matplotlib as mpl

if x is None or y is None or u is None or v is None:
raise ValueError("Must specify x, y, u, v for streamplot plots.")

# Matplotlib's streamplot has strong restrictions on what x and y can be, so need to
# get arrays transposed the 'right' way around. 'x' cannot vary within 'rows', so
# the dimension of x must be the second dimension. 'y' cannot vary with 'columns' so
# the dimension of y must be the first dimension. If x and y are both 2d, assume the
# user has got them right already.
if len(ds[x].dims) == 1:
xdim = ds[x].dims[0]
if len(ds[y].dims) == 1:
ydim = ds[y].dims[0]
if xdim is not None and ydim is None:
ydim = set(ds[y].dims) - set([xdim])
if ydim is not None and xdim is None:
xdim = set(ds[x].dims) - set([ydim])

x, y, u, v = broadcast(ds[x], ds[y], ds[u], ds[v])

if xdim is not None and ydim is not None:
# Need to ensure the arrays are transposed correctly
x = x.transpose(ydim, xdim)
y = y.transpose(ydim, xdim)
u = u.transpose(ydim, xdim)
v = v.transpose(ydim, xdim)

args = [x.values, y.values, u.values, v.values]
hue = kwargs.pop("hue")
cmap_params = kwargs.pop("cmap_params")

if hue:
kwargs["color"] = ds[hue].values

# TODO: Fix this by always returning a norm with vmin, vmax in cmap_params
if not cmap_params["norm"]:
cmap_params["norm"] = mpl.colors.Normalize(
cmap_params.pop("vmin"), cmap_params.pop("vmax")
)

kwargs.pop("hue_style")
hdl = ax.streamplot(*args, **kwargs, **cmap_params)

# Return .lines so colorbar creation works properly
return hdl.lines
55 changes: 55 additions & 0 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2212,6 +2212,61 @@ def test_facetgrid(self):
self.ds.plot.quiver(x="x", y="y", u="u", v="v", row="row", col="col")


@requires_matplotlib
class TestDatasetStreamplotPlots(PlotTestCase):
@pytest.fixture(autouse=True)
def setUp(self):
das = [
DataArray(
np.random.randn(3, 3, 2, 2),
dims=["x", "y", "row", "col"],
coords=[range(k) for k in [3, 3, 2, 2]],
)
for _ in [1, 2]
]
ds = Dataset({"u": das[0], "v": das[1]})
ds.x.attrs["units"] = "xunits"
ds.y.attrs["units"] = "yunits"
ds.col.attrs["units"] = "colunits"
ds.row.attrs["units"] = "rowunits"
ds.u.attrs["units"] = "uunits"
ds.v.attrs["units"] = "vunits"
ds["mag"] = np.hypot(ds.u, ds.v)
self.ds = ds

def test_streamline(self):
with figure_context():
hdl = self.ds.isel(row=0, col=0).plot.streamplot(x="x", y="y", u="u", v="v")
assert isinstance(hdl, mpl.collections.LineCollection)
with raises_regex(ValueError, "specify x, y, u, v"):
self.ds.isel(row=0, col=0).plot.streamplot(x="x", y="y", u="u")

with raises_regex(ValueError, "hue_style"):
self.ds.isel(row=0, col=0).plot.streamplot(
x="x", y="y", u="u", v="v", hue="mag", hue_style="discrete"
)

def test_facetgrid(self):
with figure_context():
fg = self.ds.plot.streamplot(
x="x", y="y", u="u", v="v", row="row", col="col", hue="mag"
)
for handle in fg._mappables:
assert isinstance(handle, mpl.collections.LineCollection)

with figure_context():
fg = self.ds.plot.streamplot(
x="x",
y="y",
u="u",
v="v",
row="row",
col="col",
hue="mag",
add_guide=False,
)


@requires_matplotlib
class TestDatasetScatterPlots(PlotTestCase):
@pytest.fixture(autouse=True)
Expand Down