From ae6f6743c700a4d59c09ec7da75afcc81c84f4f0 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Fri, 5 Mar 2021 16:32:14 +0000 Subject: [PATCH 1/5] Add Dataset.plot.streamplot() method --- doc/api.rst | 1 + doc/plotting.rst | 20 +++++++++++ xarray/plot/dataset_plot.py | 70 +++++++++++++++++++++++++++++++++---- xarray/tests/test_plot.py | 55 +++++++++++++++++++++++++++++ 4 files changed, 139 insertions(+), 7 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 9add7a96109..1ea45af5e95 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -242,6 +242,7 @@ Plotting Dataset.plot.scatter Dataset.plot.quiver + Dataset.plot.streamplot DataArray ========= diff --git a/doc/plotting.rst b/doc/plotting.rst index f5f1168df23..098c63d0e40 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -787,6 +787,26 @@ where ``u`` and ``v`` denote the x and y direction components of the arrow vecto ``scale`` is required for faceted quiver plots. The scale determines the number of data units per arrow length unit, i.e. a smaller scale parameter makes the arrow longer. +Streamplot +~~~~~~~~~~ + +Visualizing vector fields is also supported with streamline plots: + +.. ipython:: python + :okwarning: + + @savefig ds_simple_streamplot.png + ds.isel(w=1, z=1).plot.streamplot(x="x", y="y", u="A", v="B") + + +where ``u`` and ``v`` denote the x and y direction components of the vectors tangent to the streamlines. Again, faceting is also possible: + +.. ipython:: python + :okwarning: + + @savefig ds_facet_streamplot.png + ds.plot.streamplot(x="x", y="y", u="A", v="B", col="w", row="z") + .. _plot-maps: Maps diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 59d3ca98f23..641f504f33f 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -49,12 +49,14 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname): add_colorbar = False add_legend = False else: - if add_guide is True and funcname != "quiver": + if add_guide is True and not (funcname == "quiver" or funcname == "streamplot"): raise ValueError("Cannot set add_guide when hue is None.") add_legend = False add_colorbar = False - if (add_guide or add_guide is None) and funcname == "quiver": + if (add_guide or add_guide is None) and ( + funcname == "quiver" or funcname == "streamplot" + ): add_quiverkey = True if hue: add_colorbar = True @@ -62,7 +64,8 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname): hue_style = "continuous" elif hue_style != "continuous": raise ValueError( - "hue_style must be 'continuous' or None for .plot.quiver" + "hue_style must be 'continuous' or None for .plot.quiver or " + ".plot.streamplot" ) else: add_quiverkey = False @@ -186,7 +189,7 @@ def _dsplot(plotfunc): x, y : str Variable names for x, y axis. u, v : str, optional - Variable names for quiver plots + Variable names for quiver or streamplot plots hue: str, optional Variable by which to color scattered points hue_style: str, optional @@ -338,8 +341,10 @@ def newplotfunc( else: cmap_params_subset = {} - if (u is not None or v is not None) and plotfunc.__name__ != "quiver": - raise ValueError("u, v are only allowed for quiver plots.") + if (u is not None or v is not None) and not ( + plotfunc.__name__ == "quiver" or plotfunc.__name__ == "streamplot" + ): + raise ValueError("u, v are only allowed for quiver or streamplot plots.") primitive = plotfunc( ds=ds, @@ -383,7 +388,7 @@ def newplotfunc( coordinates="figure", ) - if plotfunc.__name__ == "quiver": + if plotfunc.__name__ == "quiver" or plotfunc.__name__ == "streamplot": title = ds[u]._title_for_slice() else: title = ds[x]._title_for_slice() @@ -526,3 +531,54 @@ def quiver(ds, x, y, ax, u, v, **kwargs): kwargs.setdefault("pivot", "middle") hdl = ax.quiver(*args, **kwargs, **cmap_params) return hdl + + +@_dsplot +def streamplot(ds, x, y, ax, u, v, **kwargs): + """ Quiver plot with Dataset variables.""" + import matplotlib as mpl + + if x is None or y is None or u is None or v is None: + raise ValueError("Must specify x, y, u, v for streamplot plots.") + + # Matplotlib's streamplot has strong restrictions on what x and y can be, so need to + # get arrays transposed the 'right' way around. 'x' cannot vary within 'rows', so + # the dimension of x must be the second dimension. 'y' cannot vary with 'columns' so + # the dimension of y must be the first dimension. If x and y are both 2d, assume the + # user has got them right already. + if len(ds[x].dims) == 1: + xdim = ds[x].dims[0] + if len(ds[y].dims) == 1: + ydim = ds[y].dims[0] + if xdim is not None and ydim is None: + ydim = set(ds[y].dims) - set([xdim]) + if ydim is not None and xdim is None: + xdim = set(ds[x].dims) - set([ydim]) + + x, y, u, v = broadcast(ds[x], ds[y], ds[u], ds[v]) + + if xdim is not None and ydim is not None: + # Need to ensure the arrays are transposed correctly + x = x.transpose(ydim, xdim) + y = y.transpose(ydim, xdim) + u = u.transpose(ydim, xdim) + v = v.transpose(ydim, xdim) + + args = [x.values, y.values, u.values, v.values] + hue = kwargs.pop("hue") + cmap_params = kwargs.pop("cmap_params") + + if hue: + kwargs["color"] = ds[hue].values + + # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params + if not cmap_params["norm"]: + cmap_params["norm"] = mpl.colors.Normalize( + cmap_params.pop("vmin"), cmap_params.pop("vmax") + ) + + kwargs.pop("hue_style") + hdl = ax.streamplot(*args, **kwargs, **cmap_params) + + # Return .lines so colorbar creation works properly + return hdl.lines diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 705b2d5e2e7..40b46d0c953 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2212,6 +2212,61 @@ def test_facetgrid(self): self.ds.plot.quiver(x="x", y="y", u="u", v="v", row="row", col="col") +@requires_matplotlib +class TestDatasetStreamplotPlots(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self): + das = [ + DataArray( + np.random.randn(3, 3, 2, 2), + dims=["x", "y", "row", "col"], + coords=[range(k) for k in [3, 3, 2, 2]], + ) + for _ in [1, 2] + ] + ds = Dataset({"u": das[0], "v": das[1]}) + ds.x.attrs["units"] = "xunits" + ds.y.attrs["units"] = "yunits" + ds.col.attrs["units"] = "colunits" + ds.row.attrs["units"] = "rowunits" + ds.u.attrs["units"] = "uunits" + ds.v.attrs["units"] = "vunits" + ds["mag"] = np.hypot(ds.u, ds.v) + self.ds = ds + + def test_streamline(self): + with figure_context(): + hdl = self.ds.isel(row=0, col=0).plot.streamplot(x="x", y="y", u="u", v="v") + assert isinstance(hdl, mpl.collections.LineCollection) + with raises_regex(ValueError, "specify x, y, u, v"): + self.ds.isel(row=0, col=0).plot.streamplot(x="x", y="y", u="u") + + with raises_regex(ValueError, "hue_style"): + self.ds.isel(row=0, col=0).plot.streamplot( + x="x", y="y", u="u", v="v", hue="mag", hue_style="discrete" + ) + + def test_facetgrid(self): + with figure_context(): + fg = self.ds.plot.streamplot( + x="x", y="y", u="u", v="v", row="row", col="col", hue="mag" + ) + for handle in fg._mappables: + assert isinstance(handle, mpl.collections.LineCollection) + + with figure_context(): + fg = self.ds.plot.streamplot( + x="x", + y="y", + u="u", + v="v", + row="row", + col="col", + hue="mag", + add_guide=False, + ) + + @requires_matplotlib class TestDatasetScatterPlots(PlotTestCase): @pytest.fixture(autouse=True) From 378d6b32182b518d1b3c620a35578191274d29d8 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Fri, 5 Mar 2021 17:42:17 +0000 Subject: [PATCH 2/5] Update whats-new.rst --- doc/whats-new.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9e59fdc5b35..bd08bcc76eb 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,6 +25,9 @@ New Features - Support for `dask.graph_manipulation `_ (requires dask >=2021.3) By `Guido Imperiale `_ +- Add :py:meth:`Dataset.plot.streamplot` for streamplot plots with :py:class:`Dataset` + variables (:pull:`5003`). + By `John Omotani `_. Breaking changes ~~~~~~~~~~~~~~~~ From f0f3764ceeeeb7087f85fbb470a394957b8e3046 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Fri, 5 Mar 2021 18:02:20 +0000 Subject: [PATCH 3/5] Fix colorbar streamplot needs to use separate condition to one for quiver plots. --- xarray/plot/dataset_plot.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 641f504f33f..16fc045c174 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -54,9 +54,7 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname): add_legend = False add_colorbar = False - if (add_guide or add_guide is None) and ( - funcname == "quiver" or funcname == "streamplot" - ): + if (add_guide or add_guide is None) and funcname == "quiver": add_quiverkey = True if hue: add_colorbar = True @@ -70,6 +68,17 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname): else: add_quiverkey = False + if (add_guide or add_guide is None) and funcname == "streamplot": + if hue: + add_colorbar = True + if not hue_style: + hue_style = "continuous" + elif hue_style != "continuous": + raise ValueError( + "hue_style must be 'continuous' or None for .plot.quiver or " + ".plot.streamplot" + ) + if hue_style is not None and hue_style not in ["discrete", "continuous"]: raise ValueError("hue_style must be either None, 'discrete' or 'continuous'.") From 8f9b90da49e2c8b695f57ebdd686d2a0a8521896 Mon Sep 17 00:00:00 2001 From: johnomotani Date: Mon, 15 Mar 2021 19:10:26 +0000 Subject: [PATCH 4/5] Apply suggestions from code review Co-authored-by: keewis --- xarray/plot/dataset_plot.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 16fc045c174..426553f5fd5 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -49,7 +49,7 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname): add_colorbar = False add_legend = False else: - if add_guide is True and not (funcname == "quiver" or funcname == "streamplot"): + if add_guide is True and funcname not in ("quiver", "streamplot"): raise ValueError("Cannot set add_guide when hue is None.") add_legend = False add_colorbar = False @@ -350,8 +350,8 @@ def newplotfunc( else: cmap_params_subset = {} - if (u is not None or v is not None) and not ( - plotfunc.__name__ == "quiver" or plotfunc.__name__ == "streamplot" + if (u is not None or v is not None) and plotfunc.__name__ not in ( + "quiver", "streamplot" ): raise ValueError("u, v are only allowed for quiver or streamplot plots.") @@ -397,7 +397,7 @@ def newplotfunc( coordinates="figure", ) - if plotfunc.__name__ == "quiver" or plotfunc.__name__ == "streamplot": + if plotfunc.__name__ in ("quiver", "streamplot"): title = ds[u]._title_for_slice() else: title = ds[x]._title_for_slice() From 484b467cdda8a96872bb7f57b2551d89799891a9 Mon Sep 17 00:00:00 2001 From: johnomotani Date: Mon, 15 Mar 2021 19:39:50 +0000 Subject: [PATCH 5/5] black fix --- xarray/plot/dataset_plot.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 426553f5fd5..e5261a960cb 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -351,7 +351,8 @@ def newplotfunc( cmap_params_subset = {} if (u is not None or v is not None) and plotfunc.__name__ not in ( - "quiver", "streamplot" + "quiver", + "streamplot", ): raise ValueError("u, v are only allowed for quiver or streamplot plots.")