Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d17d0f9

Browse files
FIX cross_validate with multimetric scoring returns the non-failed scorers results even if some fail (#23101)
Co-authored-by: Guillaume Lemaitre <[email protected]>
1 parent 6261579 commit d17d0f9

File tree

6 files changed

+122
-27
lines changed

6 files changed

+122
-27
lines changed

doc/whats_new/v1.3.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ Changelog
4848
:class:`ensemble.ExtraTreesClassifier` and :class:`ensemble.ExtraTreesRegressor`.
4949
:pr:`25177` by :user:`Tim Head <betatim>`.
5050

51+
:mod:`sklearn.model_selection`
52+
..............................
53+
- |Fix| :func:`model_selection.cross_validate` with multimetric scoring in
54+
case of some failing scorers the non-failing scorers now returns proper
55+
scores instead of `error_score` values.
56+
:pr:`23101` by :user:`András Simon <simonandras>` and `Thomas Fan`_.
57+
5158
:mod:`sklearn.pipeline`
5259
.......................
5360
- |Feature| :class:`pipeline.FeatureUnion` can now use indexing notation (e.g.

sklearn/inspection/_permutation_importance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def permutation_importance(
252252
scorer = check_scoring(estimator, scoring=scoring)
253253
else:
254254
scorers_dict = _check_multimetric_scoring(estimator, scoring)
255-
scorer = _MultimetricScorer(**scorers_dict)
255+
scorer = _MultimetricScorer(scorers=scorers_dict)
256256

257257
baseline_score = _weights_scorer(scorer, estimator, X, y, sample_weight)
258258

sklearn/metrics/_scorer.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from collections.abc import Iterable
2222
from functools import partial
2323
from collections import Counter
24+
from traceback import format_exc
2425

2526
import numpy as np
2627
import copy
@@ -91,10 +92,16 @@ class _MultimetricScorer:
9192
----------
9293
scorers : dict
9394
Dictionary mapping names to callable scorers.
95+
96+
raise_exc : bool, default=True
97+
Whether to raise the exception in `__call__` or not. If set to `False`
98+
a formatted string of the exception details is passed as result of
99+
the failing scorer.
94100
"""
95101

96-
def __init__(self, **scorers):
102+
def __init__(self, *, scorers, raise_exc=True):
97103
self._scorers = scorers
104+
self._raise_exc = raise_exc
98105

99106
def __call__(self, estimator, *args, **kwargs):
100107
"""Evaluate predicted target values."""
@@ -103,11 +110,18 @@ def __call__(self, estimator, *args, **kwargs):
103110
cached_call = partial(_cached_call, cache)
104111

105112
for name, scorer in self._scorers.items():
106-
if isinstance(scorer, _BaseScorer):
107-
score = scorer._score(cached_call, estimator, *args, **kwargs)
108-
else:
109-
score = scorer(estimator, *args, **kwargs)
110-
scores[name] = score
113+
try:
114+
if isinstance(scorer, _BaseScorer):
115+
score = scorer._score(cached_call, estimator, *args, **kwargs)
116+
else:
117+
score = scorer(estimator, *args, **kwargs)
118+
scores[name] = score
119+
except Exception as e:
120+
if self._raise_exc:
121+
raise e
122+
else:
123+
scores[name] = format_exc()
124+
111125
return scores
112126

113127
def _use_cache(self, estimator):

sklearn/metrics/tests/test_score_objects.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ def test_multimetric_scorer_calls_method_once(
786786
mock_est.classes_ = np.array([0, 1])
787787

788788
scorer_dict = _check_multimetric_scoring(LogisticRegression(), scorers)
789-
multi_scorer = _MultimetricScorer(**scorer_dict)
789+
multi_scorer = _MultimetricScorer(scorers=scorer_dict)
790790
results = multi_scorer(mock_est, X, y)
791791

792792
assert set(scorers) == set(results) # compare dict keys
@@ -813,7 +813,7 @@ def predict_proba(self, X):
813813

814814
scorers = ["roc_auc", "neg_log_loss"]
815815
scorer_dict = _check_multimetric_scoring(clf, scorers)
816-
scorer = _MultimetricScorer(**scorer_dict)
816+
scorer = _MultimetricScorer(scorers=scorer_dict)
817817
scorer(clf, X, y)
818818

819819
assert predict_proba_call_cnt == 1
@@ -836,7 +836,7 @@ def predict(self, X):
836836

837837
scorers = {"neg_mse": "neg_mean_squared_error", "r2": "roc_auc"}
838838
scorer_dict = _check_multimetric_scoring(clf, scorers)
839-
scorer = _MultimetricScorer(**scorer_dict)
839+
scorer = _MultimetricScorer(scorers=scorer_dict)
840840
scorer(clf, X, y)
841841

842842
assert predict_called_cnt == 1
@@ -859,7 +859,7 @@ def test_multimetric_scorer_sanity_check():
859859
clf.fit(X, y)
860860

861861
scorer_dict = _check_multimetric_scoring(clf, scorers)
862-
multi_scorer = _MultimetricScorer(**scorer_dict)
862+
multi_scorer = _MultimetricScorer(scorers=scorer_dict)
863863

864864
result = multi_scorer(clf, X, y)
865865

@@ -873,6 +873,49 @@ def test_multimetric_scorer_sanity_check():
873873
assert_allclose(value, separate_scores[score_name])
874874

875875

876+
@pytest.mark.parametrize("raise_exc", [True, False])
877+
def test_multimetric_scorer_exception_handling(raise_exc):
878+
"""Check that the calling of the `_MultimetricScorer` returns
879+
exception messages in the result dict for the failing scorers
880+
in case of `raise_exc` is `False` and if `raise_exc` is `True`,
881+
then the proper exception is raised.
882+
"""
883+
scorers = {
884+
"failing_1": "neg_mean_squared_log_error",
885+
"non_failing": "neg_median_absolute_error",
886+
"failing_2": "neg_mean_squared_log_error",
887+
}
888+
889+
X, y = make_classification(
890+
n_samples=50, n_features=2, n_redundant=0, random_state=0
891+
)
892+
y *= -1 # neg_mean_squared_log_error fails if y contains negative values
893+
894+
clf = DecisionTreeClassifier().fit(X, y)
895+
896+
scorer_dict = _check_multimetric_scoring(clf, scorers)
897+
multi_scorer = _MultimetricScorer(scorers=scorer_dict, raise_exc=raise_exc)
898+
899+
error_msg = (
900+
"Mean Squared Logarithmic Error cannot be used when targets contain"
901+
" negative values."
902+
)
903+
904+
if raise_exc:
905+
with pytest.raises(ValueError, match=error_msg):
906+
multi_scorer(clf, X, y)
907+
else:
908+
result = multi_scorer(clf, X, y)
909+
910+
exception_message_1 = result["failing_1"]
911+
score = result["non_failing"]
912+
exception_message_2 = result["failing_2"]
913+
914+
assert isinstance(exception_message_1, str) and error_msg in exception_message_1
915+
assert isinstance(score, float)
916+
assert isinstance(exception_message_2, str) and error_msg in exception_message_2
917+
918+
876919
@pytest.mark.parametrize(
877920
"scorer_name, metric",
878921
[

sklearn/model_selection/_validation.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -758,27 +758,45 @@ def _score(estimator, X_test, y_test, scorer, error_score="raise"):
758758
"""
759759
if isinstance(scorer, dict):
760760
# will cache method calls if needed. scorer() returns a dict
761-
scorer = _MultimetricScorer(**scorer)
761+
scorer = _MultimetricScorer(scorers=scorer, raise_exc=(error_score == "raise"))
762762

763763
try:
764764
if y_test is None:
765765
scores = scorer(estimator, X_test)
766766
else:
767767
scores = scorer(estimator, X_test, y_test)
768768
except Exception:
769-
if error_score == "raise":
769+
if isinstance(scorer, _MultimetricScorer):
770+
# If `_MultimetricScorer` raises exception, the `error_score`
771+
# parameter is equal to "raise".
770772
raise
771773
else:
772-
if isinstance(scorer, _MultimetricScorer):
773-
scores = {name: error_score for name in scorer._scorers}
774+
if error_score == "raise":
775+
raise
774776
else:
775777
scores = error_score
776-
warnings.warn(
777-
"Scoring failed. The score on this train-test partition for "
778-
f"these parameters will be set to {error_score}. Details: \n"
779-
f"{format_exc()}",
780-
UserWarning,
781-
)
778+
warnings.warn(
779+
"Scoring failed. The score on this train-test partition for "
780+
f"these parameters will be set to {error_score}. Details: \n"
781+
f"{format_exc()}",
782+
UserWarning,
783+
)
784+
785+
# Check non-raised error messages in `_MultimetricScorer`
786+
if isinstance(scorer, _MultimetricScorer):
787+
exception_messages = [
788+
(name, str_e) for name, str_e in scores.items() if isinstance(str_e, str)
789+
]
790+
if exception_messages:
791+
# error_score != "raise"
792+
for name, str_e in exception_messages:
793+
scores[name] = error_score
794+
warnings.warn(
795+
"Scoring failed. The score on this train-test partition for "
796+
f"these parameters will be set to {error_score}. Details: \n"
797+
f"{str_e}",
798+
UserWarning,
799+
)
782800

783801
error_msg = "scoring must return a number, got %s (%s) instead. (scorer=%s)"
784802
if isinstance(scores, dict):

sklearn/model_selection/tests/test_validation.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2232,15 +2232,22 @@ def test_cross_val_score_failing_scorer(error_score):
22322232
def test_cross_validate_failing_scorer(
22332233
error_score, return_train_score, with_multimetric
22342234
):
2235-
# check that an estimator can fail during scoring in `cross_validate` and
2236-
# that we can optionally replaced it with `error_score`
2235+
# Check that an estimator can fail during scoring in `cross_validate` and
2236+
# that we can optionally replace it with `error_score`. In the multimetric
2237+
# case also check the result of a non-failing scorer where the other scorers
2238+
# are failing.
22372239
X, y = load_iris(return_X_y=True)
22382240
clf = LogisticRegression(max_iter=5).fit(X, y)
22392241

22402242
error_msg = "This scorer is supposed to fail!!!"
22412243
failing_scorer = partial(_failing_scorer, error_msg=error_msg)
22422244
if with_multimetric:
2243-
scoring = {"score_1": failing_scorer, "score_2": failing_scorer}
2245+
non_failing_scorer = make_scorer(mean_squared_error)
2246+
scoring = {
2247+
"score_1": failing_scorer,
2248+
"score_2": non_failing_scorer,
2249+
"score_3": failing_scorer,
2250+
}
22442251
else:
22452252
scoring = failing_scorer
22462253

@@ -2272,9 +2279,15 @@ def test_cross_validate_failing_scorer(
22722279
)
22732280
for key in results:
22742281
if "_score" in key:
2275-
# check the test (and optionally train score) for all
2276-
# scorers that should be assigned to `error_score`.
2277-
assert_allclose(results[key], error_score)
2282+
if "_score_2" in key:
2283+
# check the test (and optionally train) score for the
2284+
# scorer that should be non-failing
2285+
for i in results[key]:
2286+
assert isinstance(i, float)
2287+
else:
2288+
# check the test (and optionally train) score for all
2289+
# scorers that should be assigned to `error_score`.
2290+
assert_allclose(results[key], error_score)
22782291

22792292

22802293
def three_params_scorer(i, j, k):

0 commit comments

Comments
 (0)