diff --git a/doc/api.rst b/doc/api.rst index 4af9bb01208..a140d9e2b81 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -243,6 +243,7 @@ Plotting Dataset.plot.scatter Dataset.plot.quiver + Dataset.plot.streamplot DataArray ========= diff --git a/doc/user-guide/plotting.rst b/doc/user-guide/plotting.rst index f5f1168df23..098c63d0e40 100644 --- a/doc/user-guide/plotting.rst +++ b/doc/user-guide/plotting.rst @@ -787,6 +787,26 @@ where ``u`` and ``v`` denote the x and y direction components of the arrow vecto ``scale`` is required for faceted quiver plots. The scale determines the number of data units per arrow length unit, i.e. a smaller scale parameter makes the arrow longer. +Streamplot +~~~~~~~~~~ + +Visualizing vector fields is also supported with streamline plots: + +.. ipython:: python + :okwarning: + + @savefig ds_simple_streamplot.png + ds.isel(w=1, z=1).plot.streamplot(x="x", y="y", u="A", v="B") + + +where ``u`` and ``v`` denote the x and y direction components of the vectors tangent to the streamlines. Again, faceting is also possible: + +.. ipython:: python + :okwarning: + + @savefig ds_facet_streamplot.png + ds.plot.streamplot(x="x", y="y", u="A", v="B", col="w", row="z") + .. _plot-maps: Maps diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 77d6296acac..bcef55c9aec 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -31,6 +31,9 @@ New Features - Support for `dask.graph_manipulation `_ (requires dask >=2021.3) By `Guido Imperiale `_ +- Add :py:meth:`Dataset.plot.streamplot` for streamplot plots with :py:class:`Dataset` + variables (:pull:`5003`). + By `John Omotani `_. - Many of the arguments for the :py:attr:`DataArray.str` methods now support providing an array-like input. In this case, the array provided to the arguments is broadcast against the original array and applied elementwise. diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 59d3ca98f23..e5261a960cb 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -49,7 +49,7 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname): add_colorbar = False add_legend = False else: - if add_guide is True and funcname != "quiver": + if add_guide is True and funcname not in ("quiver", "streamplot"): raise ValueError("Cannot set add_guide when hue is None.") add_legend = False add_colorbar = False @@ -62,11 +62,23 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname): hue_style = "continuous" elif hue_style != "continuous": raise ValueError( - "hue_style must be 'continuous' or None for .plot.quiver" + "hue_style must be 'continuous' or None for .plot.quiver or " + ".plot.streamplot" ) else: add_quiverkey = False + if (add_guide or add_guide is None) and funcname == "streamplot": + if hue: + add_colorbar = True + if not hue_style: + hue_style = "continuous" + elif hue_style != "continuous": + raise ValueError( + "hue_style must be 'continuous' or None for .plot.quiver or " + ".plot.streamplot" + ) + if hue_style is not None and hue_style not in ["discrete", "continuous"]: raise ValueError("hue_style must be either None, 'discrete' or 'continuous'.") @@ -186,7 +198,7 @@ def _dsplot(plotfunc): x, y : str Variable names for x, y axis. u, v : str, optional - Variable names for quiver plots + Variable names for quiver or streamplot plots hue: str, optional Variable by which to color scattered points hue_style: str, optional @@ -338,8 +350,11 @@ def newplotfunc( else: cmap_params_subset = {} - if (u is not None or v is not None) and plotfunc.__name__ != "quiver": - raise ValueError("u, v are only allowed for quiver plots.") + if (u is not None or v is not None) and plotfunc.__name__ not in ( + "quiver", + "streamplot", + ): + raise ValueError("u, v are only allowed for quiver or streamplot plots.") primitive = plotfunc( ds=ds, @@ -383,7 +398,7 @@ def newplotfunc( coordinates="figure", ) - if plotfunc.__name__ == "quiver": + if plotfunc.__name__ in ("quiver", "streamplot"): title = ds[u]._title_for_slice() else: title = ds[x]._title_for_slice() @@ -526,3 +541,54 @@ def quiver(ds, x, y, ax, u, v, **kwargs): kwargs.setdefault("pivot", "middle") hdl = ax.quiver(*args, **kwargs, **cmap_params) return hdl + + +@_dsplot +def streamplot(ds, x, y, ax, u, v, **kwargs): + """ Quiver plot with Dataset variables.""" + import matplotlib as mpl + + if x is None or y is None or u is None or v is None: + raise ValueError("Must specify x, y, u, v for streamplot plots.") + + # Matplotlib's streamplot has strong restrictions on what x and y can be, so need to + # get arrays transposed the 'right' way around. 'x' cannot vary within 'rows', so + # the dimension of x must be the second dimension. 'y' cannot vary with 'columns' so + # the dimension of y must be the first dimension. If x and y are both 2d, assume the + # user has got them right already. + if len(ds[x].dims) == 1: + xdim = ds[x].dims[0] + if len(ds[y].dims) == 1: + ydim = ds[y].dims[0] + if xdim is not None and ydim is None: + ydim = set(ds[y].dims) - set([xdim]) + if ydim is not None and xdim is None: + xdim = set(ds[x].dims) - set([ydim]) + + x, y, u, v = broadcast(ds[x], ds[y], ds[u], ds[v]) + + if xdim is not None and ydim is not None: + # Need to ensure the arrays are transposed correctly + x = x.transpose(ydim, xdim) + y = y.transpose(ydim, xdim) + u = u.transpose(ydim, xdim) + v = v.transpose(ydim, xdim) + + args = [x.values, y.values, u.values, v.values] + hue = kwargs.pop("hue") + cmap_params = kwargs.pop("cmap_params") + + if hue: + kwargs["color"] = ds[hue].values + + # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params + if not cmap_params["norm"]: + cmap_params["norm"] = mpl.colors.Normalize( + cmap_params.pop("vmin"), cmap_params.pop("vmax") + ) + + kwargs.pop("hue_style") + hdl = ax.streamplot(*args, **kwargs, **cmap_params) + + # Return .lines so colorbar creation works properly + return hdl.lines diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 705b2d5e2e7..40b46d0c953 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2212,6 +2212,61 @@ def test_facetgrid(self): self.ds.plot.quiver(x="x", y="y", u="u", v="v", row="row", col="col") +@requires_matplotlib +class TestDatasetStreamplotPlots(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self): + das = [ + DataArray( + np.random.randn(3, 3, 2, 2), + dims=["x", "y", "row", "col"], + coords=[range(k) for k in [3, 3, 2, 2]], + ) + for _ in [1, 2] + ] + ds = Dataset({"u": das[0], "v": das[1]}) + ds.x.attrs["units"] = "xunits" + ds.y.attrs["units"] = "yunits" + ds.col.attrs["units"] = "colunits" + ds.row.attrs["units"] = "rowunits" + ds.u.attrs["units"] = "uunits" + ds.v.attrs["units"] = "vunits" + ds["mag"] = np.hypot(ds.u, ds.v) + self.ds = ds + + def test_streamline(self): + with figure_context(): + hdl = self.ds.isel(row=0, col=0).plot.streamplot(x="x", y="y", u="u", v="v") + assert isinstance(hdl, mpl.collections.LineCollection) + with raises_regex(ValueError, "specify x, y, u, v"): + self.ds.isel(row=0, col=0).plot.streamplot(x="x", y="y", u="u") + + with raises_regex(ValueError, "hue_style"): + self.ds.isel(row=0, col=0).plot.streamplot( + x="x", y="y", u="u", v="v", hue="mag", hue_style="discrete" + ) + + def test_facetgrid(self): + with figure_context(): + fg = self.ds.plot.streamplot( + x="x", y="y", u="u", v="v", row="row", col="col", hue="mag" + ) + for handle in fg._mappables: + assert isinstance(handle, mpl.collections.LineCollection) + + with figure_context(): + fg = self.ds.plot.streamplot( + x="x", + y="y", + u="u", + v="v", + row="row", + col="col", + hue="mag", + add_guide=False, + ) + + @requires_matplotlib class TestDatasetScatterPlots(PlotTestCase): @pytest.fixture(autouse=True)