From a56558932d78c3dca172525867070ce44be7be77 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 7 Aug 2020 09:42:50 +0200 Subject: [PATCH 01/19] TST wip --- sklearn/metrics/_scorer.py | 20 +++++-- sklearn/metrics/tests/test_score_objects.py | 58 +++++++++++++++++++++ 2 files changed, 75 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_scorer.py b/sklearn/metrics/_scorer.py index 9ad57f4611e52..5d32cc788cd72 100644 --- a/sklearn/metrics/_scorer.py +++ b/sklearn/metrics/_scorer.py @@ -239,7 +239,9 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): y_pred = method_caller(clf, "predict_proba", X) if y_type == "binary": if y_pred.shape[1] == 2: - y_pred = y_pred[:, 1] + pos_label = self._kwargs.get("pos_label", clf.classes_[1]) + col_idx = np.flatnonzero(clf.classes_ == pos_label)[0] + y_pred = y_pred[:, col_idx] elif y_pred.shape[1] == 1: # not multiclass raise ValueError('got predict_proba of shape {},' ' but need classifier with two' @@ -298,16 +300,28 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): try: y_pred = method_caller(clf, "decision_function", X) - # For multi-output multi-class estimator if isinstance(y_pred, list): + # For multi-output multi-class estimator y_pred = np.vstack([p for p in y_pred]).T + elif ( + y_type == "binary" + and "pos_label" in self._kwargs + and self._kwargs["pos_label"] == clf.classes_[0] + ): + # The positive class is not the `pos_label` seen by the + # classifier and we need to inverse the predictions + y_pred *= -1 except (NotImplementedError, AttributeError): y_pred = method_caller(clf, "predict_proba", X) if y_type == "binary": if y_pred.shape[1] == 2: - y_pred = y_pred[:, 1] + pos_label = self._kwargs.get( + "pos_label", clf.classes_[1] + ) + col_idx = np.flatnonzero(clf.classes_ == pos_label)[0] + y_pred = y_pred[:, col_idx] else: raise ValueError('got predict_proba of shape {},' ' but need classifier with two' diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 67900b7cb77c3..2ee064ecd9b12 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -618,6 +618,8 @@ def test_multimetric_scorer_calls_method_once(scorers, expected_predict_count, mock_est.predict = predict_func mock_est.predict_proba = predict_proba_func mock_est.decision_function = decision_function_func + # add the classes that would be found during fit + mock_est.classes_ = np.array([0, 1]) scorer_dict = _check_multimetric_scoring(LogisticRegression(), scorers) multi_scorer = _MultimetricScorer(**scorer_dict) @@ -747,3 +749,59 @@ def test_multiclass_roc_no_proba_scorer_errors(scorer_name): msg = "'Perceptron' object has no attribute 'predict_proba'" with pytest.raises(AttributeError, match=msg): scorer(lr, X, y) + + +def _make_imbalanced_string_dataset(): + from sklearn.datasets import load_breast_cancer + from sklearn.utils import shuffle + + X, y = load_breast_cancer(return_X_y=True) + # create an highly imbalanced + idx_positive = np.flatnonzero(y == 1) + idx_negative = np.flatnonzero(y == 0) + idx_selected = np.hstack([idx_negative, idx_positive[:25]]) + X, y = X[idx_selected], y[idx_selected] + X, y = shuffle(X, y, random_state=42) + # only use 2 features to make the problem even harder + X = X[:, :2] + y = np.array( + ["cancer" if c == 1 else "not cancer" for c in y], dtype=object + ) + X_train, X_test, y_train, y_test = train_test_split( + X, y, stratify=y, random_state=0, + ) + return X_train, X_test, y_train, y_test + + +def test_xxx(): + from sklearn.metrics import average_precision_score + X_train, X_test, y_train, y_test = _make_imbalanced_string_dataset() + + classifier = LogisticRegression().fit(X_train, y_train) + y_proba = classifier.predict_proba(X_test) + y_decision_function = classifier.decision_function(X_test) + + pos_label = "cancer" + y_proba = y_proba[:, 0] + y_decision_function *= -1 + + assert classifier.classes_[0] == pos_label + + ap_proba = average_precision_score(y_test, y_proba, pos_label=pos_label) + ap_decision_function = average_precision_score( + y_test, y_decision_function, pos_label=pos_label + ) + assert ap_proba == pytest.approx(ap_decision_function) + + average_precision_scorer = make_scorer( + average_precision_score, needs_threshold=True, + ) + with pytest.raises(ValueError): + average_precision_scorer(classifier, X_test, y_test) + + average_precision_scorer = make_scorer( + average_precision_score, needs_threshold=True, pos_label=pos_label + ) + ap_scorer = average_precision_scorer(classifier, X_test, y_test) + + assert ap_scorer == pytest.approx(ap_proba) From 6c0db49123b74ae32998d8ccb8c40aba59653073 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 7 Aug 2020 10:43:45 +0200 Subject: [PATCH 02/19] TST wip --- sklearn/metrics/tests/test_score_objects.py | 79 +++++++++++++++------ 1 file changed, 58 insertions(+), 21 deletions(-) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 2ee064ecd9b12..4f146c34fe606 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -16,9 +16,18 @@ from sklearn.utils._testing import ignore_warnings from sklearn.base import BaseEstimator -from sklearn.metrics import (f1_score, r2_score, roc_auc_score, fbeta_score, - log_loss, precision_score, recall_score, - jaccard_score) +from sklearn.metrics import ( + average_precision_score, + brier_score_loss, + f1_score, + fbeta_score, + jaccard_score, + log_loss, + precision_score, + r2_score, + recall_score, + roc_auc_score, +) from sklearn.metrics import cluster as cluster_module from sklearn.metrics import check_scoring from sklearn.metrics._scorer import (_PredictScorer, _passthrough_scorer, @@ -751,7 +760,8 @@ def test_multiclass_roc_no_proba_scorer_errors(scorer_name): scorer(lr, X, y) -def _make_imbalanced_string_dataset(): +@pytest.fixture +def fitted_clf_predictions(): from sklearn.datasets import load_breast_cancer from sklearn.utils import shuffle @@ -770,38 +780,65 @@ def _make_imbalanced_string_dataset(): X_train, X_test, y_train, y_test = train_test_split( X, y, stratify=y, random_state=0, ) - return X_train, X_test, y_train, y_test + classifier = LogisticRegression().fit(X_train, y_train) + y_pred_proba = classifier.predict_proba(X_test) + y_pred_decision = classifier.decision_function(X_test) + return classifier, X_test, y_test, y_pred_proba, y_pred_decision -def test_xxx(): - from sklearn.metrics import average_precision_score - X_train, X_test, y_train, y_test = _make_imbalanced_string_dataset() - classifier = LogisticRegression().fit(X_train, y_train) - y_proba = classifier.predict_proba(X_test) - y_decision_function = classifier.decision_function(X_test) +def test_average_precision_pos_label(fitted_clf_predictions): + clf, X_test, y_test, y_pred_proba, y_pred_decision = fitted_clf_predictions pos_label = "cancer" - y_proba = y_proba[:, 0] - y_decision_function *= -1 - - assert classifier.classes_[0] == pos_label - - ap_proba = average_precision_score(y_test, y_proba, pos_label=pos_label) + # we need to select the positive column or reverse the decision values + y_pred_proba = y_pred_proba[:, 0] + y_pred_decision = y_pred_decision * -1 + assert clf.classes_[0] == pos_label + + # check that when calling the scoring function, probability estimates and + # decision values lead to the same results + ap_proba = average_precision_score( + y_test, y_pred_proba, pos_label=pos_label + ) ap_decision_function = average_precision_score( - y_test, y_decision_function, pos_label=pos_label + y_test, y_pred_decision, pos_label=pos_label ) assert ap_proba == pytest.approx(ap_decision_function) + # create a scorer which would require to pass a `pos_label` + # check that it fails if `pos_label` is not provided average_precision_scorer = make_scorer( average_precision_score, needs_threshold=True, ) - with pytest.raises(ValueError): - average_precision_scorer(classifier, X_test, y_test) + err_msg = "pos_label=1 is invalid. Set it to a label in y_true." + with pytest.raises(ValueError, match=err_msg): + average_precision_scorer(clf, X_test, y_test) + # otherwise, the scorer should give the same results than calling the + # scoring function average_precision_scorer = make_scorer( average_precision_score, needs_threshold=True, pos_label=pos_label ) - ap_scorer = average_precision_scorer(classifier, X_test, y_test) + ap_scorer = average_precision_scorer(clf, X_test, y_test) assert ap_scorer == pytest.approx(ap_proba) + + +def test_brier_score_loss_pos_label(fitted_clf_predictions): + clf, X_test, y_test, y_pred_proba, y_pred_decision = fitted_clf_predictions + + pos_label = "cancer" + # we need to select the positive column or reverse the decision values + y_pred_proba = y_pred_proba[:, 0] + y_pred_decision = y_pred_decision * -1 + assert clf.classes_[0] == pos_label + + print(brier_score_loss(y_test, y_pred_proba, pos_label=pos_label)) + brier_scorer = make_scorer( + brier_score_loss, + needs_proba=True, + greater_is_better=False, + pos_label=pos_label, + ) + print(brier_scorer(clf, X_test, y_test)) From c41e999205f9b6f1e8ece6d0fcf6baa0cde52361 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 7 Aug 2020 11:01:36 +0200 Subject: [PATCH 03/19] TST wip --- sklearn/metrics/tests/test_score_objects.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 4f146c34fe606..040a64c298482 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -829,16 +829,18 @@ def test_brier_score_loss_pos_label(fitted_clf_predictions): clf, X_test, y_test, y_pred_proba, y_pred_decision = fitted_clf_predictions pos_label = "cancer" - # we need to select the positive column or reverse the decision values - y_pred_proba = y_pred_proba[:, 0] - y_pred_decision = y_pred_decision * -1 assert clf.classes_[0] == pos_label - print(brier_score_loss(y_test, y_pred_proba, pos_label=pos_label)) + # brier score loss is symmetric + brier_pos_cancer = brier_score_loss( + y_test, y_pred_proba[:, 0], pos_label="cancer" + ) + brier_pos_not_cancer = brier_score_loss( + y_test, y_pred_proba[:, 1], pos_label="not cancer" + ) + assert brier_pos_cancer == brier_pos_not_cancer + brier_scorer = make_scorer( - brier_score_loss, - needs_proba=True, - greater_is_better=False, - pos_label=pos_label, + brier_score_loss, needs_proba=True, pos_label=pos_label, ) - print(brier_scorer(clf, X_test, y_test)) + assert brier_scorer(clf, X_test, y_test) == pytest.approx(brier_pos_cancer) From f53f8336ae3200de7a9d0d007ed1745356744f5d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 7 Aug 2020 11:05:56 +0200 Subject: [PATCH 04/19] TST PEP8 + comments --- sklearn/metrics/tests/test_score_objects.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 040a64c298482..a1a5eaeac271c 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -788,6 +788,9 @@ def fitted_clf_predictions(): def test_average_precision_pos_label(fitted_clf_predictions): + # check that _ThresholdScorer will lead to the right score when passing + # `pos_label`. Currently, only `average_precision_score` is defined to + # be such a scorer. clf, X_test, y_test, y_pred_proba, y_pred_decision = fitted_clf_predictions pos_label = "cancer" @@ -826,7 +829,10 @@ def test_average_precision_pos_label(fitted_clf_predictions): def test_brier_score_loss_pos_label(fitted_clf_predictions): - clf, X_test, y_test, y_pred_proba, y_pred_decision = fitted_clf_predictions + # check that _ProbaScorer leads to the right score when `pos_label` is + # provided. Currently only the `brier_score_loss` is defined to be such + # a scorer. + clf, X_test, y_test, y_pred_proba, _ = fitted_clf_predictions pos_label = "cancer" assert clf.classes_[0] == pos_label From dd4e9fe00538f97487329c540aac1bb47d9ee280 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 7 Aug 2020 11:22:54 +0200 Subject: [PATCH 05/19] TST force to use predict_proba as well --- sklearn/metrics/tests/test_score_objects.py | 22 ++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index a1a5eaeac271c..135e5bd329458 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -1,10 +1,11 @@ +from copy import deepcopy import pickle import tempfile import shutil import os import numbers from unittest.mock import Mock -from functools import partial +from functools import partial, partialmethod import numpy as np import pytest @@ -827,6 +828,25 @@ def test_average_precision_pos_label(fitted_clf_predictions): assert ap_scorer == pytest.approx(ap_proba) + # The above scorer call is using `clf.decision_function`. We will force + # it to use `clf.predict_proba`. + clf_without_predict_proba = deepcopy(clf) + + def _predict_proba(self, X): + raise NotImplementedError + + clf_without_predict_proba.predict_proba = partial( + _predict_proba, clf_without_predict_proba + ) + # sanity check + with pytest.raises(NotImplementedError): + clf_without_predict_proba.predict_proba(X_test) + + ap_scorer = average_precision_scorer( + clf_without_predict_proba, X_test, y_test + ) + assert ap_scorer == pytest.approx(ap_proba) + def test_brier_score_loss_pos_label(fitted_clf_predictions): # check that _ProbaScorer leads to the right score when `pos_label` is From e32cfa757b40673d0dc6b7e7eea44b8c1f39b880 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 7 Aug 2020 11:26:37 +0200 Subject: [PATCH 06/19] DOC add whats + PEP8 --- doc/whats_new/v0.24.rst | 5 +++++ sklearn/metrics/tests/test_score_objects.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/doc/whats_new/v0.24.rst b/doc/whats_new/v0.24.rst index 347b30bff5685..a6b05b4d646b5 100644 --- a/doc/whats_new/v0.24.rst +++ b/doc/whats_new/v0.24.rst @@ -268,6 +268,11 @@ Changelog class to be used when computing the roc auc statistics. :pr:`17651` by :user:`Clara Matos `. +- |Fix| Fix a bug which was not selected the appropriate probability estimates + or reversing the decision values if `pos_label` was provided and it was not + corresponding to `classifier.classes_[1]`. + :pr:`#18114` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.model_selection` .............................. diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 135e5bd329458..a7d978a084ff6 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -5,7 +5,7 @@ import os import numbers from unittest.mock import Mock -from functools import partial, partialmethod +from functools import partial import numpy as np import pytest From 07915e99e8218abf1bbc14194bdd6046036b3637 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 7 Aug 2020 17:01:02 +0200 Subject: [PATCH 07/19] TST add some tolerance since the average of squared in diff ordered --- sklearn/metrics/tests/test_score_objects.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index a7d978a084ff6..cda860c5610be 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -864,7 +864,7 @@ def test_brier_score_loss_pos_label(fitted_clf_predictions): brier_pos_not_cancer = brier_score_loss( y_test, y_pred_proba[:, 1], pos_label="not cancer" ) - assert brier_pos_cancer == brier_pos_not_cancer + assert brier_pos_cancer == pytest.approx(brier_pos_not_cancer) brier_scorer = make_scorer( brier_score_loss, needs_proba=True, pos_label=pos_label, From fc1c4227b1978bf11efc0cb41780f59e4525af6b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 10 Aug 2020 10:42:19 +0200 Subject: [PATCH 08/19] STY add better error message and refactor code --- sklearn/metrics/_scorer.py | 74 +++++++++++++-------- sklearn/metrics/tests/test_score_objects.py | 21 ++++++ 2 files changed, 67 insertions(+), 28 deletions(-) diff --git a/sklearn/metrics/_scorer.py b/sklearn/metrics/_scorer.py index 5d32cc788cd72..c8c0e4d4a11f5 100644 --- a/sklearn/metrics/_scorer.py +++ b/sklearn/metrics/_scorer.py @@ -127,6 +127,38 @@ def __init__(self, score_func, sign, kwargs): self._score_func = score_func self._sign = sign + @staticmethod + def _check_pos_label(pos_label, classes): + if pos_label not in classes: + raise ValueError( + f"pos_label should be present in the target when the " + f"classifier was trained. Got pos_label={pos_label} while the " + f"possible classes are {classes}." + ) + + def _select_proba(self, y_pred, classes, support_multi_class): + """Select the column of y_pred when probabilities are provided.""" + if y_pred.shape[1] == 2: + pos_label = self._kwargs.get("pos_label", classes[1]) + self._check_pos_label(pos_label, classes) + col_idx = np.flatnonzero(classes == pos_label)[0] + y_pred = y_pred[:, col_idx] + else: + err_msg = ( + f"Got predict_proba of shape {y_pred.shape}, but need " + f"classifier with two classes for {self._score_func.__name__} " + f"scoring" + ) + if support_multi_class and y_pred.shape[1] == 1: + # In _ProbaScorer, y_true can be tagged as binary while the + # y_pred is multi_class. This case is supported when label is + # provided. + raise ValueError(err_msg) + else: + raise ValueError(err_msg) + + return y_pred + def __repr__(self): kwargs_string = "".join([", %s=%s" % (str(k), str(v)) for k, v in self._kwargs.items()]) @@ -239,14 +271,9 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): y_pred = method_caller(clf, "predict_proba", X) if y_type == "binary": if y_pred.shape[1] == 2: - pos_label = self._kwargs.get("pos_label", clf.classes_[1]) - col_idx = np.flatnonzero(clf.classes_ == pos_label)[0] - y_pred = y_pred[:, col_idx] - elif y_pred.shape[1] == 1: # not multiclass - raise ValueError('got predict_proba of shape {},' - ' but need classifier with two' - ' classes for {} scoring'.format( - y_pred.shape, self._score_func.__name__)) + self._select_proba( + y_pred, clf.classes_, support_multi_class=True + ) if sample_weight is not None: return self._sign * self._score_func(y, y_pred, sample_weight=sample_weight, @@ -303,31 +330,22 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): if isinstance(y_pred, list): # For multi-output multi-class estimator y_pred = np.vstack([p for p in y_pred]).T - elif ( - y_type == "binary" - and "pos_label" in self._kwargs - and self._kwargs["pos_label"] == clf.classes_[0] - ): - # The positive class is not the `pos_label` seen by the - # classifier and we need to inverse the predictions - y_pred *= -1 + elif y_type == "binary" and "pos_label" in self._kwargs: + self._check_pos_label( + self._kwargs["pos_label"], clf.classes + ) + if self._kwargs["pos_label"] == clf.classes_[0]: + # The positive class is not the `pos_label` seen by the + # classifier and we need to inverse the predictions + y_pred *= -1 except (NotImplementedError, AttributeError): y_pred = method_caller(clf, "predict_proba", X) if y_type == "binary": - if y_pred.shape[1] == 2: - pos_label = self._kwargs.get( - "pos_label", clf.classes_[1] - ) - col_idx = np.flatnonzero(clf.classes_ == pos_label)[0] - y_pred = y_pred[:, col_idx] - else: - raise ValueError('got predict_proba of shape {},' - ' but need classifier with two' - ' classes for {} scoring'.format( - y_pred.shape, - self._score_func.__name__)) + y_pred = self._select_proba( + y_pred, clf.classes_, support_multi_class=False, + ) elif isinstance(y_pred, list): y_pred = np.vstack([p[:, -1] for p in y_pred]).T diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index cda860c5610be..9521dcdf3a064 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -870,3 +870,24 @@ def test_brier_score_loss_pos_label(fitted_clf_predictions): brier_score_loss, needs_proba=True, pos_label=pos_label, ) assert brier_scorer(clf, X_test, y_test) == pytest.approx(brier_pos_cancer) + + +@pytest.mark.parametrize( + "scorer", + [ + make_scorer( + average_precision_score, needs_threshold=True, pos_label="xxx" + ), + make_scorer(brier_score_loss, needs_proba=True, pos_label="xxx"), + ], + ids=["ThresholdScorer", "ProbaScorer"], +) +def test_scorer_select_proba_error(scorer): + X, y = make_classification( + n_classes=2, n_informative=3, n_samples=20, random_state=0 + ) + lr = LogisticRegression(multi_class="multinomial").fit(X, y) + + err_msg = "pos_label should be present in the target" + with pytest.raises(ValueError, match=err_msg): + scorer(lr, X, y) From aa5cd1655042ba2d4425b08bc571e2517a277840 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 10 Aug 2020 11:11:34 +0200 Subject: [PATCH 09/19] fix --- sklearn/metrics/_scorer.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/_scorer.py b/sklearn/metrics/_scorer.py index c8c0e4d4a11f5..fc16db7bfbbe7 100644 --- a/sklearn/metrics/_scorer.py +++ b/sklearn/metrics/_scorer.py @@ -154,7 +154,7 @@ def _select_proba(self, y_pred, classes, support_multi_class): # y_pred is multi_class. This case is supported when label is # provided. raise ValueError(err_msg) - else: + elif not support_multi_class: raise ValueError(err_msg) return y_pred @@ -270,10 +270,9 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): y_type = type_of_target(y) y_pred = method_caller(clf, "predict_proba", X) if y_type == "binary": - if y_pred.shape[1] == 2: - self._select_proba( - y_pred, clf.classes_, support_multi_class=True - ) + y_pred = self._select_proba( + y_pred, clf.classes_, support_multi_class=True + ) if sample_weight is not None: return self._sign * self._score_func(y, y_pred, sample_weight=sample_weight, @@ -332,7 +331,7 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): y_pred = np.vstack([p for p in y_pred]).T elif y_type == "binary" and "pos_label" in self._kwargs: self._check_pos_label( - self._kwargs["pos_label"], clf.classes + self._kwargs["pos_label"], clf.classes_ ) if self._kwargs["pos_label"] == clf.classes_[0]: # The positive class is not the `pos_label` seen by the From a669ecfb95c1bd9d4f258818dfc3b208d510d82b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 10 Aug 2020 11:38:12 +0200 Subject: [PATCH 10/19] fix --- sklearn/metrics/_scorer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_scorer.py b/sklearn/metrics/_scorer.py index fc16db7bfbbe7..dbd9d73326308 100644 --- a/sklearn/metrics/_scorer.py +++ b/sklearn/metrics/_scorer.py @@ -129,7 +129,7 @@ def __init__(self, score_func, sign, kwargs): @staticmethod def _check_pos_label(pos_label, classes): - if pos_label not in classes: + if pos_label not in list(classes): raise ValueError( f"pos_label should be present in the target when the " f"classifier was trained. Got pos_label={pos_label} while the " From a477e7bfa66636fb215be7b72da8bde7b72689c8 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 11 Aug 2020 14:12:50 +0200 Subject: [PATCH 11/19] Update sklearn/metrics/tests/test_score_objects.py Co-authored-by: Olivier Grisel --- sklearn/metrics/tests/test_score_objects.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 9521dcdf3a064..a996af52dd257 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -767,7 +767,7 @@ def fitted_clf_predictions(): from sklearn.utils import shuffle X, y = load_breast_cancer(return_X_y=True) - # create an highly imbalanced + # create an highly imbalanced classification task idx_positive = np.flatnonzero(y == 1) idx_negative = np.flatnonzero(y == 0) idx_selected = np.hstack([idx_negative, idx_positive[:25]]) From 09b47bbaeebfc9ae441613956ee225a408bc2888 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 11 Aug 2020 14:29:34 +0200 Subject: [PATCH 12/19] add test for PredictScorer --- sklearn/metrics/tests/test_score_objects.py | 30 ++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index a996af52dd257..615445eead601 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -782,17 +782,19 @@ def fitted_clf_predictions(): X, y, stratify=y, random_state=0, ) classifier = LogisticRegression().fit(X_train, y_train) + y_pred = classifier.predict(X_test) y_pred_proba = classifier.predict_proba(X_test) y_pred_decision = classifier.decision_function(X_test) - return classifier, X_test, y_test, y_pred_proba, y_pred_decision + return classifier, X_test, y_test, y_pred, y_pred_proba, y_pred_decision def test_average_precision_pos_label(fitted_clf_predictions): # check that _ThresholdScorer will lead to the right score when passing # `pos_label`. Currently, only `average_precision_score` is defined to # be such a scorer. - clf, X_test, y_test, y_pred_proba, y_pred_decision = fitted_clf_predictions + clf, X_test, y_test, _, y_pred_proba, y_pred_decision = \ + fitted_clf_predictions pos_label = "cancer" # we need to select the positive column or reverse the decision values @@ -852,7 +854,7 @@ def test_brier_score_loss_pos_label(fitted_clf_predictions): # check that _ProbaScorer leads to the right score when `pos_label` is # provided. Currently only the `brier_score_loss` is defined to be such # a scorer. - clf, X_test, y_test, y_pred_proba, _ = fitted_clf_predictions + clf, X_test, y_test, _, y_pred_proba, _ = fitted_clf_predictions pos_label = "cancer" assert clf.classes_[0] == pos_label @@ -872,6 +874,26 @@ def test_brier_score_loss_pos_label(fitted_clf_predictions): assert brier_scorer(clf, X_test, y_test) == pytest.approx(brier_pos_cancer) +@pytest.mark.parametrize( + "score_func", [f1_score, precision_score, recall_score, jaccard_score] +) +def test_non_symmetric_metric_pos_label(score_func, fitted_clf_predictions): + # check that _PredictScorer leads to the right score when `pos_label` is + # provided. We check for all possible metric supported. + clf, X_test, y_test, y_pred, _, _ = fitted_clf_predictions + + pos_label = "cancer" + assert clf.classes_[0] == pos_label + + score_pos_cancer = score_func(y_test, y_pred, pos_label="cancer") + score_pos_not_cancer = score_func(y_test, y_pred, pos_label="not cancer") + + assert score_pos_cancer != pytest.approx(score_pos_not_cancer) + + scorer = make_scorer(score_func, pos_label=pos_label) + assert scorer(clf, X_test, y_test) == pytest.approx(score_pos_cancer) + + @pytest.mark.parametrize( "scorer", [ @@ -883,6 +905,8 @@ def test_brier_score_loss_pos_label(fitted_clf_predictions): ids=["ThresholdScorer", "ProbaScorer"], ) def test_scorer_select_proba_error(scorer): + # check that we raise the the proper error when passing an unknown + # pos_label X, y = make_classification( n_classes=2, n_informative=3, n_samples=20, random_state=0 ) From 42e7f00a34ba3ee76d3d04021a88640431c4ba87 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 18 Aug 2020 10:55:21 +0200 Subject: [PATCH 13/19] apply olivier suggestions --- doc/whats_new/v0.24.rst | 8 ++-- sklearn/metrics/_scorer.py | 9 ++-- sklearn/metrics/tests/test_score_objects.py | 47 +++++++++++++++++---- 3 files changed, 47 insertions(+), 17 deletions(-) diff --git a/doc/whats_new/v0.24.rst b/doc/whats_new/v0.24.rst index a6b05b4d646b5..37b3b4fd1cad9 100644 --- a/doc/whats_new/v0.24.rst +++ b/doc/whats_new/v0.24.rst @@ -268,9 +268,11 @@ Changelog class to be used when computing the roc auc statistics. :pr:`17651` by :user:`Clara Matos `. -- |Fix| Fix a bug which was not selected the appropriate probability estimates - or reversing the decision values if `pos_label` was provided and it was not - corresponding to `classifier.classes_[1]`. +- |Fix| Fix scorers that accept a pos_label parameter and compute their metrics + from values returned by `decision_function` or `predict_proba`. Previously, + they would return erroneous values when pos_label was not corresponding to + `classifier.classes_[1]`. This is especially important when training + classifiers directly with string labeled target classes. :pr:`#18114` by :user:`Guillaume Lemaitre `. :mod:`sklearn.model_selection` diff --git a/sklearn/metrics/_scorer.py b/sklearn/metrics/_scorer.py index dbd9d73326308..796444f612a6d 100644 --- a/sklearn/metrics/_scorer.py +++ b/sklearn/metrics/_scorer.py @@ -131,9 +131,7 @@ def __init__(self, score_func, sign, kwargs): def _check_pos_label(pos_label, classes): if pos_label not in list(classes): raise ValueError( - f"pos_label should be present in the target when the " - f"classifier was trained. Got pos_label={pos_label} while the " - f"possible classes are {classes}." + f"pos_label={pos_label} is not a valid label: {classes}" ) def _select_proba(self, y_pred, classes, support_multi_class): @@ -334,8 +332,9 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): self._kwargs["pos_label"], clf.classes_ ) if self._kwargs["pos_label"] == clf.classes_[0]: - # The positive class is not the `pos_label` seen by the - # classifier and we need to inverse the predictions + # The implicit positive class of the binary classifier + # does not match `pos_label`: we need to invert the + # predictions y_pred *= -1 except (NotImplementedError, AttributeError): diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 615445eead601..52bafb160bfdb 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -762,7 +762,32 @@ def test_multiclass_roc_no_proba_scorer_errors(scorer_name): @pytest.fixture -def fitted_clf_predictions(): +def string_labeled_classification_problem(): + """Train a classifier on binary problem with string target. + + The classifier is trained on a binary classification problem where the + minority class of interest has a string label that is intentionally not the + greatest class label using the lexicographic order. + + In addition, the dataset is imbalanced to better identify problems when + using non-symmetric performance metrics such as f1-score, average precision + and so on. + + Returns + ------- + classifier : estimator object + Trained classifier on the binary problem. + X_test : ndarray of shape (n_samples, n_features) + Data to be used as testing set in tests. + y_test : ndarray of shape (n_samples,), dtype=object + Binary target where labels are strings. + y_pred : ndarray of shape (n_samples,), dtype=object + Prediction of `classifier` when predicting for `X_test`. + y_pred_proba : ndarray of shape (n_samples, 2), dtype=np.float64 + Probabilities of `classifier` when predicting for `X_test`. + y_pred_decision : ndarray of shape (n_samples,), dtype=np.float64 + Decision function values of `classifier` when predicting on `X_test`. + """ from sklearn.datasets import load_breast_cancer from sklearn.utils import shuffle @@ -789,12 +814,12 @@ def fitted_clf_predictions(): return classifier, X_test, y_test, y_pred, y_pred_proba, y_pred_decision -def test_average_precision_pos_label(fitted_clf_predictions): +def test_average_precision_pos_label(string_labeled_classification_problem): # check that _ThresholdScorer will lead to the right score when passing # `pos_label`. Currently, only `average_precision_score` is defined to # be such a scorer. clf, X_test, y_test, _, y_pred_proba, y_pred_decision = \ - fitted_clf_predictions + string_labeled_classification_problem pos_label = "cancer" # we need to select the positive column or reverse the decision values @@ -850,11 +875,12 @@ def _predict_proba(self, X): assert ap_scorer == pytest.approx(ap_proba) -def test_brier_score_loss_pos_label(fitted_clf_predictions): +def test_brier_score_loss_pos_label(string_labeled_classification_problem): # check that _ProbaScorer leads to the right score when `pos_label` is # provided. Currently only the `brier_score_loss` is defined to be such # a scorer. - clf, X_test, y_test, _, y_pred_proba, _ = fitted_clf_predictions + clf, X_test, y_test, _, y_pred_proba, _ = \ + string_labeled_classification_problem pos_label = "cancer" assert clf.classes_[0] == pos_label @@ -877,10 +903,12 @@ def test_brier_score_loss_pos_label(fitted_clf_predictions): @pytest.mark.parametrize( "score_func", [f1_score, precision_score, recall_score, jaccard_score] ) -def test_non_symmetric_metric_pos_label(score_func, fitted_clf_predictions): +def test_non_symmetric_metric_pos_label( + score_func, string_labeled_classification_problem +): # check that _PredictScorer leads to the right score when `pos_label` is # provided. We check for all possible metric supported. - clf, X_test, y_test, y_pred, _, _ = fitted_clf_predictions + clf, X_test, y_test, y_pred, _, _ = string_labeled_classification_problem pos_label = "cancer" assert clf.classes_[0] == pos_label @@ -901,8 +929,9 @@ def test_non_symmetric_metric_pos_label(score_func, fitted_clf_predictions): average_precision_score, needs_threshold=True, pos_label="xxx" ), make_scorer(brier_score_loss, needs_proba=True, pos_label="xxx"), + make_scorer(f1_score, pos_label="xxx") ], - ids=["ThresholdScorer", "ProbaScorer"], + ids=["ThresholdScorer", "ProbaScorer", "PredictScorer"], ) def test_scorer_select_proba_error(scorer): # check that we raise the the proper error when passing an unknown @@ -912,6 +941,6 @@ def test_scorer_select_proba_error(scorer): ) lr = LogisticRegression(multi_class="multinomial").fit(X, y) - err_msg = "pos_label should be present in the target" + err_msg = "is not a valid label" with pytest.raises(ValueError, match=err_msg): scorer(lr, X, y) From e9d787308e97789f97fd9e4ff092545c0b6dbc55 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 18 Aug 2020 11:26:57 +0200 Subject: [PATCH 14/19] use list --- sklearn/metrics/_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index a7bad09ed98d0..5db6a5798abd0 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -1252,7 +1252,7 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label): str(average_options)) y_type, y_true, y_pred = _check_targets(y_true, y_pred) - present_labels = unique_labels(y_true, y_pred) + present_labels = unique_labels(y_true, y_pred).tolist() if average == 'binary': if y_type == 'binary': if pos_label not in present_labels: From 70257689cc84eeb61a5a39e7d8ce21443c7aacd3 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 19 Aug 2020 14:31:27 +0200 Subject: [PATCH 15/19] fix --- sklearn/metrics/tests/test_classification.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 6677f3119dacd..3b512e10a271d 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -4,7 +4,6 @@ from itertools import chain from itertools import permutations import warnings -import re import numpy as np from scipy import linalg @@ -1247,7 +1246,7 @@ def test_multilabel_hamming_loss(): def test_jaccard_score_validation(): y_true = np.array([0, 1, 0, 1, 1]) y_pred = np.array([0, 1, 0, 1, 1]) - err_msg = r"pos_label=2 is not a valid label: array\(\[0, 1\]\)" + err_msg = r"pos_label=2 is not a valid label: \[0, 1\]" with pytest.raises(ValueError, match=err_msg): jaccard_score(y_true, y_pred, average='binary', pos_label=2) @@ -2262,9 +2261,12 @@ def test_brier_score_loss(): # ensure to raise an error for multiclass y_true y_true = np.array([0, 1, 2, 0]) y_pred = np.array([0.8, 0.6, 0.4, 0.2]) - error_message = ("Only binary classification is supported. Labels " - "in y_true: {}".format(np.array([0, 1, 2]))) - with pytest.raises(ValueError, match=re.escape(error_message)): + error_message = ( + "Only binary classification is supported. The type of the target is " + "multiclass" + ) + + with pytest.raises(ValueError, match=error_message): brier_score_loss(y_true, y_pred) # calculate correctly when there's only one class in y_true From 536753f29e137d1742edca5f2902ae7ca24a4686 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 19 Aug 2020 15:09:58 +0200 Subject: [PATCH 16/19] fix --- sklearn/metrics/tests/test_classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 3b512e10a271d..69fd423df21d8 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -2262,8 +2262,8 @@ def test_brier_score_loss(): y_true = np.array([0, 1, 2, 0]) y_pred = np.array([0.8, 0.6, 0.4, 0.2]) error_message = ( - "Only binary classification is supported. The type of the target is " - "multiclass" + "Only binary classification is supported. Labels in y_true: " + "\[0 1 2\]" ) with pytest.raises(ValueError, match=error_message): From 6a12a1ff396e7891763bfd9f1f50fd784d24900f Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 19 Aug 2020 15:11:36 +0200 Subject: [PATCH 17/19] PEP8 --- sklearn/metrics/tests/test_classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 69fd423df21d8..e093c4107a5b0 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -2262,8 +2262,8 @@ def test_brier_score_loss(): y_true = np.array([0, 1, 2, 0]) y_pred = np.array([0.8, 0.6, 0.4, 0.2]) error_message = ( - "Only binary classification is supported. Labels in y_true: " - "\[0 1 2\]" + r"Only binary classification is supported. Labels in y_true: " + r"\[0 1 2\]" ) with pytest.raises(ValueError, match=error_message): From e08f679af0d9dd32fef66c203492956ae257ab05 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 20 Aug 2020 11:21:48 +0200 Subject: [PATCH 18/19] MNT refactor response method in scorer --- sklearn/metrics/_base.py | 143 ++++++++++++++++++ sklearn/metrics/_plot/base.py | 114 -------------- .../metrics/_plot/precision_recall_curve.py | 2 +- sklearn/metrics/_plot/roc_curve.py | 2 +- sklearn/metrics/_scorer.py | 143 ++++++++---------- 5 files changed, 212 insertions(+), 192 deletions(-) delete mode 100644 sklearn/metrics/_plot/base.py diff --git a/sklearn/metrics/_base.py b/sklearn/metrics/_base.py index 21d0ab38f6a91..89cf51c474099 100644 --- a/sklearn/metrics/_base.py +++ b/sklearn/metrics/_base.py @@ -16,6 +16,7 @@ import numpy as np +from ..base import is_classifier, is_regressor from ..utils import check_array, check_consistent_length from ..utils.multiclass import type_of_target @@ -200,3 +201,145 @@ def _average_multiclass_ovo_score(binary_metric, y_true, y_score, pair_scores[ix] = (a_true_score + b_true_score) / 2 return np.average(pair_scores, weights=prevalence) + + +def _check_classifier_response_method(estimator, response_method): + """Return prediction method from the response_method + + Parameters + ---------- + estimator: object + Classifier to check + + response_method: {'auto', 'predict_proba', 'decision_function', 'predict'} + Specifies whether to use :term:`predict_proba` or + :term:`decision_function` as the target response. If set to 'auto', + :term:`predict_proba` is tried first and if it does not exist + :term:`decision_function` is tried next and :term:`predict` last. + + Returns + ------- + prediction_method: callable + prediction method of estimator + """ + + possible_response_methods = ( + "predict", "predict_proba", "decision_function", "auto" + ) + if response_method not in possible_response_methods: + raise ValueError( + f"response_method must be one of " + f"{','.join(possible_response_methods)}." + ) + + error_msg = "response method {} is not defined in {}" + if response_method != "auto": + prediction_method = getattr(estimator, response_method, None) + if prediction_method is None: + raise ValueError( + error_msg.format(response_method, estimator.__class__.__name__) + ) + else: + predict_proba = getattr(estimator, 'predict_proba', None) + decision_function = getattr(estimator, 'decision_function', None) + predict = getattr(estimator, 'predict', None) + prediction_method = predict_proba or decision_function or predict + if prediction_method is None: + raise ValueError( + error_msg.format( + "decision_function, predict_proba or predict", + estimator.__class__.__name__ + ) + ) + + return prediction_method + + +def _get_response( + estimator, + X, + y_true, + response_method, + pos_label=None, + support_multi_class=False, +): + """Return response and positive label. + + Parameters + ---------- + estimator : estimator instance + Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` + in which the last estimator is a classifier. + + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Input values. + + y_true : array-like of shape (n_samples,) + The true label. + + response_method: {'auto', 'predict_proba', 'decision_function', 'predict'} + Specifies whether to use :term:`predict_proba` or + :term:`decision_function` as the target response. If set to 'auto', + :term:`predict_proba` is tried first and if it does not exist + :term:`decision_function` is tried next and :term:`predict` last. + + pos_label : str or int, default=None + The class considered as the positive class when computing + the metrics. By default, `estimators.classes_[1]` is + considered as the positive class. + + support_multi_class : bool, default=False + ... + + Returns + ------- + y_pred : ndarray of shape (n_samples,) + Target scores calculated from the provided response_method + and pos_label. + + pos_label : str or int + The class considered as the positive class when computing + the metrics. + """ + if is_regressor(estimator): + if response_method not in ("predict", "auto"): + raise ValueError( + f"{estimator.__class__.__name__} should be a classifier" + ) + return estimator.predict(X), None + else: + y_type = type_of_target(y_true) + classes = estimator.classes_ + prediction_method = _check_classifier_response_method( + estimator, response_method + ) + y_pred = prediction_method(X) + + if pos_label is not None and pos_label not in classes: + raise ValueError( + f"The class provided by 'pos_label' is unknown. Got " + f"{pos_label} instead of one of {classes}." + ) + + if prediction_method.__name__ == "predict_proba": + if y_type == "binary": + pos_label = pos_label if pos_label is not None else classes[-1] + if y_pred.shape[1] == 2: + col_idx = np.flatnonzero(classes == pos_label)[0] + y_pred = y_pred[:, col_idx] + else: + err_msg = ( + f"Got predict_proba of shape {y_pred.shape}, but need " + f"classifier with two classes" + ) + if support_multi_class and y_pred.shape[1] == 1: + raise ValueError(err_msg) + elif not support_multi_class: + raise ValueError(err_msg) + elif prediction_method.__name__ == "decision_function": + if y_type == "binary": + pos_label = pos_label if pos_label is not None else classes[-1] + if pos_label == classes[0]: + y_pred *= -1 + + return y_pred, pos_label diff --git a/sklearn/metrics/_plot/base.py b/sklearn/metrics/_plot/base.py deleted file mode 100644 index 0e44a7715a1ed..0000000000000 --- a/sklearn/metrics/_plot/base.py +++ /dev/null @@ -1,114 +0,0 @@ -import numpy as np - -from sklearn.base import is_classifier - - -def _check_classifier_response_method(estimator, response_method): - """Return prediction method from the response_method - - Parameters - ---------- - estimator: object - Classifier to check - - response_method: {'auto', 'predict_proba', 'decision_function'} - Specifies whether to use :term:`predict_proba` or - :term:`decision_function` as the target response. If set to 'auto', - :term:`predict_proba` is tried first and if it does not exist - :term:`decision_function` is tried next. - - Returns - ------- - prediction_method: callable - prediction method of estimator - """ - - if response_method not in ("predict_proba", "decision_function", "auto"): - raise ValueError("response_method must be 'predict_proba', " - "'decision_function' or 'auto'") - - error_msg = "response method {} is not defined in {}" - if response_method != "auto": - prediction_method = getattr(estimator, response_method, None) - if prediction_method is None: - raise ValueError(error_msg.format(response_method, - estimator.__class__.__name__)) - else: - predict_proba = getattr(estimator, 'predict_proba', None) - decision_function = getattr(estimator, 'decision_function', None) - prediction_method = predict_proba or decision_function - if prediction_method is None: - raise ValueError(error_msg.format( - "decision_function or predict_proba", - estimator.__class__.__name__)) - - return prediction_method - - -def _get_response(X, estimator, response_method, pos_label=None): - """Return response and positive label. - - Parameters - ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Input values. - - estimator : estimator instance - Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` - in which the last estimator is a classifier. - - response_method: {'auto', 'predict_proba', 'decision_function'} - Specifies whether to use :term:`predict_proba` or - :term:`decision_function` as the target response. If set to 'auto', - :term:`predict_proba` is tried first and if it does not exist - :term:`decision_function` is tried next. - - pos_label : str or int, default=None - The class considered as the positive class when computing - the metrics. By default, `estimators.classes_[1]` is - considered as the positive class. - - Returns - ------- - y_pred: ndarray of shape (n_samples,) - Target scores calculated from the provided response_method - and pos_label. - - pos_label: str or int - The class considered as the positive class when computing - the metrics. - """ - classification_error = ( - "{} should be a binary classifier".format(estimator.__class__.__name__) - ) - - if not is_classifier(estimator): - raise ValueError(classification_error) - - prediction_method = _check_classifier_response_method( - estimator, response_method) - - y_pred = prediction_method(X) - - if pos_label is not None and pos_label not in estimator.classes_: - raise ValueError( - f"The class provided by 'pos_label' is unknown. Got " - f"{pos_label} instead of one of {estimator.classes_}" - ) - - if y_pred.ndim != 1: # `predict_proba` - if y_pred.shape[1] != 2: - raise ValueError(classification_error) - if pos_label is None: - pos_label = estimator.classes_[1] - y_pred = y_pred[:, 1] - else: - class_idx = np.flatnonzero(estimator.classes_ == pos_label) - y_pred = y_pred[:, class_idx] - else: - if pos_label is None: - pos_label = estimator.classes_[1] - elif pos_label == estimator.classes_[0]: - y_pred *= -1 - - return y_pred, pos_label diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index b355283407802..d75558437dee6 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -1,4 +1,4 @@ -from .base import _get_response +from .._base import _get_response from .. import average_precision_score from .. import precision_recall_curve diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 6246209912694..51ee896c6ca0d 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -1,4 +1,4 @@ -from .base import _get_response +from .._base import _get_response from .. import auc from .. import roc_curve diff --git a/sklearn/metrics/_scorer.py b/sklearn/metrics/_scorer.py index 796444f612a6d..9a1903f86d38d 100644 --- a/sklearn/metrics/_scorer.py +++ b/sklearn/metrics/_scorer.py @@ -21,9 +21,11 @@ from collections.abc import Iterable from functools import partial from collections import Counter +from inspect import signature import numpy as np +from ._base import _get_response from . import (r2_score, median_absolute_error, max_error, mean_absolute_error, mean_squared_error, mean_squared_log_error, mean_poisson_deviance, mean_gamma_deviance, accuracy_score, @@ -46,16 +48,17 @@ from ..base import is_regressor -def _cached_call(cache, estimator, method, *args, **kwargs): +def _cached_call(cache, estimator, *args, **kwargs): """Call estimator with method and args and kwargs.""" + response_method = kwargs["response_method"] if cache is None: - return getattr(estimator, method)(*args, **kwargs) + return _get_response(estimator, *args, **kwargs) try: - return cache[method] + return cache[response_method] except KeyError: - result = getattr(estimator, method)(*args, **kwargs) - cache[method] = result + result = _get_response(estimator, *args, **kwargs) + cache[response_method] = result return result @@ -83,8 +86,9 @@ def __call__(self, estimator, *args, **kwargs): for name, scorer in self._scorers.items(): if isinstance(scorer, _BaseScorer): - score = scorer._score(cached_call, estimator, - *args, **kwargs) + score = scorer._score( + cached_call, estimator, *args, **kwargs + ) else: score = scorer(estimator, *args, **kwargs) scores[name] = score @@ -127,35 +131,15 @@ def __init__(self, score_func, sign, kwargs): self._score_func = score_func self._sign = sign - @staticmethod - def _check_pos_label(pos_label, classes): - if pos_label not in list(classes): - raise ValueError( - f"pos_label={pos_label} is not a valid label: {classes}" - ) - - def _select_proba(self, y_pred, classes, support_multi_class): - """Select the column of y_pred when probabilities are provided.""" - if y_pred.shape[1] == 2: - pos_label = self._kwargs.get("pos_label", classes[1]) - self._check_pos_label(pos_label, classes) - col_idx = np.flatnonzero(classes == pos_label)[0] - y_pred = y_pred[:, col_idx] + def _get_pos_label(self): + score_func_params = signature(self._score_func).parameters + if "pos_label" in self._kwargs: + pos_label = self._kwargs["pos_label"] + elif "pos_label" in score_func_params: + pos_label = score_func_params["pos_label"].default else: - err_msg = ( - f"Got predict_proba of shape {y_pred.shape}, but need " - f"classifier with two classes for {self._score_func.__name__} " - f"scoring" - ) - if support_multi_class and y_pred.shape[1] == 1: - # In _ProbaScorer, y_true can be tagged as binary while the - # y_pred is multi_class. This case is supported when label is - # provided. - raise ValueError(err_msg) - elif not support_multi_class: - raise ValueError(err_msg) - - return y_pred + pos_label = None + return pos_label def __repr__(self): kwargs_string = "".join([", %s=%s" % (str(k), str(v)) @@ -188,8 +172,13 @@ def __call__(self, estimator, X, y_true, sample_weight=None): score : float Score function applied to prediction of estimator on X. """ - return self._score(partial(_cached_call, None), estimator, X, y_true, - sample_weight=sample_weight) + return self._score( + partial(_cached_call, None), + estimator, + X, + y_true, + sample_weight=sample_weight + ) def _factory_args(self): """Return non-default make_scorer arguments for repr.""" @@ -224,15 +213,15 @@ def _score(self, method_caller, estimator, X, y_true, sample_weight=None): score : float Score function applied to prediction of estimator on X. """ - - y_pred = method_caller(estimator, "predict", X) + y_pred, _ = method_caller(estimator, X, y_true, response_method="predict") if sample_weight is not None: - return self._sign * self._score_func(y_true, y_pred, - sample_weight=sample_weight, - **self._kwargs) + return self._sign * self._score_func( + y_true, y_pred, sample_weight=sample_weight, **self._kwargs + ) else: - return self._sign * self._score_func(y_true, y_pred, - **self._kwargs) + return self._sign * self._score_func( + y_true, y_pred, **self._kwargs + ) class _ProbaScorer(_BaseScorer): @@ -264,17 +253,19 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): score : float Score function applied to prediction of estimator on X. """ + y_pred, _ = method_caller( + clf, + X, + y, + response_method="predict_proba", + pos_label=self._get_pos_label(), + support_multi_class=True, + ) - y_type = type_of_target(y) - y_pred = method_caller(clf, "predict_proba", X) - if y_type == "binary": - y_pred = self._select_proba( - y_pred, clf.classes_, support_multi_class=True - ) if sample_weight is not None: - return self._sign * self._score_func(y, y_pred, - sample_weight=sample_weight, - **self._kwargs) + return self._sign * self._score_func( + y, y_pred, sample_weight=sample_weight, **self._kwargs + ) else: return self._sign * self._score_func(y, y_pred, **self._kwargs) @@ -319,38 +310,38 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): raise ValueError("{0} format is not supported".format(y_type)) if is_regressor(clf): - y_pred = method_caller(clf, "predict", X) + y_pred, _ = method_caller(clf, X, y, response_method="predict") else: + pos_label = self._get_pos_label() try: - y_pred = method_caller(clf, "decision_function", X) + y_pred, _ = method_caller( + clf, + X, + y, + response_method="decision_function", + pos_label=pos_label, + ) if isinstance(y_pred, list): # For multi-output multi-class estimator y_pred = np.vstack([p for p in y_pred]).T - elif y_type == "binary" and "pos_label" in self._kwargs: - self._check_pos_label( - self._kwargs["pos_label"], clf.classes_ - ) - if self._kwargs["pos_label"] == clf.classes_[0]: - # The implicit positive class of the binary classifier - # does not match `pos_label`: we need to invert the - # predictions - y_pred *= -1 - - except (NotImplementedError, AttributeError): - y_pred = method_caller(clf, "predict_proba", X) - - if y_type == "binary": - y_pred = self._select_proba( - y_pred, clf.classes_, support_multi_class=False, - ) - elif isinstance(y_pred, list): + + except (NotImplementedError, AttributeError, ValueError): + y_pred, _ = method_caller( + clf, + X, + y, + response_method="predict_proba", + pos_label=pos_label, + ) + + if isinstance(y_pred, list): y_pred = np.vstack([p[:, -1] for p in y_pred]).T if sample_weight is not None: - return self._sign * self._score_func(y, y_pred, - sample_weight=sample_weight, - **self._kwargs) + return self._sign * self._score_func( + y, y_pred, sample_weight=sample_weight, **self._kwargs + ) else: return self._sign * self._score_func(y, y_pred, **self._kwargs) From 0b8fd88156e41434e4cf1288e685d2d696711b54 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 20 Aug 2020 11:59:58 +0200 Subject: [PATCH 19/19] iter --- sklearn/metrics/_base.py | 29 +++++++++++---------- sklearn/metrics/_scorer.py | 4 ++- sklearn/metrics/tests/test_score_objects.py | 16 +++++++++--- 3 files changed, 30 insertions(+), 19 deletions(-) diff --git a/sklearn/metrics/_base.py b/sklearn/metrics/_base.py index 89cf51c474099..ae3f4140f4974 100644 --- a/sklearn/metrics/_base.py +++ b/sklearn/metrics/_base.py @@ -16,8 +16,9 @@ import numpy as np -from ..base import is_classifier, is_regressor -from ..utils import check_array, check_consistent_length +from ..base import is_classifier +from ..utils import check_array +from ..utils import check_consistent_length from ..utils.multiclass import type_of_target @@ -208,10 +209,10 @@ def _check_classifier_response_method(estimator, response_method): Parameters ---------- - estimator: object - Classifier to check + estimator : estimator instance + Classifier to check. - response_method: {'auto', 'predict_proba', 'decision_function', 'predict'} + response_method : {'auto', 'predict_proba', 'decision_function', 'predict'} Specifies whether to use :term:`predict_proba` or :term:`decision_function` as the target response. If set to 'auto', :term:`predict_proba` is tried first and if it does not exist @@ -219,8 +220,8 @@ def _check_classifier_response_method(estimator, response_method): Returns ------- - prediction_method: callable - prediction method of estimator + prediction_method : callable + Prediction method of estimator. """ possible_response_methods = ( @@ -301,13 +302,7 @@ def _get_response( The class considered as the positive class when computing the metrics. """ - if is_regressor(estimator): - if response_method not in ("predict", "auto"): - raise ValueError( - f"{estimator.__class__.__name__} should be a classifier" - ) - return estimator.predict(X), None - else: + if is_classifier(estimator): y_type = type_of_target(y_true) classes = estimator.classes_ prediction_method = _check_classifier_response_method( @@ -341,5 +336,11 @@ def _get_response( pos_label = pos_label if pos_label is not None else classes[-1] if pos_label == classes[0]: y_pred *= -1 + else: + if response_method not in ("predict", "auto"): + raise ValueError( + f"{estimator.__class__.__name__} should be a classifier" + ) + y_pred, pos_label = estimator.predict(X), None return y_pred, pos_label diff --git a/sklearn/metrics/_scorer.py b/sklearn/metrics/_scorer.py index 9a1903f86d38d..ffbc4bb0e5f8c 100644 --- a/sklearn/metrics/_scorer.py +++ b/sklearn/metrics/_scorer.py @@ -213,7 +213,9 @@ def _score(self, method_caller, estimator, X, y_true, sample_weight=None): score : float Score function applied to prediction of estimator on X. """ - y_pred, _ = method_caller(estimator, X, y_true, response_method="predict") + y_pred, _ = method_caller( + estimator, X, y_true, response_method="predict" + ) if sample_weight is not None: return self._sign * self._score_func( y_true, y_pred, sample_weight=sample_weight, **self._kwargs diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 52bafb160bfdb..97c35bf48db7c 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -147,12 +147,14 @@ class EstimatorWithoutFit: class EstimatorWithFit(BaseEstimator): """Dummy estimator to test scoring validators""" def fit(self, X, y): + self.classes_ = np.unique(y) return self class EstimatorWithFitAndScore: """Dummy estimator to test scoring validators""" def fit(self, X, y): + self.classes_ = np.unique(y) return self def score(self, X, y): @@ -162,6 +164,7 @@ def score(self, X, y): class EstimatorWithFitAndPredict: """Dummy estimator to test scoring validators""" def fit(self, X, y): + self.classes_ = np.unique(y) self.y = y return self @@ -616,13 +619,18 @@ def test_multimetric_scorer_calls_method_once(scorers, expected_predict_count, X, y = np.array([[1], [1], [0], [0], [0]]), np.array([0, 1, 1, 1, 0]) mock_est = Mock() + mock_est._estimator_type = "classifier" fit_func = Mock(return_value=mock_est) + fit_func.__name__ = "fit" predict_func = Mock(return_value=y) + predict_func.__name__ = "predict" pos_proba = np.random.rand(X.shape[0]) proba = np.c_[1 - pos_proba, pos_proba] predict_proba_func = Mock(return_value=proba) + predict_proba_func.__name__ = "predict_proba" decision_function_func = Mock(return_value=pos_proba) + decision_function_func.__name__ = "decision_function" mock_est.fit = fit_func mock_est.predict = predict_func @@ -756,8 +764,8 @@ def test_multiclass_roc_no_proba_scorer_errors(scorer_name): X, y = make_classification(n_classes=3, n_informative=3, n_samples=20, random_state=0) lr = Perceptron().fit(X, y) - msg = "'Perceptron' object has no attribute 'predict_proba'" - with pytest.raises(AttributeError, match=msg): + msg = "response method predict_proba is not defined in Perceptron" + with pytest.raises(ValueError, match=msg): scorer(lr, X, y) @@ -842,7 +850,7 @@ def test_average_precision_pos_label(string_labeled_classification_problem): average_precision_scorer = make_scorer( average_precision_score, needs_threshold=True, ) - err_msg = "pos_label=1 is invalid. Set it to a label in y_true." + err_msg = "pos_label" with pytest.raises(ValueError, match=err_msg): average_precision_scorer(clf, X_test, y_test) @@ -941,6 +949,6 @@ def test_scorer_select_proba_error(scorer): ) lr = LogisticRegression(multi_class="multinomial").fit(X, y) - err_msg = "is not a valid label" + err_msg = "pos_label" with pytest.raises(ValueError, match=err_msg): scorer(lr, X, y)