Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[API, MAINT] Deprecate usage of y_prob and probas_pred in sklearn.metrics #28092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
8dacf5e
Adding api deprecation
adam2392 Jan 10, 2024
91fe2f5
Update pr number
adam2392 Jan 10, 2024
213f9fb
Fix lint
adam2392 Jan 10, 2024
16ff269
Merge branch 'main' into ndim
adam2392 Jan 11, 2024
19e92c5
Fix warnings
adam2392 Jan 11, 2024
60f303f
Merge branch 'ndim' of https://github.com/adam2392/scikit-learn into …
adam2392 Jan 11, 2024
7272f57
Merge branch 'main' into ndim
adam2392 Jan 18, 2024
06917fa
Merge branch 'main' into ndim
adam2392 Jan 19, 2024
cec7fff
Merge branch 'main' into ndim
adam2392 Jan 22, 2024
ee67505
Merge branch 'main' into ndim
adam2392 Jan 25, 2024
5073f51
Apply suggestions from code review
adam2392 Jan 25, 2024
fda3ef0
Merge branch 'ndim' of https://github.com/adam2392/scikit-learn into …
adam2392 Jan 25, 2024
5c0ecca
Address comments
adam2392 Jan 25, 2024
c279517
Move changelog
adam2392 Jan 25, 2024
d6be8ab
Move changelog
adam2392 Jan 25, 2024
ee4b64d
1.7
adam2392 Jan 25, 2024
7e6bfa1
Merge branch 'main' into ndim
adam2392 Jan 25, 2024
3ded4e6
Allow correct comparison
adam2392 Jan 25, 2024
538b658
Merge branch 'ndim' of https://github.com/adam2392/scikit-learn into …
adam2392 Jan 25, 2024
397cb2f
Merge branch 'main' into ndim
adam2392 Jan 29, 2024
f85eb0e
Merge branch 'main' into ndim
adam2392 Feb 1, 2024
d6f3b56
Merge branch 'main' into ndim
adam2392 Feb 13, 2024
2cc25be
Merge branch 'main' into ndim
adam2392 Feb 15, 2024
8bb3c50
Apply suggestions from code review
adam2392 Feb 16, 2024
a6e7eea
Fix unit test
adam2392 Feb 16, 2024
545013d
Fix unit test
adam2392 Feb 16, 2024
ba59056
Merge branch 'main' into ndim
adam2392 Feb 16, 2024
50dea0f
Fix unit tests
adam2392 Feb 16, 2024
62dc74e
Fix lint' -s
adam2392 Feb 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ Changelog
:class:`~calibration.CalibrationDisplay`.
:pr:`28051` by :user:`Pierre de Fréminville <pidefrem>`.

- |API| :func:`metrics.precision_recall_curve` deprecated the keyword argument `probas_pred`
in favor of `y_score`. `probas_pred` will be removed in version 1.7.
:pr:`28092` by :user:`Adam Li <adam2392>`.

- |API| :func:`metrics.brier_score_loss` deprecated the keyword argument `y_prob`
in favor of `y_proba`. `y_prob` will be removed in version 1.7.
:pr:`28092` by :user:`Adam Li <adam2392>`.

:mod:`sklearn.model_selection`
..............................

