diff --git a/doc/whats_new/v0.24.rst b/doc/whats_new/v0.24.rst index 347b30bff5685..37b3b4fd1cad9 100644 --- a/doc/whats_new/v0.24.rst +++ b/doc/whats_new/v0.24.rst @@ -268,6 +268,13 @@ Changelog class to be used when computing the roc auc statistics. :pr:`17651` by :user:`Clara Matos `. +- |Fix| Fix scorers that accept a pos_label parameter and compute their metrics + from values returned by `decision_function` or `predict_proba`. Previously, + they would return erroneous values when pos_label was not corresponding to + `classifier.classes_[1]`. This is especially important when training + classifiers directly with string labeled target classes. + :pr:`#18114` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.model_selection` .............................. diff --git a/sklearn/metrics/_base.py b/sklearn/metrics/_base.py index 21d0ab38f6a91..ae3f4140f4974 100644 --- a/sklearn/metrics/_base.py +++ b/sklearn/metrics/_base.py @@ -16,7 +16,9 @@ import numpy as np -from ..utils import check_array, check_consistent_length +from ..base import is_classifier +from ..utils import check_array +from ..utils import check_consistent_length from ..utils.multiclass import type_of_target @@ -200,3 +202,145 @@ def _average_multiclass_ovo_score(binary_metric, y_true, y_score, pair_scores[ix] = (a_true_score + b_true_score) / 2 return np.average(pair_scores, weights=prevalence) + + +def _check_classifier_response_method(estimator, response_method): + """Return prediction method from the response_method + + Parameters + ---------- + estimator : estimator instance + Classifier to check. + + response_method : {'auto', 'predict_proba', 'decision_function', 'predict'} + Specifies whether to use :term:`predict_proba` or + :term:`decision_function` as the target response. If set to 'auto', + :term:`predict_proba` is tried first and if it does not exist + :term:`decision_function` is tried next and :term:`predict` last. + + Returns + ------- + prediction_method : callable + Prediction method of estimator. + """ + + possible_response_methods = ( + "predict", "predict_proba", "decision_function", "auto" + ) + if response_method not in possible_response_methods: + raise ValueError( + f"response_method must be one of " + f"{','.join(possible_response_methods)}." + ) + + error_msg = "response method {} is not defined in {}" + if response_method != "auto": + prediction_method = getattr(estimator, response_method, None) + if prediction_method is None: + raise ValueError( + error_msg.format(response_method, estimator.__class__.__name__) + ) + else: + predict_proba = getattr(estimator, 'predict_proba', None) + decision_function = getattr(estimator, 'decision_function', None) + predict = getattr(estimator, 'predict', None) + prediction_method = predict_proba or decision_function or predict + if prediction_method is None: + raise ValueError( + error_msg.format( + "decision_function, predict_proba or predict", + estimator.__class__.__name__ + ) + ) + + return prediction_method + + +def _get_response( + estimator, + X, + y_true, + response_method, + pos_label=None, + support_multi_class=False, +): + """Return response and positive label. + + Parameters + ---------- + estimator : estimator instance + Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` + in which the last estimator is a classifier. + + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Input values. + + y_true : array-like of shape (n_samples,) + The true label. + + response_method: {'auto', 'predict_proba', 'decision_function', 'predict'} + Specifies whether to use :term:`predict_proba` or + :term:`decision_function` as the target response. If set to 'auto', + :term:`predict_proba` is tried first and if it does not exist + :term:`decision_function` is tried next and :term:`predict` last. + + pos_label : str or int, default=None + The class considered as the positive class when computing + the metrics. By default, `estimators.classes_[1]` is + considered as the positive class. + + support_multi_class : bool, default=False + ... + + Returns + ------- + y_pred : ndarray of shape (n_samples,) + Target scores calculated from the provided response_method + and pos_label. + + pos_label : str or int + The class considered as the positive class when computing + the metrics. + """ + if is_classifier(estimator): + y_type = type_of_target(y_true) + classes = estimator.classes_ + prediction_method = _check_classifier_response_method( + estimator, response_method + ) + y_pred = prediction_method(X) + + if pos_label is not None and pos_label not in classes: + raise ValueError( + f"The class provided by 'pos_label' is unknown. Got " + f"{pos_label} instead of one of {classes}." + ) + + if prediction_method.__name__ == "predict_proba": + if y_type == "binary": + pos_label = pos_label if pos_label is not None else classes[-1] + if y_pred.shape[1] == 2: + col_idx = np.flatnonzero(classes == pos_label)[0] + y_pred = y_pred[:, col_idx] + else: + err_msg = ( + f"Got predict_proba of shape {y_pred.shape}, but need " + f"classifier with two classes" + ) + if support_multi_class and y_pred.shape[1] == 1: + raise ValueError(err_msg) + elif not support_multi_class: + raise ValueError(err_msg) + elif prediction_method.__name__ == "decision_function": + if y_type == "binary": + pos_label = pos_label if pos_label is not None else classes[-1] + if pos_label == classes[0]: + y_pred *= -1 + else: + if response_method not in ("predict", "auto"): + raise ValueError( + f"{estimator.__class__.__name__} should be a classifier" + ) + y_pred, pos_label = estimator.predict(X), None + + return y_pred, pos_label diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index a7bad09ed98d0..5db6a5798abd0 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -1252,7 +1252,7 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label): str(average_options)) y_type, y_true, y_pred = _check_targets(y_true, y_pred) - present_labels = unique_labels(y_true, y_pred) + present_labels = unique_labels(y_true, y_pred).tolist() if average == 'binary': if y_type == 'binary': if pos_label not in present_labels: diff --git a/sklearn/metrics/_plot/base.py b/sklearn/metrics/_plot/base.py deleted file mode 100644 index 0e44a7715a1ed..0000000000000 --- a/sklearn/metrics/_plot/base.py +++ /dev/null @@ -1,114 +0,0 @@ -import numpy as np - -from sklearn.base import is_classifier - - -def _check_classifier_response_method(estimator, response_method): - """Return prediction method from the response_method - - Parameters - ---------- - estimator: object - Classifier to check - - response_method: {'auto', 'predict_proba', 'decision_function'} - Specifies whether to use :term:`predict_proba` or - :term:`decision_function` as the target response. If set to 'auto', - :term:`predict_proba` is tried first and if it does not exist - :term:`decision_function` is tried next. - - Returns - ------- - prediction_method: callable - prediction method of estimator - """ - - if response_method not in ("predict_proba", "decision_function", "auto"): - raise ValueError("response_method must be 'predict_proba', " - "'decision_function' or 'auto'") - - error_msg = "response method {} is not defined in {}" - if response_method != "auto": - prediction_method = getattr(estimator, response_method, None) - if prediction_method is None: - raise ValueError(error_msg.format(response_method, - estimator.__class__.__name__)) - else: - predict_proba = getattr(estimator, 'predict_proba', None) - decision_function = getattr(estimator, 'decision_function', None) - prediction_method = predict_proba or decision_function - if prediction_method is None: - raise ValueError(error_msg.format( - "decision_function or predict_proba", - estimator.__class__.__name__)) - - return prediction_method - - -def _get_response(X, estimator, response_method, pos_label=None): - """Return response and positive label. - - Parameters - ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Input values. - - estimator : estimator instance - Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` - in which the last estimator is a classifier. - - response_method: {'auto', 'predict_proba', 'decision_function'} - Specifies whether to use :term:`predict_proba` or - :term:`decision_function` as the target response. If set to 'auto', - :term:`predict_proba` is tried first and if it does not exist - :term:`decision_function` is tried next. - - pos_label : str or int, default=None - The class considered as the positive class when computing - the metrics. By default, `estimators.classes_[1]` is - considered as the positive class. - - Returns - ------- - y_pred: ndarray of shape (n_samples,) - Target scores calculated from the provided response_method - and pos_label. - - pos_label: str or int - The class considered as the positive class when computing - the metrics. - """ - classification_error = ( - "{} should be a binary classifier".format(estimator.__class__.__name__) - ) - - if not is_classifier(estimator): - raise ValueError(classification_error) - - prediction_method = _check_classifier_response_method( - estimator, response_method) - - y_pred = prediction_method(X) - - if pos_label is not None and pos_label not in estimator.classes_: - raise ValueError( - f"The class provided by 'pos_label' is unknown. Got " - f"{pos_label} instead of one of {estimator.classes_}" - ) - - if y_pred.ndim != 1: # `predict_proba` - if y_pred.shape[1] != 2: - raise ValueError(classification_error) - if pos_label is None: - pos_label = estimator.classes_[1] - y_pred = y_pred[:, 1] - else: - class_idx = np.flatnonzero(estimator.classes_ == pos_label) - y_pred = y_pred[:, class_idx] - else: - if pos_label is None: - pos_label = estimator.classes_[1] - elif pos_label == estimator.classes_[0]: - y_pred *= -1 - - return y_pred, pos_label diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index b355283407802..d75558437dee6 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -1,4 +1,4 @@ -from .base import _get_response +from .._base import _get_response from .. import average_precision_score from .. import precision_recall_curve diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 6246209912694..51ee896c6ca0d 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -1,4 +1,4 @@ -from .base import _get_response +from .._base import _get_response from .. import auc from .. import roc_curve diff --git a/sklearn/metrics/_scorer.py b/sklearn/metrics/_scorer.py index 9ad57f4611e52..ffbc4bb0e5f8c 100644 --- a/sklearn/metrics/_scorer.py +++ b/sklearn/metrics/_scorer.py @@ -21,9 +21,11 @@ from collections.abc import Iterable from functools import partial from collections import Counter +from inspect import signature import numpy as np +from ._base import _get_response from . import (r2_score, median_absolute_error, max_error, mean_absolute_error, mean_squared_error, mean_squared_log_error, mean_poisson_deviance, mean_gamma_deviance, accuracy_score, @@ -46,16 +48,17 @@ from ..base import is_regressor -def _cached_call(cache, estimator, method, *args, **kwargs): +def _cached_call(cache, estimator, *args, **kwargs): """Call estimator with method and args and kwargs.""" + response_method = kwargs["response_method"] if cache is None: - return getattr(estimator, method)(*args, **kwargs) + return _get_response(estimator, *args, **kwargs) try: - return cache[method] + return cache[response_method] except KeyError: - result = getattr(estimator, method)(*args, **kwargs) - cache[method] = result + result = _get_response(estimator, *args, **kwargs) + cache[response_method] = result return result @@ -83,8 +86,9 @@ def __call__(self, estimator, *args, **kwargs): for name, scorer in self._scorers.items(): if isinstance(scorer, _BaseScorer): - score = scorer._score(cached_call, estimator, - *args, **kwargs) + score = scorer._score( + cached_call, estimator, *args, **kwargs + ) else: score = scorer(estimator, *args, **kwargs) scores[name] = score @@ -127,6 +131,16 @@ def __init__(self, score_func, sign, kwargs): self._score_func = score_func self._sign = sign + def _get_pos_label(self): + score_func_params = signature(self._score_func).parameters + if "pos_label" in self._kwargs: + pos_label = self._kwargs["pos_label"] + elif "pos_label" in score_func_params: + pos_label = score_func_params["pos_label"].default + else: + pos_label = None + return pos_label + def __repr__(self): kwargs_string = "".join([", %s=%s" % (str(k), str(v)) for k, v in self._kwargs.items()]) @@ -158,8 +172,13 @@ def __call__(self, estimator, X, y_true, sample_weight=None): score : float Score function applied to prediction of estimator on X. """ - return self._score(partial(_cached_call, None), estimator, X, y_true, - sample_weight=sample_weight) + return self._score( + partial(_cached_call, None), + estimator, + X, + y_true, + sample_weight=sample_weight + ) def _factory_args(self): """Return non-default make_scorer arguments for repr.""" @@ -194,15 +213,17 @@ def _score(self, method_caller, estimator, X, y_true, sample_weight=None): score : float Score function applied to prediction of estimator on X. """ - - y_pred = method_caller(estimator, "predict", X) + y_pred, _ = method_caller( + estimator, X, y_true, response_method="predict" + ) if sample_weight is not None: - return self._sign * self._score_func(y_true, y_pred, - sample_weight=sample_weight, - **self._kwargs) + return self._sign * self._score_func( + y_true, y_pred, sample_weight=sample_weight, **self._kwargs + ) else: - return self._sign * self._score_func(y_true, y_pred, - **self._kwargs) + return self._sign * self._score_func( + y_true, y_pred, **self._kwargs + ) class _ProbaScorer(_BaseScorer): @@ -234,21 +255,19 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): score : float Score function applied to prediction of estimator on X. """ + y_pred, _ = method_caller( + clf, + X, + y, + response_method="predict_proba", + pos_label=self._get_pos_label(), + support_multi_class=True, + ) - y_type = type_of_target(y) - y_pred = method_caller(clf, "predict_proba", X) - if y_type == "binary": - if y_pred.shape[1] == 2: - y_pred = y_pred[:, 1] - elif y_pred.shape[1] == 1: # not multiclass - raise ValueError('got predict_proba of shape {},' - ' but need classifier with two' - ' classes for {} scoring'.format( - y_pred.shape, self._score_func.__name__)) if sample_weight is not None: - return self._sign * self._score_func(y, y_pred, - sample_weight=sample_weight, - **self._kwargs) + return self._sign * self._score_func( + y, y_pred, sample_weight=sample_weight, **self._kwargs + ) else: return self._sign * self._score_func(y, y_pred, **self._kwargs) @@ -293,34 +312,38 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): raise ValueError("{0} format is not supported".format(y_type)) if is_regressor(clf): - y_pred = method_caller(clf, "predict", X) + y_pred, _ = method_caller(clf, X, y, response_method="predict") else: + pos_label = self._get_pos_label() try: - y_pred = method_caller(clf, "decision_function", X) + y_pred, _ = method_caller( + clf, + X, + y, + response_method="decision_function", + pos_label=pos_label, + ) - # For multi-output multi-class estimator if isinstance(y_pred, list): + # For multi-output multi-class estimator y_pred = np.vstack([p for p in y_pred]).T - except (NotImplementedError, AttributeError): - y_pred = method_caller(clf, "predict_proba", X) - - if y_type == "binary": - if y_pred.shape[1] == 2: - y_pred = y_pred[:, 1] - else: - raise ValueError('got predict_proba of shape {},' - ' but need classifier with two' - ' classes for {} scoring'.format( - y_pred.shape, - self._score_func.__name__)) - elif isinstance(y_pred, list): + except (NotImplementedError, AttributeError, ValueError): + y_pred, _ = method_caller( + clf, + X, + y, + response_method="predict_proba", + pos_label=pos_label, + ) + + if isinstance(y_pred, list): y_pred = np.vstack([p[:, -1] for p in y_pred]).T if sample_weight is not None: - return self._sign * self._score_func(y, y_pred, - sample_weight=sample_weight, - **self._kwargs) + return self._sign * self._score_func( + y, y_pred, sample_weight=sample_weight, **self._kwargs + ) else: return self._sign * self._score_func(y, y_pred, **self._kwargs) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 6677f3119dacd..e093c4107a5b0 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -4,7 +4,6 @@ from itertools import chain from itertools import permutations import warnings -import re import numpy as np from scipy import linalg @@ -1247,7 +1246,7 @@ def test_multilabel_hamming_loss(): def test_jaccard_score_validation(): y_true = np.array([0, 1, 0, 1, 1]) y_pred = np.array([0, 1, 0, 1, 1]) - err_msg = r"pos_label=2 is not a valid label: array\(\[0, 1\]\)" + err_msg = r"pos_label=2 is not a valid label: \[0, 1\]" with pytest.raises(ValueError, match=err_msg): jaccard_score(y_true, y_pred, average='binary', pos_label=2) @@ -2262,9 +2261,12 @@ def test_brier_score_loss(): # ensure to raise an error for multiclass y_true y_true = np.array([0, 1, 2, 0]) y_pred = np.array([0.8, 0.6, 0.4, 0.2]) - error_message = ("Only binary classification is supported. Labels " - "in y_true: {}".format(np.array([0, 1, 2]))) - with pytest.raises(ValueError, match=re.escape(error_message)): + error_message = ( + r"Only binary classification is supported. Labels in y_true: " + r"\[0 1 2\]" + ) + + with pytest.raises(ValueError, match=error_message): brier_score_loss(y_true, y_pred) # calculate correctly when there's only one class in y_true diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 67900b7cb77c3..97c35bf48db7c 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -1,3 +1,4 @@ +from copy import deepcopy import pickle import tempfile import shutil @@ -16,9 +17,18 @@ from sklearn.utils._testing import ignore_warnings from sklearn.base import BaseEstimator -from sklearn.metrics import (f1_score, r2_score, roc_auc_score, fbeta_score, - log_loss, precision_score, recall_score, - jaccard_score) +from sklearn.metrics import ( + average_precision_score, + brier_score_loss, + f1_score, + fbeta_score, + jaccard_score, + log_loss, + precision_score, + r2_score, + recall_score, + roc_auc_score, +) from sklearn.metrics import cluster as cluster_module from sklearn.metrics import check_scoring from sklearn.metrics._scorer import (_PredictScorer, _passthrough_scorer, @@ -137,12 +147,14 @@ class EstimatorWithoutFit: class EstimatorWithFit(BaseEstimator): """Dummy estimator to test scoring validators""" def fit(self, X, y): + self.classes_ = np.unique(y) return self class EstimatorWithFitAndScore: """Dummy estimator to test scoring validators""" def fit(self, X, y): + self.classes_ = np.unique(y) return self def score(self, X, y): @@ -152,6 +164,7 @@ def score(self, X, y): class EstimatorWithFitAndPredict: """Dummy estimator to test scoring validators""" def fit(self, X, y): + self.classes_ = np.unique(y) self.y = y return self @@ -606,18 +619,25 @@ def test_multimetric_scorer_calls_method_once(scorers, expected_predict_count, X, y = np.array([[1], [1], [0], [0], [0]]), np.array([0, 1, 1, 1, 0]) mock_est = Mock() + mock_est._estimator_type = "classifier" fit_func = Mock(return_value=mock_est) + fit_func.__name__ = "fit" predict_func = Mock(return_value=y) + predict_func.__name__ = "predict" pos_proba = np.random.rand(X.shape[0]) proba = np.c_[1 - pos_proba, pos_proba] predict_proba_func = Mock(return_value=proba) + predict_proba_func.__name__ = "predict_proba" decision_function_func = Mock(return_value=pos_proba) + decision_function_func.__name__ = "decision_function" mock_est.fit = fit_func mock_est.predict = predict_func mock_est.predict_proba = predict_proba_func mock_est.decision_function = decision_function_func + # add the classes that would be found during fit + mock_est.classes_ = np.array([0, 1]) scorer_dict = _check_multimetric_scoring(LogisticRegression(), scorers) multi_scorer = _MultimetricScorer(**scorer_dict) @@ -744,6 +764,191 @@ def test_multiclass_roc_no_proba_scorer_errors(scorer_name): X, y = make_classification(n_classes=3, n_informative=3, n_samples=20, random_state=0) lr = Perceptron().fit(X, y) - msg = "'Perceptron' object has no attribute 'predict_proba'" - with pytest.raises(AttributeError, match=msg): + msg = "response method predict_proba is not defined in Perceptron" + with pytest.raises(ValueError, match=msg): + scorer(lr, X, y) + + +@pytest.fixture +def string_labeled_classification_problem(): + """Train a classifier on binary problem with string target. + + The classifier is trained on a binary classification problem where the + minority class of interest has a string label that is intentionally not the + greatest class label using the lexicographic order. + + In addition, the dataset is imbalanced to better identify problems when + using non-symmetric performance metrics such as f1-score, average precision + and so on. + + Returns + ------- + classifier : estimator object + Trained classifier on the binary problem. + X_test : ndarray of shape (n_samples, n_features) + Data to be used as testing set in tests. + y_test : ndarray of shape (n_samples,), dtype=object + Binary target where labels are strings. + y_pred : ndarray of shape (n_samples,), dtype=object + Prediction of `classifier` when predicting for `X_test`. + y_pred_proba : ndarray of shape (n_samples, 2), dtype=np.float64 + Probabilities of `classifier` when predicting for `X_test`. + y_pred_decision : ndarray of shape (n_samples,), dtype=np.float64 + Decision function values of `classifier` when predicting on `X_test`. + """ + from sklearn.datasets import load_breast_cancer + from sklearn.utils import shuffle + + X, y = load_breast_cancer(return_X_y=True) + # create an highly imbalanced classification task + idx_positive = np.flatnonzero(y == 1) + idx_negative = np.flatnonzero(y == 0) + idx_selected = np.hstack([idx_negative, idx_positive[:25]]) + X, y = X[idx_selected], y[idx_selected] + X, y = shuffle(X, y, random_state=42) + # only use 2 features to make the problem even harder + X = X[:, :2] + y = np.array( + ["cancer" if c == 1 else "not cancer" for c in y], dtype=object + ) + X_train, X_test, y_train, y_test = train_test_split( + X, y, stratify=y, random_state=0, + ) + classifier = LogisticRegression().fit(X_train, y_train) + y_pred = classifier.predict(X_test) + y_pred_proba = classifier.predict_proba(X_test) + y_pred_decision = classifier.decision_function(X_test) + + return classifier, X_test, y_test, y_pred, y_pred_proba, y_pred_decision + + +def test_average_precision_pos_label(string_labeled_classification_problem): + # check that _ThresholdScorer will lead to the right score when passing + # `pos_label`. Currently, only `average_precision_score` is defined to + # be such a scorer. + clf, X_test, y_test, _, y_pred_proba, y_pred_decision = \ + string_labeled_classification_problem + + pos_label = "cancer" + # we need to select the positive column or reverse the decision values + y_pred_proba = y_pred_proba[:, 0] + y_pred_decision = y_pred_decision * -1 + assert clf.classes_[0] == pos_label + + # check that when calling the scoring function, probability estimates and + # decision values lead to the same results + ap_proba = average_precision_score( + y_test, y_pred_proba, pos_label=pos_label + ) + ap_decision_function = average_precision_score( + y_test, y_pred_decision, pos_label=pos_label + ) + assert ap_proba == pytest.approx(ap_decision_function) + + # create a scorer which would require to pass a `pos_label` + # check that it fails if `pos_label` is not provided + average_precision_scorer = make_scorer( + average_precision_score, needs_threshold=True, + ) + err_msg = "pos_label" + with pytest.raises(ValueError, match=err_msg): + average_precision_scorer(clf, X_test, y_test) + + # otherwise, the scorer should give the same results than calling the + # scoring function + average_precision_scorer = make_scorer( + average_precision_score, needs_threshold=True, pos_label=pos_label + ) + ap_scorer = average_precision_scorer(clf, X_test, y_test) + + assert ap_scorer == pytest.approx(ap_proba) + + # The above scorer call is using `clf.decision_function`. We will force + # it to use `clf.predict_proba`. + clf_without_predict_proba = deepcopy(clf) + + def _predict_proba(self, X): + raise NotImplementedError + + clf_without_predict_proba.predict_proba = partial( + _predict_proba, clf_without_predict_proba + ) + # sanity check + with pytest.raises(NotImplementedError): + clf_without_predict_proba.predict_proba(X_test) + + ap_scorer = average_precision_scorer( + clf_without_predict_proba, X_test, y_test + ) + assert ap_scorer == pytest.approx(ap_proba) + + +def test_brier_score_loss_pos_label(string_labeled_classification_problem): + # check that _ProbaScorer leads to the right score when `pos_label` is + # provided. Currently only the `brier_score_loss` is defined to be such + # a scorer. + clf, X_test, y_test, _, y_pred_proba, _ = \ + string_labeled_classification_problem + + pos_label = "cancer" + assert clf.classes_[0] == pos_label + + # brier score loss is symmetric + brier_pos_cancer = brier_score_loss( + y_test, y_pred_proba[:, 0], pos_label="cancer" + ) + brier_pos_not_cancer = brier_score_loss( + y_test, y_pred_proba[:, 1], pos_label="not cancer" + ) + assert brier_pos_cancer == pytest.approx(brier_pos_not_cancer) + + brier_scorer = make_scorer( + brier_score_loss, needs_proba=True, pos_label=pos_label, + ) + assert brier_scorer(clf, X_test, y_test) == pytest.approx(brier_pos_cancer) + + +@pytest.mark.parametrize( + "score_func", [f1_score, precision_score, recall_score, jaccard_score] +) +def test_non_symmetric_metric_pos_label( + score_func, string_labeled_classification_problem +): + # check that _PredictScorer leads to the right score when `pos_label` is + # provided. We check for all possible metric supported. + clf, X_test, y_test, y_pred, _, _ = string_labeled_classification_problem + + pos_label = "cancer" + assert clf.classes_[0] == pos_label + + score_pos_cancer = score_func(y_test, y_pred, pos_label="cancer") + score_pos_not_cancer = score_func(y_test, y_pred, pos_label="not cancer") + + assert score_pos_cancer != pytest.approx(score_pos_not_cancer) + + scorer = make_scorer(score_func, pos_label=pos_label) + assert scorer(clf, X_test, y_test) == pytest.approx(score_pos_cancer) + + +@pytest.mark.parametrize( + "scorer", + [ + make_scorer( + average_precision_score, needs_threshold=True, pos_label="xxx" + ), + make_scorer(brier_score_loss, needs_proba=True, pos_label="xxx"), + make_scorer(f1_score, pos_label="xxx") + ], + ids=["ThresholdScorer", "ProbaScorer", "PredictScorer"], +) +def test_scorer_select_proba_error(scorer): + # check that we raise the the proper error when passing an unknown + # pos_label + X, y = make_classification( + n_classes=2, n_informative=3, n_samples=20, random_state=0 + ) + lr = LogisticRegression(multi_class="multinomial").fit(X, y) + + err_msg = "pos_label" + with pytest.raises(ValueError, match=err_msg): scorer(lr, X, y)