Thanks to visit codestin.com
Credit goes to github.com

Skip to content

FEA add temperature scaling to CalibratedClassifierCV #31068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 64 commits into
base: main
Choose a base branch
from

Conversation

virchan
Copy link
Member

@virchan virchan commented Mar 25, 2025

Reference Issues/PRs

Closes #28574

What does this implement/fix? Explain your changes.

This PR adds temperature scaling to scikit-learn's CalibratedClassifierCV:

Temperature scaling can be enabled by setting method = "temperature" in CalibratedClassifierCV:

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.calibration import CalibratedClassifierCV
from sklearn.svm import LinearSVC

X, y = make_classification(random_state=42)

X_train, X_calib, y_train, y_calib = train_test_split(X, y, random_state=42)

clf = LinearSVC(random_state=42)
clf.fit(X_train, y_train)
cal_clf = CalibratedClassifierCV(clf, method="temperature").fit(X_train, y_train)

This method supports both binary and multi-class classification.

Any other comments?

Cc @adrinjalali, @lorentzenchr in advance.

Copy link

github-actions bot commented Mar 25, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: b0584d6. Link to the linter CI: here

Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A follow-up to my comment on the Array API: I don't think we can support the Array API here, as scipy.optimize.minimize does not appear to support it.

If I missed anything, please let me know—I'd be happy to investigate further.

@virchan virchan marked this pull request as ready for review March 25, 2025 10:55
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. Here is a first pass of feedback:

virchan added 4 commits March 27, 2025 18:14
…fier`.

Updated constructor of `_TemperatureScaling` class.
Updated `test_temperature_scaling` in `test_calibration.py`.
Added `__sklearn_tags__` to `_TemperatureScaling` class.
Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still working on addressing the feedback, but I also wanted to share some findings related to it and provide an update.

Copy link
Member

@lorentzenchr lorentzenchr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I few computational things seem off.

virchan added 2 commits April 25, 2025 22:16
Update `minimize` in `_temperture_scaling` to `minimize.scalar`.
Update `test_calibration.py` to check the optimised inverse temperature is between 0.1 and 10.
Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some CI failures—I'll fix those shortly.

Also considering adding a verbose parameter to CalibratedClassifierCV to optionally display convergence info when optimising the inverse temperature beta.

Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, CI is complaining about missing test coverage for tags and for the errors raised in _fit_calibrator and _CalibratedClassifier.

I'm not quite sure how to handle the tags part, but I think the other two cases are fine, since the tests should cover warnings raised by the user-facing API, rather than by private functions.

I'll work on the convergence warning for minimize_scalar later.

Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updates on Raising Warnings When scipy-optimize.minimize_scalar Doesn't Converge.

minimize_scalar(method = "bounded") raises the following warnings when optimisation is unsuccessful:

  1. "Maximum number of function calls reached". Controled by the "maxiter" in the options dictionary.
  2. "Nan result encountered".

(Reference: here)

For (1), the default value is maxiter=500. In my testing (over about a week of trial and error), setting maxiter=40 or lower may trigger non-convergence. However, reducing the number of iterations negatively affects the performance of the calibrator in otherwise converging cases. So I don't think it's worth lowering maxiter just to surface a warning in rare failure cases.

For (2), the docstring of HalfMultinomialLoss says:

        loss_i = log(sum(exp(raw_pred_{i, k}), k=0..n_classes-1))
                - sum(y_true_{i, k} * raw_pred_{i, k}, k=0..n_classes-1)

As far as I can tell, the only two ways loss_i could become NaN are:

  • if raw_pred contains np.inf, or is entirely -np.inf.
  • if raw_pred contains NaN.

But in either case, scikit-learn's check_array would catch this and raise an error before minimize_scalar even runs.

There's also the edge case where the minimiser lands exactly on a boundary ($\pm 10$), but due to the xtol tolerance in SciPy's optimizer, it ends up returning something like $\pm 9.9999999…$ and still counts as a successful termination.

So overall, I'm leaning towards our current implementation being fine in terms of convergence handling. That said, I might be missing some edge cases or subtle points.

Just to be safe, I've updated the options dictionary in _temperature_scaling to show a warning when convergence fails, and I've also tightened the absolute tolerance xatol for convergence.

Let me know if I've overlooked anything — I'll keep working on it!

@Ball-Man
Copy link

If I can add a minor comment, while testing it I noticed that the learned temperature is treated as a plain attribute (_TemperatureScaling.beta). As a result, even after fitting, the _TemperatureScaling instance is detected as not fitted by check_is_fitted. This is easily solved by treating temperature as a learned parameter instead, with the trailing underscore syntax (e.g. _TemperatureScaling.beta_). I guess this may be superfluous, since it's an internal estimator. Still, the sigmoid calibrator (as a counterexample) does behave appropriately, learning the parameters a_ and b_. Thanks for your contribution.

Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review, @Ball-Man!

I can confirm that scikit-learn's check_is_fitted function raises an error for temperature scaling, but not for sigmoid or isotonic calibration, after fitting on a calibration dataset.

I've updated the attribute as you suggested.

Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realised that # pragma: no cover can be used to raise errors without affecting CI coverage. So I've added it to both _CalibratedClassifier.predict_proba() and _fit_calibrator.

Specifically, if the method parameter is anything other than "sigmoid", "isotonic", or "temperature", the input validation in CalibratedClassifierCV will raise an error before any of these private functions are even called.

Similarly, I was able to raise an error when minimize_scalar fails to converge, without upsetting CI coverage in the absence of a dedicated test case. So this fix felt reasonable.

Let me know your thoughts, I'm happy to keep working on this!

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments otherwise looks good. Thank you for the updates @virchan

calibrators.append(calibrator)

else: # pragma: no cover
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this kind of an error is raised in the methods of the class before calling the private functions then I don't think we need to raise this error here? Was it present previously?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is suggested in #31068 (comment). I also believe it's safer to raise an error in this case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could possibly also add a test like test_fit_calibrator_function_raises and directly invoke this private function to test for the error.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could ask @lorentzenchr to weigh in here again, in light of #31068 (comment)---whether it's worthwhile to also include tests for private functions.

I'm happy to follow whichever approach the team prefers.

Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI passed!

calibrators.append(calibrator)

else: # pragma: no cover
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is suggested in #31068 (comment). I also believe it's safer to raise an error in this case.

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you @virchan

calibrators.append(calibrator)

else: # pragma: no cover
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could possibly also add a test like test_fit_calibrator_function_raises and directly invoke this private function to test for the error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement temperature scaling for (multi-class) calibration
5 participants