diff --git a/doc/_embedded_plots/grouped_bar.py b/doc/_embedded_plots/grouped_bar.py new file mode 100644 index 000000000000..f02e269328d2 --- /dev/null +++ b/doc/_embedded_plots/grouped_bar.py @@ -0,0 +1,15 @@ +import matplotlib.pyplot as plt + +categories = ['A', 'B'] +data0 = [1.0, 3.0] +data1 = [1.4, 3.4] +data2 = [1.8, 3.8] + +fig, ax = plt.subplots(figsize=(4, 2.2)) +ax.grouped_bar( + [data0, data1, data2], + tick_labels=categories, + labels=['dataset 0', 'dataset 1', 'dataset 2'], + colors=['#1f77b4', '#58a1cf', '#abd0e6'], +) +ax.legend() diff --git a/doc/api/axes_api.rst b/doc/api/axes_api.rst index f389226d907a..b543bd35e36f 100644 --- a/doc/api/axes_api.rst +++ b/doc/api/axes_api.rst @@ -67,6 +67,7 @@ Basic Axes.bar Axes.barh Axes.bar_label + Axes.grouped_bar Axes.stem Axes.eventplot diff --git a/doc/api/pyplot_summary.rst b/doc/api/pyplot_summary.rst index d0def34c4995..629f927a95c3 100644 --- a/doc/api/pyplot_summary.rst +++ b/doc/api/pyplot_summary.rst @@ -61,6 +61,7 @@ Basic bar barh bar_label + grouped_bar stem eventplot pie diff --git a/doc/users/next_whats_new/grouped_bar.rst b/doc/users/next_whats_new/grouped_bar.rst new file mode 100644 index 000000000000..af57c71b8a3a --- /dev/null +++ b/doc/users/next_whats_new/grouped_bar.rst @@ -0,0 +1,26 @@ +Grouped bar charts +------------------ + +The new method `~.Axes.grouped_bar()` simplifies the creation of grouped bar charts +significantly. It supports different input data types (lists of datasets, dicts of +datasets, data in 2D arrays, pandas DataFrames), and allows for easy customization +of placement via controllable distances between bars and between bar groups. + +Example: + +.. plot:: + :include-source: true + :alt: Diagram of a grouped bar chart of 3 datasets with 2 categories. + + import matplotlib.pyplot as plt + + categories = ['A', 'B'] + datasets = { + 'dataset 0': [1, 11], + 'dataset 1': [3, 13], + 'dataset 2': [5, 15], + } + + fig, ax = plt.subplots() + ax.grouped_bar(datasets, tick_labels=categories) + ax.legend() diff --git a/galleries/examples/lines_bars_and_markers/barchart.py b/galleries/examples/lines_bars_and_markers/barchart.py index f2157a89c0cd..dbb0f5bbbadd 100644 --- a/galleries/examples/lines_bars_and_markers/barchart.py +++ b/galleries/examples/lines_bars_and_markers/barchart.py @@ -10,7 +10,6 @@ # data from https://allisonhorst.github.io/palmerpenguins/ import matplotlib.pyplot as plt -import numpy as np species = ("Adelie", "Chinstrap", "Gentoo") penguin_means = { @@ -19,22 +18,15 @@ 'Flipper Length': (189.95, 195.82, 217.19), } -x = np.arange(len(species)) # the label locations -width = 0.25 # the width of the bars -multiplier = 0 - fig, ax = plt.subplots(layout='constrained') -for attribute, measurement in penguin_means.items(): - offset = width * multiplier - rects = ax.bar(x + offset, measurement, width, label=attribute) - ax.bar_label(rects, padding=3) - multiplier += 1 +res = ax.grouped_bar(penguin_means, tick_labels=species, group_spacing=1) +for container in res.bar_containers: + ax.bar_label(container, padding=3) -# Add some text for labels, title and custom x-axis tick labels, etc. +# Add some text for labels, title, etc. ax.set_ylabel('Length (mm)') ax.set_title('Penguin attributes by species') -ax.set_xticks(x + width, species) ax.legend(loc='upper left', ncols=3) ax.set_ylim(0, 250) diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py index f03de3236c8d..6a73e2e20be9 100644 --- a/lib/matplotlib/axes/_axes.py +++ b/lib/matplotlib/axes/_axes.py @@ -64,6 +64,23 @@ def _make_axes_method(func): return func +class _GroupedBarReturn: + """ + A provisional result object for `.Axes.grouped_bar`. + + This is a placeholder for a future better return type. We try to build in + backward compatibility / migration possibilities. + + The only public interfaces are the ``bar_containers`` attribute and the + ``remove()`` method. + """ + def __init__(self, bar_containers): + self.bar_containers = bar_containers + + def remove(self): + [b.remove() for b in self.bars] + + @_docstring.interpd class Axes(_AxesBase): """ @@ -2489,6 +2506,7 @@ def bar(self, x, height, width=0.8, bottom=None, *, align="center", See Also -------- barh : Plot a horizontal bar plot. + grouped_bar : Plot multiple datasets as grouped bar plot. Notes ----- @@ -3048,6 +3066,311 @@ def broken_barh(self, xranges, yrange, **kwargs): return col + @_docstring.interpd + def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing=0, + tick_labels=None, labels=None, orientation="vertical", colors=None, + **kwargs): + """ + Make a grouped bar plot. + + .. versionadded:: 3.11 + + The API is still provisional. We may still fine-tune some aspects based on + user-feedback. + + Grouped bar charts visualize a collection of categorical datasets. Each value + in a dataset belongs to a distinct category and these categories are the same + across all datasets. The categories typically have string names, but could + also be dates or index keys. The values in each dataset are represented by a + sequence of bars of the same color. The bars of all datasets are grouped + together by their shared categories. The category names are drawn as the tick + labels for each bar group. Each dataset has a distinct bar color, and can + optionally get a label that is used for the legend. + + Example: + + .. code-block:: python + + grouped_bar([dataset_1, dataset_2, dataset_3], + tick_labels=['A', 'B'], + labels=['dataset 1', 'dataset 2', 'dataset 3']) + + .. plot:: _embedded_plots/grouped_bar.py + + Parameters + ---------- + heights : list of array-like or dict of array-like or 2D array \ +or pandas.DataFrame + The heights for all x and groups. One of: + + - list of array-like: A list of datasets, each dataset must have + the same number of elements. + + .. code-block:: none + + # category_A, category_B + dataset_0 = [value_0_A, value_0_B] + dataset_1 = [value_1_A, value_1_B] + dataset_2 = [value_2_A, value_2_B] + + Example call:: + + grouped_bar([dataset_0, dataset_1, dataset_2]) + + - dict of array-like: A mapping from names to datasets. Each dataset + (dict value) must have the same number of elements. + + Example call: + + .. code-block:: python + + data_dict = {'ds0': dataset_0, 'ds1': dataset_1, 'ds2': dataset_2} + grouped_bar(data_dict) + + The names are used as *labels*, i.e. this is equivalent to + + .. code-block:: python + + grouped_bar(data_dict.values(), labels=data_dict.keys()) + + When using a dict input, you must not pass *labels* explicitly. + + - a 2D array: The rows are the categories, the columns are the different + datasets. + + .. code-block:: none + + dataset_0 dataset_1 dataset_2 + category_A ds0_a ds1_a ds2_a + category_B ds0_b ds1_b ds2_b + + Example call: + + .. code-block:: python + + categories = ["A", "B"] + dataset_labels = ["dataset_0", "dataset_1", "dataset_2"] + array = np.random.random((2, 3)) + grouped_bar(array, tick_labels=categories, labels=dataset_labels) + + - a `pandas.DataFrame`. + + The index is used for the categories, the columns are used for the + datasets. + + .. code-block:: python + + df = pd.DataFrame( + np.random.random((2, 3)), + index=["A", "B"], + columns=["dataset_0", "dataset_1", "dataset_2"] + ) + grouped_bar(df) + + i.e. this is equivalent to + + .. code-block:: + + grouped_bar(df.to_numpy(), tick_labels=df.index, labels=df.columns) + + Note that ``grouped_bar(df)`` produces a structurally equivalent plot like + ``df.plot.bar()``. + + positions : array-like, optional + The center positions of the bar groups. The values have to be equidistant. + If not given, a sequence of integer positions 0, 1, 2, ... is used. + + tick_labels : list of str, optional + The category labels, which are placed on ticks at the center *positions* + of the bar groups. If not set, the axis ticks (positions and labels) are + left unchanged. + + labels : list of str, optional + The labels of the datasets, i.e. the bars within one group. + These will show up in the legend. + + group_spacing : float, default: 1.5 + The space between two bar groups as multiples of bar width. + + The default value of 1.5 thus means that there's a gap of + 1.5 bar widths between bar groups. + + bar_spacing : float, default: 0 + The space between bars as multiples of bar width. + + orientation : {"vertical", "horizontal"}, default: "vertical" + The direction of the bars. + + colors : list of :mpltype:`color`, optional + A sequence of colors to be cycled through and used to color bars + of the different datasets. The sequence need not be exactly the + same length as the number of provided y, in which case the colors + will repeat from the beginning. + + If not specified, the colors from the Axes property cycle will be used. + + **kwargs : `.Rectangle` properties + + %(Rectangle:kwdoc)s + + Returns + ------- + _GroupedBarReturn + + A provisional result object. This will be refined in the future. + For now, the guaranteed API on the returned object is limited to + + - the attribute ``bar_containers``, which is a list of + `.BarContainer`, i.e. the results of the individual `~.Axes.bar` + calls for each dataset. + + - a ``remove()`` method, that remove all bars from the Axes. + See also `.Artist.remove()`. + + See Also + -------- + bar : A lower-level API for bar plots, with more degrees of freedom like + individual bar sizes and colors. + + Notes + ----- + For a better understanding, we compare the `~.Axes.grouped_bar` API with + those of `~.Axes.bar` and `~.Axes.boxplot`. + + **Comparison to bar()** + + `~.Axes.grouped_bar` intentionally deviates from the `~.Axes.bar` API in some + aspects. ``bar(x, y)`` is a lower-level API and places bars with height *y* + at explicit positions *x*. It also allows to specify individual bar widths + and colors. This kind of detailed control and flexibility is difficult to + manage and often not needed when plotting multiple datasets as a grouped bar + plot. Therefore, ``grouped_bar`` focusses on the abstraction of bar plots + as visualization of categorical data. + + The following examples may help to transfer from ``bar`` to + ``grouped_bar``. + + Positions are de-emphasized due to categories, and default to integer values. + If you have used ``range(N)`` as positions, you can leave that value out:: + + bar(range(N), heights) + grouped_bar([heights]) + + If needed, positions can be passed as keyword arguments:: + + bar(x, heights) + grouped_bar([heights], positions=x) + + To place category labels in `~.Axes.bar` you could use the argument + *tick_label* or use a list of category names as *x*. + `~.Axes.grouped_bar` expects them in the argument *tick_labels*:: + + bar(range(N), heights, tick_label=["A", "B"]) + bar(["A", "B"], heights) + grouped_bar([heights], tick_labels=["A", "B"]) + + Dataset labels, which are shown in the legend, are still passed via the + *label* parameter:: + + bar(..., label="dataset") + grouped_bar(..., label=["dataset"]) + + **Comparison to boxplot()** + + Both, `~.Axes.grouped_bar` and `~.Axes.boxplot` visualize categorical data + from multiple datasets. The basic API on *tick_labels* and *positions* + is the same, so that you can easily switch between plotting all + individual values as `~.Axes.grouped_bar` or the statistical distribution + per category as `~.Axes.boxplot`:: + + grouped_bar(values, positions=..., tick_labels=...) + boxplot(values, positions=..., tick_labels=...) + + """ + if cbook._is_pandas_dataframe(heights): + if labels is None: + labels = heights.columns.tolist() + if tick_labels is None: + tick_labels = heights.index.tolist() + heights = heights.to_numpy().T + elif hasattr(heights, 'keys'): # dict + if labels is not None: + raise ValueError( + "'labels' cannot be used if 'heights' are a mapping") + labels = heights.keys() + heights = list(heights.values()) + elif hasattr(heights, 'shape'): # numpy array + heights = heights.T + + num_datasets = len(heights) + num_groups = len(next(iter(heights))) # inferred from first dataset + + # validate that all datasets have the same length, i.e. num_groups + # - can be skipped if heights is an array + if not hasattr(heights, 'shape'): + for i, dataset in enumerate(heights): + if len(dataset) != num_groups: + raise ValueError( + "'heights' contains datasets with different number of " + f"elements. dataset 0 has {num_groups} elements but " + f"dataset {i} has {len(dataset)} elements." + ) + + if positions is None: + group_centers = np.arange(num_groups) + group_distance = 1 + else: + group_centers = np.asanyarray(positions) + if len(group_centers) > 1: + d = np.diff(group_centers) + if not np.allclose(d, d.mean()): + raise ValueError("'positions' must be equidistant") + group_distance = d[0] + else: + group_distance = 1 + + _api.check_in_list(["vertical", "horizontal"], orientation=orientation) + + if colors is None: + colors = itertools.cycle([None]) + else: + # Note: This is equivalent to the behavior in stackplot + # TODO: do we want to be more restrictive and check lengths? + colors = itertools.cycle(colors) + + bar_width = (group_distance / + (num_datasets + (num_datasets - 1) * bar_spacing + group_spacing)) + bar_spacing_abs = bar_spacing * bar_width + margin_abs = 0.5 * group_spacing * bar_width + + if labels is None: + labels = [None] * num_datasets + else: + assert len(labels) == num_datasets + + # place the bars, but only use numerical positions, categorical tick labels + # are handled separately below + bar_containers = [] + for i, (hs, label, color) in enumerate( + zip(heights, labels, colors)): + lefts = (group_centers - 0.5 * group_distance + margin_abs + + i * (bar_width + bar_spacing_abs)) + if orientation == "vertical": + bc = self.bar(lefts, hs, width=bar_width, align="edge", + label=label, color=color, **kwargs) + else: + bc = self.barh(lefts, hs, height=bar_width, align="edge", + label=label, color=color, **kwargs) + bar_containers.append(bc) + + if tick_labels is not None: + if orientation == "vertical": + self.xaxis.set_ticks(group_centers, labels=tick_labels) + else: + self.yaxis.set_ticks(group_centers, labels=tick_labels) + + return _GroupedBarReturn(bar_containers) + @_preprocess_data() def stem(self, *args, linefmt=None, markerfmt=None, basefmt=None, bottom=0, label=None, orientation='vertical'): diff --git a/lib/matplotlib/axes/_axes.pyi b/lib/matplotlib/axes/_axes.pyi index 1877cc192b15..75ae4a821ec6 100644 --- a/lib/matplotlib/axes/_axes.pyi +++ b/lib/matplotlib/axes/_axes.pyi @@ -39,6 +39,12 @@ from typing import Any, Literal, overload import numpy as np from numpy.typing import ArrayLike from matplotlib.typing import ColorType, MarkerType, LineStyleType +import pandas as pd + + +class _GroupedBarReturn: + def __init__(self, bar_containers: list[BarContainer]) -> None: ... + def remove(self) -> None: ... class Axes(_AxesBase): def get_title(self, loc: Literal["left", "center", "right"] = ...) -> str: ... @@ -279,6 +285,19 @@ class Axes(_AxesBase): data=..., **kwargs ) -> PolyCollection: ... + def grouped_bar( + self, + heights : Sequence[ArrayLike] | dict[str, ArrayLike] | np.ndarray | pd.DataFrame, + *, + positions : ArrayLike | None = ..., + tick_labels : Sequence[str] | None = ..., + labels : Sequence[str] | None = ..., + group_spacing : float | None = ..., + bar_spacing : float | None = ..., + orientation: Literal["vertical", "horizontal"] = ..., + colors: Iterable[ColorType] | None = ..., + **kwargs + ) -> list[BarContainer]: ... def stem( self, *args: ArrayLike | str, diff --git a/lib/matplotlib/pyplot.py b/lib/matplotlib/pyplot.py index 1e8cf869e6b4..3c1bfdf953b0 100644 --- a/lib/matplotlib/pyplot.py +++ b/lib/matplotlib/pyplot.py @@ -94,6 +94,7 @@ import PIL.Image from numpy.typing import ArrayLike + import pandas as pd import matplotlib.axes import matplotlib.artist @@ -3388,6 +3389,33 @@ def grid( gca().grid(visible=visible, which=which, axis=axis, **kwargs) +# Autogenerated by boilerplate.py. Do not edit as changes will be lost. +@_copy_docstring_and_deprecators(Axes.grouped_bar) +def grouped_bar( + heights: Sequence[ArrayLike] | dict[str, ArrayLike] | np.ndarray | pd.DataFrame, + *, + positions: ArrayLike | None = None, + group_spacing: float | None = 1.5, + bar_spacing: float | None = 0, + tick_labels: Sequence[str] | None = None, + labels: Sequence[str] | None = None, + orientation: Literal["vertical", "horizontal"] = "vertical", + colors: Iterable[ColorType] | None = None, + **kwargs, +) -> list[BarContainer]: + return gca().grouped_bar( + heights, + positions=positions, + group_spacing=group_spacing, + bar_spacing=bar_spacing, + tick_labels=tick_labels, + labels=labels, + orientation=orientation, + colors=colors, + **kwargs, + ) + + # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.hexbin) def hexbin( diff --git a/lib/matplotlib/tests/baseline_images/test_axes/grouped_bar.png b/lib/matplotlib/tests/baseline_images/test_axes/grouped_bar.png new file mode 100644 index 000000000000..19d676a6b662 Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_axes/grouped_bar.png differ diff --git a/lib/matplotlib/tests/test_axes.py b/lib/matplotlib/tests/test_axes.py index 38857e846c55..315f81265390 100644 --- a/lib/matplotlib/tests/test_axes.py +++ b/lib/matplotlib/tests/test_axes.py @@ -2144,6 +2144,91 @@ def test_bar_datetime_start(): assert isinstance(ax.xaxis.get_major_formatter(), mdates.AutoDateFormatter) +@image_comparison(["grouped_bar.png"], style="mpl20") +def test_grouped_bar(): + data = { + 'data1': [1, 2, 3], + 'data2': [1.2, 2.2, 3.2], + 'data3': [1.4, 2.4, 3.4], + } + + fig, ax = plt.subplots() + ax.grouped_bar(data, tick_labels=['A', 'B', 'C'], + group_spacing=0.5, bar_spacing=0.1, + colors=['#1f77b4', '#58a1cf', '#abd0e6']) + ax.set_yticks([]) + + +@check_figures_equal(extensions=["png"]) +def test_grouped_bar_list_of_datasets(fig_test, fig_ref): + categories = ['A', 'B'] + data1 = [1, 1.2] + data2 = [2, 2.4] + data3 = [3, 3.6] + + ax = fig_test.subplots() + ax.grouped_bar([data1, data2, data3], tick_labels=categories, + labels=["data1", "data2", "data3"]) + ax.legend() + + ax = fig_ref.subplots() + label_pos = np.array([0, 1]) + bar_width = 1 / (3 + 1.5) # 3 bars + 1.5 group_spacing + data_shift = -1 * bar_width + np.array([0, bar_width, 2 * bar_width]) + ax.bar(label_pos + data_shift[0], data1, width=bar_width, label="data1") + ax.bar(label_pos + data_shift[1], data2, width=bar_width, label="data2") + ax.bar(label_pos + data_shift[2], data3, width=bar_width, label="data3") + ax.set_xticks(label_pos, categories) + ax.legend() + + +@check_figures_equal(extensions=["png"]) +def test_grouped_bar_dict_of_datasets(fig_test, fig_ref): + categories = ['A', 'B'] + data_dict = dict(data1=[1, 1.2], data2=[2, 2.4], data3=[3, 3.6]) + + ax = fig_test.subplots() + ax.grouped_bar(data_dict, tick_labels=categories) + ax.legend() + + ax = fig_ref.subplots() + ax.grouped_bar(data_dict.values(), tick_labels=categories, labels=data_dict.keys()) + ax.legend() + + +@check_figures_equal(extensions=["png"]) +def test_grouped_bar_array(fig_test, fig_ref): + categories = ['A', 'B'] + array = np.array([[1, 2, 3], [1.2, 2.4, 3.6]]) + labels = ['data1', 'data2', 'data3'] + + ax = fig_test.subplots() + ax.grouped_bar(array, tick_labels=categories, labels=labels) + ax.legend() + + ax = fig_ref.subplots() + list_of_datasets = [column for column in array.T] + ax.grouped_bar(list_of_datasets, tick_labels=categories, labels=labels) + ax.legend() + + +@check_figures_equal(extensions=["png"]) +def test_grouped_bar_dataframe(fig_test, fig_ref, pd): + categories = ['A', 'B'] + labels = ['data1', 'data2', 'data3'] + df = pd.DataFrame([[1, 2, 3], [1.2, 2.4, 3.6]], + index=categories, columns=labels) + + ax = fig_test.subplots() + ax.grouped_bar(df) + ax.legend() + + ax = fig_ref.subplots() + list_of_datasets = [df[col].to_numpy() for col in df.columns] + ax.grouped_bar(list_of_datasets, tick_labels=categories, labels=labels) + ax.legend() + + def test_boxplot_dates_pandas(pd): # smoke test for boxplot and dates in pandas data = np.random.rand(5, 2) diff --git a/tools/boilerplate.py b/tools/boilerplate.py index 962ae899c458..bc8d8d5d98c4 100644 --- a/tools/boilerplate.py +++ b/tools/boilerplate.py @@ -238,6 +238,7 @@ def boilerplate_gen(): 'fill_between', 'fill_betweenx', 'grid', + 'grouped_bar', 'hexbin', 'hist', 'stairs',