Expand Down
58 changes: 46 additions & 12 deletions sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@
column_or_1d,
)
from ..utils._array_api import _union1d, _weighted_sum, get_namespace
from ..utils._param_validation import Interval, Options, StrOptions, validate_params
from ..utils._param_validation import (
Hidden,
Interval,
Options,
StrOptions,
validate_params,
)
from ..utils.extmath import _nanaverage
from ..utils.multiclass import type_of_target, unique_labels
from ..utils.sparsefuncs import count_nonzero
Expand Down Expand Up @@ -3146,13 +3152,16 @@ def hinge_loss(y_true, pred_decision, *, labels=None, sample_weight=None):
@validate_params(
{
"y_true": ["array-like"],
"y_prob": ["array-like"],
"y_proba": ["array-like", Hidden(None)],
"sample_weight": ["array-like", None],
"pos_label": [Real, str, "boolean", None],
"y_prob": ["array-like", Hidden(StrOptions({"deprecated"}))],
},
prefer_skip_nested_validation=True,
)
def brier_score_loss(y_true, y_prob, *, sample_weight=None, pos_label=None):
def brier_score_loss(
y_true, y_proba=None, *, sample_weight=None, pos_label=None, y_prob="deprecated"
):
"""Compute the Brier score loss.

The smaller the Brier score loss, the better, hence the naming with "loss".
Expand Down Expand Up @@ -3180,7 +3189,7 @@ def brier_score_loss(y_true, y_prob, *, sample_weight=None, pos_label=None):
y_true : array-like of shape (n_samples,)
True targets.

y_prob : array-like of shape (n_samples,)
y_proba : array-like of shape (n_samples,)
Probabilities of the positive class.

sample_weight : array-like of shape (n_samples,), default=None
Expand All @@ -3196,6 +3205,13 @@ def brier_score_loss(y_true, y_prob, *, sample_weight=None, pos_label=None):
* otherwise, `pos_label` defaults to the greater label,
i.e. `np.unique(y_true)[-1]`.

y_prob : array-like of shape (n_samples,)
Probabilities of the positive class.

.. deprecated:: 1.5
`y_prob` is deprecated and will be removed in 1.7. Use
`y_proba` instead.

Returns
-------
score : float
Expand All @@ -3222,11 +3238,29 @@ def brier_score_loss(y_true, y_prob, *, sample_weight=None, pos_label=None):
>>> brier_score_loss(y_true, np.array(y_prob) > 0.5)
0.0
"""
# TODO(1.7): remove in 1.7 and reset y_proba to be required
# Note: validate params will raise an error if y_prob is not array-like,
# or "deprecated"
if y_proba is not None and not isinstance(y_prob, str):
raise ValueError(
"`y_prob` and `y_proba` cannot be both specified. Please use `y_proba` only"
" as `y_prob` is deprecated in v1.5 and will be removed in v1.7."
)
if y_proba is None:
warnings.warn(
(
"y_prob was deprecated in version 1.5 and will be removed in 1.7."
"Please use ``y_proba`` instead."
),
FutureWarning,
)
y_proba = y_prob

y_true = column_or_1d(y_true)
y_prob = column_or_1d(y_prob)
y_proba = column_or_1d(y_proba)
assert_all_finite(y_true)
assert_all_finite(y_prob)
check_consistent_length(y_true, y_prob, sample_weight)
assert_all_finite(y_proba)
check_consistent_length(y_true, y_proba, sample_weight)

y_type = type_of_target(y_true, input_name="y_true")
if y_type != "binary":
Expand All @@ -3235,10 +3269,10 @@ def brier_score_loss(y_true, y_prob, *, sample_weight=None, pos_label=None):
f"is {y_type}."
)

if y_prob.max() > 1:
raise ValueError("y_prob contains values greater than 1.")
if y_prob.min() < 0:
raise ValueError("y_prob contains values less than 0.")
if y_proba.max() > 1:
raise ValueError("y_proba contains values greater than 1.")
if y_proba.min() < 0:
raise ValueError("y_proba contains values less than 0.")

try:
pos_label = _check_pos_label_consistency(pos_label, y_true)
Expand All @@ -3251,4 +3285,4 @@ def brier_score_loss(y_true, y_prob, *, sample_weight=None, pos_label=None):
else:
raise
y_true = np.array(y_true == pos_label, int)
return np.average((y_true - y_prob) ** 2, weights=sample_weight)
return np.average((y_true - y_proba) ** 2, weights=sample_weight)
47 changes: 42 additions & 5 deletions sklearn/metrics/_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
column_or_1d,
)
from ..utils._encode import _encode, _unique
from ..utils._param_validation import Interval, StrOptions, validate_params
from ..utils._param_validation import Hidden, Interval, StrOptions, validate_params
from ..utils.extmath import stable_cumsum
from ..utils.fixes import trapezoid
from ..utils.multiclass import type_of_target
Expand Down Expand Up @@ -865,15 +865,25 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):
@validate_params(
{
"y_true": ["array-like"],
"probas_pred": ["array-like"],
"y_score": ["array-like", Hidden(None)],
"pos_label": [Real, str, "boolean", None],
"sample_weight": ["array-like", None],
"drop_intermediate": ["boolean"],
"probas_pred": [
"array-like",
Hidden(StrOptions({"deprecated"})),
],
},
prefer_skip_nested_validation=True,
)
def precision_recall_curve(
y_true, probas_pred, *, pos_label=None, sample_weight=None, drop_intermediate=False
y_true,
y_score=None,
*,
pos_label=None,
sample_weight=None,
drop_intermediate=False,
probas_pred="deprecated",
):
"""Compute precision-recall pairs for different probability thresholds.
Expand Down Expand Up @@ -903,7 +913,7 @@ def precision_recall_curve(
True binary labels. If labels are not either {-1, 1} or {0, 1}, then
pos_label should be explicitly given.
probas_pred : array-like of shape (n_samples,)
y_score : array-like of shape (n_samples,)
Target scores, can either be probability estimates of the positive
class, or non-thresholded measure of decisions (as returned by
`decision_function` on some classifiers).
Expand All @@ -923,6 +933,15 @@ def precision_recall_curve(
.. versionadded:: 1.3
probas_pred : array-like of shape (n_samples,)
Target scores, can either be probability estimates of the positive
class, or non-thresholded measure of decisions (as returned by
`decision_function` on some classifiers).
.. deprecated:: 1.5
`probas_pred` is deprecated and will be removed in 1.7. Use
`y_score` instead.
Returns
-------
precision : ndarray of shape (n_thresholds + 1,)
Expand Down Expand Up @@ -962,8 +981,26 @@ def precision_recall_curve(
>>> thresholds
array([0.1 , 0.35, 0.4 , 0.8 ])
"""
# TODO(1.7): remove in 1.7 and reset y_score to be required
# Note: validate params will raise an error if probas_pred is not array-like,
# or "deprecated"
if y_score is not None and not isinstance(probas_pred, str):
raise ValueError(
"`probas_pred` and `y_score` cannot be both specified. Please use `y_score`"
" only as `probas_pred` is deprecated in v1.5 and will be removed in v1.7."
)
if y_score is None:
warnings.warn(
(
"probas_pred was deprecated in version 1.5 and will be removed in 1.7."
"Please use ``y_score`` instead."
),
FutureWarning,
)
y_score = probas_pred

