diff --git a/doc/api.rst b/doc/api.rst index 9cb02441d37..9add7a96109 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -241,6 +241,7 @@ Plotting :template: autosummary/accessor_method.rst Dataset.plot.scatter + Dataset.plot.quiver DataArray ========= diff --git a/doc/plotting.rst b/doc/plotting.rst index 3699f794ae8..2ada3e25431 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -715,6 +715,9 @@ Consider this dataset ds +Scatter +~~~~~~~ + Suppose we want to scatter ``A`` against ``B`` .. ipython:: python @@ -762,6 +765,27 @@ Faceting is also possible For more advanced scatter plots, we recommend converting the relevant data variables to a pandas DataFrame and using the extensive plotting capabilities of ``seaborn``. +Quiver +~~~~~~ + +Visualizing vector fields is supported with quiver plots: + +.. ipython:: python + :okwarning: + + @savefig ds_simple_quiver.png + ds.isel(w=1, z=1).plot.quiver(x="x", y="y", u="A", v="B") + + +where ``u`` and ``v`` denote the x and y direction components of the arrow vectors. Again, faceting is also possible: + +.. ipython:: python + :okwarning: + + @savefig ds_facet_quiver.png + ds.plot.quiver(x="x", y="y", u="A", v="B", col="w", row="z", scale=4) + +``scale`` is required for faceted quiver plots. The scale determines the number of data units per arrow length unit, i.e. a smaller scale parameter makes the arrow longer. .. _plot-maps: diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1bca3aec68e..f5344aa266e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -65,13 +65,11 @@ New Features contain missing values; 8x faster in our benchmark, and 2x faster than pandas. (:pull:`4746`); By `Maximilian Roos `_. - -- Performance improvement when constructing DataArrays. Significantly speeds up repr for Datasets with large number of variables. - By `Deepak Cherian `_ +- Add :py:meth:`Dataset.plot.quiver` for quiver plots with :py:class:`Dataset` variables. + By `Deepak Cherian `_. - add ``"drop_conflicts"`` to the strategies supported by the ``combine_attrs`` kwarg (:issue:`4749`, :pull:`4827`). By `Justus Magin `_. - By `Deepak Cherian `_. - :py:meth:`DataArray.swap_dims` & :py:meth:`Dataset.swap_dims` now accept dims in the form of kwargs as well as a dict, like most similar methods. By `Maximilian Roos `_. @@ -152,6 +150,8 @@ Internal Changes all resources. (:pull:`#4809`), By `Alessandro Amici `_. - Ensure warnings cannot be turned into exceptions in :py:func:`testing.assert_equal` and the other ``assert_*`` functions (:pull:`4864`). By `Mathias Hauser `_. +- Performance improvement when constructing DataArrays. Significantly speeds up repr for Datasets with large number of variables. + By `Deepak Cherian `_ .. _whats-new.0.16.2: diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 6d942e1b0fa..59d3ca98f23 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -7,6 +7,7 @@ from .facetgrid import _easy_facetgrid from .utils import ( _add_colorbar, + _get_nice_quiver_magnitude, _is_numeric, _process_cmap_cbar_kwargs, get_axis, @@ -17,7 +18,7 @@ _MARKERSIZE_RANGE = np.array([18.0, 72.0]) -def _infer_meta_data(ds, x, y, hue, hue_style, add_guide): +def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname): dvars = set(ds.variables.keys()) error_msg = " must be one of ({:s})".format(", ".join(dvars)) @@ -48,11 +49,24 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide): add_colorbar = False add_legend = False else: - if add_guide is True: + if add_guide is True and funcname != "quiver": raise ValueError("Cannot set add_guide when hue is None.") add_legend = False add_colorbar = False + if (add_guide or add_guide is None) and funcname == "quiver": + add_quiverkey = True + if hue: + add_colorbar = True + if not hue_style: + hue_style = "continuous" + elif hue_style != "continuous": + raise ValueError( + "hue_style must be 'continuous' or None for .plot.quiver" + ) + else: + add_quiverkey = False + if hue_style is not None and hue_style not in ["discrete", "continuous"]: raise ValueError("hue_style must be either None, 'discrete' or 'continuous'.") @@ -66,6 +80,7 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide): return { "add_colorbar": add_colorbar, "add_legend": add_legend, + "add_quiverkey": add_quiverkey, "hue_label": hue_label, "hue_style": hue_style, "xlabel": label_from_attrs(ds[x]), @@ -170,6 +185,8 @@ def _dsplot(plotfunc): ds : Dataset x, y : str Variable names for x, y axis. + u, v : str, optional + Variable names for quiver plots hue: str, optional Variable by which to color scattered points hue_style: str, optional @@ -250,6 +267,8 @@ def newplotfunc( ds, x=None, y=None, + u=None, + v=None, hue=None, hue_style=None, col=None, @@ -282,7 +301,9 @@ def newplotfunc( if _is_facetgrid: # facetgrid call meta_data = kwargs.pop("meta_data") else: - meta_data = _infer_meta_data(ds, x, y, hue, hue_style, add_guide) + meta_data = _infer_meta_data( + ds, x, y, hue, hue_style, add_guide, funcname=plotfunc.__name__ + ) hue_style = meta_data["hue_style"] @@ -317,13 +338,18 @@ def newplotfunc( else: cmap_params_subset = {} + if (u is not None or v is not None) and plotfunc.__name__ != "quiver": + raise ValueError("u, v are only allowed for quiver plots.") + primitive = plotfunc( ds=ds, x=x, y=y, + ax=ax, + u=u, + v=v, hue=hue, hue_style=hue_style, - ax=ax, cmap_params=cmap_params_subset, **kwargs, ) @@ -344,6 +370,25 @@ def newplotfunc( cbar_kwargs["label"] = meta_data.get("hue_label", None) _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params) + if meta_data["add_quiverkey"]: + magnitude = _get_nice_quiver_magnitude(ds[u], ds[v]) + units = ds[u].attrs.get("units", "") + ax.quiverkey( + primitive, + X=0.85, + Y=0.9, + U=magnitude, + label=f"{magnitude}\n{units}", + labelpos="E", + coordinates="figure", + ) + + if plotfunc.__name__ == "quiver": + title = ds[u]._title_for_slice() + else: + title = ds[x]._title_for_slice() + ax.set_title(title) + return primitive @functools.wraps(newplotfunc) @@ -351,6 +396,8 @@ def plotmethod( _PlotMethods_obj, x=None, y=None, + u=None, + v=None, hue=None, hue_style=None, col=None, @@ -398,7 +445,7 @@ def plotmethod( @_dsplot -def scatter(ds, x, y, ax, **kwargs): +def scatter(ds, x, y, ax, u, v, **kwargs): """ Scatter Dataset data variables against each other. """ @@ -450,3 +497,32 @@ def scatter(ds, x, y, ax, **kwargs): ) return primitive + + +@_dsplot +def quiver(ds, x, y, ax, u, v, **kwargs): + """ Quiver plot with Dataset variables.""" + import matplotlib as mpl + + if x is None or y is None or u is None or v is None: + raise ValueError("Must specify x, y, u, v for quiver plots.") + + x, y, u, v = broadcast(ds[x], ds[y], ds[u], ds[v]) + + args = [x.values, y.values, u.values, v.values] + hue = kwargs.pop("hue") + cmap_params = kwargs.pop("cmap_params") + + if hue: + args.append(ds[hue].values) + + # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params + if not cmap_params["norm"]: + cmap_params["norm"] = mpl.colors.Normalize( + cmap_params.pop("vmin"), cmap_params.pop("vmax") + ) + + kwargs.pop("hue_style") + kwargs.setdefault("pivot", "middle") + hdl = ax.quiver(*args, **kwargs, **cmap_params) + return hdl diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index bfa400d7ba4..2d3c0595026 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -6,6 +6,7 @@ from ..core.formatting import format_item from .utils import ( + _get_nice_quiver_magnitude, _infer_xy_labels, _process_cmap_cbar_kwargs, import_matplotlib_pyplot, @@ -195,7 +196,11 @@ def __init__( self.axes = axes self.row_names = row_names self.col_names = col_names + + # guides self.figlegend = None + self.quiverkey = None + self.cbar = None # Next the private variables self._single_group = single_group @@ -327,14 +332,15 @@ def map_dataset( from .dataset_plot import _infer_meta_data, _parse_size kwargs["add_guide"] = False - kwargs["_is_facetgrid"] = True if kwargs.get("markersize", None): kwargs["size_mapping"] = _parse_size( self.data[kwargs["markersize"]], kwargs.pop("size_norm", None) ) - meta_data = _infer_meta_data(self.data, x, y, hue, hue_style, add_guide) + meta_data = _infer_meta_data( + self.data, x, y, hue, hue_style, add_guide, funcname=func.__name__ + ) kwargs["meta_data"] = meta_data if hue and meta_data["hue_style"] == "continuous": @@ -344,6 +350,12 @@ def map_dataset( kwargs["meta_data"]["cmap_params"] = cmap_params kwargs["meta_data"]["cbar_kwargs"] = cbar_kwargs + kwargs["_is_facetgrid"] = True + + if func.__name__ == "quiver" and "scale" not in kwargs: + raise ValueError("Please provide scale.") + # TODO: come up with an algorithm for reasonable scale choice + for d, ax in zip(self.name_dicts.flat, self.axes.flat): # None is the sentinel value if d is not None: @@ -365,6 +377,9 @@ def map_dataset( elif meta_data["add_colorbar"]: self.add_colorbar(label=self._hue_label, **cbar_kwargs) + if meta_data["add_quiverkey"]: + self.add_quiverkey(kwargs["u"], kwargs["v"]) + return self def _finalize_grid(self, *axlabels): @@ -380,30 +395,22 @@ def _finalize_grid(self, *axlabels): self._finalized = True - def add_legend(self, **kwargs): - figlegend = self.fig.legend( - handles=self._mappables[-1], - labels=list(self._hue_var.values), - title=self._hue_label, - loc="center right", - **kwargs, - ) - - self.figlegend = figlegend + def _adjust_fig_for_guide(self, guide): # Draw the plot to set the bounding boxes correctly - self.fig.draw(self.fig.canvas.get_renderer()) + renderer = self.fig.canvas.get_renderer() + self.fig.draw(renderer) # Calculate and set the new width of the figure so the legend fits - legend_width = figlegend.get_window_extent().width / self.fig.dpi + guide_width = guide.get_window_extent(renderer).width / self.fig.dpi figure_width = self.fig.get_figwidth() - self.fig.set_figwidth(figure_width + legend_width) + self.fig.set_figwidth(figure_width + guide_width) # Draw the plot again to get the new transformations - self.fig.draw(self.fig.canvas.get_renderer()) + self.fig.draw(renderer) # Now calculate how much space we need on the right side - legend_width = figlegend.get_window_extent().width / self.fig.dpi - space_needed = legend_width / (figure_width + legend_width) + 0.02 + guide_width = guide.get_window_extent(renderer).width / self.fig.dpi + space_needed = guide_width / (figure_width + guide_width) + 0.02 # margin = .01 # _space_needed = margin + space_needed right = 1 - space_needed @@ -411,6 +418,16 @@ def add_legend(self, **kwargs): # Place the subplot axes to give space for the legend self.fig.subplots_adjust(right=right) + def add_legend(self, **kwargs): + self.figlegend = self.fig.legend( + handles=self._mappables[-1], + labels=list(self._hue_var.values), + title=self._hue_label, + loc="center right", + **kwargs, + ) + self._adjust_fig_for_guide(self.figlegend) + def add_colorbar(self, **kwargs): """Draw a colorbar""" kwargs = kwargs.copy() @@ -426,6 +443,26 @@ def add_colorbar(self, **kwargs): ) return self + def add_quiverkey(self, u, v, **kwargs): + kwargs = kwargs.copy() + + magnitude = _get_nice_quiver_magnitude(self.data[u], self.data[v]) + units = self.data[u].attrs.get("units", "") + self.quiverkey = self.axes.flat[-1].quiverkey( + self._mappables[-1], + X=0.8, + Y=0.9, + U=magnitude, + label=f"{magnitude}\n{units}", + labelpos="E", + coordinates="figure", + ) + + # TODO: does not work because self.quiverkey.get_window_extent(renderer) = 0 + # https://github.com/matplotlib/matplotlib/issues/18530 + # self._adjust_fig_for_guide(self.quiverkey.text) + return self + def set_axis_labels(self, x_var=None, y_var=None): """Set axis labels on the left column and bottom row of the grid.""" if x_var is not None: diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index ffe796987c5..5510cf7f219 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -841,3 +841,12 @@ def _process_cmap_cbar_kwargs( } return cmap_params, cbar_kwargs + + +def _get_nice_quiver_magnitude(u, v): + import matplotlib as mpl + + ticker = mpl.ticker.MaxNLocator(3) + mean = np.mean(np.hypot(u.values, v.values)) + magnitude = ticker.tick_values(0, mean)[-2] + return magnitude diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 47b15446f1d..705b2d5e2e7 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1475,7 +1475,7 @@ def test_facetgrid_cbar_kwargs(self): ) # catch contour case - if hasattr(g, "cbar"): + if g.cbar is not None: assert get_colorbar_label(g.cbar) == "test_label" def test_facetgrid_no_cbar_ax(self): @@ -2152,6 +2152,66 @@ def test_wrong_num_of_dimensions(self): self.darray.plot.line(row="row", hue="hue") +@requires_matplotlib +class TestDatasetQuiverPlots(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self): + das = [ + DataArray( + np.random.randn(3, 3, 4, 4), + dims=["x", "y", "row", "col"], + coords=[range(k) for k in [3, 3, 4, 4]], + ) + for _ in [1, 2] + ] + ds = Dataset({"u": das[0], "v": das[1]}) + ds.x.attrs["units"] = "xunits" + ds.y.attrs["units"] = "yunits" + ds.col.attrs["units"] = "colunits" + ds.row.attrs["units"] = "rowunits" + ds.u.attrs["units"] = "uunits" + ds.v.attrs["units"] = "vunits" + ds["mag"] = np.hypot(ds.u, ds.v) + self.ds = ds + + def test_quiver(self): + with figure_context(): + hdl = self.ds.isel(row=0, col=0).plot.quiver(x="x", y="y", u="u", v="v") + assert isinstance(hdl, mpl.quiver.Quiver) + with raises_regex(ValueError, "specify x, y, u, v"): + self.ds.isel(row=0, col=0).plot.quiver(x="x", y="y", u="u") + + with raises_regex(ValueError, "hue_style"): + self.ds.isel(row=0, col=0).plot.quiver( + x="x", y="y", u="u", v="v", hue="mag", hue_style="discrete" + ) + + def test_facetgrid(self): + with figure_context(): + fg = self.ds.plot.quiver( + x="x", y="y", u="u", v="v", row="row", col="col", scale=1, hue="mag" + ) + for handle in fg._mappables: + assert isinstance(handle, mpl.quiver.Quiver) + assert "uunits" in fg.quiverkey.text.get_text() + + with figure_context(): + fg = self.ds.plot.quiver( + x="x", + y="y", + u="u", + v="v", + row="row", + col="col", + scale=1, + hue="mag", + add_guide=False, + ) + assert fg.quiverkey is None + with raises_regex(ValueError, "Please provide scale"): + self.ds.plot.quiver(x="x", y="y", u="u", v="v", row="row", col="col") + + @requires_matplotlib class TestDatasetScatterPlots(PlotTestCase): @pytest.fixture(autouse=True) @@ -2194,7 +2254,13 @@ def test_accessor(self): def test_add_guide(self, add_guide, hue_style, legend, colorbar): meta_data = _infer_meta_data( - self.ds, x="A", y="B", hue="hue", hue_style=hue_style, add_guide=add_guide + self.ds, + x="A", + y="B", + hue="hue", + hue_style=hue_style, + add_guide=add_guide, + funcname="scatter", ) assert meta_data["add_legend"] is legend assert meta_data["add_colorbar"] is colorbar @@ -2273,6 +2339,9 @@ def test_facetgrid_hue_style(self): def test_scatter(self, x, y, hue, markersize): self.ds.plot.scatter(x, y, hue=hue, markersize=markersize) + with raises_regex(ValueError, "u, v"): + self.ds.plot.scatter(x, y, u="col", v="row") + def test_non_numeric_legend(self): ds2 = self.ds.copy() ds2["hue"] = ["a", "b", "c", "d"]