From 51a3bcb55bef8b4dbb7210304c637b504fe0a8d7 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 3 Nov 2021 11:47:55 +0100 Subject: [PATCH 01/12] MNT refactor _get_response_values [ci skip] --- sklearn/metrics/_plot/base.py | 252 ++++++++++++++----------- sklearn/utils/__init__.py | 110 +++++++++++ sklearn/utils/_mocking.py | 43 +++++ sklearn/utils/tests/test_mocking.py | 30 ++- sklearn/utils/tests/test_utils.py | 145 +++++++++++++- sklearn/utils/tests/test_validation.py | 82 +++++++- sklearn/utils/validation.py | 54 +++++- 7 files changed, 607 insertions(+), 109 deletions(-) diff --git a/sklearn/metrics/_plot/base.py b/sklearn/metrics/_plot/base.py index 60377e3b10f66..3cbec426242a7 100644 --- a/sklearn/metrics/_plot/base.py +++ b/sklearn/metrics/_plot/base.py @@ -1,116 +1,158 @@ from ...base import is_classifier +from ...exceptions import NotFittedError +from ...utils.multiclass import type_of_target +from ...utils.validation import check_is_fitted -def _check_classifier_response_method(estimator, response_method): - """Return prediction method from the response_method +def _check_estimator_target(estimator, y): + """Helper to check that estimator is a binary classifier and y is binary. - Parameters - ---------- - estimator: object - Classifier to check + This function is aside from the class `BinaryClassifierCurveDisplayMixin` + below because it allows to have consistent error messages between the + displays and the plotting functions. - response_method: {'auto', 'predict_proba', 'decision_function'} - Specifies whether to use :term:`predict_proba` or - :term:`decision_function` as the target response. If set to 'auto', - :term:`predict_proba` is tried first and if it does not exist - :term:`decision_function` is tried next. - - Returns - ------- - prediction_method: callable - prediction method of estimator + FIXME: Move into `BinaryClassifierCurveDisplayMixin.from_estimator` when + the plotting functions will be removed in 1.2. """ + try: + check_is_fitted(estimator) + except NotFittedError as e: + raise NotFittedError( + f"This {estimator.__class__.__name__} instance is not fitted yet. Call " + "'fit' with appropriate arguments before intending to use it to plotting " + "functionalities." + ) from e - if response_method not in ("predict_proba", "decision_function", "auto"): + if not is_classifier(estimator): + raise ValueError( + "This plotting functionalities only support a binary classifier. " + f"Got a {estimator.__class__.__name__} instead." + ) + elif len(estimator.classes_) != 2: raise ValueError( - "response_method must be 'predict_proba', 'decision_function' or 'auto'" + f"This {estimator.__class__.__name__} instance is not a binary " + "classifier. It was fitted on multiclass problem with " + f"{len(estimator.classes_)} classes." + ) + elif type_of_target(y) != "binary": + raise ValueError( + f"The target y is not binary. Got {type_of_target(y)} type of target." ) - error_msg = "response method {} is not defined in {}" - if response_method != "auto": - prediction_method = getattr(estimator, response_method, None) - if prediction_method is None: - raise ValueError( - error_msg.format(response_method, estimator.__class__.__name__) - ) - else: - predict_proba = getattr(estimator, "predict_proba", None) - decision_function = getattr(estimator, "decision_function", None) - prediction_method = predict_proba or decision_function - if prediction_method is None: - raise ValueError( - error_msg.format( - "decision_function or predict_proba", estimator.__class__.__name__ - ) - ) - - return prediction_method - - -def _get_response(X, estimator, response_method, pos_label=None): - """Return response and positive label. - - Parameters - ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Input values. - - estimator : estimator instance - Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` - in which the last estimator is a classifier. - - response_method: {'auto', 'predict_proba', 'decision_function'} - Specifies whether to use :term:`predict_proba` or - :term:`decision_function` as the target response. If set to 'auto', - :term:`predict_proba` is tried first and if it does not exist - :term:`decision_function` is tried next. - - pos_label : str or int, default=None - The class considered as the positive class when computing - the metrics. By default, `estimators.classes_[1]` is - considered as the positive class. - - Returns - ------- - y_pred: ndarray of shape (n_samples,) - Target scores calculated from the provided response_method - and pos_label. - - pos_label: str or int - The class considered as the positive class when computing - the metrics. - """ - classification_error = ( - "Expected 'estimator' to be a binary classifier, but got" - f" {estimator.__class__.__name__}" - ) - if not is_classifier(estimator): - raise ValueError(classification_error) - - prediction_method = _check_classifier_response_method(estimator, response_method) - y_pred = prediction_method(X) - if pos_label is not None: - try: - class_idx = estimator.classes_.tolist().index(pos_label) - except ValueError as e: - raise ValueError( - "The class provided by 'pos_label' is unknown. Got " - f"{pos_label} instead of one of {set(estimator.classes_)}" - ) from e - else: - class_idx = 1 - pos_label = estimator.classes_[class_idx] - - if y_pred.ndim != 1: # `predict_proba` - y_pred_shape = y_pred.shape[1] - if y_pred_shape != 2: - raise ValueError( - f"{classification_error} fit on multiclass ({y_pred_shape} classes)" - " data" - ) - y_pred = y_pred[:, class_idx] - elif pos_label == estimator.classes_[0]: # `decision_function` - y_pred *= -1 - - return y_pred, pos_label +# from ...base import is_classifier + + +# def _check_classifier_response_method(estimator, response_method): +# """Return prediction method from the response_method + +# Parameters +# ---------- +# estimator: object +# Classifier to check + +# response_method: {'auto', 'predict_proba', 'decision_function'} +# Specifies whether to use :term:`predict_proba` or +# :term:`decision_function` as the target response. If set to 'auto', +# :term:`predict_proba` is tried first and if it does not exist +# :term:`decision_function` is tried next. + +# Returns +# ------- +# prediction_method: callable +# prediction method of estimator +# """ + +# if response_method not in ("predict_proba", "decision_function", "auto"): +# raise ValueError( +# "response_method must be 'predict_proba', 'decision_function' or 'auto'" +# ) + +# error_msg = "response method {} is not defined in {}" +# if response_method != "auto": +# prediction_method = getattr(estimator, response_method, None) +# if prediction_method is None: +# raise ValueError( +# error_msg.format(response_method, estimator.__class__.__name__) +# ) +# else: +# predict_proba = getattr(estimator, "predict_proba", None) +# decision_function = getattr(estimator, "decision_function", None) +# prediction_method = predict_proba or decision_function +# if prediction_method is None: +# raise ValueError( +# error_msg.format( +# "decision_function or predict_proba", estimator.__class__.__name__ +# ) +# ) + +# return prediction_method + + +# def _get_response(X, estimator, response_method, pos_label=None): +# """Return response and positive label. + +# Parameters +# ---------- +# X : {array-like, sparse matrix} of shape (n_samples, n_features) +# Input values. + +# estimator : estimator instance +# Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` +# in which the last estimator is a classifier. + +# response_method: {'auto', 'predict_proba', 'decision_function'} +# Specifies whether to use :term:`predict_proba` or +# :term:`decision_function` as the target response. If set to 'auto', +# :term:`predict_proba` is tried first and if it does not exist +# :term:`decision_function` is tried next. + +# pos_label : str or int, default=None +# The class considered as the positive class when computing +# the metrics. By default, `estimators.classes_[1]` is +# considered as the positive class. + +# Returns +# ------- +# y_pred: ndarray of shape (n_samples,) +# Target scores calculated from the provided response_method +# and pos_label. + +# pos_label: str or int +# The class considered as the positive class when computing +# the metrics. +# """ +# classification_error = ( +# "Expected 'estimator' to be a binary classifier, but got" +# f" {estimator.__class__.__name__}" +# ) + +# if not is_classifier(estimator): +# raise ValueError(classification_error) + +# prediction_method = _check_classifier_response_method(estimator, response_method) +# y_pred = prediction_method(X) +# if pos_label is not None: +# try: +# class_idx = estimator.classes_.tolist().index(pos_label) +# except ValueError as e: +# raise ValueError( +# "The class provided by 'pos_label' is unknown. Got " +# f"{pos_label} instead of one of {set(estimator.classes_)}" +# ) from e +# else: +# class_idx = 1 +# pos_label = estimator.classes_[class_idx] + +# if y_pred.ndim != 1: # `predict_proba` +# y_pred_shape = y_pred.shape[1] +# if y_pred_shape != 2: +# raise ValueError( +# f"{classification_error} fit on multiclass ({y_pred_shape} classes)" +# " data" +# ) +# y_pred = y_pred[:, class_idx] +# elif pos_label == estimator.classes_[0]: # `decision_function` +# y_pred *= -1 + +# return y_pred, pos_label diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 8290318d35deb..48b152f67053d 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -27,10 +27,12 @@ from .deprecation import deprecated from .fixes import np_version, parse_version from ._estimator_html_repr import estimator_html_repr +from .multiclass import type_of_target from .validation import ( as_float_array, assert_all_finite, check_random_state, + _check_response_method, column_or_1d, check_array, check_consistent_length, @@ -1244,3 +1246,111 @@ def is_abstract(c): # itemgetter is used to ensure the sort does not extend to the 2nd item of # the tuple return sorted(set(estimators), key=itemgetter(0)) + + +def _get_response_values( + estimator, + X, + y_true, + response_method=None, + pos_label=None, +): + """Compute the response values of a classifier or a regressor. + + The response values are predictions, one scalar value for each sample in X + that depends on the specific choice of `response_method`. + + This helper only accepts multiclass classifiers with the `predict` response + method. + + If `estimator` is a binary classifier, also return the label for the + effective positive class. + + .. versionadded:: 1.1 + + Parameters + ---------- + estimator : estimator instance + Fitted classifier or regressor or a fitted :class:`~sklearn.pipeline.Pipeline` + in which the last estimator is a classifier. + + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Input values. + + y_true : array-like of shape (n_samples,) + The true label. + + response_method : {"predict_proba", "decision_function", "predict"} or \ + list of such str, default=None + Specifies the response method to use get prediction from an estimator + (i.e. :term:`predict_proba`, :term:`decision_function` or + :term:`predict`). Possible choices are: + + - if `str`, it corresponds to the name to the method to return; + - if a list of `str`, it provides the method names in order of + preference. The method returned corresponds to the first method in + the list and which is implemented by `estimator`; + - if `None`, :term:`predict_proba` is tried first and if it does not + exist :term:`decision_function` is tried next and :term:`predict` + last. + + pos_label : str or int, default=None + The class considered as the positive class when computing + the metrics. By default, `estimators.classes_[1]` is + considered as the positive class. + + Returns + ------- + y_pred : ndarray of shape (n_samples,) + Target scores calculated from the provided response_method + and `pos_label`. + + pos_label : str, int or None + The class considered as the positive class when computing + the metrics. Returns `None` if `estimator` is a regressor. + + Raises + ------ + ValueError + If `pos_label` is not a valid label. + If the shape of `y_pred` is not consistent for binary classifier. + If the response method can be applied to a classifier only and + `estimator` is a regressor. + """ + from sklearn.base import is_classifier # noqa + + if is_classifier(estimator): + y_type = type_of_target(y_true) + prediction_method = _check_response_method(estimator, response_method) + y_pred = prediction_method(X) + classes = estimator.classes_ + + if pos_label is not None and pos_label not in classes.tolist(): + raise ValueError( + f"pos_label={pos_label} is not a valid label: It should be " + f"one of {classes}" + ) + elif pos_label is None and y_type == "binary": + pos_label = pos_label if pos_label is not None else classes[-1] + + if prediction_method.__name__ == "predict_proba": + if y_type == "binary" and y_pred.shape[1] <= 2: + if y_pred.shape[1] == 2: + col_idx = np.flatnonzero(classes == pos_label)[0] + y_pred = y_pred[:, col_idx] + else: + err_msg = ( + f"Got predict_proba of shape {y_pred.shape}, but need " + "classifier with two classes." + ) + raise ValueError(err_msg) + elif prediction_method.__name__ == "decision_function": + if y_type == "binary": + if pos_label == classes[0]: + y_pred *= -1 + else: + if response_method not in ("predict", None): + raise ValueError(f"{estimator.__class__.__name__} should be a classifier") + y_pred, pos_label = estimator.predict(X), None + + return y_pred, pos_label diff --git a/sklearn/utils/_mocking.py b/sklearn/utils/_mocking.py index 33a73f77d2d47..a0084fc37c2ba 100644 --- a/sklearn/utils/_mocking.py +++ b/sklearn/utils/_mocking.py @@ -1,6 +1,7 @@ import numpy as np from ..base import BaseEstimator, ClassifierMixin +from .metaestimators import available_if from .validation import _num_samples, check_array, check_is_fitted @@ -331,3 +332,45 @@ def predict_proba(self, X): def _more_tags(self): return {"_skip_test": True} + + +def _check_response(method): + def check(self): + if self.response_methods is not None and method in self.response_methods: + return True + return False + + return check + + +class _MockEstimatorOnOffPrediction(BaseEstimator): + """Estimator for which we can turn on/off the prediction methods. + Parameters + ---------- + response_methods: list of \ + {"predict", "predict_proba", "decision_function"}, default=None + List containing the response implemented by the estimator. When, the + response is in the list, it will return the name of the response method + when called. Otherwise, an `AttributeError` is raised. It allows to + use `getattr` as any conventional estimator. By default, no response + methods are mocked. + """ + + def __init__(self, response_methods=None): + self.response_methods = response_methods + + def fit(self, X, y): + self.classes_ = np.unique(y) + return self + + @available_if(_check_response("predict")) + def predict(self, X): + return "predict" + + @available_if(_check_response("predict_proba")) + def predict_proba(self, X): + return "predict_proba" + + @available_if(_check_response("decision_function")) + def decision_function(self, X): + return "decision_function" diff --git a/sklearn/utils/tests/test_mocking.py b/sklearn/utils/tests/test_mocking.py index 0aeeeaa572460..c1ceb55d7c21e 100644 --- a/sklearn/utils/tests/test_mocking.py +++ b/sklearn/utils/tests/test_mocking.py @@ -10,7 +10,10 @@ from sklearn.utils import _safe_indexing from sklearn.utils._testing import _convert_container -from sklearn.utils._mocking import CheckingClassifier +from sklearn.utils._mocking import ( + _MockEstimatorOnOffPrediction, + CheckingClassifier, +) @pytest.fixture @@ -178,3 +181,28 @@ def test_checking_classifier_methods_to_check(iris, methods_to_check, predict_me getattr(clf, predict_method)(X) else: getattr(clf, predict_method)(X) + + +@pytest.mark.parametrize( + "response_methods", + [ + ["predict"], + ["predict", "predict_proba"], + ["predict", "decision_function"], + ["predict", "predict_proba", "decision_function"], + ], +) +def test_mock_estimator_on_off_prediction(iris, response_methods): + X, y = iris + estimator = _MockEstimatorOnOffPrediction(response_methods=response_methods) + + estimator.fit(X, y) + assert hasattr(estimator, "classes_") + assert_array_equal(estimator.classes_, np.unique(y)) + + possible_responses = ["predict", "predict_proba", "decision_function"] + for response in possible_responses: + if response in response_methods: + assert hasattr(estimator, response) + else: + assert not hasattr(estimator, response) diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 24638d26b6138..74eaa14394d93 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -8,7 +8,18 @@ import numpy as np import scipy.sparse as sp +from sklearn.datasets import ( + make_classification, + make_regression, +) +from sklearn.linear_model import ( + LinearRegression, + LogisticRegression, +) +from sklearn.tree import DecisionTreeClassifier + from sklearn.utils._testing import ( + assert_allclose, assert_array_equal, assert_allclose_dense_sparse, assert_no_warnings, @@ -30,8 +41,12 @@ from sklearn.utils import is_scalar_nan from sklearn.utils import _to_object_array from sklearn.utils import _approximate_mode +from sklearn.utils import _get_response_values from sklearn.utils.fixes import parse_version -from sklearn.utils._mocking import MockDataFrame +from sklearn.utils._mocking import ( + _MockEstimatorOnOffPrediction, + MockDataFrame, +) from sklearn.utils._testing import SkipTest from sklearn import config_context @@ -725,3 +740,131 @@ def test_to_object_array(sequence): assert isinstance(out, np.ndarray) assert out.dtype.kind == "O" assert out.ndim == 1 + + +@pytest.mark.parametrize("response_method", ["decision_function", "predict_proba"]) +def test_get_response_values_regressor_error(response_method): + """Check the error message with regressor an not supported response + method.""" + my_estimator = _MockEstimatorOnOffPrediction(response_methods=[response_method]) + X, y = "mocking_data", "mocking_target" + err_msg = f"{my_estimator.__class__.__name__} should be a classifier" + with pytest.raises(ValueError, match=err_msg): + _get_response_values(my_estimator, X, y, response_method=response_method) + + +@pytest.mark.parametrize("response_method", ["predict", None]) +def test_get_response_values_regressor(response_method): + """Check the behaviour of `_get_response_values` with regressor.""" + X, y = make_regression(n_samples=10, random_state=0) + regressor = LinearRegression().fit(X, y) + y_pred, pos_label = _get_response_values( + regressor, + X, + y, + response_method=response_method, + ) + assert_allclose(y_pred, regressor.predict(X)) + assert pos_label is None + + +@pytest.mark.parametrize( + "response_method", + [None, "predict_proba", "decision_function", "predict"], +) +def test_get_response_values_classifier_unknown_pos_label(response_method): + """Check that `_get_response_values` raises the proper error message with + classifier.""" + X, y = make_classification(n_samples=10, n_classes=2, random_state=0) + classifier = LogisticRegression().fit(X, y) + + # provide a `pos_labe` which is not in `y` + err_msg = r"pos_label=whatever is not a valid label: It should be one of \[0 1\]" + with pytest.raises(ValueError, match=err_msg): + _get_response_values( + classifier, + X, + y, + response_method=response_method, + pos_label="whatever", + ) + + +def test_get_response_values_classifier_inconsistent_y_pred_for_binary_proba(): + """Check that `_get_response_values` will raise an error when `y_pred` has a + single class with `predict_proba`.""" + X, y_two_class = make_classification(n_samples=10, n_classes=2, random_state=0) + y_single_class = np.zeros_like(y_two_class) + classifier = DecisionTreeClassifier().fit(X, y_single_class) + + err_msg = ( + r"Got predict_proba of shape \(10, 1\), but need classifier with " + r"two classes" + ) + with pytest.raises(ValueError, match=err_msg): + _get_response_values( + classifier, X, y_two_class, response_method="predict_proba" + ) + + +def test_get_response_values_binary_classifier_decision_function(): + """Check the behaviour of `_get_response_values` with `decision_function` + and binary classifier. + """ + X, y = make_classification( + n_samples=10, + n_classes=2, + weights=[0.3, 0.7], + random_state=0, + ) + classifier = LogisticRegression().fit(X, y) + response_method = "decision_function" + + # default `pos_label` + y_pred, pos_label = _get_response_values( + classifier, X, y, response_method=response_method, pos_label=None + ) + assert_allclose(y_pred, classifier.decision_function(X)) + assert pos_label == 1 + + # when forcing `pos_label=classifier.classes_[0]` + y_pred, pos_label = _get_response_values( + classifier, + X, + y, + response_method=response_method, + pos_label=classifier.classes_[0], + ) + assert_allclose(y_pred, classifier.decision_function(X) * -1) + assert pos_label == 0 + + +def test_get_response_values_binary_classifier_predict_proba(): + """Check that `_get_response_values` with `predict_proba` and binary + classifier.""" + X, y = make_classification( + n_samples=10, + n_classes=2, + weights=[0.3, 0.7], + random_state=0, + ) + classifier = LogisticRegression().fit(X, y) + response_method = "predict_proba" + + # default `pos_label` + y_pred, pos_label = _get_response_values( + classifier, X, y, response_method=response_method, pos_label=None + ) + assert_allclose(y_pred, classifier.predict_proba(X)[:, 1]) + assert pos_label == 1 + + # when forcing `pos_label=classifier.classes_[0]` + y_pred, pos_label = _get_response_values( + classifier, + X, + y, + response_method=response_method, + pos_label=classifier.classes_[0], + ) + assert_allclose(y_pred, classifier.predict_proba(X)[:, 0]) + assert pos_label == 0 diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 18f88373b02f3..155efb2461728 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -24,7 +24,13 @@ from sklearn.utils import as_float_array, check_array, check_symmetric from sklearn.utils import check_X_y from sklearn.utils import deprecated -from sklearn.utils._mocking import MockDataFrame + +# TODO: add this estimator into the _mocking module in a further refactoring +from sklearn.metrics.tests.test_score_objects import EstimatorWithFit +from sklearn.utils._mocking import ( + MockDataFrame, + _MockEstimatorOnOffPrediction, +) from sklearn.utils.fixes import parse_version from sklearn.utils.estimator_checks import _NotAnArray from sklearn.random_projection import _sparse_random_matrix @@ -53,6 +59,7 @@ _get_feature_names, _check_feature_names_in, _check_fit_params, + _check_response_method, ) from sklearn.base import BaseEstimator import sklearn @@ -1685,3 +1692,76 @@ def test_check_feature_names_in_pandas(): with pytest.raises(ValueError, match="input_features is not equal to"): est.get_feature_names_out(["x1", "x2", "x3"]) + + +def test_check_response_method_unknown_method(): + """Check the error message when passing an unknown response method.""" + err_msg = ( + "RandomForestRegressor has none of the following attributes: unknown_method." + ) + with pytest.raises(AttributeError, match=err_msg): + _check_response_method(RandomForestRegressor(), "unknown_method") + + +@pytest.mark.parametrize( + "response_method", ["decision_function", "predict_proba", "predict", None] +) +def test_check_response_method_not_supported_response_method(response_method): + """Check the error message when a response method is not supported by the + estimator.""" + err_msg = "EstimatorWithFit has none of the following attributes: {}." + if response_method is None: + err_msg = err_msg.format("predict_proba, decision_function, predict") + else: + err_msg = err_msg.format(response_method) + with pytest.raises(AttributeError, match=err_msg): + _check_response_method(EstimatorWithFit(), response_method) + + +@pytest.mark.parametrize( + "response_methods, expected_method_name", + [ + (["predict_proba", "decision_function", "predict"], "predict_proba"), + (["decision_function", "predict"], "decision_function"), + (["predict_proba", "predict"], "predict_proba"), + (["predict_proba", "predict_proba"]), + (["decision_function", "decision_function"]), + (["predict"], "predict"), + ], +) +def test_check_response_method_order_None(response_methods, expected_method_name): + """Check the order of the response method when using None.""" + my_estimator = _MockEstimatorOnOffPrediction(response_methods) + + X = "mocking_data" + method_name_predicting = _check_response_method(my_estimator, None)(X) + assert method_name_predicting == expected_method_name + + +def test_check_response_method_list_str(): + """Check that we can pass a list of ordered method.""" + method_implemented = ["predict_proba"] + my_estimator = _MockEstimatorOnOffPrediction(method_implemented) + + X = "mocking_data" + + # raise an error when no methods are defined + response_method = ["decision_function", "predict"] + err_msg = ( + "_MockEstimatorOnOffPrediction has none of the following attributes: " + f"{', '.join(response_method)}." + ) + with pytest.raises(AttributeError, match=err_msg): + _check_response_method(my_estimator, response_method)(X) + + # check that we don't get issue when one of the method is defined + response_method = ["decision_function", "predict_proba"] + method_name_predicting = _check_response_method(my_estimator, response_method)(X) + assert method_name_predicting == "predict_proba" + + # check the order of the methods returned + method_implemented = ["predict_proba", "predict"] + my_estimator = _MockEstimatorOnOffPrediction(method_implemented) + response_method = ["decision_function", "predict", "predict_proba"] + method_name_predicting = _check_response_method(my_estimator, response_method)(X) + assert method_name_predicting == "predict" diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 4d36d78a2d458..17bca143c14fe 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -9,7 +9,7 @@ # Sylvain Marie # License: BSD 3 clause -from functools import wraps +from functools import reduce, wraps import warnings import numbers import operator @@ -1679,6 +1679,58 @@ def _check_sample_weight( return sample_weight +def _check_response_method(estimator, response_method=None): + """Check if `response_method` is available in estimator and return it. + + .. versionadded:: 1.1 + + Parameters + ---------- + estimator : estimator instance + Classifier or regressor to check. + + response_method : {"predict_proba", "decision_function", "predict"} or \ + list of such str, default=None + Specifies the response method to use get prediction from an estimator + (i.e. :term:`predict_proba`, :term:`decision_function` or + :term:`predict`). Possible choices are: + + - if `str`, it corresponds to the name to the method to return; + - if a list of `str`, it provides the method names in order of + preference. The method returned corresponds to the first method in + the list and which is implemented by `estimator`; + - if `None`, :term:`predict_proba` is tried first and if it does not + exist :term:`decision_function` is tried next and :term:`predict` + last. + + Returns + ------- + prediction_method : callable + Prediction method of estimator. + + Raises + ------ + ValueError + If `response_method` is not available in `estimator`. + """ + if response_method is None: + list_methods = ["predict_proba", "decision_function", "predict"] + elif isinstance(response_method, str): + list_methods = [response_method] + else: + list_methods = response_method + + prediction_method = [getattr(estimator, method, None) for method in list_methods] + prediction_method = reduce(lambda x, y: x or y, prediction_method) + if prediction_method is None: + raise AttributeError( + f"{estimator.__class__.__name__} has none of the following attributes: " + f"{', '.join(list_methods)}." + ) + + return prediction_method + + def _allclose_dense_sparse(x, y, rtol=1e-7, atol=1e-9): """Check allclose for sparse and dense data. From 8675defa7ca0056fad994af9a2ecddff19956518 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 3 Nov 2021 14:25:31 +0100 Subject: [PATCH 02/12] iter --- sklearn/calibration.py | 25 ++-- sklearn/ensemble/_stacking.py | 30 ++--- sklearn/metrics/_plot/base.py | 118 ------------------ sklearn/metrics/_plot/det_curve.py | 34 +++-- .../metrics/_plot/precision_recall_curve.py | 37 ++++-- sklearn/metrics/_plot/roc_curve.py | 32 ++++- sklearn/metrics/_plot/tests/test_base.py | 76 +---------- .../_plot/tests/test_common_curve_display.py | 102 +++++++++++++-- .../_plot/tests/test_plot_curve_common.py | 89 +++++++++---- .../_plot/tests/test_plot_precision_recall.py | 68 ---------- .../tests/test_precision_recall_display.py | 43 ------- sklearn/tests/test_calibration.py | 40 +----- 12 files changed, 272 insertions(+), 422 deletions(-) diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 8bc1b9842de6c..ecd5ff70b1a5f 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -25,7 +25,6 @@ RegressorMixin, clone, MetaEstimatorMixin, - is_classifier, ) from .preprocessing import label_binarize, LabelEncoder from .utils import ( @@ -35,7 +34,10 @@ check_matplotlib_support, ) -from .utils.multiclass import check_classification_targets +from .utils.multiclass import ( + check_classification_targets, + type_of_target, +) from .utils.fixes import delayed from .utils.validation import ( _check_sample_weight, @@ -43,12 +45,12 @@ check_consistent_length, check_is_fitted, ) -from .utils import _safe_indexing +from .utils import _get_response_values, _safe_indexing from .isotonic import IsotonicRegression from .svm import LinearSVC from .model_selection import check_cv, cross_val_predict from .metrics._base import _check_pos_label_consistency -from .metrics._plot.base import _get_response +from .metrics._plot.base import _check_estimator_target class CalibratedClassifierCV(ClassifierMixin, MetaEstimatorMixin, BaseEstimator): @@ -1237,11 +1239,10 @@ def from_estimator( method_name = f"{cls.__name__}.from_estimator" check_matplotlib_support(method_name) - if not is_classifier(estimator): - raise ValueError("'estimator' should be a fitted classifier.") + _check_estimator_target(estimator, y) - y_prob, pos_label = _get_response( - X, estimator, response_method="predict_proba", pos_label=pos_label + y_prob, pos_label = _get_response_values( + estimator, X, y, response_method="predict_proba", pos_label=pos_label ) name = name if name is not None else estimator.__class__.__name__ @@ -1354,9 +1355,15 @@ def from_predictions( >>> disp = CalibrationDisplay.from_predictions(y_test, y_prob) >>> plt.show() """ - method_name = f"{cls.__name__}.from_estimator" + method_name = f"{cls.__name__}.from_predictions" check_matplotlib_support(method_name) + if type_of_target(y_true) != "binary": + raise ValueError( + f"The target y is not binary. Got {type_of_target(y_true)} type of" + " target." + ) + prob_true, prob_pred = calibration_curve( y_true, y_prob, n_bins=n_bins, strategy=strategy, pos_label=pos_label ) diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index 507794bd4e092..7be95673a093b 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -30,8 +30,11 @@ from ..utils import Bunch from ..utils.metaestimators import if_delegate_has_method from ..utils.multiclass import check_classification_targets -from ..utils.validation import check_is_fitted -from ..utils.validation import column_or_1d +from ..utils.validation import ( + _check_response_method, + check_is_fitted, + column_or_1d, +) from ..utils.fixes import delayed @@ -104,21 +107,14 @@ def _concatenate_predictions(self, X, predictions): def _method_name(name, estimator, method): if estimator == "drop": return None - if method == "auto": - if getattr(estimator, "predict_proba", None): - return "predict_proba" - elif getattr(estimator, "decision_function", None): - return "decision_function" - else: - return "predict" - else: - if not hasattr(estimator, method): - raise ValueError( - "Underlying estimator {} does not implement the method {}.".format( - name, method - ) - ) - return method + method = None if method == "auto" else method + try: + method_name = _check_response_method(estimator, method).__name__ + except AttributeError as e: + raise ValueError( + f"Underlying estimator {name} does not implement the method {method}." + ) from e + return method_name def fit(self, X, y, sample_weight=None): """Fit the estimators. diff --git a/sklearn/metrics/_plot/base.py b/sklearn/metrics/_plot/base.py index 3cbec426242a7..0a5cfa3ae3804 100644 --- a/sklearn/metrics/_plot/base.py +++ b/sklearn/metrics/_plot/base.py @@ -38,121 +38,3 @@ def _check_estimator_target(estimator, y): raise ValueError( f"The target y is not binary. Got {type_of_target(y)} type of target." ) - - -# from ...base import is_classifier - - -# def _check_classifier_response_method(estimator, response_method): -# """Return prediction method from the response_method - -# Parameters -# ---------- -# estimator: object -# Classifier to check - -# response_method: {'auto', 'predict_proba', 'decision_function'} -# Specifies whether to use :term:`predict_proba` or -# :term:`decision_function` as the target response. If set to 'auto', -# :term:`predict_proba` is tried first and if it does not exist -# :term:`decision_function` is tried next. - -# Returns -# ------- -# prediction_method: callable -# prediction method of estimator -# """ - -# if response_method not in ("predict_proba", "decision_function", "auto"): -# raise ValueError( -# "response_method must be 'predict_proba', 'decision_function' or 'auto'" -# ) - -# error_msg = "response method {} is not defined in {}" -# if response_method != "auto": -# prediction_method = getattr(estimator, response_method, None) -# if prediction_method is None: -# raise ValueError( -# error_msg.format(response_method, estimator.__class__.__name__) -# ) -# else: -# predict_proba = getattr(estimator, "predict_proba", None) -# decision_function = getattr(estimator, "decision_function", None) -# prediction_method = predict_proba or decision_function -# if prediction_method is None: -# raise ValueError( -# error_msg.format( -# "decision_function or predict_proba", estimator.__class__.__name__ -# ) -# ) - -# return prediction_method - - -# def _get_response(X, estimator, response_method, pos_label=None): -# """Return response and positive label. - -# Parameters -# ---------- -# X : {array-like, sparse matrix} of shape (n_samples, n_features) -# Input values. - -# estimator : estimator instance -# Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` -# in which the last estimator is a classifier. - -# response_method: {'auto', 'predict_proba', 'decision_function'} -# Specifies whether to use :term:`predict_proba` or -# :term:`decision_function` as the target response. If set to 'auto', -# :term:`predict_proba` is tried first and if it does not exist -# :term:`decision_function` is tried next. - -# pos_label : str or int, default=None -# The class considered as the positive class when computing -# the metrics. By default, `estimators.classes_[1]` is -# considered as the positive class. - -# Returns -# ------- -# y_pred: ndarray of shape (n_samples,) -# Target scores calculated from the provided response_method -# and pos_label. - -# pos_label: str or int -# The class considered as the positive class when computing -# the metrics. -# """ -# classification_error = ( -# "Expected 'estimator' to be a binary classifier, but got" -# f" {estimator.__class__.__name__}" -# ) - -# if not is_classifier(estimator): -# raise ValueError(classification_error) - -# prediction_method = _check_classifier_response_method(estimator, response_method) -# y_pred = prediction_method(X) -# if pos_label is not None: -# try: -# class_idx = estimator.classes_.tolist().index(pos_label) -# except ValueError as e: -# raise ValueError( -# "The class provided by 'pos_label' is unknown. Got " -# f"{pos_label} instead of one of {set(estimator.classes_)}" -# ) from e -# else: -# class_idx = 1 -# pos_label = estimator.classes_[class_idx] - -# if y_pred.ndim != 1: # `predict_proba` -# y_pred_shape = y_pred.shape[1] -# if y_pred_shape != 2: -# raise ValueError( -# f"{classification_error} fit on multiclass ({y_pred_shape} classes)" -# " data" -# ) -# y_pred = y_pred[:, class_idx] -# elif pos_label == estimator.classes_[0]: # `decision_function` -# y_pred *= -1 - -# return y_pred, pos_label diff --git a/sklearn/metrics/_plot/det_curve.py b/sklearn/metrics/_plot/det_curve.py index 92e84ce9b7974..ab09c2b15e4d1 100644 --- a/sklearn/metrics/_plot/det_curve.py +++ b/sklearn/metrics/_plot/det_curve.py @@ -1,12 +1,16 @@ import scipy as sp -from .base import _get_response +from .base import _check_estimator_target from .. import det_curve from .._base import _check_pos_label_consistency -from ...utils import check_matplotlib_support -from ...utils import deprecated +from ...utils import ( + check_matplotlib_support, + deprecated, + _get_response_values, +) +from ...utils.multiclass import type_of_target class DetCurveDisplay: @@ -168,11 +172,16 @@ def from_estimator( """ check_matplotlib_support(f"{cls.__name__}.from_estimator") + _check_estimator_target(estimator, y) + if response_method == "auto": + response_method = ["predict_proba", "decision_function"] + name = estimator.__class__.__name__ if name is None else name - y_pred, pos_label = _get_response( - X, + y_pred, pos_label = _get_response_values( estimator, + X, + y, response_method, pos_label=pos_label, ) @@ -265,6 +274,13 @@ def from_predictions( >>> plt.show() """ check_matplotlib_support(f"{cls.__name__}.from_predictions") + + if type_of_target(y_true) != "binary": + raise ValueError( + f"The target y is not binary. Got {type_of_target(y_true)} type of" + " target." + ) + fpr, fnr, _ = det_curve( y_true, y_pred, @@ -454,8 +470,12 @@ def plot_det_curve( """ check_matplotlib_support("plot_det_curve") - y_pred, pos_label = _get_response( - X, estimator, response_method, pos_label=pos_label + _check_estimator_target(estimator, y) + if response_method == "auto": + response_method = ["predict_proba", "decision_function"] + + y_pred, pos_label = _get_response_values( + estimator, X, y, response_method, pos_label=pos_label ) fpr, fnr, _ = det_curve( diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index eaf8240062174..fcce47162ec06 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -1,12 +1,16 @@ -from sklearn.base import is_classifier -from .base import _get_response +from .base import _check_estimator_target from .. import average_precision_score from .. import precision_recall_curve from .._base import _check_pos_label_consistency from .._classification import check_consistent_length -from ...utils import check_matplotlib_support, deprecated +from ...utils import ( + check_matplotlib_support, + deprecated, + _get_response_values, +) +from ...utils.multiclass import type_of_target class PrecisionRecallDisplay: @@ -235,11 +239,15 @@ def from_estimator( """ method_name = f"{cls.__name__}.from_estimator" check_matplotlib_support(method_name) - if not is_classifier(estimator): - raise ValueError(f"{method_name} only supports classifiers") - y_pred, pos_label = _get_response( - X, + + _check_estimator_target(estimator, y) + if response_method == "auto": + response_method = ["predict_proba", "decision_function"] + + y_pred, pos_label = _get_response_values( estimator, + X, + y, response_method, pos_label=pos_label, ) @@ -325,6 +333,12 @@ def from_predictions( """ check_matplotlib_support(f"{cls.__name__}.from_predictions") + if type_of_target(y_true) != "binary": + raise ValueError( + f"The target y is not binary. Got {type_of_target(y_true)} type of" + " target." + ) + check_consistent_length(y_true, y_pred, sample_weight) pos_label = _check_pos_label_consistency(pos_label, y_true) @@ -430,8 +444,13 @@ def plot_precision_recall_curve( """ check_matplotlib_support("plot_precision_recall_curve") - y_pred, pos_label = _get_response( - X, estimator, response_method, pos_label=pos_label + _check_estimator_target(estimator, y) + + if response_method == "auto": + response_method = ["predict_proba", "decision_function"] + + y_pred, pos_label = _get_response_values( + estimator, X, y, response_method, pos_label=pos_label ) precision, recall, _ = precision_recall_curve( diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 6c39e6bc152cd..9c40b854d3bdc 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -1,10 +1,15 @@ -from .base import _get_response +from .base import _check_estimator_target from .. import auc from .. import roc_curve from .._base import _check_pos_label_consistency -from ...utils import check_matplotlib_support, deprecated +from ...utils import ( + check_matplotlib_support, + deprecated, + _get_response_values, +) +from ...utils.multiclass import type_of_target class RocCurveDisplay: @@ -226,11 +231,16 @@ def from_estimator( """ check_matplotlib_support(f"{cls.__name__}.from_estimator") + _check_estimator_target(estimator, y) + if response_method == "auto": + response_method = ["predict_proba", "decision_function"] + name = estimator.__class__.__name__ if name is None else name - y_pred, pos_label = _get_response( - X, + y_pred, pos_label = _get_response_values( estimator, + X, + y, response_method=response_method, pos_label=pos_label, ) @@ -330,6 +340,12 @@ def from_predictions( """ check_matplotlib_support(f"{cls.__name__}.from_predictions") + if type_of_target(y_true) != "binary": + raise ValueError( + f"The target y is not binary. Got {type_of_target(y_true)} type of" + " target." + ) + fpr, tpr, _ = roc_curve( y_true, y_pred, @@ -451,8 +467,12 @@ def plot_roc_curve( """ check_matplotlib_support("plot_roc_curve") - y_pred, pos_label = _get_response( - X, estimator, response_method, pos_label=pos_label + _check_estimator_target(estimator, y) + if response_method == "auto": + response_method = ["predict_proba", "decision_function"] + + y_pred, pos_label = _get_response_values( + estimator, X, y, response_method, pos_label=pos_label ) fpr, tpr, _ = roc_curve( diff --git a/sklearn/metrics/_plot/tests/test_base.py b/sklearn/metrics/_plot/tests/test_base.py index 2f67d7dd223f4..5a1ad81acb575 100644 --- a/sklearn/metrics/_plot/tests/test_base.py +++ b/sklearn/metrics/_plot/tests/test_base.py @@ -1,75 +1 @@ -import numpy as np -import pytest - -from sklearn.datasets import load_iris -from sklearn.linear_model import LogisticRegression -from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor - -from sklearn.metrics._plot.base import _get_response - - -@pytest.mark.parametrize( - "estimator, err_msg, params", - [ - ( - DecisionTreeRegressor(), - "Expected 'estimator' to be a binary classifier", - {"response_method": "auto"}, - ), - ( - DecisionTreeClassifier(), - "The class provided by 'pos_label' is unknown.", - {"response_method": "auto", "pos_label": "unknown"}, - ), - ( - DecisionTreeClassifier(), - "fit on multiclass", - {"response_method": "predict_proba"}, - ), - ], -) -def test_get_response_error(estimator, err_msg, params): - """Check that we raise the proper error messages in `_get_response`.""" - X, y = load_iris(return_X_y=True) - - estimator.fit(X, y) - with pytest.raises(ValueError, match=err_msg): - _get_response(X, estimator, **params) - - -def test_get_response_predict_proba(): - """Check the behaviour of `_get_response` using `predict_proba`.""" - X, y = load_iris(return_X_y=True) - X_binary, y_binary = X[:100], y[:100] - - classifier = DecisionTreeClassifier().fit(X_binary, y_binary) - y_proba, pos_label = _get_response( - X_binary, classifier, response_method="predict_proba" - ) - np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 1]) - assert pos_label == 1 - - y_proba, pos_label = _get_response( - X_binary, classifier, response_method="predict_proba", pos_label=0 - ) - np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 0]) - assert pos_label == 0 - - -def test_get_response_decision_function(): - """Check the behaviour of `get_response` using `decision_function`.""" - X, y = load_iris(return_X_y=True) - X_binary, y_binary = X[:100], y[:100] - - classifier = LogisticRegression().fit(X_binary, y_binary) - y_score, pos_label = _get_response( - X_binary, classifier, response_method="decision_function" - ) - np.testing.assert_allclose(y_score, classifier.decision_function(X_binary)) - assert pos_label == 1 - - y_score, pos_label = _get_response( - X_binary, classifier, response_method="decision_function", pos_label=0 - ) - np.testing.assert_allclose(y_score, classifier.decision_function(X_binary) * -1) - assert pos_label == 0 +"""some file""" diff --git a/sklearn/metrics/_plot/tests/test_common_curve_display.py b/sklearn/metrics/_plot/tests/test_common_curve_display.py index 5ed036b77f4d0..119c8c8a2a1fb 100644 --- a/sklearn/metrics/_plot/tests/test_common_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_common_curve_display.py @@ -1,3 +1,4 @@ +import numpy as np import pytest from sklearn.base import ClassifierMixin, clone @@ -7,8 +8,9 @@ from sklearn.linear_model import LogisticRegression from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler -from sklearn.tree import DecisionTreeClassifier +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor +from sklearn.calibration import CalibrationDisplay from sklearn.metrics import ( DetCurveDisplay, PrecisionRecallDisplay, @@ -28,40 +30,84 @@ def data_binary(data): @pytest.mark.parametrize( - "Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay] + "Display", + [CalibrationDisplay, DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay], ) -def test_display_curve_error_non_binary(pyplot, data, Display): +def test_display_curve_error_classifier(pyplot, data, data_binary, Display): """Check that a proper error is raised when only binary classification is supported.""" X, y = data + X_binary, y_binary = data_binary + + # Case 1: multiclass classifier with multiclass target clf = DecisionTreeClassifier().fit(X, y) msg = ( - "Expected 'estimator' to be a binary classifier, but got DecisionTreeClassifier" + "This DecisionTreeClassifier instance is not a binary classifier. It was " + f"fitted on multiclass problem with {len(np.unique(y))} classes." ) with pytest.raises(ValueError, match=msg): Display.from_estimator(clf, X, y) + # Case 2: multiclass classifier with binary target + with pytest.raises(ValueError, match=msg): + Display.from_estimator(clf, X_binary, y_binary) + + # Case 3: binary classifier with multiclass target + clf = DecisionTreeClassifier().fit(X_binary, y_binary) + msg = "The target y is not binary. Got multiclass type of target." + with pytest.raises(ValueError, match=msg): + Display.from_estimator(clf, X, y) + + +@pytest.mark.parametrize( + "Display", + [CalibrationDisplay, DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay], +) +def test_display_curve_error_regression(pyplot, data_binary, Display): + """Check that we raise an error with regressor.""" + + # Case 1: regressor + X, y = data_binary + regressor = DecisionTreeRegressor().fit(X, y) + + msg = ( + "This plotting functionalities only support a binary classifier. Got a " + "DecisionTreeRegressor instead." + ) + with pytest.raises(ValueError, match=msg): + Display.from_estimator(regressor, X, y) + + # Case 2: regression target + classifier = DecisionTreeClassifier().fit(X, y) + # Force `y_true` to be seen as a regression problem + y = y + 0.5 + msg = "The target y is not binary. Got continuous type of target." + with pytest.raises(ValueError, match=msg): + Display.from_estimator(classifier, X, y) + with pytest.raises(ValueError, match=msg): + Display.from_predictions(y, regressor.fit(X, y).predict(X)) + @pytest.mark.parametrize( "response_method, msg", [ ( "predict_proba", - "response method predict_proba is not defined in MyClassifier", + "MyClassifier has none of the following attributes: predict_proba.", ), ( "decision_function", - "response method decision_function is not defined in MyClassifier", + "MyClassifier has none of the following attributes: decision_function.", ), ( "auto", - "response method decision_function or predict_proba is not " - "defined in MyClassifier", + "MyClassifier has none of the following attributes: predict_proba, " + "decision_function.", ), ( "bad_method", - "response_method must be 'predict_proba', 'decision_function' or 'auto'", + "MyClassifier has none of the following attributes: bad_method.", ), ], ) @@ -86,7 +132,7 @@ def fit(self, X, y): clf = MyClassifier().fit(X, y) - with pytest.raises(ValueError, match=msg): + with pytest.raises(AttributeError, match=msg): Display.from_estimator(clf, X, y, response_method=response_method) @@ -135,7 +181,8 @@ def test_display_curve_estimator_name_multiple_calls( ], ) @pytest.mark.parametrize( - "Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay] + "Display", + [CalibrationDisplay, DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay], ) def test_display_curve_not_fitted_errors(pyplot, data_binary, clf, Display): """Check that a proper error is raised when the classifier is not @@ -150,3 +197,36 @@ def test_display_curve_not_fitted_errors(pyplot, data_binary, clf, Display): disp = Display.from_estimator(model, X, y) assert model.__class__.__name__ in disp.line_.get_label() assert disp.estimator_name == model.__class__.__name__ + + +@pytest.mark.parametrize( + "Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay] +) +def test_display_curve_n_samples_consistency(pyplot, data_binary, Display): + """Check the error raised when `y_pred` or `sample_weight` have inconsistent + length.""" + X, y = data_binary + classifier = DecisionTreeClassifier().fit(X, y) + + msg = "Found input variables with inconsistent numbers of samples" + with pytest.raises(ValueError, match=msg): + Display.from_estimator(classifier, X[:-2], y) + with pytest.raises(ValueError, match=msg): + Display.from_estimator(classifier, X, y[:-2]) + with pytest.raises(ValueError, match=msg): + Display.from_estimator(classifier, X, y, sample_weight=np.ones(X.shape[0] - 2)) + + +@pytest.mark.parametrize( + "Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay] +) +def test_display_curve_error_pos_label(pyplot, data_binary, Display): + """Check consistence of error message when `pos_label` should be specified.""" + X, y = data_binary + y = y + 10 + + classifier = DecisionTreeClassifier().fit(X, y) + y_pred = classifier.predict_proba(X)[:, -1] + msg = r"y_true takes value in {10, 11} and pos_label is not specified" + with pytest.raises(ValueError, match=msg): + Display.from_predictions(y, y_pred) diff --git a/sklearn/metrics/_plot/tests/test_plot_curve_common.py b/sklearn/metrics/_plot/tests/test_plot_curve_common.py index d430acd42596c..be0a6b8e46fed 100644 --- a/sklearn/metrics/_plot/tests/test_plot_curve_common.py +++ b/sklearn/metrics/_plot/tests/test_plot_curve_common.py @@ -1,3 +1,4 @@ +import numpy as np import pytest from sklearn.base import ClassifierMixin @@ -8,13 +9,14 @@ from sklearn.linear_model import LogisticRegression from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler -from sklearn.tree import DecisionTreeClassifier +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor -from sklearn.metrics import plot_det_curve -from sklearn.metrics import plot_roc_curve +from sklearn.metrics import plot_det_curve, plot_roc_curve, plot_precision_recall_curve pytestmark = pytest.mark.filterwarnings( "ignore:Function plot_roc_curve is deprecated", + "ignore:Function plot_det_curve is deprecated", + "ignore:Function plot_precision_recall_curve is deprecated", ) @@ -29,44 +31,87 @@ def data_binary(data): return X[y < 2], y[y < 2] -@pytest.mark.filterwarnings("ignore: Function plot_det_curve is deprecated") -@pytest.mark.parametrize("plot_func", [plot_det_curve, plot_roc_curve]) -def test_plot_curve_error_non_binary(pyplot, data, plot_func): +@pytest.mark.parametrize( + "plot_func", [plot_det_curve, plot_roc_curve, plot_precision_recall_curve] +) +def test_plot_curve_error_classifier(pyplot, data, data_binary, plot_func): + """Check that a proper error is raised when only binary classification is + supported.""" X, y = data - clf = DecisionTreeClassifier() - clf.fit(X, y) + X_binary, y_binary = data_binary + + # Case 1: multiclass classifier with multiclass target + clf = DecisionTreeClassifier().fit(X, y) msg = ( - "Expected 'estimator' to be a binary classifier, but got DecisionTreeClassifier" + "This DecisionTreeClassifier instance is not a binary classifier. It was " + f"fitted on multiclass problem with {len(np.unique(y))} classes." ) with pytest.raises(ValueError, match=msg): plot_func(clf, X, y) + # Case 2: multiclass classifier with binary target + with pytest.raises(ValueError, match=msg): + plot_func(clf, X_binary, y_binary) + + # Case 3: binary classifier with multiclass target + clf = DecisionTreeClassifier().fit(X_binary, y_binary) + msg = "The target y is not binary. Got multiclass type of target." + with pytest.raises(ValueError, match=msg): + plot_func(clf, X, y) + + +@pytest.mark.parametrize( + "plot_func", [plot_det_curve, plot_roc_curve, plot_precision_recall_curve] +) +def test_plot_curve_error_regression(pyplot, data_binary, plot_func): + """Check that we raise an error with regressor.""" + + # Case 1: regressor + X, y = data_binary + regressor = DecisionTreeRegressor().fit(X, y) + + msg = ( + "This plotting functionalities only support a binary classifier. Got a " + "DecisionTreeRegressor instead." + ) + with pytest.raises(ValueError, match=msg): + plot_func(regressor, X, y) + + # Case 2: regression target + classifier = DecisionTreeClassifier().fit(X, y) + # Force `y_true` to be seen as a regression problem + y = y + 0.5 + msg = "The target y is not binary. Got continuous type of target." + with pytest.raises(ValueError, match=msg): + plot_func(classifier, X, y) + -@pytest.mark.filterwarnings("ignore: Function plot_det_curve is deprecated") @pytest.mark.parametrize( "response_method, msg", [ ( "predict_proba", - "response method predict_proba is not defined in MyClassifier", + "MyClassifier has none of the following attributes: predict_proba.", ), ( "decision_function", - "response method decision_function is not defined in MyClassifier", + "MyClassifier has none of the following attributes: decision_function.", ), ( "auto", - "response method decision_function or predict_proba is not " - "defined in MyClassifier", + "MyClassifier has none of the following attributes: predict_proba, " + "decision_function.", ), ( "bad_method", - "response_method must be 'predict_proba', 'decision_function' or 'auto'", + "MyClassifier has none of the following attributes: bad_method", ), ], ) -@pytest.mark.parametrize("plot_func", [plot_det_curve, plot_roc_curve]) +@pytest.mark.parametrize( + "plot_func", [plot_det_curve, plot_roc_curve, plot_precision_recall_curve] +) def test_plot_curve_error_no_response( pyplot, data_binary, @@ -83,12 +128,13 @@ def fit(self, X, y): clf = MyClassifier().fit(X, y) - with pytest.raises(ValueError, match=msg): + with pytest.raises(AttributeError, match=msg): plot_func(clf, X, y, response_method=response_method) -@pytest.mark.filterwarnings("ignore: Function plot_det_curve is deprecated") -@pytest.mark.parametrize("plot_func", [plot_det_curve, plot_roc_curve]) +@pytest.mark.parametrize( + "plot_func", [plot_det_curve, plot_roc_curve, plot_precision_recall_curve] +) def test_plot_curve_estimator_name_multiple_calls(pyplot, data_binary, plot_func): # non-regression test checking that the `name` used when calling # `plot_func` is used as well when calling `disp.plot()` @@ -106,7 +152,6 @@ def test_plot_curve_estimator_name_multiple_calls(pyplot, data_binary, plot_func assert clf_name in disp.line_.get_label() -@pytest.mark.filterwarnings("ignore: Function plot_det_curve is deprecated") @pytest.mark.parametrize( "clf", [ @@ -117,7 +162,9 @@ def test_plot_curve_estimator_name_multiple_calls(pyplot, data_binary, plot_func ), ], ) -@pytest.mark.parametrize("plot_func", [plot_det_curve, plot_roc_curve]) +@pytest.mark.parametrize( + "plot_func", [plot_det_curve, plot_roc_curve, plot_precision_recall_curve] +) def test_plot_det_curve_not_fitted_errors(pyplot, data_binary, clf, plot_func): X, y = data_binary # clone since we parametrize the test and the classifier will be fitted diff --git a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py index 1d687b0c31abc..f2e3f79fbb5d3 100644 --- a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py +++ b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py @@ -2,13 +2,11 @@ import numpy as np from numpy.testing import assert_allclose -from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.metrics import plot_precision_recall_curve from sklearn.metrics import average_precision_score from sklearn.metrics import precision_recall_curve from sklearn.datasets import make_classification from sklearn.datasets import load_breast_cancer -from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split from sklearn.exceptions import NotFittedError @@ -26,72 +24,6 @@ ) -def test_errors(pyplot): - X, y_multiclass = make_classification( - n_classes=3, n_samples=50, n_informative=3, random_state=0 - ) - y_binary = y_multiclass == 0 - - # Unfitted classifier - binary_clf = DecisionTreeClassifier() - with pytest.raises(NotFittedError): - plot_precision_recall_curve(binary_clf, X, y_binary) - binary_clf.fit(X, y_binary) - - multi_clf = DecisionTreeClassifier().fit(X, y_multiclass) - - # Fitted multiclass classifier with binary data - msg = ( - "Expected 'estimator' to be a binary classifier, but got DecisionTreeClassifier" - ) - with pytest.raises(ValueError, match=msg): - plot_precision_recall_curve(multi_clf, X, y_binary) - - reg = DecisionTreeRegressor().fit(X, y_multiclass) - msg = ( - "Expected 'estimator' to be a binary classifier, but got DecisionTreeRegressor" - ) - with pytest.raises(ValueError, match=msg): - plot_precision_recall_curve(reg, X, y_binary) - - -@pytest.mark.parametrize( - "response_method, msg", - [ - ( - "predict_proba", - "response method predict_proba is not defined in MyClassifier", - ), - ( - "decision_function", - "response method decision_function is not defined in MyClassifier", - ), - ( - "auto", - "response method decision_function or predict_proba is not " - "defined in MyClassifier", - ), - ( - "bad_method", - "response_method must be 'predict_proba', 'decision_function' or 'auto'", - ), - ], -) -def test_error_bad_response(pyplot, response_method, msg): - X, y = make_classification(n_classes=2, n_samples=50, random_state=0) - - class MyClassifier(ClassifierMixin, BaseEstimator): - def fit(self, X, y): - self.fitted_ = True - self.classes_ = [0, 1] - return self - - clf = MyClassifier().fit(X, y) - - with pytest.raises(ValueError, match=msg): - plot_precision_recall_curve(clf, X, y, response_method=response_method) - - @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) @pytest.mark.parametrize("with_sample_weight", [True, False]) def test_plot_precision_recall(pyplot, response_method, with_sample_weight): diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index 165e2b75df36e..2e38da7f82ca7 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -9,7 +9,6 @@ from sklearn.model_selection import train_test_split from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler -from sklearn.svm import SVC, SVR from sklearn.utils import shuffle from sklearn.metrics import PrecisionRecallDisplay, plot_precision_recall_curve @@ -21,48 +20,6 @@ ) -def test_precision_recall_display_validation(pyplot): - """Check that we raise the proper error when validating parameters.""" - X, y = make_classification( - n_samples=100, n_informative=5, n_classes=5, random_state=0 - ) - - with pytest.raises(NotFittedError): - PrecisionRecallDisplay.from_estimator(SVC(), X, y) - - regressor = SVR().fit(X, y) - y_pred_regressor = regressor.predict(X) - classifier = SVC(probability=True).fit(X, y) - y_pred_classifier = classifier.predict_proba(X)[:, -1] - - err_msg = "PrecisionRecallDisplay.from_estimator only supports classifiers" - with pytest.raises(ValueError, match=err_msg): - PrecisionRecallDisplay.from_estimator(regressor, X, y) - - err_msg = "Expected 'estimator' to be a binary classifier, but got SVC" - with pytest.raises(ValueError, match=err_msg): - PrecisionRecallDisplay.from_estimator(classifier, X, y) - - err_msg = "{} format is not supported" - with pytest.raises(ValueError, match=err_msg.format("continuous")): - # Force `y_true` to be seen as a regression problem - PrecisionRecallDisplay.from_predictions(y + 0.5, y_pred_classifier, pos_label=1) - with pytest.raises(ValueError, match=err_msg.format("multiclass")): - PrecisionRecallDisplay.from_predictions(y, y_pred_regressor, pos_label=1) - - err_msg = "Found input variables with inconsistent numbers of samples" - with pytest.raises(ValueError, match=err_msg): - PrecisionRecallDisplay.from_predictions(y, y_pred_classifier[::2]) - - X, y = make_classification(n_classes=2, n_samples=50, random_state=0) - y += 10 - classifier.fit(X, y) - y_pred_classifier = classifier.predict_proba(X)[:, -1] - err_msg = r"y_true takes value in {10, 11} and pos_label is not specified" - with pytest.raises(ValueError, match=err_msg): - PrecisionRecallDisplay.from_predictions(y, y_pred_classifier) - - # FIXME: Remove in 1.2 def test_plot_precision_recall_curve_deprecation(pyplot): """Check that we raise a FutureWarning when calling diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py index b6b5c482b1eb5..b53425d0735fc 100644 --- a/sklearn/tests/test_calibration.py +++ b/sklearn/tests/test_calibration.py @@ -27,7 +27,7 @@ RandomForestRegressor, VotingClassifier, ) -from sklearn.linear_model import LogisticRegression, LinearRegression +from sklearn.linear_model import LogisticRegression from sklearn.tree import DecisionTreeClassifier from sklearn.svm import LinearSVC from sklearn.pipeline import Pipeline, make_pipeline @@ -189,7 +189,7 @@ def test_parallel_execution(data, method, ensemble): X, y = data X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) - base_estimator = LinearSVC(random_state=42) + base_estimator = make_pipeline(StandardScaler(), LinearSVC(random_state=42)) cal_clf_parallel = CalibratedClassifierCV( base_estimator, method=method, n_jobs=2, ensemble=ensemble @@ -634,42 +634,6 @@ def iris_data_binary(iris_data): return X[y < 2], y[y < 2] -def test_calibration_display_validation(pyplot, iris_data, iris_data_binary): - X, y = iris_data - X_binary, y_binary = iris_data_binary - - reg = LinearRegression().fit(X, y) - msg = "'estimator' should be a fitted classifier" - with pytest.raises(ValueError, match=msg): - CalibrationDisplay.from_estimator(reg, X, y) - - clf = LinearSVC().fit(X, y) - msg = "response method predict_proba is not defined in" - with pytest.raises(ValueError, match=msg): - CalibrationDisplay.from_estimator(clf, X, y) - - clf = LogisticRegression() - with pytest.raises(NotFittedError): - CalibrationDisplay.from_estimator(clf, X, y) - - -@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) -def test_calibration_display_non_binary(pyplot, iris_data, constructor_name): - X, y = iris_data - clf = DecisionTreeClassifier() - clf.fit(X, y) - y_prob = clf.predict_proba(X) - - if constructor_name == "from_estimator": - msg = "to be a binary classifier, but got" - with pytest.raises(ValueError, match=msg): - CalibrationDisplay.from_estimator(clf, X, y) - else: - msg = "y should be a 1d array, got an array of shape" - with pytest.raises(ValueError, match=msg): - CalibrationDisplay.from_predictions(y, y_prob) - - @pytest.mark.parametrize("n_bins", [5, 10]) @pytest.mark.parametrize("strategy", ["uniform", "quantile"]) def test_calibration_display_compute(pyplot, iris_data_binary, n_bins, strategy): From 64dfc86bf7dd9bda03ae977a9e7b466564970ba8 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 3 Nov 2021 15:11:28 +0100 Subject: [PATCH 03/12] iter --- sklearn/metrics/_plot/tests/test_base.py | 46 +++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/tests/test_base.py b/sklearn/metrics/_plot/tests/test_base.py index 5a1ad81acb575..0505c047b3e13 100644 --- a/sklearn/metrics/_plot/tests/test_base.py +++ b/sklearn/metrics/_plot/tests/test_base.py @@ -1 +1,45 @@ -"""some file""" +import pytest + +from sklearn.datasets import load_iris +from sklearn.exceptions import NotFittedError +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor + +from sklearn.metrics._plot.base import _check_estimator_target + +X, y = load_iris(return_X_y=True) +X_binary, y_binary = X[:100], y[:100] + + +@pytest.mark.parametrize( + "estimator, target, err_type, err_msg", + [ + ( + DecisionTreeClassifier(), + y_binary, + NotFittedError, + "This DecisionTreeClassifier instance is not fitted yet", + ), + ( + DecisionTreeRegressor().fit(X_binary, y_binary), + y_binary, + ValueError, + "This plotting functionalities only support a binary classifier", + ), + ( + DecisionTreeClassifier().fit(X, y), + y, + ValueError, + "This DecisionTreeClassifier instance is not a binary classifier", + ), + ( + DecisionTreeClassifier().fit(X_binary, y_binary), + y, + ValueError, + "The target y is not binary", + ), + ], +) +def test_check_estimator_target(estimator, target, err_type, err_msg): + """Check that we raise the expected error when checking the estimator and target.""" + with pytest.raises(err_type, match=err_msg): + _check_estimator_target(estimator, target) From a7e46da177b65135613138fbbb8b086f423f7949 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 29 Nov 2021 16:03:29 +0100 Subject: [PATCH 04/12] reuse type_of_target --- sklearn/calibration.py | 10 +++++----- sklearn/metrics/_plot/base.py | 17 +++++------------ sklearn/metrics/_plot/det_curve.py | 12 ++++++------ sklearn/metrics/_plot/precision_recall_curve.py | 12 ++++++------ sklearn/metrics/_plot/roc_curve.py | 12 ++++++------ sklearn/metrics/_plot/tests/test_base.py | 6 +++--- 6 files changed, 31 insertions(+), 38 deletions(-) diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 61ba0c9a37209..8c7db349cf4c8 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -49,7 +49,7 @@ from .svm import LinearSVC from .model_selection import check_cv, cross_val_predict from .metrics._base import _check_pos_label_consistency -from .metrics._plot.base import _check_estimator_target +from .metrics._plot.base import _check_estimator_and_target_is_binary class CalibratedClassifierCV(ClassifierMixin, MetaEstimatorMixin, BaseEstimator): @@ -1218,7 +1218,7 @@ def from_estimator( method_name = f"{cls.__name__}.from_estimator" check_matplotlib_support(method_name) - _check_estimator_target(estimator, y) + _check_estimator_and_target_is_binary(estimator, y) y_prob, pos_label = _get_response_values( estimator, X, y, response_method="predict_proba", pos_label=pos_label @@ -1337,10 +1337,10 @@ def from_predictions( method_name = f"{cls.__name__}.from_predictions" check_matplotlib_support(method_name) - if type_of_target(y_true) != "binary": + target_type = type_of_target(y_true) + if target_type != "binary": raise ValueError( - f"The target y is not binary. Got {type_of_target(y_true)} type of" - " target." + f"The target y is not binary. Got {target_type} type of target." ) prob_true, prob_pred = calibration_curve( diff --git a/sklearn/metrics/_plot/base.py b/sklearn/metrics/_plot/base.py index 0a5cfa3ae3804..6d279e0aeaa22 100644 --- a/sklearn/metrics/_plot/base.py +++ b/sklearn/metrics/_plot/base.py @@ -4,16 +4,8 @@ from ...utils.validation import check_is_fitted -def _check_estimator_target(estimator, y): - """Helper to check that estimator is a binary classifier and y is binary. - - This function is aside from the class `BinaryClassifierCurveDisplayMixin` - below because it allows to have consistent error messages between the - displays and the plotting functions. - - FIXME: Move into `BinaryClassifierCurveDisplayMixin.from_estimator` when - the plotting functions will be removed in 1.2. - """ +def _check_estimator_and_target_is_binary(estimator, y): + """Helper to check that estimator is a binary classifier and y is binary.""" try: check_is_fitted(estimator) except NotFittedError as e: @@ -34,7 +26,8 @@ def _check_estimator_target(estimator, y): "classifier. It was fitted on multiclass problem with " f"{len(estimator.classes_)} classes." ) - elif type_of_target(y) != "binary": + target_type = type_of_target(y) + if target_type != "binary": raise ValueError( - f"The target y is not binary. Got {type_of_target(y)} type of target." + f"The target y is not binary. Got {target_type} type of target." ) diff --git a/sklearn/metrics/_plot/det_curve.py b/sklearn/metrics/_plot/det_curve.py index ab09c2b15e4d1..7065979f39ce1 100644 --- a/sklearn/metrics/_plot/det_curve.py +++ b/sklearn/metrics/_plot/det_curve.py @@ -1,6 +1,6 @@ import scipy as sp -from .base import _check_estimator_target +from .base import _check_estimator_and_target_is_binary from .. import det_curve from .._base import _check_pos_label_consistency @@ -172,7 +172,7 @@ def from_estimator( """ check_matplotlib_support(f"{cls.__name__}.from_estimator") - _check_estimator_target(estimator, y) + _check_estimator_and_target_is_binary(estimator, y) if response_method == "auto": response_method = ["predict_proba", "decision_function"] @@ -275,10 +275,10 @@ def from_predictions( """ check_matplotlib_support(f"{cls.__name__}.from_predictions") - if type_of_target(y_true) != "binary": + target_type = type_of_target(y_true) + if target_type != "binary": raise ValueError( - f"The target y is not binary. Got {type_of_target(y_true)} type of" - " target." + f"The target y is not binary. Got {target_type} type of target." ) fpr, fnr, _ = det_curve( @@ -470,7 +470,7 @@ def plot_det_curve( """ check_matplotlib_support("plot_det_curve") - _check_estimator_target(estimator, y) + _check_estimator_and_target_is_binary(estimator, y) if response_method == "auto": response_method = ["predict_proba", "decision_function"] diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index fcce47162ec06..6a29913023b9d 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -1,4 +1,4 @@ -from .base import _check_estimator_target +from .base import _check_estimator_and_target_is_binary from .. import average_precision_score from .. import precision_recall_curve @@ -240,7 +240,7 @@ def from_estimator( method_name = f"{cls.__name__}.from_estimator" check_matplotlib_support(method_name) - _check_estimator_target(estimator, y) + _check_estimator_and_target_is_binary(estimator, y) if response_method == "auto": response_method = ["predict_proba", "decision_function"] @@ -333,10 +333,10 @@ def from_predictions( """ check_matplotlib_support(f"{cls.__name__}.from_predictions") - if type_of_target(y_true) != "binary": + target_type = type_of_target(y_true) + if target_type != "binary": raise ValueError( - f"The target y is not binary. Got {type_of_target(y_true)} type of" - " target." + f"The target y is not binary. Got {target_type} type of target." ) check_consistent_length(y_true, y_pred, sample_weight) @@ -444,7 +444,7 @@ def plot_precision_recall_curve( """ check_matplotlib_support("plot_precision_recall_curve") - _check_estimator_target(estimator, y) + _check_estimator_and_target_is_binary(estimator, y) if response_method == "auto": response_method = ["predict_proba", "decision_function"] diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 3200de41f8f39..80d96ea8da5e2 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -1,4 +1,4 @@ -from .base import _check_estimator_target +from .base import _check_estimator_and_target_is_binary from .. import auc from .. import roc_curve @@ -231,7 +231,7 @@ def from_estimator( """ check_matplotlib_support(f"{cls.__name__}.from_estimator") - _check_estimator_target(estimator, y) + _check_estimator_and_target_is_binary(estimator, y) if response_method == "auto": response_method = ["predict_proba", "decision_function"] @@ -340,10 +340,10 @@ def from_predictions( """ check_matplotlib_support(f"{cls.__name__}.from_predictions") - if type_of_target(y_true) != "binary": + target_type = type_of_target + if target_type != "binary": raise ValueError( - f"The target y is not binary. Got {type_of_target(y_true)} type of" - " target." + f"The target y is not binary. Got {target_type} type of target." ) fpr, tpr, _ = roc_curve( @@ -464,7 +464,7 @@ def plot_roc_curve( """ check_matplotlib_support("plot_roc_curve") - _check_estimator_target(estimator, y) + _check_estimator_and_target_is_binary(estimator, y) if response_method == "auto": response_method = ["predict_proba", "decision_function"] diff --git a/sklearn/metrics/_plot/tests/test_base.py b/sklearn/metrics/_plot/tests/test_base.py index 0505c047b3e13..cd4ced34cc358 100644 --- a/sklearn/metrics/_plot/tests/test_base.py +++ b/sklearn/metrics/_plot/tests/test_base.py @@ -4,7 +4,7 @@ from sklearn.exceptions import NotFittedError from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor -from sklearn.metrics._plot.base import _check_estimator_target +from sklearn.metrics._plot.base import _check_estimator_and_target_is_binary X, y = load_iris(return_X_y=True) X_binary, y_binary = X[:100], y[:100] @@ -39,7 +39,7 @@ ), ], ) -def test_check_estimator_target(estimator, target, err_type, err_msg): +def test_check_estimator_and_target_is_binary(estimator, target, err_type, err_msg): """Check that we raise the expected error when checking the estimator and target.""" with pytest.raises(err_type, match=err_msg): - _check_estimator_target(estimator, target) + _check_estimator_and_target_is_binary(estimator, target) From a69176da47dbcd18bcc595b07c5552ed17bd3d5d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 29 Nov 2021 16:33:33 +0100 Subject: [PATCH 05/12] iter --- sklearn/metrics/_plot/roc_curve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 80d96ea8da5e2..4966a18e6592f 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -340,7 +340,7 @@ def from_predictions( """ check_matplotlib_support(f"{cls.__name__}.from_predictions") - target_type = type_of_target + target_type = type_of_target(y_true) if target_type != "binary": raise ValueError( f"The target y is not binary. Got {target_type} type of target." From 8214f8bb046177d8d1fb27e9ac2cf4119f88907b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 29 Nov 2021 16:47:37 +0100 Subject: [PATCH 06/12] make response_method to not be None --- sklearn/ensemble/_stacking.py | 6 +++++- sklearn/utils/__init__.py | 9 +++----- sklearn/utils/tests/test_utils.py | 7 +++--- sklearn/utils/tests/test_validation.py | 30 ++++---------------------- sklearn/utils/validation.py | 13 ++++------- 5 files changed, 19 insertions(+), 46 deletions(-) diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index 7be95673a093b..2f24b2a18fb32 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -107,7 +107,11 @@ def _concatenate_predictions(self, X, predictions): def _method_name(name, estimator, method): if estimator == "drop": return None - method = None if method == "auto" else method + method = ( + ["predict_proba", "decision_function", "predict"] + if method == "auto" + else method + ) try: method_name = _check_response_method(estimator, method).__name__ except AttributeError as e: diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index af48476d80a2b..e62ff1734de8c 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -1310,7 +1310,7 @@ def _get_response_values( The true label. response_method : {"predict_proba", "decision_function", "predict"} or \ - list of such str, default=None + list of such str Specifies the response method to use get prediction from an estimator (i.e. :term:`predict_proba`, :term:`decision_function` or :term:`predict`). Possible choices are: @@ -1318,10 +1318,7 @@ def _get_response_values( - if `str`, it corresponds to the name to the method to return; - if a list of `str`, it provides the method names in order of preference. The method returned corresponds to the first method in - the list and which is implemented by `estimator`; - - if `None`, :term:`predict_proba` is tried first and if it does not - exist :term:`decision_function` is tried next and :term:`predict` - last. + the list and which is implemented by `estimator`. pos_label : str or int, default=None The class considered as the positive class when computing @@ -1378,7 +1375,7 @@ def _get_response_values( if pos_label == classes[0]: y_pred *= -1 else: - if response_method not in ("predict", None): + if response_method != "predict": raise ValueError(f"{estimator.__class__.__name__} should be a classifier") y_pred, pos_label = estimator.predict(X), None diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 74eaa14394d93..da69401ae17c3 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -753,8 +753,7 @@ def test_get_response_values_regressor_error(response_method): _get_response_values(my_estimator, X, y, response_method=response_method) -@pytest.mark.parametrize("response_method", ["predict", None]) -def test_get_response_values_regressor(response_method): +def test_get_response_values_regressor(): """Check the behaviour of `_get_response_values` with regressor.""" X, y = make_regression(n_samples=10, random_state=0) regressor = LinearRegression().fit(X, y) @@ -762,7 +761,7 @@ def test_get_response_values_regressor(response_method): regressor, X, y, - response_method=response_method, + response_method="predict", ) assert_allclose(y_pred, regressor.predict(X)) assert pos_label is None @@ -770,7 +769,7 @@ def test_get_response_values_regressor(response_method): @pytest.mark.parametrize( "response_method", - [None, "predict_proba", "decision_function", "predict"], + ["predict_proba", "decision_function", "predict"], ) def test_get_response_values_classifier_unknown_pos_label(response_method): """Check that `_get_response_values` raises the proper error message with diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 155efb2461728..807ceb80b6c3d 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -1704,40 +1704,18 @@ def test_check_response_method_unknown_method(): @pytest.mark.parametrize( - "response_method", ["decision_function", "predict_proba", "predict", None] + "response_method", ["decision_function", "predict_proba", "predict"] ) def test_check_response_method_not_supported_response_method(response_method): """Check the error message when a response method is not supported by the estimator.""" - err_msg = "EstimatorWithFit has none of the following attributes: {}." - if response_method is None: - err_msg = err_msg.format("predict_proba, decision_function, predict") - else: - err_msg = err_msg.format(response_method) + err_msg = ( + f"EstimatorWithFit has none of the following attributes: {response_method}." + ) with pytest.raises(AttributeError, match=err_msg): _check_response_method(EstimatorWithFit(), response_method) -@pytest.mark.parametrize( - "response_methods, expected_method_name", - [ - (["predict_proba", "decision_function", "predict"], "predict_proba"), - (["decision_function", "predict"], "decision_function"), - (["predict_proba", "predict"], "predict_proba"), - (["predict_proba", "predict_proba"]), - (["decision_function", "decision_function"]), - (["predict"], "predict"), - ], -) -def test_check_response_method_order_None(response_methods, expected_method_name): - """Check the order of the response method when using None.""" - my_estimator = _MockEstimatorOnOffPrediction(response_methods) - - X = "mocking_data" - method_name_predicting = _check_response_method(my_estimator, None)(X) - assert method_name_predicting == expected_method_name - - def test_check_response_method_list_str(): """Check that we can pass a list of ordered method.""" method_implemented = ["predict_proba"] diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 511974a0d650f..2e007419da036 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1684,7 +1684,7 @@ def _check_sample_weight( return sample_weight -def _check_response_method(estimator, response_method=None): +def _check_response_method(estimator, response_method): """Check if `response_method` is available in estimator and return it. .. versionadded:: 1.1 @@ -1695,7 +1695,7 @@ def _check_response_method(estimator, response_method=None): Classifier or regressor to check. response_method : {"predict_proba", "decision_function", "predict"} or \ - list of such str, default=None + list of such str Specifies the response method to use get prediction from an estimator (i.e. :term:`predict_proba`, :term:`decision_function` or :term:`predict`). Possible choices are: @@ -1703,10 +1703,7 @@ def _check_response_method(estimator, response_method=None): - if `str`, it corresponds to the name to the method to return; - if a list of `str`, it provides the method names in order of preference. The method returned corresponds to the first method in - the list and which is implemented by `estimator`; - - if `None`, :term:`predict_proba` is tried first and if it does not - exist :term:`decision_function` is tried next and :term:`predict` - last. + the list and which is implemented by `estimator`. Returns ------- @@ -1718,9 +1715,7 @@ def _check_response_method(estimator, response_method=None): ValueError If `response_method` is not available in `estimator`. """ - if response_method is None: - list_methods = ["predict_proba", "decision_function", "predict"] - elif isinstance(response_method, str): + if isinstance(response_method, str): list_methods = [response_method] else: list_methods = response_method From d7e4912ed3885e1cfafad4f40466ab434faacbbd Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 29 Nov 2021 16:48:58 +0100 Subject: [PATCH 07/12] Update sklearn/utils/_mocking.py Co-authored-by: Thomas J. Fan --- sklearn/utils/_mocking.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sklearn/utils/_mocking.py b/sklearn/utils/_mocking.py index a0084fc37c2ba..e49c108fae56c 100644 --- a/sklearn/utils/_mocking.py +++ b/sklearn/utils/_mocking.py @@ -336,9 +336,7 @@ def _more_tags(self): def _check_response(method): def check(self): - if self.response_methods is not None and method in self.response_methods: - return True - return False + return self.response_methods is not None and method in self.response_methods: return check From 75bd0e18b4b34ab2d9f8884cc32d9dce54c112ca Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 29 Nov 2021 17:08:36 +0100 Subject: [PATCH 08/12] iter --- sklearn/utils/__init__.py | 2 +- sklearn/utils/_mocking.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index e62ff1734de8c..2e70c2622076f 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -1281,7 +1281,7 @@ def _get_response_values( estimator, X, y_true, - response_method=None, + response_method, pos_label=None, ): """Compute the response values of a classifier or a regressor. diff --git a/sklearn/utils/_mocking.py b/sklearn/utils/_mocking.py index e49c108fae56c..ef98a2dd71b91 100644 --- a/sklearn/utils/_mocking.py +++ b/sklearn/utils/_mocking.py @@ -336,7 +336,7 @@ def _more_tags(self): def _check_response(method): def check(self): - return self.response_methods is not None and method in self.response_methods: + return self.response_methods is not None and method in self.response_methods return check From 04b54c914a4658f6bfca47f50d6878c27980278b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 29 Nov 2021 17:28:29 +0100 Subject: [PATCH 09/12] iter --- sklearn/calibration.py | 11 +++++-- sklearn/metrics/_plot/base.py | 29 +++++++++++++++++-- sklearn/metrics/_plot/det_curve.py | 9 ++++-- .../metrics/_plot/precision_recall_curve.py | 9 ++++-- sklearn/metrics/_plot/roc_curve.py | 9 ++++-- sklearn/utils/__init__.py | 17 ++++++++--- 6 files changed, 65 insertions(+), 19 deletions(-) diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 8c7db349cf4c8..da1d5d77f016e 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -1218,10 +1218,15 @@ def from_estimator( method_name = f"{cls.__name__}.from_estimator" check_matplotlib_support(method_name) - _check_estimator_and_target_is_binary(estimator, y) - + target_type = type_of_target(y) + _check_estimator_and_target_is_binary(estimator, y, target_type=target_type) y_prob, pos_label = _get_response_values( - estimator, X, y, response_method="predict_proba", pos_label=pos_label + estimator, + X, + y, + response_method="predict_proba", + pos_label=pos_label, + target_type=target_type, ) name = name if name is not None else estimator.__class__.__name__ diff --git a/sklearn/metrics/_plot/base.py b/sklearn/metrics/_plot/base.py index 6d279e0aeaa22..790ba4660d5dd 100644 --- a/sklearn/metrics/_plot/base.py +++ b/sklearn/metrics/_plot/base.py @@ -4,8 +4,29 @@ from ...utils.validation import check_is_fitted -def _check_estimator_and_target_is_binary(estimator, y): - """Helper to check that estimator is a binary classifier and y is binary.""" +def _check_estimator_and_target_is_binary(estimator, y, target_type=None): + """Helper to check that estimator is a binary classifier and y is binary. + + Parameters + ---------- + estimator : estimator instance + An estimator that should be used to predict the target. + + y : ndarray + The associated target. + + target_type : str, default=None + The type of the target `y` as returned by + :func:`~sklearn.utils.multiclass.type_of_target`. If `None`, the type + will be inferred by calling :func:`~sklearn.utils.multiclass.type_of_target`. + Providing the type of the target could save time by avoid calling the + :func:`~sklearn.utils.multiclass.type_of_target` function. + + Raises + ------ + ValueError + If the estimator or the target are not binary. + """ try: check_is_fitted(estimator) except NotFittedError as e: @@ -26,7 +47,9 @@ def _check_estimator_and_target_is_binary(estimator, y): "classifier. It was fitted on multiclass problem with " f"{len(estimator.classes_)} classes." ) - target_type = type_of_target(y) + + if target_type is None: + target_type = type_of_target(y) if target_type != "binary": raise ValueError( f"The target y is not binary. Got {target_type} type of target." diff --git a/sklearn/metrics/_plot/det_curve.py b/sklearn/metrics/_plot/det_curve.py index 7065979f39ce1..b4007f4f13768 100644 --- a/sklearn/metrics/_plot/det_curve.py +++ b/sklearn/metrics/_plot/det_curve.py @@ -172,7 +172,8 @@ def from_estimator( """ check_matplotlib_support(f"{cls.__name__}.from_estimator") - _check_estimator_and_target_is_binary(estimator, y) + target_type = type_of_target(y) + _check_estimator_and_target_is_binary(estimator, y, target_type=target_type) if response_method == "auto": response_method = ["predict_proba", "decision_function"] @@ -184,6 +185,7 @@ def from_estimator( y, response_method, pos_label=pos_label, + target_type=target_type, ) return cls.from_predictions( @@ -470,12 +472,13 @@ def plot_det_curve( """ check_matplotlib_support("plot_det_curve") - _check_estimator_and_target_is_binary(estimator, y) + target_type = type_of_target(y) + _check_estimator_and_target_is_binary(estimator, y, target_type=target_type) if response_method == "auto": response_method = ["predict_proba", "decision_function"] y_pred, pos_label = _get_response_values( - estimator, X, y, response_method, pos_label=pos_label + estimator, X, y, response_method, pos_label=pos_label, target_type=target_type ) fpr, fnr, _ = det_curve( diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 6a29913023b9d..87500bf94a7da 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -240,7 +240,8 @@ def from_estimator( method_name = f"{cls.__name__}.from_estimator" check_matplotlib_support(method_name) - _check_estimator_and_target_is_binary(estimator, y) + target_type = type_of_target(y) + _check_estimator_and_target_is_binary(estimator, y, target_type=target_type) if response_method == "auto": response_method = ["predict_proba", "decision_function"] @@ -250,6 +251,7 @@ def from_estimator( y, response_method, pos_label=pos_label, + target_type=target_type, ) name = name if name is not None else estimator.__class__.__name__ @@ -444,13 +446,14 @@ def plot_precision_recall_curve( """ check_matplotlib_support("plot_precision_recall_curve") - _check_estimator_and_target_is_binary(estimator, y) + target_type = type_of_target(y) + _check_estimator_and_target_is_binary(estimator, y, target_type=target_type) if response_method == "auto": response_method = ["predict_proba", "decision_function"] y_pred, pos_label = _get_response_values( - estimator, X, y, response_method, pos_label=pos_label + estimator, X, y, response_method, pos_label=pos_label, target_type=target_type ) precision, recall, _ = precision_recall_curve( diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 4966a18e6592f..379ef14d5482f 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -231,7 +231,8 @@ def from_estimator( """ check_matplotlib_support(f"{cls.__name__}.from_estimator") - _check_estimator_and_target_is_binary(estimator, y) + target_type = type_of_target(y) + _check_estimator_and_target_is_binary(estimator, y, target_type=target_type) if response_method == "auto": response_method = ["predict_proba", "decision_function"] @@ -243,6 +244,7 @@ def from_estimator( y, response_method=response_method, pos_label=pos_label, + target_type=target_type, ) return cls.from_predictions( @@ -464,12 +466,13 @@ def plot_roc_curve( """ check_matplotlib_support("plot_roc_curve") - _check_estimator_and_target_is_binary(estimator, y) + target_type = type_of_target(y) + _check_estimator_and_target_is_binary(estimator, y, target_type=target_type) if response_method == "auto": response_method = ["predict_proba", "decision_function"] y_pred, pos_label = _get_response_values( - estimator, X, y, response_method, pos_label=pos_label + estimator, X, y, response_method, pos_label=pos_label, target_type=target_type ) fpr, tpr, _ = roc_curve( diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 2e70c2622076f..2e9fd5c9f0f29 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -1283,6 +1283,7 @@ def _get_response_values( y_true, response_method, pos_label=None, + target_type=None, ): """Compute the response values of a classifier or a regressor. @@ -1325,6 +1326,13 @@ def _get_response_values( the metrics. By default, `estimators.classes_[1]` is considered as the positive class. + target_type : str, default=None + The type of the target `y` as returned by + :func:`~sklearn.utils.multiclass.type_of_target`. If `None`, the type + will be inferred by calling :func:`~sklearn.utils.multiclass.type_of_target`. + Providing the type of the target could save time by avoid calling the + :func:`~sklearn.utils.multiclass.type_of_target` function. + Returns ------- y_pred : ndarray of shape (n_samples,) @@ -1346,7 +1354,8 @@ def _get_response_values( from sklearn.base import is_classifier # noqa if is_classifier(estimator): - y_type = type_of_target(y_true) + if target_type is None: + target_type = type_of_target(y_true) prediction_method = _check_response_method(estimator, response_method) y_pred = prediction_method(X) classes = estimator.classes_ @@ -1356,11 +1365,11 @@ def _get_response_values( f"pos_label={pos_label} is not a valid label: It should be " f"one of {classes}" ) - elif pos_label is None and y_type == "binary": + elif pos_label is None and target_type == "binary": pos_label = pos_label if pos_label is not None else classes[-1] if prediction_method.__name__ == "predict_proba": - if y_type == "binary" and y_pred.shape[1] <= 2: + if target_type == "binary" and y_pred.shape[1] <= 2: if y_pred.shape[1] == 2: col_idx = np.flatnonzero(classes == pos_label)[0] y_pred = y_pred[:, col_idx] @@ -1371,7 +1380,7 @@ def _get_response_values( ) raise ValueError(err_msg) elif prediction_method.__name__ == "decision_function": - if y_type == "binary": + if target_type == "binary": if pos_label == classes[0]: y_pred *= -1 else: From 9e834b4b9ede1d43490d9a05b9282e91562b50f7 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 29 Nov 2021 17:42:38 +0100 Subject: [PATCH 10/12] iter --- sklearn/utils/tests/test_utils.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index da69401ae17c3..ba93d371f4029 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -753,7 +753,8 @@ def test_get_response_values_regressor_error(response_method): _get_response_values(my_estimator, X, y, response_method=response_method) -def test_get_response_values_regressor(): +@pytest.mark.parametrize("target_type", [None, "continuous"]) +def test_get_response_values_regressor(target_type): """Check the behaviour of `_get_response_values` with regressor.""" X, y = make_regression(n_samples=10, random_state=0) regressor = LinearRegression().fit(X, y) @@ -762,6 +763,7 @@ def test_get_response_values_regressor(): X, y, response_method="predict", + target_type=target_type, ) assert_allclose(y_pred, regressor.predict(X)) assert pos_label is None @@ -777,7 +779,7 @@ def test_get_response_values_classifier_unknown_pos_label(response_method): X, y = make_classification(n_samples=10, n_classes=2, random_state=0) classifier = LogisticRegression().fit(X, y) - # provide a `pos_labe` which is not in `y` + # provide a `pos_label` which is not in `y` err_msg = r"pos_label=whatever is not a valid label: It should be one of \[0 1\]" with pytest.raises(ValueError, match=err_msg): _get_response_values( @@ -806,7 +808,8 @@ def test_get_response_values_classifier_inconsistent_y_pred_for_binary_proba(): ) -def test_get_response_values_binary_classifier_decision_function(): +@pytest.mark.parametrize("target_type", [None, "binary"]) +def test_get_response_values_binary_classifier_decision_function(target_type): """Check the behaviour of `_get_response_values` with `decision_function` and binary classifier. """ @@ -821,7 +824,12 @@ def test_get_response_values_binary_classifier_decision_function(): # default `pos_label` y_pred, pos_label = _get_response_values( - classifier, X, y, response_method=response_method, pos_label=None + classifier, + X, + y, + response_method=response_method, + pos_label=None, + target_type=target_type, ) assert_allclose(y_pred, classifier.decision_function(X)) assert pos_label == 1 @@ -833,12 +841,14 @@ def test_get_response_values_binary_classifier_decision_function(): y, response_method=response_method, pos_label=classifier.classes_[0], + target_type=target_type, ) assert_allclose(y_pred, classifier.decision_function(X) * -1) assert pos_label == 0 -def test_get_response_values_binary_classifier_predict_proba(): +@pytest.mark.parametrize("target_type", [None, "binary"]) +def test_get_response_values_binary_classifier_predict_proba(target_type): """Check that `_get_response_values` with `predict_proba` and binary classifier.""" X, y = make_classification( @@ -852,7 +862,12 @@ def test_get_response_values_binary_classifier_predict_proba(): # default `pos_label` y_pred, pos_label = _get_response_values( - classifier, X, y, response_method=response_method, pos_label=None + classifier, + X, + y, + response_method=response_method, + pos_label=None, + target_type=target_type, ) assert_allclose(y_pred, classifier.predict_proba(X)[:, 1]) assert pos_label == 1 @@ -864,6 +879,7 @@ def test_get_response_values_binary_classifier_predict_proba(): y, response_method=response_method, pos_label=classifier.classes_[0], + target_type=target_type, ) assert_allclose(y_pred, classifier.predict_proba(X)[:, 0]) assert pos_label == 0 From 34877016b3bf837840b7cc4feeb9bdc1dff0b9c0 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 29 Nov 2021 17:47:03 +0100 Subject: [PATCH 11/12] iter --- sklearn/metrics/_plot/tests/test_base.py | 41 +++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/tests/test_base.py b/sklearn/metrics/_plot/tests/test_base.py index cd4ced34cc358..9cc8a7152d8d7 100644 --- a/sklearn/metrics/_plot/tests/test_base.py +++ b/sklearn/metrics/_plot/tests/test_base.py @@ -11,29 +11,68 @@ @pytest.mark.parametrize( - "estimator, target, err_type, err_msg", + "estimator, target, target_type, err_type, err_msg", [ ( DecisionTreeClassifier(), y_binary, + None, + NotFittedError, + "This DecisionTreeClassifier instance is not fitted yet", + ), + ( + DecisionTreeClassifier(), + y_binary, + "binary", NotFittedError, "This DecisionTreeClassifier instance is not fitted yet", ), ( DecisionTreeRegressor().fit(X_binary, y_binary), y_binary, + None, ValueError, "This plotting functionalities only support a binary classifier", ), + ( + DecisionTreeRegressor().fit(X_binary, y_binary), + y_binary, + "binary", + ValueError, + "This plotting functionalities only support a binary classifier", + ), + ( + DecisionTreeClassifier().fit(X, y), + y, + None, + ValueError, + "This DecisionTreeClassifier instance is not a binary classifier", + ), ( DecisionTreeClassifier().fit(X, y), y, + "multiclass", ValueError, "This DecisionTreeClassifier instance is not a binary classifier", ), + ( + DecisionTreeClassifier().fit(X, y), + y_binary, + "multiclass", + ValueError, + "This DecisionTreeClassifier instance is not a binary classifier", + ), + ( + DecisionTreeClassifier().fit(X_binary, y_binary), + y, + None, + ValueError, + "The target y is not binary", + ), ( DecisionTreeClassifier().fit(X_binary, y_binary), y, + "multiclass", ValueError, "The target y is not binary", ), From 2920e32428376d8897ca57f1a49635e2672e0d9d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 29 Nov 2021 17:47:48 +0100 Subject: [PATCH 12/12] iter --- sklearn/metrics/_plot/tests/test_base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_base.py b/sklearn/metrics/_plot/tests/test_base.py index 9cc8a7152d8d7..fd8f5f6b0b337 100644 --- a/sklearn/metrics/_plot/tests/test_base.py +++ b/sklearn/metrics/_plot/tests/test_base.py @@ -78,7 +78,9 @@ ), ], ) -def test_check_estimator_and_target_is_binary(estimator, target, err_type, err_msg): +def test_check_estimator_and_target_is_binary( + estimator, target, target_type, err_type, err_msg +): """Check that we raise the expected error when checking the estimator and target.""" with pytest.raises(err_type, match=err_msg): - _check_estimator_and_target_is_binary(estimator, target) + _check_estimator_and_target_is_binary(estimator, target, target_type)