diff --git a/examples/cluster/plot_agglomerative_clustering_metrics.py b/examples/cluster/plot_agglomerative_clustering_metrics.py index 38fd3682d48ec..f1a77d442dbe8 100644 --- a/examples/cluster/plot_agglomerative_clustering_metrics.py +++ b/examples/cluster/plot_agglomerative_clustering_metrics.py @@ -38,6 +38,7 @@ # License: BSD 3-Clause or CC-0 import matplotlib.pyplot as plt +import matplotlib.patheffects as PathEffects import numpy as np from sklearn.cluster import AgglomerativeClustering @@ -80,18 +81,20 @@ def sqr(x): labels = ("Waveform 1", "Waveform 2", "Waveform 3") +colors = ["#f7bd01", "#377eb8", "#f781bf"] + # Plot the ground-truth labelling plt.figure() plt.axes([0, 0, 1, 1]) -for l, c, n in zip(range(n_clusters), "rgb", labels): - lines = plt.plot(X[y == l].T, c=c, alpha=0.5) +for l, color, n in zip(range(n_clusters), colors, labels): + lines = plt.plot(X[y == l].T, c=color, alpha=0.5) lines[0].set_label(n) plt.legend(loc="best") plt.axis("tight") plt.axis("off") -plt.suptitle("Ground truth", size=20) +plt.suptitle("Ground truth", size=20, y=1) # Plot the distances @@ -106,19 +109,22 @@ def sqr(x): avg_dist /= avg_dist.max() for i in range(n_clusters): for j in range(n_clusters): - plt.text( + t = plt.text( i, j, "%5.3f" % avg_dist[i, j], verticalalignment="center", horizontalalignment="center", ) + t.set_path_effects( + [PathEffects.withStroke(linewidth=5, foreground="w", alpha=0.5)] + ) - plt.imshow(avg_dist, interpolation="nearest", cmap=plt.cm.gnuplot2, vmin=0) + plt.imshow(avg_dist, interpolation="nearest", cmap="cividis", vmin=0) plt.xticks(range(n_clusters), labels, rotation=45) plt.yticks(range(n_clusters), labels) plt.colorbar() - plt.suptitle("Interclass %s distances" % metric, size=18) + plt.suptitle("Interclass %s distances" % metric, size=18, y=1) plt.tight_layout() @@ -130,11 +136,11 @@ def sqr(x): model.fit(X) plt.figure() plt.axes([0, 0, 1, 1]) - for l, c in zip(np.arange(model.n_clusters), "rgbk"): - plt.plot(X[model.labels_ == l].T, c=c, alpha=0.5) + for l, color in zip(np.arange(model.n_clusters), colors): + plt.plot(X[model.labels_ == l].T, c=color, alpha=0.5) plt.axis("tight") plt.axis("off") - plt.suptitle("AgglomerativeClustering(metric=%s)" % metric, size=20) + plt.suptitle("AgglomerativeClustering(metric=%s)" % metric, size=20, y=1) plt.show()