Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit bd131f9

Browse files
clewis7kushalkolar
andauthored
add kmeans clustering example (#734)
* add kmeans clustering example * update conf * switch to tool tip * switch to linear interp and 3D camera for kmeans * increase timeout for deploy docs connection * increase log level * requested changes --------- Co-authored-by: Kushal Kolar <[email protected]>
1 parent b564192 commit bd131f9

File tree

3 files changed

+121
-2
lines changed

3 files changed

+121
-2
lines changed

.github/workflows/docs-deploy.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ jobs:
9898
server: ${{ secrets.DOCS_SERVER }}
9999
username: ${{ secrets.DOCS_USERNAME }}
100100
password: ${{ secrets.DOCS_PASSWORD }}
101+
log-level: verbose
102+
timeout: 60000
101103
local-dir: docs/build/html/
102104
server-dir: ./ # deploy to the root dir
103105
exclude: | # don't delete the /ver/ dir

docs/source/conf.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,6 @@
8888
templates_path = ["_templates"]
8989
exclude_patterns = []
9090

91-
napoleon_custom_sections = ["Features"]
92-
9391
# -- Options for HTML output -------------------------------------------------
9492
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
9593

examples/machine_learning/kmeans.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""
2+
K-Means Clustering of MNIST Dataset
3+
===================================
4+
5+
Example showing how you can perform K-Means clustering on the MNIST dataset.
6+
"""
7+
8+
# test_example = false
9+
# sphinx_gallery_pygfx_docs = 'screenshot'
10+
11+
import fastplotlib as fpl
12+
import numpy as np
13+
from sklearn.datasets import load_digits
14+
from sklearn.cluster import KMeans
15+
from sklearn.decomposition import PCA
16+
17+
# load the data
18+
mnist = load_digits()
19+
20+
# get the data and labels
21+
data = mnist['data'] # (1797, 64)
22+
labels = mnist['target'] # (1797,)
23+
24+
# visualize the first 5 digits
25+
# NOTE: this is just to give a sense of the dataset if you are unfamiliar,
26+
# the more interesting visualization is below :D
27+
fig_data = fpl.Figure(shape=(1, 5), size=(900, 300))
28+
29+
# iterate through each subplot
30+
for i, subplot in enumerate(fig_data):
31+
# reshape each image to (8, 8)
32+
subplot.add_image(data[i].reshape(8,8), cmap="gray", interpolation="linear")
33+
# add the label as a title
34+
subplot.set_title(f"Label: {labels[i]}")
35+
# turn off the axes and toolbar
36+
subplot.axes.visible = False
37+
subplot.toolbar = False
38+
39+
fig_data.show()
40+
41+
# project the data from 64 dimensions down to the number of unique digits
42+
n_digits = len(np.unique(labels)) # 10
43+
44+
reduced_data = PCA(n_components=n_digits).fit_transform(data) # (1797, 10)
45+
46+
# performs K-Means clustering, take the best of 4 runs
47+
kmeans = KMeans(n_clusters=n_digits, n_init=4)
48+
# fit the lower-dimension data
49+
kmeans.fit(reduced_data)
50+
51+
# get the centroids (center of the clusters)
52+
centroids = kmeans.cluster_centers_
53+
54+
# plot the kmeans result and corresponding original image
55+
figure = fpl.Figure(
56+
shape=(1,2),
57+
size=(700, 400),
58+
cameras=["3d", "2d"],
59+
controller_types=[["fly", "panzoom"]]
60+
)
61+
62+
# set the axes to False
63+
figure[0, 0].axes.visible = False
64+
figure[0, 1].axes.visible = False
65+
66+
figure[0, 0].set_title(f"K-means clustering of PCA-reduced data")
67+
68+
# plot the centroids
69+
figure[0, 0].add_scatter(
70+
data=np.vstack([centroids[:, 0], centroids[:, 1], centroids[:, 2]]).T,
71+
colors="white",
72+
sizes=15
73+
)
74+
# plot the down-projected data
75+
digit_scatter = figure[0,0].add_scatter(
76+
data=np.vstack([reduced_data[:, 0], reduced_data[:, 1], reduced_data[:, 2]]).T,
77+
sizes=5,
78+
cmap="tab10", # use a qualitative cmap
79+
cmap_transform=kmeans.labels_, # color by the predicted cluster
80+
)
81+
82+
# initial index
83+
ix = 0
84+
85+
# plot the initial image
86+
digit_img = figure[0, 1].add_image(
87+
data=data[ix].reshape(8,8),
88+
cmap="gray",
89+
name="digit",
90+
interpolation="linear"
91+
)
92+
93+
# change the color and size of the initial selected data point
94+
digit_scatter.colors[ix] = "magenta"
95+
digit_scatter.sizes[ix] = 10
96+
97+
# define event handler to update the selected data point
98+
@digit_scatter.add_event_handler("pointer_enter")
99+
def update(ev):
100+
# reset colors and sizes
101+
digit_scatter.cmap = "tab10"
102+
digit_scatter.sizes = 5
103+
104+
# update with new seleciton
105+
ix = ev.pick_info["vertex_index"]
106+
107+
digit_scatter.colors[ix] = "magenta"
108+
digit_scatter.sizes[ix] = 10
109+
110+
# update digit fig
111+
figure[0, 1]["digit"].data = data[ix].reshape(8, 8)
112+
113+
figure.show()
114+
115+
# NOTE: `if __name__ == "__main__"` is NOT how to use fastplotlib interactively
116+
# please see our docs for using fastplotlib interactively in ipython and jupyter
117+
if __name__ == "__main__":
118+
print(__doc__)
119+
fpl.loop.run()

0 commit comments

Comments
 (0)