|
| 1 | +""" |
| 2 | +=========================== |
| 3 | +Creating annotated heatmaps |
| 4 | +=========================== |
| 5 | +
|
| 6 | +It is often desirable to show data which depends on two independent |
| 7 | +variables as a color coded image plot. This is often referred to as a |
| 8 | +heatmap. If the data is categorical, this would be called a categorical |
| 9 | +heatmap. |
| 10 | +Matplotlib's :meth:`imshow <matplotlib.axes.Axes.imshow>` function makes |
| 11 | +production of such plots particularly easy. |
| 12 | +
|
| 13 | +The following examples show how to create a heatmap with annotations. |
| 14 | +We will start with an easy example and expand it to be usable as a |
| 15 | +universal function. |
| 16 | +""" |
| 17 | + |
| 18 | + |
| 19 | +############################################################################## |
| 20 | +# |
| 21 | +# A simple categorical heatmap |
| 22 | +# ---------------------------- |
| 23 | +# |
| 24 | +# We may start by defining some data. What we need is a 2D list or array |
| 25 | +# which defines the data to color code. We then also need two lists or arrays |
| 26 | +# of categories; of course the number of elements in those lists |
| 27 | +# need to match the data along the respective axes. |
| 28 | +# The heatmap itself is an :meth:`imshow <matplotlib.axes.Axes.imshow>` plot |
| 29 | +# with the labels set to the categories we have. |
| 30 | +# Note that it is important to set both, the tick locations |
| 31 | +# (:meth:`set_xticks<matplotlib.axes.Axes.set_xticks>`) as well as the |
| 32 | +# tick labels (:meth:`set_xticklabels<matplotlib.axes.Axes.set_xticklabels>`), |
| 33 | +# otherwise they would become out of sync. The locations are just |
| 34 | +# the ascending integer numbers, while the ticklabels are the labels to show. |
| 35 | +# Finally we can label the data itself by creating a |
| 36 | +# :class:`~matplotlib.text.Text` within each cell showing the value of |
| 37 | +# that cell. |
| 38 | +# sphinx_gallery_thumbnail_number = 2 |
| 39 | + |
| 40 | +import numpy as np |
| 41 | +import matplotlib |
| 42 | +import matplotlib.pyplot as plt |
| 43 | + |
| 44 | +vegetables = ["cucumber", "tomato", "lettuce", "asparagus", |
| 45 | + "potato", "wheat", "barley"] |
| 46 | +farmers = ["Farmer Joe", "Upland Bros.", "Smith Gardening", |
| 47 | + "Agrifun", "Organiculture", "BioGoods Ltd.", "Cornylee Corp."] |
| 48 | + |
| 49 | +harvest = np.array([[0.8, 2.4, 2.5, 3.9, 0.0, 4.0, 0.0], |
| 50 | + [2.4, 0.0, 4.0, 1.0, 2.7, 0.0, 0.0], |
| 51 | + [1.1, 2.4, 0.8, 4.3, 1.9, 4.4, 0.0], |
| 52 | + [0.6, 0.0, 0.3, 0.0, 3.1, 0.0, 0.0], |
| 53 | + [0.7, 1.7, 0.6, 2.6, 2.2, 6.2, 0.0], |
| 54 | + [1.3, 1.2, 0.0, 0.0, 0.0, 3.2, 5.1], |
| 55 | + [0.1, 2.0, 0.0, 1.4, 0.0, 1.9, 6.3]]) |
| 56 | + |
| 57 | + |
| 58 | +fig, ax = plt.subplots() |
| 59 | +im = ax.imshow(harvest) |
| 60 | + |
| 61 | +# We want to show all ticks... |
| 62 | +ax.set_xticks(np.arange(len(farmers))) |
| 63 | +ax.set_yticks(np.arange(len(vegetables))) |
| 64 | +# ... and label them with the respective list entries |
| 65 | +ax.set_xticklabels(farmers) |
| 66 | +ax.set_yticklabels(vegetables) |
| 67 | + |
| 68 | +# Rotate the tick labels and set their alignment. |
| 69 | +plt.setp(ax.get_xticklabels(), rotation=45, ha="right", |
| 70 | + rotation_mode="anchor") |
| 71 | + |
| 72 | +# Loop over data dimensions and create text annotations. |
| 73 | +for i in range(len(vegetables)): |
| 74 | + for j in range(len(farmers)): |
| 75 | + text = ax.text(j, i, harvest[i, j], |
| 76 | + ha="center", va="center", color="w") |
| 77 | + |
| 78 | +ax.set_title("Harvest of local farmers (in tons/year)") |
| 79 | +fig.tight_layout() |
| 80 | +plt.show() |
| 81 | + |
| 82 | + |
| 83 | +############################################################################# |
| 84 | +# Using the helper function code style |
| 85 | +# ------------------------------------ |
| 86 | +# |
| 87 | +# As discussed in the :ref:`Coding styles <coding_styles>` |
| 88 | +# one might want to reuse such code to create some kind of heatmap |
| 89 | +# for different input data and/or on different axes. |
| 90 | +# We create a function that takes the data and the row and column labels as |
| 91 | +# input, and allows arguments that are used to customize the plot |
| 92 | +# |
| 93 | +# Here, in addition to the above we also want to create a colorbar and |
| 94 | +# position the labels above of the heatmap instead of below it. |
| 95 | +# The annotations shall get different colors depending on a threshold |
| 96 | +# for better contrast against the pixel color. |
| 97 | +# Finally, we turn the surrounding axes spines off and create |
| 98 | +# a grid of white lines to separate the cells. |
| 99 | + |
| 100 | + |
| 101 | +def heatmap(data, row_labels, col_labels, ax=None, |
| 102 | + cbar_kw={}, cbarlabel="", **kwargs): |
| 103 | + """ |
| 104 | + Create a heatmap from a numpy array and two lists of labels. |
| 105 | +
|
| 106 | + Arguments: |
| 107 | + data : A 2D numpy array of shape (N,M) |
| 108 | + row_labels : A list or array of length N with the labels |
| 109 | + for the rows |
| 110 | + col_labels : A list or array of length M with the labels |
| 111 | + for the columns |
| 112 | + Optional arguments: |
| 113 | + ax : A matplotlib.axes.Axes instance to which the heatmap |
| 114 | + is plotted. If not provided, use current axes or |
| 115 | + create a new one. |
| 116 | + cbar_kw : A dictionary with arguments to |
| 117 | + :meth:`matplotlib.Figure.colorbar`. |
| 118 | + cbarlabel : The label for the colorbar |
| 119 | + All other arguments are directly passed on to the imshow call. |
| 120 | + """ |
| 121 | + |
| 122 | + if not ax: |
| 123 | + ax = plt.gca() |
| 124 | + |
| 125 | + # Plot the heatmap |
| 126 | + im = ax.imshow(data, **kwargs) |
| 127 | + |
| 128 | + # Create colorbar |
| 129 | + cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw) |
| 130 | + cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") |
| 131 | + |
| 132 | + # We want to show all ticks... |
| 133 | + ax.set_xticks(np.arange(data.shape[1])) |
| 134 | + ax.set_yticks(np.arange(data.shape[0])) |
| 135 | + # ... and label them with the respective list entries. |
| 136 | + ax.set_xticklabels(col_labels) |
| 137 | + ax.set_yticklabels(row_labels) |
| 138 | + |
| 139 | + # Let the horizontal axes labeling appear on top. |
| 140 | + ax.tick_params(top=True, bottom=False, |
| 141 | + labeltop=True, labelbottom=False) |
| 142 | + |
| 143 | + # Rotate the tick labels and set their alignment. |
| 144 | + plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", |
| 145 | + rotation_mode="anchor") |
| 146 | + |
| 147 | + # Turn spines off and create white grid. |
| 148 | + for edge, spine in ax.spines.items(): |
| 149 | + spine.set_visible(False) |
| 150 | + |
| 151 | + ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True) |
| 152 | + ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True) |
| 153 | + ax.grid(which="minor", color="w", linestyle='-', linewidth=3) |
| 154 | + ax.tick_params(which="minor", bottom=False, left=False) |
| 155 | + |
| 156 | + return im, cbar |
| 157 | + |
| 158 | + |
| 159 | +def annotate_heatmap(im, data=None, valfmt="{x:.2f}", |
| 160 | + textcolors=["black", "white"], |
| 161 | + threshold=None, **textkw): |
| 162 | + """ |
| 163 | + A function to annotate a heatmap. |
| 164 | +
|
| 165 | + Arguments: |
| 166 | + im : The AxesImage to be labeled. |
| 167 | + Optional arguments: |
| 168 | + data : Data used to annotate. If None, the image's data is used. |
| 169 | + valfmt : The format of the annotations inside the heatmap. |
| 170 | + This should either use the string format method, e.g. |
| 171 | + "$ {x:.2f}", or be a :class:`matplotlib.ticker.Formatter`. |
| 172 | + textcolors : A list or array of two color specifications. The first is |
| 173 | + used for values below a threshold, the second for those |
| 174 | + above. |
| 175 | + threshold : Value in data units according to which the colors from |
| 176 | + textcolors are applied. If None (the default) uses the |
| 177 | + middle of the colormap as separation. |
| 178 | +
|
| 179 | + Further arguments are passed on to the created text labels. |
| 180 | + """ |
| 181 | + |
| 182 | + if not isinstance(data, (list, np.ndarray)): |
| 183 | + data = im.get_array() |
| 184 | + |
| 185 | + # Normalize the threshold to the images color range. |
| 186 | + if threshold is not None: |
| 187 | + threshold = im.norm(threshold) |
| 188 | + else: |
| 189 | + threshold = im.norm(data.max())/2. |
| 190 | + |
| 191 | + # Set default alignment to center, but allow it to be |
| 192 | + # overwritten by textkw. |
| 193 | + kw = dict(horizontalalignment="center", |
| 194 | + verticalalignment="center") |
| 195 | + kw.update(textkw) |
| 196 | + |
| 197 | + # Get the formatter in case a string is supplied |
| 198 | + if isinstance(valfmt, str): |
| 199 | + valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) |
| 200 | + |
| 201 | + # Loop over the data and create a `Text` for each "pixel". |
| 202 | + # Change the text's color depending on the data. |
| 203 | + texts = [] |
| 204 | + for i in range(data.shape[0]): |
| 205 | + for j in range(data.shape[1]): |
| 206 | + kw.update(color=textcolors[im.norm(data[i, j]) > threshold]) |
| 207 | + text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) |
| 208 | + texts.append(text) |
| 209 | + |
| 210 | + return texts |
| 211 | + |
| 212 | + |
| 213 | +########################################################################## |
| 214 | +# The above now allows us to keep the actual plot creation pretty compact. |
| 215 | +# |
| 216 | + |
| 217 | +fig, ax = plt.subplots() |
| 218 | + |
| 219 | +im, cbar = heatmap(harvest, vegetables, farmers, ax=ax, |
| 220 | + cmap="YlGn", cbarlabel="harvest [t/year]") |
| 221 | +texts = annotate_heatmap(im, valfmt="{x:.1f} t") |
| 222 | + |
| 223 | +fig.tight_layout() |
| 224 | +plt.show() |
| 225 | + |
| 226 | + |
| 227 | +############################################################################# |
| 228 | +# Some more complex heatmap examples |
| 229 | +# ---------------------------------- |
| 230 | +# |
| 231 | +# In the following we show the versitality of the previously created |
| 232 | +# functions by applying it in different cases and using different arguments. |
| 233 | +# |
| 234 | + |
| 235 | +np.random.seed(19680801) |
| 236 | + |
| 237 | +fig, ((ax, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 6)) |
| 238 | + |
| 239 | +# Replicate the above example with a different font size and colormap. |
| 240 | + |
| 241 | +im, _ = heatmap(harvest, vegetables, farmers, ax=ax, |
| 242 | + cmap="Wistia", cbarlabel="harvest [t/year]") |
| 243 | +annotate_heatmap(im, valfmt="{x:.1f}", size=7) |
| 244 | + |
| 245 | +# Create some new data, give further arguments to imshow (vmin), |
| 246 | +# use an integer format on the annotations and provide some colors. |
| 247 | + |
| 248 | +data = np.random.randint(2, 100, size=(7, 7)) |
| 249 | +y = ["Book {}".format(i) for i in range(1, 8)] |
| 250 | +x = ["Store {}".format(i) for i in list("ABCDEFG")] |
| 251 | +im, _ = heatmap(data, y, x, ax=ax2, vmin=0, |
| 252 | + cmap="magma_r", cbarlabel="weekly sold copies") |
| 253 | +annotate_heatmap(im, valfmt="{x:d}", size=7, threshold=20, |
| 254 | + textcolors=["red", "white"]) |
| 255 | + |
| 256 | +# Sometimes even the data itself is categorical. Here we use a |
| 257 | +# :class:`matplotlib.colors.BoundaryNorm` to get the data into classes |
| 258 | +# and use this to colorize the plot, but also to obtain the class |
| 259 | +# labels from an array of classes. |
| 260 | + |
| 261 | +data = np.random.randn(6, 6) |
| 262 | +y = ["Prod. {}".format(i) for i in range(10, 70, 10)] |
| 263 | +x = ["Cycle {}".format(i) for i in range(1, 7)] |
| 264 | + |
| 265 | +qrates = np.array(list("ABCDEFG")) |
| 266 | +norm = matplotlib.colors.BoundaryNorm(np.linspace(-3.5, 3.5, 8), 7) |
| 267 | +fmt = matplotlib.ticker.FuncFormatter(lambda x, pos: qrates[::-1][norm(x)]) |
| 268 | + |
| 269 | +im, _ = heatmap(data, y, x, ax=ax3, |
| 270 | + cmap=plt.get_cmap("PiYG", 7), norm=norm, |
| 271 | + cbar_kw=dict(ticks=np.arange(-3, 4), format=fmt), |
| 272 | + cbarlabel="Quality Rating") |
| 273 | + |
| 274 | +annotate_heatmap(im, valfmt=fmt, size=9, fontweight="bold", threshold=-1, |
| 275 | + textcolors=["red", "black"]) |
| 276 | + |
| 277 | +# We can nicely plot a correlation matrix. Since this is bound by -1 and 1, |
| 278 | +# we use those as vmin and vmax. We may also remove leading zeros and hide |
| 279 | +# the diagonal elements (which are all 1) by using a |
| 280 | +# :class:`matplotlib.ticker.FuncFormatter`. |
| 281 | + |
| 282 | +corr_matrix = np.corrcoef(np.random.rand(6, 5)) |
| 283 | +im, _ = heatmap(corr_matrix, vegetables, vegetables, ax=ax4, |
| 284 | + cmap="PuOr", vmin=-1, vmax=1, |
| 285 | + cbarlabel="correlation coeff.") |
| 286 | + |
| 287 | + |
| 288 | +def func(x, pos): |
| 289 | + return "{:.2f}".format(x).replace("0.", ".").replace("1.00", "") |
| 290 | + |
| 291 | +annotate_heatmap(im, valfmt=matplotlib.ticker.FuncFormatter(func), size=7) |
| 292 | + |
| 293 | + |
| 294 | +plt.tight_layout() |
| 295 | +plt.show() |
| 296 | + |
| 297 | + |
| 298 | +############################################################################# |
| 299 | +# |
| 300 | +# ------------ |
| 301 | +# |
| 302 | +# References |
| 303 | +# """""""""" |
| 304 | +# |
| 305 | +# The usage of the following functions and methods is shown in this example: |
| 306 | + |
| 307 | + |
| 308 | +matplotlib.axes.Axes.imshow |
| 309 | +matplotlib.pyplot.imshow |
| 310 | +matplotlib.figure.Figure.colorbar |
| 311 | +matplotlib.pyplot.colorbar |
0 commit comments