From fe4716b48c961b5e530649a1227094dc8689622c Mon Sep 17 00:00:00 2001 From: Mikhail Iljin Date: Sun, 4 Dec 2022 00:20:48 +0000 Subject: [PATCH 1/4] Add parameter validation for metrics.roc_curve --- sklearn/metrics/_ranking.py | 10 ++++++++++ sklearn/tests/test_public_functions.py | 1 + 2 files changed, 11 insertions(+) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 4b5451d768e9e..49a861cf6f304 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -21,6 +21,7 @@ import warnings from functools import partial +from numbers import Integral import numpy as np from scipy.sparse import csr_matrix, issparse @@ -903,6 +904,15 @@ def precision_recall_curve(y_true, probas_pred, *, pos_label=None, sample_weight return np.hstack((precision[sl], 1)), np.hstack((recall[sl], 0)), thresholds[sl] +@validate_params( + { + "y_true": ["array-like", "sparse matrix"], + "y_score": ["array-like"], + "pos_label": [Integral, str, None], + "sample_weight": ["array-like", None], + "drop_intermediate": ["boolean"] + } +) def roc_curve( y_true, y_score, *, pos_label=None, sample_weight=None, drop_intermediate=True ): diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 62f9a14a59614..a27b6c4c2a4aa 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -17,6 +17,7 @@ "sklearn.metrics.auc", "sklearn.metrics.mean_absolute_error", "sklearn.metrics.zero_one_loss", + "sklearn.metrics.roc_curve", "sklearn.svm.l1_min_c", ] From 935dbcd7b3d9a3d3a9c5cb344fcbe5db69e2e7d1 Mon Sep 17 00:00:00 2001 From: Mikhail Iljin Date: Mon, 5 Dec 2022 14:49:54 +0000 Subject: [PATCH 2/4] fix styling issue --- sklearn/metrics/_ranking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 49a861cf6f304..28f491d54a67b 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -910,7 +910,7 @@ def precision_recall_curve(y_true, probas_pred, *, pos_label=None, sample_weight "y_score": ["array-like"], "pos_label": [Integral, str, None], "sample_weight": ["array-like", None], - "drop_intermediate": ["boolean"] + "drop_intermediate": ["boolean"], } ) def roc_curve( From 4344c491fbaaa6ffc126ea4746cd5e6a5c994b34 Mon Sep 17 00:00:00 2001 From: Mikhail Iljin Date: Tue, 6 Dec 2022 18:47:32 +0000 Subject: [PATCH 3/4] changed types in docstring, removed wrong type "sparse matrix" from validation --- sklearn/metrics/_ranking.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 28f491d54a67b..72553a52c539a 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -906,7 +906,7 @@ def precision_recall_curve(y_true, probas_pred, *, pos_label=None, sample_weight @validate_params( { - "y_true": ["array-like", "sparse matrix"], + "y_true": ["array-like"], "y_score": ["array-like"], "pos_label": [Integral, str, None], "sample_weight": ["array-like", None], @@ -924,11 +924,11 @@ def roc_curve( Parameters ---------- - y_true : ndarray of shape (n_samples,) + y_true : array-like of shape (n_samples,) True binary labels. If labels are not either {-1, 1} or {0, 1}, then pos_label should be explicitly given. - y_score : ndarray of shape (n_samples,) + y_score : array-like of shape (n_samples,) Target scores, can either be probability estimates of the positive class, confidence values, or non-thresholded measure of decisions (as returned by "decision_function" on some classifiers). From 1e670279e371be0a01336f513d01c055c79ba143 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= <34657725+jeremiedbb@users.noreply.github.com> Date: Tue, 6 Dec 2022 23:26:45 +0100 Subject: [PATCH 4/4] Update sklearn/tests/test_public_functions.py --- sklearn/tests/test_public_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index dd63d2033b930..ff42011427b83 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -101,8 +101,8 @@ def _check_function_param_validation( "sklearn.metrics.accuracy_score", "sklearn.metrics.auc", "sklearn.metrics.mean_absolute_error", - "sklearn.metrics.zero_one_loss", "sklearn.metrics.roc_curve", + "sklearn.metrics.zero_one_loss", "sklearn.model_selection.train_test_split", "sklearn.svm.l1_min_c", ]