Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1212500

Browse files
committed
ENH: Add grouped_bar() method
This is a WIP to implement #24313. It will be updated incrementally. As a first step, I've designed the data and label input API. Feedback is welcome.
1 parent 9b8a8c7 commit 1212500

File tree

2 files changed

+192
-0
lines changed

2 files changed

+192
-0
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""
2+
=================
3+
Grouped bar chart
4+
=================
5+
6+
This example serves to develop and discuss the API. It's geared towards illustating
7+
API usage and design decisions only through the development phase. It's not intended
8+
to go into the final PR in this form.
9+
10+
Case 1: multiple separate datasets
11+
----------------------------------
12+
13+
"""
14+
import matplotlib.pyplot as plt
15+
import numpy as np
16+
17+
x = ['A', 'B']
18+
data1 = [1, 1.2]
19+
data2 = [2, 2.4]
20+
data3 = [3, 3.6]
21+
22+
23+
fig, axs = plt.subplots(1, 2)
24+
25+
# current solution: manual positioning with multiple bar)= calls
26+
label_pos = np.array([0, 1])
27+
bar_width = 0.8 / 3
28+
data_shift = -1*bar_width + np.array([0, bar_width, 2*bar_width])
29+
axs[0].bar(label_pos + data_shift[0], data1, width=bar_width, label="data1")
30+
axs[0].bar(label_pos + data_shift[1], data2, width=bar_width, label="data2")
31+
axs[0].bar(label_pos + data_shift[2], data3, width=bar_width, label="data3")
32+
axs[0].set_xticks(label_pos, x)
33+
axs[0].legend()
34+
35+
# grouped_bar() with list of datasets
36+
# note also that this is a straight-forward generalization of the single-dataset case:
37+
# bar(x, data1, label="data1")
38+
axs[1].grouped_bar(x, [data1, data2, data3], dataset_labels=["data1", "data2", "data3"])
39+
40+
41+
# %%
42+
# Case 1b: multiple datasets as dict
43+
# ----------------------------------
44+
# instead of carrying a list of datasets and a list of dataset labels, users may
45+
# want to organized their datasets in a dict.
46+
47+
datasets = {
48+
'data1': data1,
49+
'data2': data2,
50+
'data3': data3,
51+
}
52+
53+
# %%
54+
# While you can feed keys and values into the above API, it may be convenient to pass
55+
# the whole dict as "data" and automatically extract the labels from the keys:
56+
57+
fig, axs = plt.subplots(1, 2)
58+
59+
# explicitly extract values and labels from a dict and feed to grouped_bar():
60+
axs[0].grouped_bar(x, datasets.values(), dataset_labels=datasets.keys())
61+
# accepting a dict as input
62+
axs[1].grouped_bar(x, datasets)
63+
64+
# %%
65+
# Case 2: 2D array data
66+
# ---------------------
67+
# When receiving a 2D array, we interpret the data as
68+
#
69+
# .. code-block:: none
70+
#
71+
# dataset_0 dataset_1 dataset_2
72+
# x[0]='A' ds0_a ds1_a ds2_a
73+
# x[1]='B' ds0_b ds1_b ds2_b
74+
#
75+
# This is consistent with the standard data science interpretation of instances
76+
# on the vertical and features on the horizontal. And also matches how pandas is
77+
# interpreting rows and columns.
78+
#
79+
# Note that a list of individual datasets and a 2D array behave structurally different,
80+
# i.e. hen turning a list into a numpy array, you have to transpose that array to get
81+
# the correct representation. Those two behave the same::
82+
#
83+
# grouped_bar(x, [data1, data2])
84+
# grouped_bar(x, np.array([data1, data2]).T)
85+
#
86+
# This is a conscious decision, because the commonly understood dimension ordering
87+
# semantics of "list of datasets" and 2D array of datasets is different.
88+
89+
x = ['A', 'B']
90+
data = np.array([
91+
[1, 2, 3],
92+
[1.2, 2.4, 3.6],
93+
])
94+
columns = ["data1", "data2", "data3"]
95+
96+
fig, ax = plt.subplots()
97+
ax.grouped_bar(x, data, dataset_labels=columns)
98+
99+
# %%
100+
# This creates the same plot as pandas (code cannot be executed because pandas
101+
# os not a doc dependency)::
102+
#
103+
# df = pd.DataFrame(data, index=x, columns=columns)
104+
# df.plot.bar()

lib/matplotlib/axes/_axes.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3000,6 +3000,94 @@ def broken_barh(self, xranges, yrange, **kwargs):
30003000

30013001
return col
30023002

3003+
def grouped_bar(self, x, heights, dataset_labels=None):
3004+
"""
3005+
Parameters
3006+
-----------
3007+
x : array-like of str
3008+
The labels.
3009+
heights : list of array-like or dict of array-like or 2D array
3010+
The heights for all x and groups. One of:
3011+
3012+
- list of array-like: A list of datasets, each dataset must have
3013+
``len(x)`` elements.
3014+
3015+
.. code-block:: none
3016+
3017+
x = ['a', 'b']
3018+
group_labels = ['ds0', 'ds1', 'ds2']
3019+
3020+
# group_labels: ds0 ds1 dw2
3021+
heights = [dataset_0, dataset_1, dataset_2]
3022+
3023+
# x[0] x[1]
3024+
dataset_0 = [ds0_a, ds0_b]
3025+
3026+
# x[0] x[1]
3027+
heights = [[ds0_a, ds0_b], # dataset_0
3028+
[ds1_a, ds1_b], # dataset_1
3029+
[ds2_a, ds2_b], # dataset_2
3030+
]
3031+
3032+
- dict of array-like: A names to datasets, each dataset (dict value)
3033+
must have ``len(x)`` elements.
3034+
3035+
group_labels = heights.keys()
3036+
heights = heights.values()
3037+
3038+
- a 2D array: columns map to *x*, columns are the different datasets.
3039+
3040+
.. code-block:: none
3041+
3042+
dataset_0 dataset_1 dataset_2
3043+
x[0]='a' ds0_a ds1_a ds2_a
3044+
x[1]='b' ds0_b ds1_b ds2_b
3045+
3046+
Note that this is consistent with pandas. These two calls produce
3047+
the same bar plot structure::
3048+
3049+
grouped_bar(x, array, group_labels=group_labels)
3050+
pd.DataFrame(array, index=x, columns=group_labels).plot.bar()
3051+
3052+
3053+
An iterable of array-like: The iteration runs over the groups.
3054+
Each individual array-like is the list of label values for that group.
3055+
dataset_labels : array-like of str, optional
3056+
The labels of the datasets.
3057+
"""
3058+
if hasattr(heights, 'keys'):
3059+
if dataset_labels is not None:
3060+
raise ValueError(
3061+
"'dataset_labels' cannot be used if 'heights' are a mapping")
3062+
dataset_labels = heights.keys()
3063+
heights = heights.values()
3064+
elif hasattr(heights, 'shape'):
3065+
heights = heights.T
3066+
3067+
num_labels = len(x)
3068+
num_datasets = len(heights)
3069+
3070+
for dataset in heights:
3071+
assert len(dataset) == num_labels
3072+
3073+
margin = 0.1
3074+
bar_width = (1 - 2 * margin) / num_datasets
3075+
block_centers = np.arange(num_labels)
3076+
3077+
if dataset_labels is None:
3078+
dataset_labels = [None] * num_datasets
3079+
else:
3080+
assert len(dataset_labels) == num_datasets
3081+
3082+
for i, (hs, dataset_label) in enumerate(zip(heights, dataset_labels)):
3083+
lefts = block_centers - 0.5 + margin + i * bar_width
3084+
print(i, x, lefts, hs, dataset_label)
3085+
self.bar(lefts, hs, width=bar_width, align="edge", label=dataset_label)
3086+
3087+
self.xaxis.set_ticks(block_centers, labels=x)
3088+
3089+
# TODO: does not return anything for now
3090+
30033091
@_preprocess_data()
30043092
def stem(self, *args, linefmt=None, markerfmt=None, basefmt=None, bottom=0,
30053093
label=None, orientation='vertical'):

0 commit comments

Comments
 (0)