Thanks to visit codestin.com
Credit goes to github.com

Skip to content

adding adaptive learning rate for minibatch k-means #30051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

BenJourdan
Copy link

Reference Issues/PRs

None

What does this implement/fix? Explain your changes.

This request implements a recent learning rate for minibatch k-means which can be superior to the default learning rate. We implement this with the flag adaptive_lr that defaults to false.

Details can be found in this paper that appeared in ICLR 2023. Extensive experiments can be found in this manuscript - ignore the kernel k-means results. We also added a benchmark that produces the following plot which shows the learning rate is the same or better than the default on dense datasets.

image

Any other comments?

This is a reasonably small code change. We add a flag to the MinibatchKmeans constructor and the _k_means_minibatch.pyx cython file. The learning rate implementation is straightforward. In the benchmarks, it appears to take a few more iterations for the adaptive learning rate to converge, often resulting in better solutions. When we removed early stopping we observed the running time is about the same.

This should be a cleaner version of #30045 (I made a mess since I'm still pretty new to git).

Copy link

github-actions bot commented Oct 12, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 4c795d2. Link to the linter CI: here

@adrinjalali
Copy link
Member

The diff is indeed relatively small, however, the paper is quite recent, and the improvements are marginal. So I'll let @ogrisel , @lorentzenchr , @jeremiedbb, and @GaelVaroquaux weigh in here.

@adrinjalali adrinjalali added the Needs Decision Requires decision label Oct 14, 2024
@ogrisel
Copy link
Member

ogrisel commented Oct 14, 2024

I have similar feelings. Unfortunately, arxiv.org seems to be unresponsive since yesterday for me. I cannot check the benchmark results from the paper.

@BenJourdan could you please add results for full-batch k-means to your plots? I am wondering if this can allow MB-k-means to reach the same scores as full-batch k-means on those problems.

@BenJourdan
Copy link
Author

BenJourdan commented Oct 14, 2024

Here are the results with full-batch k-means added:
results_default_params

If you mess around with the early stopping condition tol, this also affects runtime/performance. It's not exactly apples to apples to compare tol values between the mini-batch and full-batch methods but I imagine it's what users may reach for first if they are worried about runtime. max_no_improvement will also have an effect.

This was with tol=1e-1 for all the algorithms:
results_tol_1e-1

This was with tol=1e-2:
results_tol_1e-2

tol=1e-3:
results_tol_1e-3

tol=1e-4:
results_tol_1e-4

I can add more experiments varying max_no_improvement if that helps.

@BenJourdan BenJourdan force-pushed the feature_mbkm_adaptive_lr branch from 35c5b19 to 158897a Compare October 14, 2024 15:47
@ogrisel
Copy link
Member

ogrisel commented Oct 17, 2024

Thanks for the update. So from those experiments, it appears that the new lr scheme can empirically help MBKMeans close the (smallish) gap with full-batch KMeans in terms of clustering quality while keeping favorable runtimes for datasets with many data points (e.g. MNIST size or larger).

But since the method was recently published, this PR does not technically meet our inclusion criteria, although we could be less strict in cases where this is an incremental improvement of an existing method implemented in scikit-learn.

I will mention this PR at our next monthly meeting.

@BenJourdan
Copy link
Author

What was the verdict @ogrisel?

@adrinjalali
Copy link
Member

There were no clear objections to include this, and I think a few of us are in favor of including it.

@BenJourdan
Copy link
Author

@ogrisel @adrinjalali what happens next? Should I start updating the branch?

@adrinjalali
Copy link
Member

@BenJourdan seems like it.

@BenJourdan BenJourdan force-pushed the feature_mbkm_adaptive_lr branch from 65e49cb to 320a4b4 Compare November 6, 2024 16:06
@BenJourdan
Copy link
Author

Should I keep updating the branch until (if lol) someone gets assigned? Not sure what the convention is.

@gregoryschwartzman
Copy link

Hi @ogrisel, @adrinjalali,

Just checking in—since there were no objections during the meeting and some support for inclusion, would it make sense to remove the Needs Decision label and move toward review/approval?

Please let me know what you'd recommend as the next step. Thanks again for your time and feedback so far!

@adrinjalali adrinjalali removed the Needs Decision Requires decision label Apr 23, 2025
@adrinjalali
Copy link
Member

I've removed "need decision" here. I think we can move forward with this.

@gregoryschwartzman
Copy link

Great! Let us know if we should do anything on our end.

@adrinjalali
Copy link
Member

@antoinebaker would you mind having a look at this PR for a review?

Copy link
Contributor

@antoinebaker antoinebaker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @BenJourdan ! Here a first round of review.

@antoinebaker
Copy link
Contributor

I feel the code could be simplified by first defining a learning rate:

if adaptive_lr:
    lr = sqrt(wsum / wsum_batch)
else:
    lr = wsum / (weight_sums[cluster_idx] + wsum)

and then do the common updates

for feature_idx in range(n_features):
    centers_new[cluster_idx, feature_idx] = (1 - lr) * centers_old[cluster_idx, feature_idx]
for k in range(n_indices):
    sample_idx = indices[k]
    for feature_idx in range(n_features):
        weight_idx =  sample_weight[sample_idx] / wsum
        centers_new[cluster_idx, feature_idx] += lr * weight_idx * X[sample_idx, feature_idx]

Copy link
Author

