-
-
Notifications
You must be signed in to change notification settings - Fork 26k
MNT refactor _get_response_values #21538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
51a3bcb
8675def
64dfc86
5a4f33f
a7e46da
a69176d
8214f8b
d7e4912
75bd0e1
04b54c9
9e834b4
3487701
2920e32
8863fad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -30,8 +30,11 @@ | |||||||||||||||
from ..utils import Bunch | ||||||||||||||||
from ..utils.metaestimators import available_if | ||||||||||||||||
from ..utils.multiclass import check_classification_targets | ||||||||||||||||
from ..utils.validation import check_is_fitted | ||||||||||||||||
from ..utils.validation import column_or_1d | ||||||||||||||||
from ..utils.validation import ( | ||||||||||||||||
_check_response_method, | ||||||||||||||||
check_is_fitted, | ||||||||||||||||
column_or_1d, | ||||||||||||||||
) | ||||||||||||||||
from ..utils.fixes import delayed | ||||||||||||||||
from ..utils.validation import _check_feature_names_in | ||||||||||||||||
|
||||||||||||||||
|
@@ -120,21 +123,18 @@ def _concatenate_predictions(self, X, predictions): | |||||||||||||||
def _method_name(name, estimator, method): | ||||||||||||||||
if estimator == "drop": | ||||||||||||||||
return None | ||||||||||||||||
if method == "auto": | ||||||||||||||||
if getattr(estimator, "predict_proba", None): | ||||||||||||||||
return "predict_proba" | ||||||||||||||||
elif getattr(estimator, "decision_function", None): | ||||||||||||||||
return "decision_function" | ||||||||||||||||
else: | ||||||||||||||||
return "predict" | ||||||||||||||||
else: | ||||||||||||||||
if not hasattr(estimator, method): | ||||||||||||||||
raise ValueError( | ||||||||||||||||
"Underlying estimator {} does not implement the method {}.".format( | ||||||||||||||||
name, method | ||||||||||||||||
) | ||||||||||||||||
) | ||||||||||||||||
return method | ||||||||||||||||
method = ( | ||||||||||||||||
["predict_proba", "decision_function", "predict"] | ||||||||||||||||
if method == "auto" | ||||||||||||||||
else method | ||||||||||||||||
) | ||||||||||||||||
Comment on lines
+126
to
+130
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A bit lighter to parse IMO (and also what is used in other places in this PR):
Suggested change
|
||||||||||||||||
try: | ||||||||||||||||
method_name = _check_response_method(estimator, method).__name__ | ||||||||||||||||
except AttributeError as e: | ||||||||||||||||
raise ValueError( | ||||||||||||||||
f"Underlying estimator {name} does not implement the method {method}." | ||||||||||||||||
) from e | ||||||||||||||||
return method_name | ||||||||||||||||
|
||||||||||||||||
def fit(self, X, y, sample_weight=None): | ||||||||||||||||
"""Fit the estimators. | ||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -1,116 +1,56 @@ | ||||||||||
from ...base import is_classifier | ||||||||||
from ...exceptions import NotFittedError | ||||||||||
from ...utils.multiclass import type_of_target | ||||||||||
from ...utils.validation import check_is_fitted | ||||||||||
|
||||||||||
|
||||||||||
def _check_classifier_response_method(estimator, response_method): | ||||||||||
"""Return prediction method from the response_method | ||||||||||
def _check_estimator_and_target_is_binary(estimator, y, target_type=None): | ||||||||||
"""Helper to check that estimator is a binary classifier and y is binary. | ||||||||||
|
||||||||||
Parameters | ||||||||||
---------- | ||||||||||
estimator: object | ||||||||||
Classifier to check | ||||||||||
|
||||||||||
response_method: {'auto', 'predict_proba', 'decision_function'} | ||||||||||
Specifies whether to use :term:`predict_proba` or | ||||||||||
:term:`decision_function` as the target response. If set to 'auto', | ||||||||||
:term:`predict_proba` is tried first and if it does not exist | ||||||||||
:term:`decision_function` is tried next. | ||||||||||
|
||||||||||
Returns | ||||||||||
------- | ||||||||||
prediction_method: callable | ||||||||||
prediction method of estimator | ||||||||||
""" | ||||||||||
|
||||||||||
if response_method not in ("predict_proba", "decision_function", "auto"): | ||||||||||
raise ValueError( | ||||||||||
"response_method must be 'predict_proba', 'decision_function' or 'auto'" | ||||||||||
) | ||||||||||
|
||||||||||
error_msg = "response method {} is not defined in {}" | ||||||||||
if response_method != "auto": | ||||||||||
prediction_method = getattr(estimator, response_method, None) | ||||||||||
if prediction_method is None: | ||||||||||
raise ValueError( | ||||||||||
error_msg.format(response_method, estimator.__class__.__name__) | ||||||||||
) | ||||||||||
else: | ||||||||||
predict_proba = getattr(estimator, "predict_proba", None) | ||||||||||
decision_function = getattr(estimator, "decision_function", None) | ||||||||||
prediction_method = predict_proba or decision_function | ||||||||||
if prediction_method is None: | ||||||||||
raise ValueError( | ||||||||||
error_msg.format( | ||||||||||
"decision_function or predict_proba", estimator.__class__.__name__ | ||||||||||
) | ||||||||||
) | ||||||||||
|
||||||||||
return prediction_method | ||||||||||
|
||||||||||
|
||||||||||
def _get_response(X, estimator, response_method, pos_label=None): | ||||||||||
"""Return response and positive label. | ||||||||||
|
||||||||||
Parameters | ||||||||||
---------- | ||||||||||
X : {array-like, sparse matrix} of shape (n_samples, n_features) | ||||||||||
Input values. | ||||||||||
|
||||||||||
estimator : estimator instance | ||||||||||
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` | ||||||||||
in which the last estimator is a classifier. | ||||||||||
|
||||||||||
response_method: {'auto', 'predict_proba', 'decision_function'} | ||||||||||
Specifies whether to use :term:`predict_proba` or | ||||||||||
:term:`decision_function` as the target response. If set to 'auto', | ||||||||||
:term:`predict_proba` is tried first and if it does not exist | ||||||||||
:term:`decision_function` is tried next. | ||||||||||
|
||||||||||
pos_label : str or int, default=None | ||||||||||
The class considered as the positive class when computing | ||||||||||
the metrics. By default, `estimators.classes_[1]` is | ||||||||||
considered as the positive class. | ||||||||||
|
||||||||||
Returns | ||||||||||
------- | ||||||||||
y_pred: ndarray of shape (n_samples,) | ||||||||||
Target scores calculated from the provided response_method | ||||||||||
and pos_label. | ||||||||||
|
||||||||||
pos_label: str or int | ||||||||||
The class considered as the positive class when computing | ||||||||||
the metrics. | ||||||||||
An estimator that should be used to predict the target. | ||||||||||
|
||||||||||
y : ndarray | ||||||||||
The associated target. | ||||||||||
|
||||||||||
target_type : str, default=None | ||||||||||
The type of the target `y` as returned by | ||||||||||
:func:`~sklearn.utils.multiclass.type_of_target`. If `None`, the type | ||||||||||
will be inferred by calling :func:`~sklearn.utils.multiclass.type_of_target`. | ||||||||||
Providing the type of the target could save time by avoid calling the | ||||||||||
:func:`~sklearn.utils.multiclass.type_of_target` function. | ||||||||||
|
||||||||||
Raises | ||||||||||
------ | ||||||||||
ValueError | ||||||||||
If the estimator or the target are not binary. | ||||||||||
""" | ||||||||||
classification_error = ( | ||||||||||
"Expected 'estimator' to be a binary classifier, but got" | ||||||||||
f" {estimator.__class__.__name__}" | ||||||||||
) | ||||||||||
try: | ||||||||||
check_is_fitted(estimator) | ||||||||||
except NotFittedError as e: | ||||||||||
raise NotFittedError( | ||||||||||
f"This {estimator.__class__.__name__} instance is not fitted yet. Call " | ||||||||||
"'fit' with appropriate arguments before intending to use it to plotting " | ||||||||||
"functionalities." | ||||||||||
Comment on lines
+35
to
+36
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe a bit lighter to read (better suggestion welcome):
Suggested change
|
||||||||||
) from e | ||||||||||
|
||||||||||
if not is_classifier(estimator): | ||||||||||
raise ValueError(classification_error) | ||||||||||
|
||||||||||
prediction_method = _check_classifier_response_method(estimator, response_method) | ||||||||||
y_pred = prediction_method(X) | ||||||||||
if pos_label is not None: | ||||||||||
try: | ||||||||||
class_idx = estimator.classes_.tolist().index(pos_label) | ||||||||||
except ValueError as e: | ||||||||||
raise ValueError( | ||||||||||
"The class provided by 'pos_label' is unknown. Got " | ||||||||||
f"{pos_label} instead of one of {set(estimator.classes_)}" | ||||||||||
) from e | ||||||||||
else: | ||||||||||
class_idx = 1 | ||||||||||
pos_label = estimator.classes_[class_idx] | ||||||||||
|
||||||||||
if y_pred.ndim != 1: # `predict_proba` | ||||||||||
y_pred_shape = y_pred.shape[1] | ||||||||||
if y_pred_shape != 2: | ||||||||||
raise ValueError( | ||||||||||
f"{classification_error} fit on multiclass ({y_pred_shape} classes)" | ||||||||||
" data" | ||||||||||
) | ||||||||||
y_pred = y_pred[:, class_idx] | ||||||||||
elif pos_label == estimator.classes_[0]: # `decision_function` | ||||||||||
y_pred *= -1 | ||||||||||
raise ValueError( | ||||||||||
"This plotting functionalities only support a binary classifier. " | ||||||||||
f"Got a {estimator.__class__.__name__} instead." | ||||||||||
) | ||||||||||
elif len(estimator.classes_) != 2: | ||||||||||
raise ValueError( | ||||||||||
f"This {estimator.__class__.__name__} instance is not a binary " | ||||||||||
"classifier. It was fitted on multiclass problem with " | ||||||||||
f"{len(estimator.classes_)} classes." | ||||||||||
) | ||||||||||
|
||||||||||
return y_pred, pos_label | ||||||||||
if target_type is None: | ||||||||||
target_type = type_of_target(y) | ||||||||||
if target_type != "binary": | ||||||||||
raise ValueError( | ||||||||||
f"The target y is not binary. Got {target_type} type of target." | ||||||||||
) |
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,12 +1,16 @@ | ||||||||||||
import scipy as sp | ||||||||||||
|
||||||||||||
from .base import _get_response | ||||||||||||
from .base import _check_estimator_and_target_is_binary | ||||||||||||
|
||||||||||||
from .. import det_curve | ||||||||||||
from .._base import _check_pos_label_consistency | ||||||||||||
|
||||||||||||
from ...utils import check_matplotlib_support | ||||||||||||
from ...utils import deprecated | ||||||||||||
from ...utils import ( | ||||||||||||
check_matplotlib_support, | ||||||||||||
deprecated, | ||||||||||||
_get_response_values, | ||||||||||||
) | ||||||||||||
from ...utils.multiclass import type_of_target | ||||||||||||
|
||||||||||||
|
||||||||||||
class DetCurveDisplay: | ||||||||||||
|
@@ -168,13 +172,20 @@ def from_estimator( | |||||||||||
""" | ||||||||||||
check_matplotlib_support(f"{cls.__name__}.from_estimator") | ||||||||||||
|
||||||||||||
target_type = type_of_target(y) | ||||||||||||
_check_estimator_and_target_is_binary(estimator, y, target_type=target_type) | ||||||||||||
if response_method == "auto": | ||||||||||||
response_method = ["predict_proba", "decision_function"] | ||||||||||||
|
||||||||||||
name = estimator.__class__.__name__ if name is None else name | ||||||||||||
|
||||||||||||
y_pred, pos_label = _get_response( | ||||||||||||
X, | ||||||||||||
y_pred, pos_label = _get_response_values( | ||||||||||||
estimator, | ||||||||||||
X, | ||||||||||||
y, | ||||||||||||
response_method, | ||||||||||||
pos_label=pos_label, | ||||||||||||
target_type=target_type, | ||||||||||||
) | ||||||||||||
|
||||||||||||
return cls.from_predictions( | ||||||||||||
|
@@ -265,6 +276,13 @@ def from_predictions( | |||||||||||
>>> plt.show() | ||||||||||||
""" | ||||||||||||
check_matplotlib_support(f"{cls.__name__}.from_predictions") | ||||||||||||
|
||||||||||||
target_type = type_of_target(y_true) | ||||||||||||
if target_type != "binary": | ||||||||||||
raise ValueError( | ||||||||||||
f"The target y is not binary. Got {target_type} type of target." | ||||||||||||
) | ||||||||||||
Comment on lines
+280
to
+284
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this check here, since we do the check in scikit-learn/sklearn/metrics/_ranking.py Lines 310 to 314 in 845b1fa
|
||||||||||||
|
||||||||||||
fpr, fnr, _ = det_curve( | ||||||||||||
y_true, | ||||||||||||
y_pred, | ||||||||||||
|
@@ -454,8 +472,13 @@ def plot_det_curve( | |||||||||||
""" | ||||||||||||
check_matplotlib_support("plot_det_curve") | ||||||||||||
|
||||||||||||
y_pred, pos_label = _get_response( | ||||||||||||
X, estimator, response_method, pos_label=pos_label | ||||||||||||
target_type = type_of_target(y) | ||||||||||||
_check_estimator_and_target_is_binary(estimator, y, target_type=target_type) | ||||||||||||
if response_method == "auto": | ||||||||||||
response_method = ["predict_proba", "decision_function"] | ||||||||||||
|
||||||||||||
y_pred, pos_label = _get_response_values( | ||||||||||||
estimator, X, y, response_method, pos_label=pos_label, target_type=target_type | ||||||||||||
) | ||||||||||||
|
||||||||||||
fpr, fnr, _ = det_curve( | ||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.