Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit fd9ac3e

Browse files
jklymakMeeseeksDev[bot]
authored and
MeeseeksDev[bot]
committed
Backport PR #11017: Doc: Adding annotated heatmap example
1 parent edd5854 commit fd9ac3e

File tree

1 file changed

+312
-0
lines changed

1 file changed

+312
-0
lines changed
Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
"""
2+
===========================
3+
Creating annotated heatmaps
4+
===========================
5+
6+
It is often desirable to show data which depends on two independent
7+
variables as a color coded image plot. This is often referred to as a
8+
heatmap. If the data is categorical, this would be called a categorical
9+
heatmap.
10+
Matplotlib's :meth:`imshow <matplotlib.axes.Axes.imshow>` function makes
11+
production of such plots particularly easy.
12+
13+
The following examples show how to create a heatmap with annotations.
14+
We will start with an easy example and expand it to be usable as a
15+
universal function.
16+
"""
17+
18+
19+
##############################################################################
20+
#
21+
# A simple categorical heatmap
22+
# ----------------------------
23+
#
24+
# We may start by defining some data. What we need is a 2D list or array
25+
# which defines the data to color code. We then also need two lists or arrays
26+
# of categories; of course the number of elements in those lists
27+
# need to match the data along the respective axes.
28+
# The heatmap itself is an :meth:`imshow <matplotlib.axes.Axes.imshow>` plot
29+
# with the labels set to the categories we have.
30+
# Note that it is important to set both, the tick locations
31+
# (:meth:`set_xticks<matplotlib.axes.Axes.set_xticks>`) as well as the
32+
# tick labels (:meth:`set_xticklabels<matplotlib.axes.Axes.set_xticklabels>`),
33+
# otherwise they would become out of sync. The locations are just
34+
# the ascending integer numbers, while the ticklabels are the labels to show.
35+
# Finally we can label the data itself by creating a
36+
# :class:`~matplotlib.text.Text` within each cell showing the value of
37+
# that cell.
38+
39+
40+
import numpy as np
41+
import matplotlib
42+
import matplotlib.pyplot as plt
43+
# sphinx_gallery_thumbnail_number = 2
44+
45+
vegetables = ["cucumber", "tomato", "lettuce", "asparagus",
46+
"potato", "wheat", "barley"]
47+
farmers = ["Farmer Joe", "Upland Bros.", "Smith Gardening",
48+
"Agrifun", "Organiculture", "BioGoods Ltd.", "Cornylee Corp."]
49+
50+
harvest = np.array([[0.8, 2.4, 2.5, 3.9, 0.0, 4.0, 0.0],
51+
[2.4, 0.0, 4.0, 1.0, 2.7, 0.0, 0.0],
52+
[1.1, 2.4, 0.8, 4.3, 1.9, 4.4, 0.0],
53+
[0.6, 0.0, 0.3, 0.0, 3.1, 0.0, 0.0],
54+
[0.7, 1.7, 0.6, 2.6, 2.2, 6.2, 0.0],
55+
[1.3, 1.2, 0.0, 0.0, 0.0, 3.2, 5.1],
56+
[0.1, 2.0, 0.0, 1.4, 0.0, 1.9, 6.3]])
57+
58+
59+
fig, ax = plt.subplots()
60+
im = ax.imshow(harvest)
61+
62+
# We want to show all ticks...
63+
ax.set_xticks(np.arange(len(farmers)))
64+
ax.set_yticks(np.arange(len(vegetables)))
65+
# ... and label them with the respective list entries
66+
ax.set_xticklabels(farmers)
67+
ax.set_yticklabels(vegetables)
68+
69+
# Rotate the tick labels and set their alignment.
70+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
71+
rotation_mode="anchor")
72+
73+
# Loop over data dimensions and create text annotations.
74+
for i in range(len(vegetables)):
75+
for j in range(len(farmers)):
76+
text = ax.text(j, i, harvest[i, j],
77+
ha="center", va="center", color="w")
78+
79+
ax.set_title("Harvest of local farmers (in tons/year)")
80+
fig.tight_layout()
81+
plt.show()
82+
83+
84+
#############################################################################
85+
# Using the helper function code style
86+
# ------------------------------------
87+
#
88+
# As discussed in the :ref:`Coding styles <coding_styles>`
89+
# one might want to reuse such code to create some kind of heatmap
90+
# for different input data and/or on different axes.
91+
# We create a function that takes the data and the row and column labels as
92+
# input, and allows arguments that are used to customize the plot
93+
#
94+
# Here, in addition to the above we also want to create a colorbar and
95+
# position the labels above of the heatmap instead of below it.
96+
# The annotations shall get different colors depending on a threshold
97+
# for better contrast against the pixel color.
98+
# Finally, we turn the surrounding axes spines off and create
99+
# a grid of white lines to separate the cells.
100+
101+
102+
def heatmap(data, row_labels, col_labels, ax=None,
103+
cbar_kw={}, cbarlabel="", **kwargs):
104+
"""
105+
Create a heatmap from a numpy array and two lists of labels.
106+
107+
Arguments:
108+
data : A 2D numpy array of shape (N,M)
109+
row_labels : A list or array of length N with the labels
110+
for the rows
111+
col_labels : A list or array of length M with the labels
112+
for the columns
113+
Optional arguments:
114+
ax : A matplotlib.axes.Axes instance to which the heatmap
115+
is plotted. If not provided, use current axes or
116+
create a new one.
117+
cbar_kw : A dictionary with arguments to
118+
:meth:`matplotlib.Figure.colorbar`.
119+
cbarlabel : The label for the colorbar
120+
All other arguments are directly passed on to the imshow call.
121+
"""
122+
123+
if not ax:
124+
ax = plt.gca()
125+
126+
# Plot the heatmap
127+
im = ax.imshow(data, **kwargs)
128+
129+
# Create colorbar
130+
cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
131+
cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
132+
133+
# We want to show all ticks...
134+
ax.set_xticks(np.arange(data.shape[1]))
135+
ax.set_yticks(np.arange(data.shape[0]))
136+
# ... and label them with the respective list entries.
137+
ax.set_xticklabels(col_labels)
138+
ax.set_yticklabels(row_labels)
139+
140+
# Let the horizontal axes labeling appear on top.
141+
ax.tick_params(top=True, bottom=False,
142+
labeltop=True, labelbottom=False)
143+
144+
# Rotate the tick labels and set their alignment.
145+
plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
146+
rotation_mode="anchor")
147+
148+
# Turn spines off and create white grid.
149+
for edge, spine in ax.spines.items():
150+
spine.set_visible(False)
151+
152+
ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
153+
ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
154+
ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
155+
ax.tick_params(which="minor", bottom=False, left=False)
156+
157+
return im, cbar
158+
159+
160+
def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
161+
textcolors=["black", "white"],
162+
threshold=None, **textkw):
163+
"""
164+
A function to annotate a heatmap.
165+
166+
Arguments:
167+
im : The AxesImage to be labeled.
168+
Optional arguments:
169+
data : Data used to annotate. If None, the image's data is used.
170+
valfmt : The format of the annotations inside the heatmap.
171+
This should either use the string format method, e.g.
172+
"$ {x:.2f}", or be a :class:`matplotlib.ticker.Formatter`.
173+
textcolors : A list or array of two color specifications. The first is
174+
used for values below a threshold, the second for those
175+
above.
176+
threshold : Value in data units according to which the colors from
177+
textcolors are applied. If None (the default) uses the
178+
middle of the colormap as separation.
179+
180+
Further arguments are passed on to the created text labels.
181+
"""
182+
183+
if not isinstance(data, (list, np.ndarray)):
184+
data = im.get_array()
185+
186+
# Normalize the threshold to the images color range.
187+
if threshold is not None:
188+
threshold = im.norm(threshold)
189+
else:
190+
threshold = im.norm(data.max())/2.
191+
192+
# Set default alignment to center, but allow it to be
193+
# overwritten by textkw.
194+
kw = dict(horizontalalignment="center",
195+
verticalalignment="center")
196+
kw.update(textkw)
197+
198+
# Get the formatter in case a string is supplied
199+
if isinstance(valfmt, str):
200+
valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
201+
202+
# Loop over the data and create a `Text` for each "pixel".
203+
# Change the text's color depending on the data.
204+
texts = []
205+
for i in range(data.shape[0]):
206+
for j in range(data.shape[1]):
207+
kw.update(color=textcolors[im.norm(data[i, j]) > threshold])
208+
text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
209+
texts.append(text)
210+
211+
return texts
212+
213+
214+
##########################################################################
215+
# The above now allows us to keep the actual plot creation pretty compact.
216+
#
217+
218+
fig, ax = plt.subplots()
219+
220+
im, cbar = heatmap(harvest, vegetables, farmers, ax=ax,
221+
cmap="YlGn", cbarlabel="harvest [t/year]")
222+
texts = annotate_heatmap(im, valfmt="{x:.1f} t")
223+
224+
fig.tight_layout()
225+
plt.show()
226+
227+
228+
#############################################################################
229+
# Some more complex heatmap examples
230+
# ----------------------------------
231+
#
232+
# In the following we show the versitality of the previously created
233+
# functions by applying it in different cases and using different arguments.
234+
#
235+
236+
np.random.seed(19680801)
237+
238+
fig, ((ax, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 6))
239+
240+
# Replicate the above example with a different font size and colormap.
241+
242+
im, _ = heatmap(harvest, vegetables, farmers, ax=ax,
243+
cmap="Wistia", cbarlabel="harvest [t/year]")
244+
annotate_heatmap(im, valfmt="{x:.1f}", size=7)
245+
246+
# Create some new data, give further arguments to imshow (vmin),
247+
# use an integer format on the annotations and provide some colors.
248+
249+
data = np.random.randint(2, 100, size=(7, 7))
250+
y = ["Book {}".format(i) for i in range(1, 8)]
251+
x = ["Store {}".format(i) for i in list("ABCDEFG")]
252+
im, _ = heatmap(data, y, x, ax=ax2, vmin=0,
253+
cmap="magma_r", cbarlabel="weekly sold copies")
254+
annotate_heatmap(im, valfmt="{x:d}", size=7, threshold=20,
255+
textcolors=["red", "white"])
256+
257+
# Sometimes even the data itself is categorical. Here we use a
258+
# :class:`matplotlib.colors.BoundaryNorm` to get the data into classes
259+
# and use this to colorize the plot, but also to obtain the class
260+
# labels from an array of classes.
261+
262+
data = np.random.randn(6, 6)
263+
y = ["Prod. {}".format(i) for i in range(10, 70, 10)]
264+
x = ["Cycle {}".format(i) for i in range(1, 7)]
265+
266+
qrates = np.array(list("ABCDEFG"))
267+
norm = matplotlib.colors.BoundaryNorm(np.linspace(-3.5, 3.5, 8), 7)
268+
fmt = matplotlib.ticker.FuncFormatter(lambda x, pos: qrates[::-1][norm(x)])
269+
270+
im, _ = heatmap(data, y, x, ax=ax3,
271+
cmap=plt.get_cmap("PiYG", 7), norm=norm,
272+
cbar_kw=dict(ticks=np.arange(-3, 4), format=fmt),
273+
cbarlabel="Quality Rating")
274+
275+
annotate_heatmap(im, valfmt=fmt, size=9, fontweight="bold", threshold=-1,
276+
textcolors=["red", "black"])
277+
278+
# We can nicely plot a correlation matrix. Since this is bound by -1 and 1,
279+
# we use those as vmin and vmax. We may also remove leading zeros and hide
280+
# the diagonal elements (which are all 1) by using a
281+
# :class:`matplotlib.ticker.FuncFormatter`.
282+
283+
corr_matrix = np.corrcoef(np.random.rand(6, 5))
284+
im, _ = heatmap(corr_matrix, vegetables, vegetables, ax=ax4,
285+
cmap="PuOr", vmin=-1, vmax=1,
286+
cbarlabel="correlation coeff.")
287+
288+
289+
def func(x, pos):
290+
return "{:.2f}".format(x).replace("0.", ".").replace("1.00", "")
291+
292+
annotate_heatmap(im, valfmt=matplotlib.ticker.FuncFormatter(func), size=7)
293+
294+
295+
plt.tight_layout()
296+
plt.show()
297+
298+
299+
#############################################################################
300+
#
301+
# ------------
302+
#
303+
# References
304+
# """"""""""
305+
#
306+
# The usage of the following functions and methods is shown in this example:
307+
308+
309+
matplotlib.axes.Axes.imshow
310+
matplotlib.pyplot.imshow
311+
matplotlib.figure.Figure.colorbar
312+
matplotlib.pyplot.colorbar

0 commit comments

Comments
 (0)