-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Array API support for k-means #26585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Here is a gist with I think a rather comprehensive implementation of the lloyd algorithm with pytorch: https://gist.github.com/fcharras/ce1f1df7d15675268827e1fb9b65265b Scroll down to the bottom of the file for a quick tester. This function is almost a drop-in replacement for the lloyd private function in the scikit-learn implementation. For k-means, running this on gpu does offer a speedup with respect to current scikit-learn performance (in the x2 - x5 range I think ?), but in the realm of gpu implementations of kmeans it is a bit underwhelming. Both the "pairwise distance + min lookup" and "weight multiplication + centroid update" steps should be fused rather than being called separately with materializing intermediate results in memory (because the read/write in memory is the bottleneck), that would offer another x2 - x5 speedup. I'll post a similar gist for the KNN. For the KNN it's a different story, to my knowledge the best brute force implementations require materializing the pairwise distance matrix in memory and can't go farther than the IO bottleneck, so the speedup one can get is more limited, and the pytorch implementation should be decently close from the best you can get. |
It would be interesting to see if |
I think it would be interesting to write a special-cased version of the PyTorch implementation for k-means without sample weights because in practice very few people have non-None sample weights. Then it would be nice to have a summary table of run results for various implementations and hardware on the same data. Maybe let's try a dataset just small enough so that cuML does not crash: rapidsai/cuml#5470. |
A triton implementation of https://github.com/soda-inria/sklearn-numba-dpex/tree/main/sklearn_numba_dpex/kmeans would also be interesting to compare against. |
This is an early issue to publicly discuss the possibility (or not) to use the Array API (see #22352) for k-means and make it run on GPUs using PyTorch in particular.
@fcharras has already started to run some promising experiments using the raw PyTorch API. Maybe you could link to a gist with your code?
Unfortunately, the current state of the Array API is likely too limiting because AFAIK it does not yet expose the equivalent of
torch.cdist
,torch.expand
andtorch.scatter_add_
.The purpose of this issue is to precisely identify what is blocking us with the current state of Array API and discuss potential solutions:
The text was updated successfully, but these errors were encountered: