diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index 303015cc9f751..58f489316187a 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -31,6 +31,7 @@ from ..utils.metaestimators import if_delegate_has_method from ..utils.multiclass import check_classification_targets from ..utils.validation import check_is_fitted +from ..utils.validation import _check_response_method from ..utils.validation import column_or_1d from ..utils.validation import _deprecate_positional_args from ..utils.fixes import delayed @@ -96,18 +97,14 @@ def _concatenate_predictions(self, X, predictions): def _method_name(name, estimator, method): if estimator == 'drop': return None - if method == 'auto': - if getattr(estimator, 'predict_proba', None): - return 'predict_proba' - elif getattr(estimator, 'decision_function', None): - return 'decision_function' - else: - return 'predict' - else: - if not hasattr(estimator, method): - raise ValueError('Underlying estimator {} does not implement ' - 'the method {}.'.format(name, method)) - return method + method = None if method == "auto" else method + try: + method_name = _check_response_method(estimator, method).__name__ + except ValueError as e: + raise ValueError( + f"stack_method {method} not defined in {name}" + ) from e + return method_name def fit(self, X, y, sample_weight=None): """Fit the estimators. diff --git a/sklearn/ensemble/tests/test_stacking.py b/sklearn/ensemble/tests/test_stacking.py index d6b4c385b9073..08bf9f891df15 100644 --- a/sklearn/ensemble/tests/test_stacking.py +++ b/sklearn/ensemble/tests/test_stacking.py @@ -287,7 +287,7 @@ def fit(self, X, y): {'estimators': [('lr', LogisticRegression()), ('svm', SVC(max_iter=5e4))], 'stack_method': 'predict_proba'}, - ValueError, 'does not implement the method predict_proba'), + ValueError, 'stack_method predict_proba not defined in svm'), (y_iris, {'estimators': [('lr', LogisticRegression()), ('cor', NoWeightClassifier())]}, diff --git a/sklearn/inspection/_partial_dependence.py b/sklearn/inspection/_partial_dependence.py index 1e9c0c9718a51..0b0444824e0bd 100644 --- a/sklearn/inspection/_partial_dependence.py +++ b/sklearn/inspection/_partial_dependence.py @@ -22,7 +22,10 @@ from ..utils import _get_column_indices from ..utils.validation import check_is_fitted from ..utils import Bunch -from ..utils.validation import _deprecate_positional_args +from ..utils.validation import ( + _check_response_method, + _deprecate_positional_args, +) from ..tree import DecisionTreeRegressor from ..ensemble import RandomForestRegressor from ..exceptions import NotFittedError @@ -120,29 +123,7 @@ def _partial_dependence_brute(est, grid, features, X, response_method): predictions = [] averaged_predictions = [] - # define the prediction_method (predict, predict_proba, decision_function). - if is_regressor(est): - prediction_method = est.predict - else: - predict_proba = getattr(est, 'predict_proba', None) - decision_function = getattr(est, 'decision_function', None) - if response_method == 'auto': - # try predict_proba, then decision_function if it doesn't exist - prediction_method = predict_proba or decision_function - else: - prediction_method = (predict_proba if response_method == - 'predict_proba' else decision_function) - if prediction_method is None: - if response_method == 'auto': - raise ValueError( - 'The estimator has no predict_proba and no ' - 'decision_function method.' - ) - elif response_method == 'predict_proba': - raise ValueError('The estimator has no predict_proba method.') - else: - raise ValueError( - 'The estimator has no decision_function method.') + prediction_method = _check_response_method(est, response_method) for new_values in grid: X_eval = X.copy() @@ -406,11 +387,15 @@ def partial_dependence(estimator, X, features, *, response_method='auto', 'response_method {} is invalid. Accepted response_method names ' 'are {}.'.format(response_method, ', '.join(accepted_responses))) - if is_regressor(estimator) and response_method != 'auto': - raise ValueError( - "The response_method parameter is ignored for regressors and " - "must be 'auto'." - ) + if is_regressor(estimator): + if response_method != "auto": + raise ValueError( + "The response_method parameter is ignored for regressors and " + "must be 'auto'." + ) + response_method = "predict" + elif response_method == "auto": + response_method = ["predict_proba", "decision_function"] accepted_methods = ('brute', 'recursion', 'auto') if method not in accepted_methods: @@ -454,10 +439,10 @@ def partial_dependence(estimator, X, features, *, response_method='auto', "Only the following estimators support the 'recursion' " "method: {}. Try using method='brute'." .format(', '.join(supported_classes_recursion))) - if response_method == 'auto': + if isinstance(response_method, list): response_method = 'decision_function' - if response_method != 'decision_function': + if is_classifier(estimator) and response_method != 'decision_function': raise ValueError( "With the 'recursion' method, the response_method must be " "'decision_function'. Got {}.".format(response_method) diff --git a/sklearn/inspection/tests/test_partial_dependence.py b/sklearn/inspection/tests/test_partial_dependence.py index 997c61c0e5f8b..0ee4b33908f9d 100644 --- a/sklearn/inspection/tests/test_partial_dependence.py +++ b/sklearn/inspection/tests/test_partial_dependence.py @@ -212,8 +212,9 @@ def test_partial_dependence_helpers(est, method, target_feature): [123]]) if method == 'brute': - pdp, predictions = _partial_dependence_brute(est, grid, features, X, - response_method='auto') + pdp, predictions = _partial_dependence_brute( + est, grid, features, X, response_method='predict' + ) else: pdp = _partial_dependence_recursion(est, grid, features) @@ -415,13 +416,13 @@ def fit(self, X, y): 'response_method blahblah is invalid. Accepted response_method'), (NoPredictProbaNoDecisionFunction(), {'features': [0], 'response_method': 'auto'}, - 'The estimator has no predict_proba and no decision_function method'), + 'response_method predict_proba, decision_function not defined'), (NoPredictProbaNoDecisionFunction(), {'features': [0], 'response_method': 'predict_proba'}, - 'The estimator has no predict_proba method.'), + 'response_method predict_proba not defined'), (NoPredictProbaNoDecisionFunction(), {'features': [0], 'response_method': 'decision_function'}, - 'The estimator has no decision_function method.'), + 'response_method decision_function not defined'), (LinearRegression(), {'features': [0], 'method': 'blahblah'}, 'blahblah is invalid. Accepted method names are brute, recursion, auto'), diff --git a/sklearn/metrics/_plot/base.py b/sklearn/metrics/_plot/base.py deleted file mode 100644 index 0e44a7715a1ed..0000000000000 --- a/sklearn/metrics/_plot/base.py +++ /dev/null @@ -1,114 +0,0 @@ -import numpy as np - -from sklearn.base import is_classifier - - -def _check_classifier_response_method(estimator, response_method): - """Return prediction method from the response_method - - Parameters - ---------- - estimator: object - Classifier to check - - response_method: {'auto', 'predict_proba', 'decision_function'} - Specifies whether to use :term:`predict_proba` or - :term:`decision_function` as the target response. If set to 'auto', - :term:`predict_proba` is tried first and if it does not exist - :term:`decision_function` is tried next. - - Returns - ------- - prediction_method: callable - prediction method of estimator - """ - - if response_method not in ("predict_proba", "decision_function", "auto"): - raise ValueError("response_method must be 'predict_proba', " - "'decision_function' or 'auto'") - - error_msg = "response method {} is not defined in {}" - if response_method != "auto": - prediction_method = getattr(estimator, response_method, None) - if prediction_method is None: - raise ValueError(error_msg.format(response_method, - estimator.__class__.__name__)) - else: - predict_proba = getattr(estimator, 'predict_proba', None) - decision_function = getattr(estimator, 'decision_function', None) - prediction_method = predict_proba or decision_function - if prediction_method is None: - raise ValueError(error_msg.format( - "decision_function or predict_proba", - estimator.__class__.__name__)) - - return prediction_method - - -def _get_response(X, estimator, response_method, pos_label=None): - """Return response and positive label. - - Parameters - ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Input values. - - estimator : estimator instance - Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` - in which the last estimator is a classifier. - - response_method: {'auto', 'predict_proba', 'decision_function'} - Specifies whether to use :term:`predict_proba` or - :term:`decision_function` as the target response. If set to 'auto', - :term:`predict_proba` is tried first and if it does not exist - :term:`decision_function` is tried next. - - pos_label : str or int, default=None - The class considered as the positive class when computing - the metrics. By default, `estimators.classes_[1]` is - considered as the positive class. - - Returns - ------- - y_pred: ndarray of shape (n_samples,) - Target scores calculated from the provided response_method - and pos_label. - - pos_label: str or int - The class considered as the positive class when computing - the metrics. - """ - classification_error = ( - "{} should be a binary classifier".format(estimator.__class__.__name__) - ) - - if not is_classifier(estimator): - raise ValueError(classification_error) - - prediction_method = _check_classifier_response_method( - estimator, response_method) - - y_pred = prediction_method(X) - - if pos_label is not None and pos_label not in estimator.classes_: - raise ValueError( - f"The class provided by 'pos_label' is unknown. Got " - f"{pos_label} instead of one of {estimator.classes_}" - ) - - if y_pred.ndim != 1: # `predict_proba` - if y_pred.shape[1] != 2: - raise ValueError(classification_error) - if pos_label is None: - pos_label = estimator.classes_[1] - y_pred = y_pred[:, 1] - else: - class_idx = np.flatnonzero(estimator.classes_ == pos_label) - y_pred = y_pred[:, class_idx] - else: - if pos_label is None: - pos_label = estimator.classes_[1] - elif pos_label == estimator.classes_[0]: - y_pred *= -1 - - return y_pred, pos_label diff --git a/sklearn/metrics/_plot/det_curve.py b/sklearn/metrics/_plot/det_curve.py index d9f642e38052a..031e3e1eb4b90 100644 --- a/sklearn/metrics/_plot/det_curve.py +++ b/sklearn/metrics/_plot/det_curve.py @@ -1,11 +1,13 @@ import scipy as sp -from .base import _get_response +from ...base import is_classifier +from ...utils import ( + check_matplotlib_support, + _get_response, +) from .. import det_curve -from ...utils import check_matplotlib_support - class DetCurveDisplay: """DET curve visualization. @@ -209,8 +211,16 @@ def plot_det_curve( """ check_matplotlib_support('plot_det_curve') + if not is_classifier(estimator): + raise ValueError( + f"{estimator.__class__.__name__} should be a binary classifier." + ) + + if response_method == "auto": + response_method = ["predict_proba", "decision_function"] + y_pred, pos_label = _get_response( - X, estimator, response_method, pos_label=pos_label + estimator, X, y, response_method, pos_label=pos_label ) fpr, fnr, _ = det_curve( diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index dcc20bbce25a7..8388c6654b138 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -1,9 +1,11 @@ -from .base import _get_response - from .. import average_precision_score from .. import precision_recall_curve -from ...utils import check_matplotlib_support +from ...base import is_classifier +from ...utils import ( + check_matplotlib_support, + _get_response, +) from ...utils.validation import _deprecate_positional_args @@ -202,8 +204,17 @@ def plot_precision_recall_curve(estimator, X, y, *, """ check_matplotlib_support("plot_precision_recall_curve") + if not is_classifier(estimator): + raise ValueError( + f"{estimator.__class__.__name__} should be a binary classifier." + ) + + if response_method == "auto": + response_method = ["predict_proba", "decision_function"] + y_pred, pos_label = _get_response( - X, estimator, response_method, pos_label=pos_label) + estimator, X, y, response_method, pos_label=pos_label + ) precision, recall, _ = precision_recall_curve(y, y_pred, pos_label=pos_label, diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 308ae4f4bf85d..f8b82d81a17e9 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -1,9 +1,11 @@ -from .base import _get_response - from .. import auc from .. import roc_curve -from ...utils import check_matplotlib_support +from ...base import is_classifier +from ...utils import ( + check_matplotlib_support, + _get_response, +) from ...utils.validation import _deprecate_positional_args @@ -209,8 +211,17 @@ def plot_roc_curve(estimator, X, y, *, sample_weight=None, """ check_matplotlib_support('plot_roc_curve') + if not is_classifier(estimator): + raise ValueError( + f"{estimator.__class__.__name__} should be a binary classifier." + ) + + if response_method == "auto": + response_method = ["predict_proba", "decision_function"] + y_pred, pos_label = _get_response( - X, estimator, response_method, pos_label=pos_label) + estimator, X, y, response_method, pos_label=pos_label + ) fpr, tpr, _ = roc_curve(y, y_pred, pos_label=pos_label, sample_weight=sample_weight, diff --git a/sklearn/metrics/_plot/tests/test_plot_curve_common.py b/sklearn/metrics/_plot/tests/test_plot_curve_common.py index c3b56f1724372..8dd919f8bfef5 100644 --- a/sklearn/metrics/_plot/tests/test_plot_curve_common.py +++ b/sklearn/metrics/_plot/tests/test_plot_curve_common.py @@ -11,6 +11,7 @@ from sklearn.tree import DecisionTreeClassifier from sklearn.metrics import plot_det_curve +from sklearn.metrics import plot_precision_recall_curve from sklearn.metrics import plot_roc_curve @@ -25,29 +26,44 @@ def data_binary(data): return X[y < 2], y[y < 2] -@pytest.mark.parametrize("plot_func", [plot_det_curve, plot_roc_curve]) +@pytest.mark.parametrize( + "plot_func", [plot_det_curve, plot_roc_curve, plot_precision_recall_curve] +) def test_plot_curve_error_non_binary(pyplot, data, plot_func): X, y = data clf = DecisionTreeClassifier() clf.fit(X, y) - msg = "DecisionTreeClassifier should be a binary classifier" + msg = "multiclass format is not supported" with pytest.raises(ValueError, match=msg): plot_func(clf, X, y) @pytest.mark.parametrize( "response_method, msg", - [("predict_proba", "response method predict_proba is not defined in " - "MyClassifier"), - ("decision_function", "response method decision_function is not defined " - "in MyClassifier"), - ("auto", "response method decision_function or predict_proba is not " - "defined in MyClassifier"), - ("bad_method", "response_method must be 'predict_proba', " - "'decision_function' or 'auto'")] + [ + ( + "predict_proba", + "response_method predict_proba not defined in " "MyClassifier", + ), + ( + "decision_function", + "response_method decision_function not defined " "in MyClassifier", + ), + ( + "auto", + "response_method predict_proba, decision_function " + "not defined in MyClassifier", + ), + ( + "bad_method", + "response_method bad_method not defined in MyClassifier", + ), + ], +) +@pytest.mark.parametrize( + "plot_func", [plot_det_curve, plot_roc_curve, plot_precision_recall_curve] ) -@pytest.mark.parametrize("plot_func", [plot_det_curve, plot_roc_curve]) def test_plot_curve_error_no_response( pyplot, data_binary, response_method, msg, plot_func, ): @@ -64,7 +80,9 @@ def fit(self, X, y): plot_func(clf, X, y, response_method=response_method) -@pytest.mark.parametrize("plot_func", [plot_det_curve, plot_roc_curve]) +@pytest.mark.parametrize( + "plot_func", [plot_det_curve, plot_roc_curve, plot_precision_recall_curve] +) def test_plot_curve_estimator_name_multiple_calls( pyplot, data_binary, plot_func ): @@ -89,7 +107,9 @@ def test_plot_curve_estimator_name_multiple_calls( make_pipeline(StandardScaler(), LogisticRegression()), make_pipeline(make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression())]) -@pytest.mark.parametrize("plot_func", [plot_det_curve, plot_roc_curve]) +@pytest.mark.parametrize( + "plot_func", [plot_det_curve, plot_roc_curve, plot_precision_recall_curve] +) def test_plot_det_curve_not_fitted_errors(pyplot, data_binary, clf, plot_func): X, y = data_binary # clone since we parametrize the test and the classifier will be fitted diff --git a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py index 48db806df87bf..a4cba0275836b 100644 --- a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py +++ b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py @@ -2,7 +2,6 @@ import numpy as np from numpy.testing import assert_allclose -from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.metrics import plot_precision_recall_curve from sklearn.metrics import PrecisionRecallDisplay from sklearn.metrics import average_precision_score @@ -16,7 +15,6 @@ from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler from sklearn.utils import shuffle -from sklearn.compose import make_column_transformer # TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved pytestmark = pytest.mark.filterwarnings( @@ -36,44 +34,12 @@ def test_errors(pyplot): plot_precision_recall_curve(binary_clf, X, y_binary) binary_clf.fit(X, y_binary) - multi_clf = DecisionTreeClassifier().fit(X, y_multiclass) - - # Fitted multiclass classifier with binary data - msg = "DecisionTreeClassifier should be a binary classifier" - with pytest.raises(ValueError, match=msg): - plot_precision_recall_curve(multi_clf, X, y_binary) - reg = DecisionTreeRegressor().fit(X, y_multiclass) msg = "DecisionTreeRegressor should be a binary classifier" with pytest.raises(ValueError, match=msg): plot_precision_recall_curve(reg, X, y_binary) -@pytest.mark.parametrize( - "response_method, msg", - [("predict_proba", "response method predict_proba is not defined in " - "MyClassifier"), - ("decision_function", "response method decision_function is not defined " - "in MyClassifier"), - ("auto", "response method decision_function or predict_proba is not " - "defined in MyClassifier"), - ("bad_method", "response_method must be 'predict_proba', " - "'decision_function' or 'auto'")]) -def test_error_bad_response(pyplot, response_method, msg): - X, y = make_classification(n_classes=2, n_samples=50, random_state=0) - - class MyClassifier(ClassifierMixin, BaseEstimator): - def fit(self, X, y): - self.fitted_ = True - self.classes_ = [0, 1] - return self - - clf = MyClassifier().fit(X, y) - - with pytest.raises(ValueError, match=msg): - plot_precision_recall_curve(clf, X, y, response_method=response_method) - - @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) @pytest.mark.parametrize("with_sample_weight", [True, False]) @@ -124,19 +90,6 @@ def test_plot_precision_recall(pyplot, response_method, with_sample_weight): assert disp.line_.get_label() == expected_label -@pytest.mark.parametrize( - "clf", [make_pipeline(StandardScaler(), LogisticRegression()), - make_pipeline(make_column_transformer((StandardScaler(), [0, 1])), - LogisticRegression())]) -def test_precision_recall_curve_pipeline(pyplot, clf): - X, y = make_classification(n_classes=2, n_samples=50, random_state=0) - with pytest.raises(NotFittedError): - plot_precision_recall_curve(clf, X, y) - clf.fit(X, y) - disp = plot_precision_recall_curve(clf, X, y) - assert disp.estimator_name == clf.__class__.__name__ - - def test_precision_recall_curve_string_labels(pyplot): # regression test #15738 cancer = load_breast_cancer() @@ -157,23 +110,6 @@ def test_precision_recall_curve_string_labels(pyplot): assert disp.estimator_name == lr.__class__.__name__ -def test_plot_precision_recall_curve_estimator_name_multiple_calls(pyplot): - # non-regression test checking that the `name` used when calling - # `plot_roc_curve` is used as well when calling `disp.plot()` - X, y = make_classification(n_classes=2, n_samples=50, random_state=0) - clf_name = "my hand-crafted name" - clf = LogisticRegression().fit(X, y) - disp = plot_precision_recall_curve(clf, X, y, name=clf_name) - assert disp.estimator_name == clf_name - pyplot.close("all") - disp.plot() - assert clf_name in disp.line_.get_label() - pyplot.close("all") - clf_name = "another_name" - disp.plot(name=clf_name) - assert clf_name in disp.line_.get_label() - - @pytest.mark.parametrize( "average_precision, estimator_name, expected_label", [ diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 0364fbba52f63..d0ade5c0daf48 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -291,7 +291,8 @@ def det_curve(y_true, y_score, pos_label=None, sample_weight=None): >>> thresholds array([0.35, 0.4 , 0.8 ]) """ - if len(np.unique(y_true)) != 2: + y_true_type = type_of_target(y_true) + if y_true_type == "binary" and len(np.unique(y_true)) != 2: raise ValueError("Only one class present in y_true. Detection error " "tradeoff curve is not defined in that case.") diff --git a/sklearn/metrics/_scorer.py b/sklearn/metrics/_scorer.py index c686d3b7c0b34..bfd9c723dd2f7 100644 --- a/sklearn/metrics/_scorer.py +++ b/sklearn/metrics/_scorer.py @@ -21,6 +21,7 @@ from collections.abc import Iterable from functools import partial from collections import Counter +from inspect import signature import numpy as np @@ -42,21 +43,23 @@ from .cluster import normalized_mutual_info_score from .cluster import fowlkes_mallows_score +from ..utils import _get_response from ..utils.multiclass import type_of_target from ..utils.validation import _deprecate_positional_args from ..base import is_regressor -def _cached_call(cache, estimator, method, *args, **kwargs): +def _cached_call(cache, estimator, *args, **kwargs): """Call estimator with method and args and kwargs.""" + response_method = kwargs["response_method"] if cache is None: - return getattr(estimator, method)(*args, **kwargs) + return _get_response(estimator, *args, **kwargs) try: - return cache[method] + return cache[response_method] except KeyError: - result = getattr(estimator, method)(*args, **kwargs) - cache[method] = result + result = _get_response(estimator, *args, **kwargs) + cache[response_method] = result return result @@ -84,8 +87,9 @@ def __call__(self, estimator, *args, **kwargs): for name, scorer in self._scorers.items(): if isinstance(scorer, _BaseScorer): - score = scorer._score(cached_call, estimator, - *args, **kwargs) + score = scorer._score( + cached_call, estimator, *args, **kwargs + ) else: score = scorer(estimator, *args, **kwargs) scores[name] = score @@ -128,42 +132,15 @@ def __init__(self, score_func, sign, kwargs): self._score_func = score_func self._sign = sign - @staticmethod - def _check_pos_label(pos_label, classes): - if pos_label not in list(classes): - raise ValueError( - f"pos_label={pos_label} is not a valid label: {classes}" - ) - - def _select_proba_binary(self, y_pred, classes): - """Select the column of the positive label in `y_pred` when - probabilities are provided. - - Parameters - ---------- - y_pred : ndarray of shape (n_samples, n_classes) - The prediction given by `predict_proba`. - - classes : ndarray of shape (n_classes,) - The class labels for the estimator. - - Returns - ------- - y_pred : ndarray of shape (n_samples,) - Probability predictions of the positive class. - """ - if y_pred.shape[1] == 2: - pos_label = self._kwargs.get("pos_label", classes[1]) - self._check_pos_label(pos_label, classes) - col_idx = np.flatnonzero(classes == pos_label)[0] - return y_pred[:, col_idx] - - err_msg = ( - f"Got predict_proba of shape {y_pred.shape}, but need " - f"classifier with two classes for {self._score_func.__name__} " - f"scoring" - ) - raise ValueError(err_msg) + def _get_pos_label(self): + score_func_params = signature(self._score_func).parameters + if "pos_label" in self._kwargs: + pos_label = self._kwargs["pos_label"] + elif "pos_label" in score_func_params: + pos_label = score_func_params["pos_label"].default + else: + pos_label = None + return pos_label def __repr__(self): kwargs_string = "".join([", %s=%s" % (str(k), str(v)) @@ -173,19 +150,19 @@ def __repr__(self): "" if self._sign > 0 else ", greater_is_better=False", self._factory_args(), kwargs_string)) - def __call__(self, estimator, X, y_true, sample_weight=None): + def __call__(self, estimator, X, y, sample_weight=None): """Evaluate predicted target values for X relative to y_true. Parameters ---------- estimator : object - Trained estimator to use for scoring. Must have a predict_proba + Trained estimator to use for scoring. Must have a prediction method; the output of that is used to compute the score. X : {array-like, sparse matrix} Test data that will be fed to estimator.predict. - y_true : array-like + y : array-like Gold standard target values for X. sample_weight : array-like of shape (n_samples,), default=None @@ -196,8 +173,13 @@ def __call__(self, estimator, X, y_true, sample_weight=None): score : float Score function applied to prediction of estimator on X. """ - return self._score(partial(_cached_call, None), estimator, X, y_true, - sample_weight=sample_weight) + return self._score( + partial(_cached_call, None), + estimator, + X, + y, + sample_weight=sample_weight + ) def _factory_args(self): """Return non-default make_scorer arguments for repr.""" @@ -205,7 +187,7 @@ def _factory_args(self): class _PredictScorer(_BaseScorer): - def _score(self, method_caller, estimator, X, y_true, sample_weight=None): + def _score(self, method_caller, estimator, X, y, sample_weight=None): """Evaluate predicted target values for X relative to y_true. Parameters @@ -215,13 +197,13 @@ def _score(self, method_caller, estimator, X, y_true, sample_weight=None): arguments, potentially caching results. estimator : object - Trained estimator to use for scoring. Must have a predict_proba + Trained estimator to use for scoring. Must have a `predict` method; the output of that is used to compute the score. X : {array-like, sparse matrix} Test data that will be fed to estimator.predict. - y_true : array-like + y : array-like Gold standard target values for X. sample_weight : array-like of shape (n_samples,), default=None @@ -233,19 +215,20 @@ def _score(self, method_caller, estimator, X, y_true, sample_weight=None): Score function applied to prediction of estimator on X. """ - y_pred = method_caller(estimator, "predict", X) + y_pred, _ = method_caller( + estimator, X, y, response_method="predict" + ) if sample_weight is not None: - return self._sign * self._score_func(y_true, y_pred, - sample_weight=sample_weight, - **self._kwargs) + return self._sign * self._score_func( + y, y_pred, sample_weight=sample_weight, **self._kwargs + ) else: - return self._sign * self._score_func(y_true, y_pred, - **self._kwargs) + return self._sign * self._score_func(y, y_pred, **self._kwargs) class _ProbaScorer(_BaseScorer): def _score(self, method_caller, clf, X, y, sample_weight=None): - """Evaluate predicted probabilities for X relative to y_true. + """Evaluate predicted probabilities for `X` relative to `y_true`. Parameters ---------- @@ -254,14 +237,14 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): arguments, potentially caching results. clf : object - Trained classifier to use for scoring. Must have a predict_proba + Trained classifier to use for scoring. Must have a `predict_proba` method; the output of that is used to compute the score. X : {array-like, sparse matrix} - Test data that will be fed to clf.predict_proba. + Test data that will be fed to `clf.predict_proba`. y : array-like - Gold standard target values for X. These must be class labels, + Gold standard target values for `X`. These must be class labels, not probabilities. sample_weight : array-like, default=None @@ -272,18 +255,17 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): score : float Score function applied to prediction of estimator on X. """ - - y_type = type_of_target(y) - y_pred = method_caller(clf, "predict_proba", X) - if y_type == "binary" and y_pred.shape[1] <= 2: - # `y_type` could be equal to "binary" even in a multi-class - # problem: (when only 2 class are given to `y_true` during scoring) - # Thus, we need to check for the shape of `y_pred`. - y_pred = self._select_proba_binary(y_pred, clf.classes_) + y_pred, _ = method_caller( + clf, + X, + y, + response_method="predict_proba", + pos_label=self._get_pos_label(), + ) if sample_weight is not None: - return self._sign * self._score_func(y, y_pred, - sample_weight=sample_weight, - **self._kwargs) + return self._sign * self._score_func( + y, y_pred, sample_weight=sample_weight, **self._kwargs + ) else: return self._sign * self._score_func(y, y_pred, **self._kwargs) @@ -328,36 +310,38 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): raise ValueError("{0} format is not supported".format(y_type)) if is_regressor(clf): - y_pred = method_caller(clf, "predict", X) + y_pred, _ = method_caller(clf, X, y, response_method="predict") else: + pos_label = self._get_pos_label() try: - y_pred = method_caller(clf, "decision_function", X) + y_pred, _ = method_caller( + clf, + X, + y, + response_method="decision_function", + pos_label=pos_label, + ) if isinstance(y_pred, list): # For multi-output multi-class estimator y_pred = np.vstack([p for p in y_pred]).T - elif y_type == "binary" and "pos_label" in self._kwargs: - self._check_pos_label( - self._kwargs["pos_label"], clf.classes_ - ) - if self._kwargs["pos_label"] == clf.classes_[0]: - # The implicit positive class of the binary classifier - # does not match `pos_label`: we need to invert the - # predictions - y_pred *= -1 - - except (NotImplementedError, AttributeError): - y_pred = method_caller(clf, "predict_proba", X) - - if y_type == "binary": - y_pred = self._select_proba_binary(y_pred, clf.classes_) - elif isinstance(y_pred, list): + + except (NotImplementedError, AttributeError, ValueError): + y_pred, _ = method_caller( + clf, + X, + y, + response_method="predict_proba", + pos_label=pos_label, + ) + + if isinstance(y_pred, list): y_pred = np.vstack([p[:, -1] for p in y_pred]).T if sample_weight is not None: - return self._sign * self._score_func(y, y_pred, - sample_weight=sample_weight, - **self._kwargs) + return self._sign * self._score_func( + y, y_pred, sample_weight=sample_weight, **self._kwargs + ) else: return self._sign * self._score_func(y, y_pred, **self._kwargs) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 3f1401dc08713..7a9e9d4ed8539 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -12,11 +12,18 @@ import joblib from numpy.testing import assert_allclose + +from sklearn.utils._mocking import ( + DummyScorer, + EstimatorWithFit, + EstimatorWithFitAndPredict, + EstimatorWithFitAndScore, + EstimatorWithoutFit, +) from sklearn.utils._testing import assert_almost_equal from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import ignore_warnings -from sklearn.base import BaseEstimator from sklearn.metrics import ( average_precision_score, brier_score_loss, @@ -140,42 +147,6 @@ def teardown_module(): shutil.rmtree(TEMP_FOLDER) -class EstimatorWithoutFit: - """Dummy estimator to test scoring validators""" - pass - - -class EstimatorWithFit(BaseEstimator): - """Dummy estimator to test scoring validators""" - def fit(self, X, y): - return self - - -class EstimatorWithFitAndScore: - """Dummy estimator to test scoring validators""" - def fit(self, X, y): - return self - - def score(self, X, y): - return 1.0 - - -class EstimatorWithFitAndPredict: - """Dummy estimator to test scoring validators""" - def fit(self, X, y): - self.y = y - return self - - def predict(self, X): - return self.y - - -class DummyScorer: - """Dummy scorer that always returns 1.""" - def __call__(self, est, X, y): - return 1 - - def test_all_scorers_repr(): # Test that all scorers have a working repr for name, scorer in SCORERS.items(): @@ -622,13 +593,18 @@ def test_multimetric_scorer_calls_method_once(scorers, expected_predict_count, X, y = np.array([[1], [1], [0], [0], [0]]), np.array([0, 1, 1, 1, 0]) mock_est = Mock() + mock_est._estimator_type = "classifier" fit_func = Mock(return_value=mock_est) + fit_func.__name__ = "fit" predict_func = Mock(return_value=y) + predict_func.__name__ = "predict" pos_proba = np.random.rand(X.shape[0]) proba = np.c_[1 - pos_proba, pos_proba] predict_proba_func = Mock(return_value=proba) + predict_proba_func.__name__ = "predict_proba" decision_function_func = Mock(return_value=pos_proba) + decision_function_func.__name__ = "decision_function" mock_est.fit = fit_func mock_est.predict = predict_func @@ -762,8 +738,8 @@ def test_multiclass_roc_no_proba_scorer_errors(scorer_name): X, y = make_classification(n_classes=3, n_informative=3, n_samples=20, random_state=0) lr = Perceptron().fit(X, y) - msg = "'Perceptron' object has no attribute 'predict_proba'" - with pytest.raises(AttributeError, match=msg): + msg = "response_method predict_proba not defined in Perceptron" + with pytest.raises(ValueError, match=msg): scorer(lr, X, y) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index ca2be9d14fe29..52b6c54077b5a 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -28,11 +28,13 @@ from ._estimator_html_repr import estimator_html_repr from .validation import (as_float_array, assert_all_finite, + _check_response_method, check_random_state, column_or_1d, check_array, check_consistent_length, check_X_y, indexable, check_symmetric, check_scalar, _deprecate_positional_args) from .. import get_config +from .multiclass import type_of_target # Do not deprecate parallel_backend and register_parallel_backend as they are @@ -1180,3 +1182,94 @@ def is_abstract(c): # itemgetter is used to ensure the sort does not extend to the 2nd item of # the tuple return sorted(set(estimators), key=itemgetter(0)) + + +def _get_response( + estimator, + X, + y_true, + response_method=None, + pos_label=None, +): + """Return response and positive label. + + Parameters + ---------- + estimator : estimator instance + Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` + in which the last estimator is a classifier. + + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Input values. + + y_true : array-like of shape (n_samples,) + The true label. + + response_method : {'predict_proba', 'decision_function', 'predict'} or \ + list of str, default=None. + Specifies the response method to use get prediction from an estimator + (i.e. :term:`predict_proba`, :term:`decision_function` or + :term:`predict`). + + * if `str`, it corresponds to the name to the method to return. + * if a list of `str`, it provides the method names in order of + preference. The method returned corresponds to the first method in + the list and which is implemented by `estimator`. + * if `None`, :term:`predict_proba` is tried first and if it does not + exist :term:`decision_function` is tried next and :term:`predict` + last. + + pos_label : str or int, default=None + The class considered as the positive class when computing + the metrics. By default, `estimators.classes_[1]` is + considered as the positive class. + + Returns + ------- + y_pred : ndarray of shape (n_samples,) + Target scores calculated from the provided response_method + and pos_label. + + pos_label : str, int or None + The class considered as the positive class when computing + the metrics. Returns `None` with `estimator` is a regressor. + """ + from sklearn.base import is_classifier # noqa + + if is_classifier(estimator): + y_type = type_of_target(y_true) + prediction_method = _check_response_method(estimator, response_method) + y_pred = prediction_method(X) + classes = estimator.classes_ + + if pos_label is not None and pos_label not in classes.tolist(): + raise ValueError( + f"pos_label={pos_label} is not a valid label: It should be " + f"one of {classes}" + ) + elif pos_label is None and y_type == "binary": + pos_label = pos_label if pos_label is not None else classes[-1] + + if prediction_method.__name__ == "predict_proba": + if y_type == "binary" and y_pred.shape[1] <= 2: + if y_pred.shape[1] == 2: + col_idx = np.flatnonzero(classes == pos_label)[0] + y_pred = y_pred[:, col_idx] + else: + err_msg = ( + f"Got predict_proba of shape {y_pred.shape}, but need " + f"classifier with two classes." + ) + raise ValueError(err_msg) + elif prediction_method.__name__ == "decision_function": + if y_type == "binary": + if pos_label == classes[0]: + y_pred *= -1 + else: + if response_method not in ("predict", None): + raise ValueError( + f"{estimator.__class__.__name__} should be a classifier" + ) + y_pred, pos_label = estimator.predict(X), None + + return y_pred, pos_label diff --git a/sklearn/utils/_mocking.py b/sklearn/utils/_mocking.py index 00109051d035e..53d85725c558b 100644 --- a/sklearn/utils/_mocking.py +++ b/sklearn/utils/_mocking.py @@ -321,3 +321,92 @@ def predict_proba(self, X): def _more_tags(self): return {'_skip_test': True} + + +class EstimatorWithoutFit: + """Dummy estimator to test scoring validators""" + pass + + +class EstimatorWithFit(BaseEstimator): + """Dummy estimator to test scoring validators""" + def fit(self, X, y): + self.classes_ = np.unique(y) + return self + + +class EstimatorWithFitAndScore: + """Dummy estimator to test scoring validators""" + def fit(self, X, y): + self.classes_ = np.unique(y) + return self + + def score(self, X, y): + return 1.0 + + +class EstimatorWithFitAndPredict: + """Dummy estimator to test scoring validators""" + def fit(self, X, y): + self.y = y + self.classes_ = np.unique(y) + return self + + def predict(self, X): + return self.y + + +class MockEstimatorOnOffPrediction(BaseEstimator): + """Estimator for which we can turn on/off the prediction methods. + + Parameters + ---------- + response_methods: list of \ + {"predict", "predict_proba", "decision_function"}, default=None + List containing the response implemented by the estimator. When, the + response is in the list, it will return the name of the response method + when called. Otherwise, an `AttributeError` is raised. It allows to + use `getattr` as any conventional estimator. + By default, no response methods are mocked. + """ + def __init__(self, response_methods=None): + self.response_methods = response_methods + + def fit(self, X, y): + self.classes_ = np.unique(y) + return self + + def _predict(self, X): + return "predict" + + def _predict_proba(self, X): + return "predict_proba" + + def _decision_function(self, X): + return "decision_function" + + def _check_response(self, method): + if ( + self.response_methods is not None + and method in self.response_methods + ): + return getattr(self, f"_{method}") + raise AttributeError(f"{method} not implemented.") + + @property + def predict(self): + return self._check_response("predict") + + @property + def predict_proba(self): + return self._check_response("predict_proba") + + @property + def decision_function(self): + return self._check_response("decision_function") + + +class DummyScorer: + """Dummy scorer that always returns 1.""" + def __call__(self, est, X, y): + return 1 diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 44e448841cef0..436667f1d7b02 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -8,7 +8,18 @@ import numpy as np import scipy.sparse as sp +from sklearn.datasets import ( + make_classification, + make_regression, +) +from sklearn.linear_model import ( + LinearRegression, + LogisticRegression, +) +from sklearn.tree import DecisionTreeClassifier + from sklearn.utils._testing import (assert_array_equal, + assert_allclose, assert_allclose_dense_sparse, assert_warns_message, assert_no_warnings, @@ -18,6 +29,7 @@ from sklearn.utils import deprecated from sklearn.utils import gen_batches from sklearn.utils import _get_column_indices +from sklearn.utils import _get_response from sklearn.utils import resample from sklearn.utils import safe_mask from sklearn.utils import column_or_1d @@ -28,7 +40,10 @@ from sklearn.utils import get_chunk_n_rows from sklearn.utils import is_scalar_nan from sklearn.utils import _to_object_array -from sklearn.utils._mocking import MockDataFrame +from sklearn.utils._mocking import ( + MockDataFrame, + MockEstimatorOnOffPrediction, +) from sklearn import config_context # toy array @@ -693,3 +708,139 @@ def test_to_object_array(sequence): assert isinstance(out, np.ndarray) assert out.dtype.kind == 'O' assert out.ndim == 1 + + +@pytest.mark.parametrize( + "response_method", ["decision_function", "predict_proba"] +) +def test_get_response_regressor_error(response_method): + """Check the error message with regressor an not supported response + method.""" + my_estimator = MockEstimatorOnOffPrediction( + response_methods=[response_method] + ) + X, y = "mocking_data", "mocking_target" + err_msg = f"{my_estimator.__class__.__name__} should be a classifier" + with pytest.raises(ValueError, match=err_msg): + _get_response(my_estimator, X, y, response_method=response_method) + + +@pytest.mark.parametrize("response_method", ["predict", None]) +def test_get_response_regressor(response_method): + """Check the behaviour of `_get_response` with regressor.""" + X, y = make_regression(n_samples=10, random_state=0) + regressor = LinearRegression().fit(X, y) + y_pred, pos_label = _get_response( + regressor, + X, + y, + response_method=response_method, + ) + assert_allclose(y_pred, regressor.predict(X)) + assert pos_label is None + + +@pytest.mark.parametrize( + "response_method", + [None, "predict_proba", "decision_function", "predict"], +) +def test_get_response_classifier_unknown_pos_label(response_method): + """Check that `_get_response` raises the proper error message with + classifier.""" + X, y = make_classification(n_samples=10, n_classes=2, random_state=0) + classifier = LogisticRegression().fit(X, y) + + # provide a `pos_labe` which is not in `y` + err_msg = ( + r"pos_label=whatever is not a valid label: It should be one of \[0 1\]" + ) + with pytest.raises(ValueError, match=err_msg): + _get_response( + classifier, + X, + y, + response_method=response_method, + pos_label="whatever", + ) + + +def test_get_response_classifier_inconsistent_y_pred_for_binary_proba(): + """Check that `_get_response` will raise an error when `y_pred` has a + single class with `predict_proba`.""" + X, y_two_class = make_classification( + n_samples=10, n_classes=2, random_state=0 + ) + y_single_class = np.zeros_like(y_two_class) + classifier = DecisionTreeClassifier().fit(X, y_single_class) + + err_msg = ( + r"Got predict_proba of shape \(10, 1\), but need classifier with " + r"two classes" + ) + with pytest.raises(ValueError, match=err_msg): + _get_response( + classifier, X, y_two_class, response_method="predict_proba" + ) + + +def test_get_response_binary_classifier_decision_function(): + """Check the behaviour of `_get_response` with `decision_function` + and binary classifier. + """ + X, y = make_classification( + n_samples=10, + n_classes=2, + weights=[0.3, 0.7], + random_state=0, + ) + classifier = LogisticRegression().fit(X, y) + response_method = "decision_function" + + # default `pos_label` + y_pred, pos_label = _get_response( + classifier, X, y, response_method=response_method, pos_label=None + ) + assert_allclose(y_pred, classifier.decision_function(X)) + assert pos_label == 1 + + # when forcing `pos_label=classifier.classes_[0]` + y_pred, pos_label = _get_response( + classifier, + X, + y, + response_method=response_method, + pos_label=classifier.classes_[0], + ) + assert_allclose(y_pred, classifier.decision_function(X) * -1) + assert pos_label == 0 + + +def test_get_response_binary_classifier_predict_proba(): + """Check that `_get_response` with `predict_proba` and binary + classifier.""" + X, y = make_classification( + n_samples=10, + n_classes=2, + weights=[0.3, 0.7], + random_state=0, + ) + classifier = LogisticRegression().fit(X, y) + response_method = "predict_proba" + + # default `pos_label` + y_pred, pos_label = _get_response( + classifier, X, y, response_method=response_method, pos_label=None + ) + assert_allclose(y_pred, classifier.predict_proba(X)[:, 1]) + assert pos_label == 1 + + # when forcing `pos_label=classifier.classes_[0]` + y_pred, pos_label = _get_response( + classifier, + X, + y, + response_method=response_method, + pos_label=classifier.classes_[0], + ) + assert_allclose(y_pred, classifier.predict_proba(X)[:, 0]) + assert pos_label == 0 diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 293af1732e1f4..1244152aa9ae2 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -21,7 +21,11 @@ from sklearn.utils import as_float_array, check_array, check_symmetric from sklearn.utils import check_X_y from sklearn.utils import deprecated -from sklearn.utils._mocking import MockDataFrame +from sklearn.utils._mocking import ( + EstimatorWithFit, + MockDataFrame, + MockEstimatorOnOffPrediction, +) from sklearn.utils.fixes import np_version, parse_version from sklearn.utils.estimator_checks import _NotAnArray from sklearn.random_projection import _sparse_random_matrix @@ -39,6 +43,7 @@ check_memory, check_non_negative, _num_samples, + _check_response_method, check_scalar, _check_psd_eigenvalues, _deprecate_positional_args, @@ -46,7 +51,6 @@ _allclose_dense_sparse, FLOAT_DTYPES) from sklearn.utils.validation import _check_fit_params -from sklearn.utils.fixes import parse_version import sklearn @@ -1320,3 +1324,77 @@ def test_check_pandas_sparse_valid(ntype1, ntype2, expected_subtype): dtype=ntype2)}) arr = check_array(df, accept_sparse=['csr', 'csc']) assert np.issubdtype(arr.dtype, expected_subtype) + + +def test_check_response_method_unknown_method(): + """Check the error message when passing an unknown response method.""" + err_msg = "response_method unknown_method not defined" + with pytest.raises(ValueError, match=err_msg): + _check_response_method(RandomForestRegressor(), "unknown_method") + + +@pytest.mark.parametrize( + "response_method", ["decision_function", "predict_proba", "predict", None] +) +def test_check_response_method_not_supported_response_method(response_method): + """Check the error message when a response method is not supported by the + estimator.""" + err_msg = "response_method {} not defined" + if response_method is None: + err_msg = err_msg.format("predict_proba, decision_function, predict") + else: + err_msg = err_msg.format(response_method) + with pytest.raises(ValueError, match=err_msg): + _check_response_method(EstimatorWithFit(), response_method) + + +@pytest.mark.parametrize( + "response_methods, expected_method_name", + [ + (["predict_proba", "decision_function", "predict"], "predict_proba"), + (["decision_function", "predict"], "decision_function"), + (["predict_proba", "predict"], "predict_proba"), + (["predict_proba", "predict_proba"]), + (["decision_function", "decision_function"]), + (["predict"], "predict"), + ], +) +def test_check_response_method_order_None( + response_methods, expected_method_name +): + """Check the order of the response method when using None.""" + my_estimator = MockEstimatorOnOffPrediction(response_methods) + + X = "mocking_data" + method_name_predicting = _check_response_method(my_estimator, None)(X) + assert method_name_predicting == expected_method_name + + +def test_check_response_method_list_str(): + """Check that we can pass a list of ordered method.""" + method_implemented = ["predict_proba"] + my_estimator = MockEstimatorOnOffPrediction(method_implemented) + + X = "mocking_data" + + # raise an error when no methods are defined + response_method = ["decision_function", "predict"] + err_msg = "response_method decision_function, predict not defined" + with pytest.raises(ValueError, match=err_msg): + _check_response_method(my_estimator, response_method)(X) + + # check that we don't get issue when one of the method is defined + response_method = ["decision_function", "predict_proba"] + method_name_predicting = _check_response_method( + my_estimator, response_method + )(X) + assert method_name_predicting == "predict_proba" + + # check the order of the methods returned + method_implemented = ["predict_proba", "predict"] + my_estimator = MockEstimatorOnOffPrediction(method_implemented) + response_method = ["decision_function", "predict", "predict_proba"] + method_name_predicting = _check_response_method( + my_estimator, response_method + )(X) + assert method_name_predicting == "predict" diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 902a9f4ddf426..8479a03a012bb 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -9,7 +9,7 @@ # Sylvain Marie # License: BSD 3 clause -from functools import wraps +from functools import reduce, wraps import warnings import numbers @@ -1398,3 +1398,50 @@ def _check_fit_params(X, fit_params, indices=None): ) return fit_params_validated + + +def _check_response_method(estimator, response_method=None): + """Return prediction method from the `response_method`. + + Parameters + ---------- + estimator : estimator instance + Classifier or regressor to check. + + response_method : {'predict_proba', 'decision_function', 'predict'} or \ + list of str, default=None. + Specifies the response method to use get prediction from an estimator + (i.e. :term:`predict_proba`, :term:`decision_function` or + :term:`predict`). + + * if `str`, it corresponds to the name to the method to return. + * if a list of `str`, it provides the method names in order of + preference. The method returned corresponds to the first method in + the list and which is implemented by `estimator`. + * if `None`, :term:`predict_proba` is tried first and if it does not + exist :term:`decision_function` is tried next and :term:`predict` + last. + + Returns + ------- + prediction_method : callable + Prediction method of estimator. + """ + if response_method is None: + list_methods = ["predict_proba", "decision_function", "predict"] + elif isinstance(response_method, str): + list_methods = [response_method] + else: + list_methods = response_method + + prediction_method = [ + getattr(estimator, method, None) for method in list_methods + ] + prediction_method = reduce(lambda x, y: x or y, prediction_method) + if prediction_method is None: + raise ValueError( + f"response_method {', '.join(list_methods)} not defined in " + f"{estimator.__class__.__name__}" + ) + + return prediction_method