@BenJourdan BenJourdan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback. I'll add most of your suggestions as are.

Introducing the learning rate to avoid duplicating code is a good idea. However, it's a bit messy since we need to do an optimization that avoids explicitly computing the means of each batch for a given center. I'll have a go at redrafting those parts.

Thanks!

@BenJourdan BenJourdan force-pushed the feature_mbkm_adaptive_lr branch 4 times, most recently from 0bf05ba to 0094c8e Compare May 1, 2025 15:41
@BenJourdan BenJourdan force-pushed the feature_mbkm_adaptive_lr branch from 0094c8e to 13fe041 Compare May 1, 2025 15:46
@adrinjalali
Copy link
Member

@BenJourdan please (at almost all cost) avoid force pushing here. It makes review harder.

@BenJourdan
Copy link
Author

BenJourdan commented May 6, 2025

Sure, sorry about that. I'm still new to using Git.

I see my mistake was rebasing after getting the review. Sorry about that @antoinebaker.

@BenJourdan BenJourdan requested a review from antoinebaker May 7, 2025 10:30
@antoinebaker
Copy link
Contributor

antoinebaker commented May 28, 2025

Introducing the learning rate to avoid duplicating code is a good idea. However, it's a bit messy since we need to do an optimization that avoids explicitly computing the means of each batch for a given center. I'll have a go at redrafting those parts.

Could you elaborate why the suggestion in #30051 (comment) is less efficient ? I'm not sure I get what you mean by "avoids explicitly computing the means of each batch for a given center".

I think it's done in both cases by accumulating inplace the updates. In the above suggestion:

for k in range(n_indices):
    sample_idx = indices[k]
    for feature_idx in range(n_features):
        weight_idx =  sample_weight[sample_idx] / wsum
        centers_new[cluster_idx, feature_idx] += lr * weight_idx * X[sample_idx, feature_idx]

while in main

# Update cluster with new point members
for k in range(n_indices):
sample_idx = indices[k]
for feature_idx in range(n_features):
centers_new[cluster_idx, feature_idx] += X[sample_idx, feature_idx] * sample_weight[sample_idx]

@BenJourdan
Copy link
Author

BenJourdan commented May 29, 2025

Introducing the learning rate to avoid duplicating code is a good idea. However, it's a bit messy since we need to do an optimization that avoids explicitly computing the means of each batch for a given center. I'll have a go at redrafting those parts.

Could you elaborate why the suggestion in #30051 (comment) is less efficient ? I'm not sure I get what you mean by "avoids explicitly computing the means of each batch for a given center".

I think it's done in both cases by accumulating inplace the updates. In the above suggestion:

for k in range(n_indices):
    sample_idx = indices[k]
    for feature_idx in range(n_features):
        weight_idx =  sample_weight[sample_idx] / wsum
        centers_new[cluster_idx, feature_idx] += lr * weight_idx * X[sample_idx, feature_idx]

while in main

# Update cluster with new point members
for k in range(n_indices):
sample_idx = indices[k]
for feature_idx in range(n_features):
centers_new[cluster_idx, feature_idx] += X[sample_idx, feature_idx] * sample_weight[sample_idx]

You were right to question this! I benchmarked your suggestion with what we had and it was the same if not marginally faster. Originally, we tried something similar but kept seeing a speed penalty. I think this was probably because we tried doing something more similar to the explicit scaling and rescaling that main was doing beforehand:

# Undo the previous count-based scaling for this cluster center
for feature_idx in range(n_features):
centers_new[cluster_idx, feature_idx] = centers_old[cluster_idx, feature_idx] * weight_sums[cluster_idx]
# Update cluster with new point members
for k in range(n_indices):
sample_idx = indices[k]
for feature_idx in range(n_features):
centers_new[cluster_idx, feature_idx] += X[sample_idx, feature_idx] * sample_weight[sample_idx]
# Update the count statistics for this center
weight_sums[cluster_idx] += wsum
# Rescale to compute mean of all points (old and new)
alpha = 1 / weight_sums[cluster_idx]
for feature_idx in range(n_features):
centers_new[cluster_idx, feature_idx] *= alpha
else:
# No sample was assigned to this cluster in this batch of data
for feature_idx in range(n_features):
centers_new[cluster_idx, feature_idx] = centers_old[cluster_idx, feature_idx]

I've switched both dense and sparse versions over to using your suggestion and renamed lr to alpha.

@antoinebaker
Copy link
Contributor

Hello @BenJourdan could you mark as resolved the comments/suggestions up to May 1st , which are obsolete because of the force push ? They are for now only marked as "Outdated" and marking them as resolved will make the PR discussion easier to follow I think. Thanks !

@BenJourdan
Copy link
Author

Hello @BenJourdan could you mark as resolved the comments/suggestions up to May 1st , which are obsolete because of the force push ? They are for now only marked as "Outdated" and marking them as resolved will make the PR discussion easier to follow I think. Thanks !

@antoinebaker Think I've marked everything you said as resolved now. I still see the outdated tags but there's a "show resolved" tag next to them.

@antoinebaker
Copy link
Contributor

There seem to be cython linting issues:

/home/circleci/project/sklearn/cluster/_k_means_minibatch.pyx:136:29: E222 multiple spaces after operator
Problems detected by cython-lint, please fix them

@BenJourdan
Copy link
Author

@antoinebaker I've fixed the linting issue and shortened the comment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants