1
+ """
2
+ K-Means Clustering of MNIST Dataset
3
+ ===================================
4
+
5
+ Example showing how you can perform K-Means clustering on the MNIST dataset.
6
+ """
7
+
8
+ # test_example = false
9
+ # sphinx_gallery_pygfx_docs = 'screenshot'
10
+
11
+ import fastplotlib as fpl
12
+ import numpy as np
13
+ from sklearn .datasets import load_digits
14
+ from sklearn .cluster import KMeans
15
+ from sklearn .decomposition import PCA
16
+
17
+ # load the data
18
+ mnist = load_digits ()
19
+
20
+ # get the data and labels
21
+ data = mnist ['data' ] # (1797, 64)
22
+ labels = mnist ['target' ] # (1797,)
23
+
24
+ # visualize the first 5 digits
25
+ # NOTE: this is just to give a sense of the dataset if you are unfamiliar,
26
+ # the more interesting visualization is below :D
27
+ fig_data = fpl .Figure (shape = (1 , 5 ), size = (900 , 300 ))
28
+
29
+ # iterate through each subplot
30
+ for i , subplot in enumerate (fig_data ):
31
+ # reshape each image to (8, 8)
32
+ subplot .add_image (data [i ].reshape (8 ,8 ), cmap = "gray" , interpolation = "linear" )
33
+ # add the label as a title
34
+ subplot .set_title (f"Label: { labels [i ]} " )
35
+ # turn off the axes and toolbar
36
+ subplot .axes .visible = False
37
+ subplot .toolbar = False
38
+
39
+ fig_data .show ()
40
+
41
+ # project the data from 64 dimensions down to the number of unique digits
42
+ n_digits = len (np .unique (labels )) # 10
43
+
44
+ reduced_data = PCA (n_components = n_digits ).fit_transform (data ) # (1797, 10)
45
+
46
+ # performs K-Means clustering, take the best of 4 runs
47
+ kmeans = KMeans (n_clusters = n_digits , n_init = 4 )
48
+ # fit the lower-dimension data
49
+ kmeans .fit (reduced_data )
50
+
51
+ # get the centroids (center of the clusters)
52
+ centroids = kmeans .cluster_centers_
53
+
54
+ # plot the kmeans result and corresponding original image
55
+ figure = fpl .Figure (
56
+ shape = (1 ,2 ),
57
+ size = (700 , 400 ),
58
+ cameras = ["3d" , "2d" ],
59
+ controller_types = [["fly" , "panzoom" ]]
60
+ )
61
+
62
+ # set the axes to False
63
+ figure [0 , 0 ].axes .visible = False
64
+ figure [0 , 1 ].axes .visible = False
65
+
66
+ figure [0 , 0 ].set_title (f"K-means clustering of PCA-reduced data" )
67
+
68
+ # plot the centroids
69
+ figure [0 , 0 ].add_scatter (
70
+ data = np .vstack ([centroids [:, 0 ], centroids [:, 1 ], centroids [:, 2 ]]).T ,
71
+ colors = "white" ,
72
+ sizes = 15
73
+ )
74
+ # plot the down-projected data
75
+ digit_scatter = figure [0 ,0 ].add_scatter (
76
+ data = np .vstack ([reduced_data [:, 0 ], reduced_data [:, 1 ], reduced_data [:, 2 ]]).T ,
77
+ sizes = 5 ,
78
+ cmap = "tab10" , # use a qualitative cmap
79
+ cmap_transform = kmeans .labels_ , # color by the predicted cluster
80
+ )
81
+
82
+ # initial index
83
+ ix = 0
84
+
85
+ # plot the initial image
86
+ digit_img = figure [0 , 1 ].add_image (
87
+ data = data [ix ].reshape (8 ,8 ),
88
+ cmap = "gray" ,
89
+ name = "digit" ,
90
+ interpolation = "linear"
91
+ )
92
+
93
+ # change the color and size of the initial selected data point
94
+ digit_scatter .colors [ix ] = "magenta"
95
+ digit_scatter .sizes [ix ] = 10
96
+
97
+ # define event handler to update the selected data point
98
+ @digit_scatter .add_event_handler ("pointer_enter" )
99
+ def update (ev ):
100
+ # reset colors and sizes
101
+ digit_scatter .cmap = "tab10"
102
+ digit_scatter .sizes = 5
103
+
104
+ # update with new seleciton
105
+ ix = ev .pick_info ["vertex_index" ]
106
+
107
+ digit_scatter .colors [ix ] = "magenta"
108
+ digit_scatter .sizes [ix ] = 10
109
+
110
+ # update digit fig
111
+ figure [0 , 1 ]["digit" ].data = data [ix ].reshape (8 , 8 )
112
+
113
+ figure .show ()
114
+
115
+ # NOTE: `if __name__ == "__main__"` is NOT how to use fastplotlib interactively
116
+ # please see our docs for using fastplotlib interactively in ipython and jupyter
117
+ if __name__ == "__main__" :
118
+ print (__doc__ )
119
+ fpl .loop .run ()
0 commit comments