From 5a1123a52d1a73695d2c09f94461295356deb67f Mon Sep 17 00:00:00 2001 From: "wishyut.pitawanik" Date: Thu, 2 Mar 2023 23:14:45 +0100 Subject: [PATCH 1/2] added parameter validation for metrics.coverage_error --- sklearn/metrics/_ranking.py | 11 +++++++++-- sklearn/tests/test_public_functions.py | 1 + 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 7ec583177328c..4b8202d60d53e 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -1165,6 +1165,13 @@ def label_ranking_average_precision_score(y_true, y_score, *, sample_weight=None return out +@validate_params( + { + "y_true": ["array-like"], + "y_score": ["array-like"], + "sample_weight": ["array-like", None], + } +) def coverage_error(y_true, y_score, *, sample_weight=None): """Coverage error measure. @@ -1183,10 +1190,10 @@ def coverage_error(y_true, y_score, *, sample_weight=None): Parameters ---------- - y_true : ndarray of shape (n_samples, n_labels) + y_true : array-like of shape (n_samples, n_labels) True binary labels in binary indicator format. - y_score : ndarray of shape (n_samples, n_labels) + y_score : array-like of shape (n_samples, n_labels) Target scores, can either be probability estimates of the positive class, confidence values, or non-thresholded measure of decisions (as returned by "decision_function" on some classifiers). diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 8729ce1f0869e..12bd3d367854e 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -121,6 +121,7 @@ def _check_function_param_validation( "sklearn.metrics.balanced_accuracy_score", "sklearn.metrics.cluster.contingency_matrix", "sklearn.metrics.cohen_kappa_score", + "sklearn.metrics.coverage_error", "sklearn.metrics.confusion_matrix", "sklearn.metrics.d2_pinball_score", "sklearn.metrics.det_curve", From cd8fb32920e67e53357a9de0addd92405efe11b3 Mon Sep 17 00:00:00 2001 From: "wishyut.pitawanik" Date: Thu, 2 Mar 2023 23:19:56 +0100 Subject: [PATCH 2/2] fixed alphabetical order --- sklearn/tests/test_public_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 12bd3d367854e..836139b85b341 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -121,8 +121,8 @@ def _check_function_param_validation( "sklearn.metrics.balanced_accuracy_score", "sklearn.metrics.cluster.contingency_matrix", "sklearn.metrics.cohen_kappa_score", - "sklearn.metrics.coverage_error", "sklearn.metrics.confusion_matrix", + "sklearn.metrics.coverage_error", "sklearn.metrics.d2_pinball_score", "sklearn.metrics.det_curve", "sklearn.metrics.f1_score",