Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e0b81c4

Browse files
add example: annotated heatmap
1 parent e2a0813 commit e0b81c4

File tree

1 file changed

+313
-0
lines changed

1 file changed

+313
-0
lines changed
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
"""
2+
===========================
3+
Creating annotated heatmaps
4+
===========================
5+
6+
It is often desireable to show data which depends on two independent
7+
variables as a colorcoded image plot. This is often referred to as a
8+
heatmap. If the data is categorical, this would be called a categorical
9+
heatmap.
10+
Matplotlib's :meth:`imshow <matplotlib.axes.Axes.imshow>` function makes
11+
production of such plot particularly easy.
12+
13+
The following examples show how to create a heatmap with annotations.
14+
We will start with an easy example and expand it to be usable as a
15+
universal function.
16+
"""
17+
18+
# sphinx_gallery_thumbnail_number = 2
19+
20+
##############################################################################
21+
#
22+
# A simple categorical heatmap
23+
# ----------------------------
24+
#
25+
# We may start by defining some data. What we need is a 2D list or array
26+
# which defines the data to color code. We then also need two lists or arrays
27+
# of categories; of course the number of elements in those lists
28+
# need to match the data along the respective axes.
29+
# The heatmap itself is an :meth:`imshow <matplotlib.axes.Axes.imshow>` plot
30+
# with the labels set to the categories we have.
31+
# Note that it is important to set both, the tick locations
32+
# (:meth:`set_xticks<matplotlib.axes.Axes.set_xticks>`) as well as the
33+
# tick labels (:meth:`set_xticklabels<matplotlib.axes.Axes.set_xticklabels>`),
34+
# otherwise they would become out of sync. The locations are just
35+
# the ascending integer numbers, while the ticklabels are the labels to show.
36+
# Finally we can label the data itself by creating a
37+
# :class:`~matplotlib.text.Text` within each cell showing the value of
38+
# that cell.
39+
#
40+
41+
import numpy as np
42+
import matplotlib
43+
import matplotlib.pyplot as plt
44+
45+
vegetables = ["cucumber", "tomato", "lettuce", "asparagus",
46+
"potato", "wheat", "barley"]
47+
farmers = ["Farmer Joe", "Upland Bros.", "Smith Gardening",
48+
"Agrifun", "Organiculture", "BioGoods Ltd.", "Cornylee Corp."]
49+
50+
harvest = np.array([[0.8, 2.4, 2.5, 3.9, 0.0, 4.0, 0.0],
51+
[2.4, 0.0, 4.0, 1.0, 2.7, 0.0, 0.0],
52+
[1.1, 2.4, 0.8, 4.3, 1.9, 4.4, 0.0],
53+
[0.6, 0.0, 0.3, 0.0, 3.1, 0.0, 0.0],
54+
[0.7, 1.7, 0.6, 2.6, 2.2, 6.2, 0.0],
55+
[1.3, 1.2, 0.0, 0.0, 0.0, 3.2, 5.1],
56+
[0.1, 2.0, 0.0, 1.4, 0.0, 1.9, 6.3]])
57+
58+
59+
fig, ax = plt.subplots()
60+
im = ax.imshow(harvest)
61+
62+
# We want to show all ticks...
63+
ax.set_xticks(np.arange(len(farmers)))
64+
ax.set_yticks(np.arange(len(vegetables)))
65+
# ... and label them with the respective list entries
66+
ax.set_xticklabels(farmers)
67+
ax.set_yticklabels(vegetables)
68+
69+
# Rotate the tick labels and set their alignment.
70+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
71+
rotation_mode="anchor")
72+
73+
# Loop over data dimensions and create text annotations.
74+
for i in range(len(vegetables)):
75+
for j in range(len(farmers)):
76+
text = ax.text(j, i, harvest[i, j],
77+
ha="center", va="center", color="w")
78+
79+
ax.set_title("Harvest of local farmers (in tons/year)")
80+
fig.tight_layout()
81+
plt.show()
82+
83+
84+
#############################################################################
85+
# Using the helper function code style
86+
# ------------------------------------
87+
#
88+
# As discussed in the :ref:`Coding styles <coding_styles>`
89+
# one might want to reuse such code to create some kind of heatmap
90+
# for different input data and/or on different axes.
91+
# We might hence create a function which takes the data and the row and
92+
# column labels as input and allows arguments, which
93+
# may be used to customize the output.
94+
#
95+
# Here, in addition to the above we also want to create a colorbar and
96+
# position the labels above of the heatmap instead of below it.
97+
# The annotations shall
98+
# get different colors depending on a threshold for better contrast against
99+
# the pixel color. Finally, we turn the surrounding axes spines off and create
100+
# a grid of white lines to separate the cells.
101+
102+
103+
def heatmap(data, row_labels, col_labels, ax=None,
104+
cbar_kw={}, cbarlabel="", **kwargs):
105+
"""
106+
Create a heatmap from a numpy array and two lists of labels.
107+
108+
Arguments:
109+
data : A 2D numpy array of shape (N,M)
110+
row_labels : A list or array of length N with the labels
111+
for the rows
112+
col_labels : A list or array of length M with the labels
113+
for the columns
114+
Optional arguments:
115+
ax : A matplotlib.axes.Axes instance to which the heatmap
116+
is plotted. If not provided, use current axes or
117+
create a new one.
118+
cbar_kw : A dictionary with arguments to
119+
:meth:`matplotlib.Figure.colorbar`.
120+
cbarlabel : The label for the colorbar
121+
All other arguments are directly passed on to the imshow call.
122+
"""
123+
124+
if not ax:
125+
ax = plt.gca()
126+
127+
# Plot the heatmap
128+
im = ax.imshow(data, **kwargs)
129+
130+
# Create colorbar
131+
cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
132+
cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
133+
134+
# We want to show all ticks...
135+
ax.set_xticks(np.arange(data.shape[1]))
136+
ax.set_yticks(np.arange(data.shape[0]))
137+
# ... and label them with the respective list entries.
138+
ax.set_xticklabels(col_labels)
139+
ax.set_yticklabels(row_labels)
140+
141+
# Let the horizontal axes labeling appear on top.
142+
ax.tick_params(top=True, bottom=False,
143+
labeltop=True, labelbottom=False)
144+
145+
# Rotate the tick labels and set their alignment.
146+
plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
147+
rotation_mode="anchor")
148+
149+
# Turn spines off and create white grid.
150+
for edge, spine in ax.spines.items():
151+
spine.set_visible(False)
152+
153+
ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
154+
ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
155+
ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
156+
ax.tick_params(which="minor", bottom=False, left=False)
157+
158+
return im, cbar
159+
160+
161+
def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
162+
textcolors=["black", "white"],
163+
threshold=None, **textkw):
164+
"""
165+
A function to annotate a heatmap.
166+
167+
Arguments:
168+
im : The AxesImage to be labeled.
169+
Optional arguments:
170+
data : Data used to annotate. If None, the image's data is used.
171+
valfmt : The format of the annotations inside the heatmap.
172+
This should either use the string format method, e.g.
173+
"$ {x:.2f}", or be a :class:`matplotlib.ticker.Formatter`.
174+
textcolors : A list or array of two color specifications. The first is
175+
used for values below a threshold, the second for those
176+
above.
177+
threshold : Value in data units according to which the colors from
178+
textcolors are applied. If None (the default) uses the
179+
middle of the colormap as separation.
180+
181+
Further arguments are passed on to the created text labels.
182+
"""
183+
184+
if not isinstance(data, (list, np.ndarray)):
185+
data = im.get_array()
186+
187+
# Normalize the threshold to the images color range.
188+
if not threshold and threshold != 0.0:
189+
threshold = im.norm(data.max())/2.
190+
else:
191+
threshold = im.norm(threshold)
192+
193+
# Set default alignment to center, but allow it to be
194+
# overwritten by textkw.
195+
kw = dict(horizontalalignment="center",
196+
verticalalignment="center")
197+
kw.update(textkw)
198+
199+
# Get the formatter in case a string is supplied
200+
if isinstance(valfmt, str):
201+
valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
202+
203+
# Loop over the data and create a `Text` for each "pixel".
204+
# Change the text's color depending on the data.
205+
texts = []
206+
for i in range(data.shape[0]):
207+
for j in range(data.shape[1]):
208+
kw.update(dict(color=textcolors[im.norm(data[i, j]) > threshold]))
209+
text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
210+
texts.append(text)
211+
212+
return texts
213+
214+
215+
##########################################################################
216+
# The above now allows us to keep the actual plot creation pretty compact.
217+
#
218+
219+
fig, ax = plt.subplots()
220+
221+
im, cbar = heatmap(harvest, vegetables, farmers, ax=ax,
222+
cmap="YlGn", cbarlabel="harvest [t/year]")
223+
texts = annotate_heatmap(im, valfmt="{x:.1f} t")
224+
225+
fig.tight_layout()
226+
plt.show()
227+
228+
229+
#############################################################################
230+
# Some more complex heatmap examples
231+
# ----------------------------------
232+
#
233+
# In the following we show the versitality of the previously created
234+
# functions by applying it in different cases and using different arguments.
235+
#
236+
237+
np.random.seed(19680801)
238+
239+
fig, ((ax, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 6))
240+
241+
# Replicate the above example with a different font size and colormap.
242+
243+
im, _ = heatmap(harvest, vegetables, farmers, ax=ax,
244+
cmap="Wistia", cbarlabel="harvest [t/year]")
245+
annotate_heatmap(im, valfmt="{x:.1f}", size=7)
246+
247+
# Create some new data, give further arguments to imshow (vmin),
248+
# use an integer format on the annotations and provide some colors.
249+
250+
data = np.random.randint(2, 100, size=(7, 7))
251+
y = ["Book {}".format(i) for i in range(1, 8)]
252+
x = ["Store {}".format(i) for i in list("ABCDEFG")]
253+
im, _ = heatmap(data, y, x, ax=ax2, vmin=0,
254+
cmap="magma_r", cbarlabel="weekly sold copies")
255+
annotate_heatmap(im, valfmt="{x:d}", size=7, threshold=20,
256+
textcolors=["red", "white"])
257+
258+
# Sometimes even the data itself is categorical. Here we use a
259+
# :class:`matplotlib.colors.BoundaryNorm` to get the data into classes
260+
# and use this to colorize the plot, but also to obtain the class
261+
# labels from an array of classes.
262+
263+
data = np.random.randn(6, 6)
264+
y = ["Prod. {}".format(i) for i in range(10, 70, 10)]
265+
x = ["Cycle {}".format(i) for i in range(1, 7)]
266+
267+
qrates = np.array(list("ABCDEFG"))
268+
norm = matplotlib.colors.BoundaryNorm(np.linspace(-3.5, 3.5, 8), 7)
269+
fmt = matplotlib.ticker.FuncFormatter(lambda x, pos: qrates[::-1][norm(x)])
270+
271+
im, _ = heatmap(data, y, x, ax=ax3,
272+
cmap=plt.get_cmap("PiYG", 7), norm=norm,
273+
cbar_kw=dict(ticks=np.arange(-3, 4), format=fmt),
274+
cbarlabel="Quality Rating")
275+
276+
annotate_heatmap(im, valfmt=fmt, size=9, fontweight="bold", threshold=-1,
277+
textcolors=["red", "black"])
278+
279+
# We can nicely plot a correlation matrix. Since this is bound by -1 and 1,
280+
# we use those as vmin and vmax. We may also remove leading zeros and hide
281+
# the diagonal elements (which are all 1) by using a
282+
# :class:`matplotlib.ticker.FuncFormatter`.
283+
284+
corr_matrix = np.corrcoef(np.random.rand(6, 5))
285+
im, _ = heatmap(corr_matrix, vegetables, vegetables, ax=ax4,
286+
cmap="PuOr", vmin=-1, vmax=1,
287+
cbarlabel="correlation coeff.")
288+
289+
290+
def func(x, pos):
291+
return "{:.2f}".format(x).replace("0.", ".").replace("1.00", "")
292+
293+
annotate_heatmap(im, valfmt=matplotlib.ticker.FuncFormatter(func), size=7)
294+
295+
296+
plt.tight_layout()
297+
plt.show()
298+
299+
300+
#############################################################################
301+
#
302+
# ------------
303+
#
304+
# References
305+
# """"""""""
306+
#
307+
# The usage of the following functions and methods is shown in this example:
308+
309+
310+
matplotlib.axes.Axes.imshow
311+
matplotlib.pyplot.imshow
312+
matplotlib.figure.Figure.colorbar
313+
matplotlib.pyplot.colorbar

0 commit comments

Comments
 (0)