From 6940129590622325d1cd7fa2303a0494e1844884 Mon Sep 17 00:00:00 2001 From: rprkh Date: Thu, 13 Oct 2022 21:41:11 +0530 Subject: [PATCH 1/4] improve colorblind friendliness --- .../plot_agglomerative_clustering_metrics.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/examples/cluster/plot_agglomerative_clustering_metrics.py b/examples/cluster/plot_agglomerative_clustering_metrics.py index 38fd3682d48ec..665d41d0ad1d1 100644 --- a/examples/cluster/plot_agglomerative_clustering_metrics.py +++ b/examples/cluster/plot_agglomerative_clustering_metrics.py @@ -80,18 +80,20 @@ def sqr(x): labels = ("Waveform 1", "Waveform 2", "Waveform 3") +colors = ["#f7bd01", "#377eb8", "#f781bf"] + # Plot the ground-truth labelling plt.figure() plt.axes([0, 0, 1, 1]) -for l, c, n in zip(range(n_clusters), "rgb", labels): - lines = plt.plot(X[y == l].T, c=c, alpha=0.5) +for l, color, n in zip(range(n_clusters), colors, labels): + lines = plt.plot(X[y == l].T, c=color, alpha=0.5) lines[0].set_label(n) plt.legend(loc="best") plt.axis("tight") plt.axis("off") -plt.suptitle("Ground truth", size=20) +plt.suptitle("Ground truth", size=20, y=1) # Plot the distances @@ -114,27 +116,27 @@ def sqr(x): horizontalalignment="center", ) - plt.imshow(avg_dist, interpolation="nearest", cmap=plt.cm.gnuplot2, vmin=0) + plt.imshow(avg_dist, interpolation="nearest", cmap="cividis", vmin=0) plt.xticks(range(n_clusters), labels, rotation=45) plt.yticks(range(n_clusters), labels) plt.colorbar() - plt.suptitle("Interclass %s distances" % metric, size=18) + plt.suptitle("Interclass %s distances" % metric, size=18, y=1) plt.tight_layout() # Plot clustering results for index, metric in enumerate(["cosine", "euclidean", "cityblock"]): model = AgglomerativeClustering( - n_clusters=n_clusters, linkage="average", metric=metric + n_clusters=n_clusters, linkage="average", affinity=metric ) model.fit(X) plt.figure() plt.axes([0, 0, 1, 1]) - for l, c in zip(np.arange(model.n_clusters), "rgbk"): - plt.plot(X[model.labels_ == l].T, c=c, alpha=0.5) + for l, color in zip(np.arange(model.n_clusters), colors): + plt.plot(X[model.labels_ == l].T, c=color, alpha=0.5) plt.axis("tight") plt.axis("off") - plt.suptitle("AgglomerativeClustering(metric=%s)" % metric, size=20) + plt.suptitle("AgglomerativeClustering(metric=%s)" % metric, size=20, y=1) plt.show() From b7b085b65697cd3f3a0a23ab939e192caec0f72e Mon Sep 17 00:00:00 2001 From: rprkh Date: Sat, 15 Oct 2022 17:56:40 +0530 Subject: [PATCH 2/4] patheffect border around text --- examples/cluster/plot_agglomerative_clustering_metrics.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/cluster/plot_agglomerative_clustering_metrics.py b/examples/cluster/plot_agglomerative_clustering_metrics.py index 665d41d0ad1d1..84db84b435b73 100644 --- a/examples/cluster/plot_agglomerative_clustering_metrics.py +++ b/examples/cluster/plot_agglomerative_clustering_metrics.py @@ -38,6 +38,7 @@ # License: BSD 3-Clause or CC-0 import matplotlib.pyplot as plt +import matplotlib.patheffects as PathEffects import numpy as np from sklearn.cluster import AgglomerativeClustering @@ -108,13 +109,14 @@ def sqr(x): avg_dist /= avg_dist.max() for i in range(n_clusters): for j in range(n_clusters): - plt.text( + t = plt.text( i, j, "%5.3f" % avg_dist[i, j], verticalalignment="center", horizontalalignment="center", ) + t.set_path_effects([PathEffects.withStroke(linewidth=5, foreground='w', alpha=0.5)]) plt.imshow(avg_dist, interpolation="nearest", cmap="cividis", vmin=0) plt.xticks(range(n_clusters), labels, rotation=45) From d40a65a88e1547cdaf53f684be60fc4ebc07316b Mon Sep 17 00:00:00 2001 From: rprkh Date: Sat, 15 Oct 2022 18:00:14 +0530 Subject: [PATCH 3/4] lint --- examples/cluster/plot_agglomerative_clustering_metrics.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/cluster/plot_agglomerative_clustering_metrics.py b/examples/cluster/plot_agglomerative_clustering_metrics.py index 84db84b435b73..22021f37db92c 100644 --- a/examples/cluster/plot_agglomerative_clustering_metrics.py +++ b/examples/cluster/plot_agglomerative_clustering_metrics.py @@ -116,7 +116,9 @@ def sqr(x): verticalalignment="center", horizontalalignment="center", ) - t.set_path_effects([PathEffects.withStroke(linewidth=5, foreground='w', alpha=0.5)]) + t.set_path_effects( + [PathEffects.withStroke(linewidth=5, foreground="w", alpha=0.5)] + ) plt.imshow(avg_dist, interpolation="nearest", cmap="cividis", vmin=0) plt.xticks(range(n_clusters), labels, rotation=45) From d761ab592151c37ce03edb9c86c80b675fa0386c Mon Sep 17 00:00:00 2001 From: rprkh Date: Sat, 15 Oct 2022 18:32:34 +0530 Subject: [PATCH 4/4] fix depracation warning --- examples/cluster/plot_agglomerative_clustering_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cluster/plot_agglomerative_clustering_metrics.py b/examples/cluster/plot_agglomerative_clustering_metrics.py index 22021f37db92c..f1a77d442dbe8 100644 --- a/examples/cluster/plot_agglomerative_clustering_metrics.py +++ b/examples/cluster/plot_agglomerative_clustering_metrics.py @@ -131,7 +131,7 @@ def sqr(x): # Plot clustering results for index, metric in enumerate(["cosine", "euclidean", "cityblock"]): model = AgglomerativeClustering( - n_clusters=n_clusters, linkage="average", affinity=metric + n_clusters=n_clusters, linkage="average", metric=metric ) model.fit(X) plt.figure()