fps, tps, thresholds = _binary_clf_curve(
y_true, probas_pred, pos_label=pos_label, sample_weight=sample_weight
y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
)

if drop_intermediate and len(fps) > 2:
Expand Down
29 changes: 26 additions & 3 deletions sklearn/metrics/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,16 @@ def make_prediction(dataset=None, binary=False):

# run classifier, get class probabilities and label predictions
clf = svm.SVC(kernel="linear", probability=True, random_state=0)
probas_pred = clf.fit(X[:half], y[:half]).predict_proba(X[half:])
y_pred_proba = clf.fit(X[:half], y[:half]).predict_proba(X[half:])

if binary:
# only interested in probabilities of the positive case
# XXX: do we really want a special API for the binary case?
probas_pred = probas_pred[:, 1]
y_pred_proba = y_pred_proba[:, 1]

y_pred = clf.predict(X[half:])
y_true = y[half:]
return y_true, y_pred, probas_pred
return y_true, y_pred, y_pred_proba


###############################################################################
Expand Down Expand Up @@ -2864,3 +2864,26 @@ def test_classification_metric_division_by_zero_nan_validaton(scoring):
X, y = datasets.make_classification(random_state=0)
classifier = DecisionTreeClassifier(max_depth=3, random_state=0).fit(X, y)
cross_val_score(classifier, X, y, scoring=scoring, n_jobs=2, error_score="raise")


# TODO(1.7): remove
def test_brier_score_loss_deprecation_warning():
"""Check the message for future deprecation."""
# Check brier_score_loss function
y_true = np.array([0, 1, 1, 0, 1, 1])
y_pred = np.array([0.1, 0.8, 0.9, 0.3, 1.0, 0.95])

warn_msg = "y_prob was deprecated in version 1.5"
with pytest.warns(FutureWarning, match=warn_msg):
brier_score_loss(
y_true,
y_prob=y_pred,
)

error_msg = "`y_prob` and `y_proba` cannot be both specified"
with pytest.raises(ValueError, match=error_msg):
brier_score_loss(
y_true,
y_prob=y_pred,
y_proba=y_pred,
)
22 changes: 22 additions & 0 deletions sklearn/metrics/tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2242,3 +2242,25 @@ def test_roc_curve_with_probablity_estimates(global_random_seed):
y_score = rng.rand(10)
_, _, thresholds = roc_curve(y_true, y_score)
assert np.isinf(thresholds[0])


# TODO(1.7): remove
def test_precision_recall_curve_deprecation_warning():
"""Check the message for future deprecation."""
# Check precision_recall_curve function
y_true, _, y_score = make_prediction(binary=True)

warn_msg = "probas_pred was deprecated in version 1.5"
with pytest.warns(FutureWarning, match=warn_msg):
precision_recall_curve(
y_true,
probas_pred=y_score,
)

error_msg = "`probas_pred` and `y_score` cannot be both specified"
with pytest.raises(ValueError, match=error_msg):
precision_recall_curve(
y_true,
probas_pred=y_score,
y_score=y_score,
)