Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Array API support for k-means #26585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ogrisel opened this issue Jun 15, 2023 · 4 comments
Open

Array API support for k-means #26585

ogrisel opened this issue Jun 15, 2023 · 4 comments

Comments

@ogrisel
Copy link
Member

ogrisel commented Jun 15, 2023

This is an early issue to publicly discuss the possibility (or not) to use the Array API (see #22352) for k-means and make it run on GPUs using PyTorch in particular.

@fcharras has already started to run some promising experiments using the raw PyTorch API. Maybe you could link to a gist with your code?

Unfortunately, the current state of the Array API is likely too limiting because AFAIK it does not yet expose the equivalent of torch.cdist, torch.expand and torch.scatter_add_.

The purpose of this issue is to precisely identify what is blocking us with the current state of Array API and discuss potential solutions:

  • use this use case to report to the Array API standardization committee what are our needs to make the spec evolve and benefit everybody;
  • alternatively, explore the use of multi-dispatch system such as uarray that is being adopted in scipy to make it possible to maintain a pytorch-specific optimized code path as an alternative to a slower yet generic Array API code path and numpy-optimized code path that would rely on our current Cython code,
  • decide that the estimator-level engine API proposed in [DRAFT] Engine plugin API and engine entry point for Lloyd's KMeans #25535 is the only sane way to make this estimator GPU (which I now doubt personally).
@fcharras
Copy link
Contributor

fcharras commented Jun 20, 2023

Here is a gist with I think a rather comprehensive implementation of the lloyd algorithm with pytorch:

https://gist.github.com/fcharras/ce1f1df7d15675268827e1fb9b65265b

Scroll down to the bottom of the file for a quick tester.

This function is almost a drop-in replacement for the lloyd private function in the scikit-learn implementation.

For k-means, running this on gpu does offer a speedup with respect to current scikit-learn performance (in the x2 - x5 range I think ?), but in the realm of gpu implementations of kmeans it is a bit underwhelming. Both the "pairwise distance + min lookup" and "weight multiplication + centroid update" steps should be fused rather than being called separately with materializing intermediate results in memory (because the read/write in memory is the bottleneck), that would offer another x2 - x5 speedup.

I'll post a similar gist for the KNN. For the KNN it's a different story, to my knowledge the best brute force implementations require materializing the pairwise distance matrix in memory and can't go farther than the IO bottleneck, so the speedup one can get is more limited, and the pytorch implementation should be decently close from the best you can get.

@ogrisel
Copy link
Member Author

ogrisel commented Jul 13, 2023

It would be interesting to see if torch.compile could help for _min_over_pairwise_distance.

@ogrisel
Copy link
Member Author

ogrisel commented Jul 13, 2023

I think it would be interesting to write a special-cased version of the PyTorch implementation for k-means without sample weights because in practice very few people have non-None sample weights.

Then it would be nice to have a summary table of run results for various implementations and hardware on the same data.

Maybe let's try a dataset just small enough so that cuML does not crash: rapidsai/cuml#5470.

@ogrisel
Copy link
Member Author

ogrisel commented Jul 13, 2023

A triton implementation of https://github.com/soda-inria/sklearn-numba-dpex/tree/main/sklearn_numba_dpex/kmeans would also be interesting to compare against.

@glemaitre glemaitre moved this to Todo in Array API May 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Todo
Development

No branches or pull requests

2 participants