-
-
Notifications
You must be signed in to change notification settings - Fork 26k
adding adaptive learning rate for minibatch k-means #30051
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
adding adaptive learning rate for minibatch k-means #30051
Conversation
The diff is indeed relatively small, however, the paper is quite recent, and the improvements are marginal. So I'll let @ogrisel , @lorentzenchr , @jeremiedbb, and @GaelVaroquaux weigh in here. |
I have similar feelings. Unfortunately, arxiv.org seems to be unresponsive since yesterday for me. I cannot check the benchmark results from the paper. @BenJourdan could you please add results for full-batch k-means to your plots? I am wondering if this can allow MB-k-means to reach the same scores as full-batch k-means on those problems. |
35c5b19
to
158897a
Compare
Thanks for the update. So from those experiments, it appears that the new lr scheme can empirically help MBKMeans close the (smallish) gap with full-batch KMeans in terms of clustering quality while keeping favorable runtimes for datasets with many data points (e.g. MNIST size or larger). But since the method was recently published, this PR does not technically meet our inclusion criteria, although we could be less strict in cases where this is an incremental improvement of an existing method implemented in scikit-learn. I will mention this PR at our next monthly meeting. |
What was the verdict @ogrisel? |
There were no clear objections to include this, and I think a few of us are in favor of including it. |
@ogrisel @adrinjalali what happens next? Should I start updating the branch? |
@BenJourdan seems like it. |
65e49cb
to
320a4b4
Compare
Should I keep updating the branch until (if lol) someone gets assigned? Not sure what the convention is. |
Hi @ogrisel, @adrinjalali, Just checking in—since there were no objections during the meeting and some support for inclusion, would it make sense to remove the Needs Decision label and move toward review/approval? Please let me know what you'd recommend as the next step. Thanks again for your time and feedback so far! |
I've removed "need decision" here. I think we can move forward with this. |
Great! Let us know if we should do anything on our end. |
@antoinebaker would you mind having a look at this PR for a review? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @BenJourdan ! Here a first round of review.
I feel the code could be simplified by first defining a learning rate: if adaptive_lr:
lr = sqrt(wsum / wsum_batch)
else:
lr = wsum / (weight_sums[cluster_idx] + wsum) and then do the common updates for feature_idx in range(n_features):
centers_new[cluster_idx, feature_idx] = (1 - lr) * centers_old[cluster_idx, feature_idx]
for k in range(n_indices):
sample_idx = indices[k]
for feature_idx in range(n_features):
weight_idx = sample_weight[sample_idx] / wsum
centers_new[cluster_idx, feature_idx] += lr * weight_idx * X[sample_idx, feature_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback. I'll add most of your suggestions as are.
Introducing the learning rate to avoid duplicating code is a good idea. However, it's a bit messy since we need to do an optimization that avoids explicitly computing the means of each batch for a given center. I'll have a go at redrafting those parts.
Thanks!
0bf05ba
to
0094c8e
Compare
0094c8e
to
13fe041
Compare
@BenJourdan please (at almost all cost) avoid force pushing here. It makes review harder. |
Sure, sorry about that. I'm still new to using Git. I see my mistake was rebasing after getting the review. Sorry about that @antoinebaker. |
Could you elaborate why the suggestion in #30051 (comment) is less efficient ? I'm not sure I get what you mean by "avoids explicitly computing the means of each batch for a given center". I think it's done in both cases by accumulating inplace the updates. In the above suggestion: for k in range(n_indices):
sample_idx = indices[k]
for feature_idx in range(n_features):
weight_idx = sample_weight[sample_idx] / wsum
centers_new[cluster_idx, feature_idx] += lr * weight_idx * X[sample_idx, feature_idx] while in main scikit-learn/sklearn/cluster/_k_means_minibatch.pyx Lines 92 to 96 in a6c2db0
|
Co-authored-by: antoinebaker <[email protected]>
Co-authored-by: antoinebaker <[email protected]>
You were right to question this! I benchmarked your suggestion with what we had and it was the same if not marginally faster. Originally, we tried something similar but kept seeing a speed penalty. I think this was probably because we tried doing something more similar to the explicit scaling and rescaling that main was doing beforehand: scikit-learn/sklearn/cluster/_k_means_minibatch.pyx Lines 88 to 108 in a6c2db0
I've switched both dense and sparse versions over to using your suggestion and renamed |
Hello @BenJourdan could you mark as resolved the comments/suggestions up to May 1st , which are obsolete because of the force push ? They are for now only marked as "Outdated" and marking them as resolved will make the PR discussion easier to follow I think. Thanks ! |
@antoinebaker Think I've marked everything you said as resolved now. I still see the outdated tags but there's a "show resolved" tag next to them. |
There seem to be cython linting issues:
|
@antoinebaker I've fixed the linting issue and shortened the comment. |
Reference Issues/PRs
None
What does this implement/fix? Explain your changes.
This request implements a recent learning rate for minibatch k-means which can be superior to the default learning rate. We implement this with the flag
adaptive_lr
that defaults to false.Details can be found in this paper that appeared in ICLR 2023. Extensive experiments can be found in this manuscript - ignore the kernel k-means results. We also added a benchmark that produces the following plot which shows the learning rate is the same or better than the default on dense datasets.
Any other comments?
This is a reasonably small code change. We add a flag to the MinibatchKmeans constructor and the _k_means_minibatch.pyx cython file. The learning rate implementation is straightforward. In the benchmarks, it appears to take a few more iterations for the adaptive learning rate to converge, often resulting in better solutions. When we removed early stopping we observed the running time is about the same.
This should be a cleaner version of #30045 (I made a mess since I'm still pretty new to git).