Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0e4e418

Browse files
FIX check_estimator fails when validating SGDClassifier with log_loss (#24071)
Co-authored-by: Thomas J. Fan <[email protected]>
1 parent 55b55af commit 0e4e418

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3504,7 +3504,21 @@ def check_decision_proba_consistency(name, estimator_orig):
35043504
# inversions in case of machine level differences.
35053505
a = estimator.predict_proba(X_test)[:, 1].round(decimals=10)
35063506
b = estimator.decision_function(X_test).round(decimals=10)
3507-
assert_array_equal(rankdata(a), rankdata(b))
3507+
3508+
rank_proba, rank_score = rankdata(a), rankdata(b)
3509+
try:
3510+
assert_array_almost_equal(rank_proba, rank_score)
3511+
except AssertionError:
3512+
# Sometimes, the rounding applied on the probabilities will have
3513+
# ties that are not present in the scores because it is
3514+
# numerically more precise. In this case, we relax the test by
3515+
# grouping the decision function scores based on the probability
3516+
# rank and check that the score is monotonically increasing.
3517+
grouped_y_score = np.array(
3518+
[b[rank_proba == group].mean() for group in np.unique(rank_proba)]
3519+
)
3520+
sorted_idx = np.argsort(grouped_y_score)
3521+
assert_array_equal(sorted_idx, np.arange(len(sorted_idx)))
35083522

35093523

35103524
def check_outliers_fit_predict(name, estimator_orig):

sklearn/utils/tests/test_estimator_checks.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from sklearn.utils import all_estimators
3737
from sklearn.exceptions import SkipTestWarning
3838
from sklearn.utils.metaestimators import available_if
39+
from sklearn.utils.estimator_checks import check_decision_proba_consistency
3940
from sklearn.utils._param_validation import Interval, StrOptions
4041

4142
from sklearn.utils.estimator_checks import (
@@ -1159,3 +1160,13 @@ class OutlierDetectorWithConstraint(OutlierDetectorWithoutConstraint):
11591160
detector = OutlierDetectorWithConstraint()
11601161
with raises(AssertionError, match=err_msg):
11611162
check_outlier_contamination(detector.__class__.__name__, detector)
1163+
1164+
1165+
def test_decision_proba_tie_ranking():
1166+
"""Check that in case with some probabilities ties, we relax the
1167+
ranking comparison with the decision function.
1168+
Non-regression test for:
1169+
https://github.com/scikit-learn/scikit-learn/issues/24025
1170+
"""
1171+
estimator = SGDClassifier(loss="log_loss")
1172+
check_decision_proba_consistency("SGDClassifier", estimator)

0 commit comments

Comments
 (0)