From 0b6834d8d5304d80f7a67d7ab662eb630a055d86 Mon Sep 17 00:00:00 2001 From: dlitsidis Date: Thu, 29 Dec 2022 14:24:54 +0200 Subject: [PATCH 1/9] param validation for multilabel_confusion_matrix --- sklearn/metrics/_classification.py | 9 +++++++++ sklearn/tests/test_public_functions.py | 1 + 2 files changed, 10 insertions(+) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 947bc814fe92b..9d1fd713f61c5 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -392,6 +392,15 @@ def confusion_matrix( return cm +@validate_params( + { + "y_true": ["array-like", "sparse matrix"], + "y_pred": ["array-like", "sparse matrix"], + "sample_weight": ["array-like", None], + "labels": ["array-like", None], + "samplewise": ["boolean"], + } +) def multilabel_confusion_matrix( y_true, y_pred, *, sample_weight=None, labels=None, samplewise=False ): diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 2f031f04a81bf..96de09cb3301c 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -109,6 +109,7 @@ def _check_function_param_validation( "sklearn.metrics.confusion_matrix", "sklearn.metrics.mean_absolute_error", "sklearn.metrics.mean_tweedie_deviance", + "sklearn.metrics.multilabel_confusion_matrix" "sklearn.metrics.mutual_info_score", "sklearn.metrics.r2_score", "sklearn.metrics.roc_curve", From 74bf0bd13555fb1111abe0fd0139d583483914b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= <34657725+jeremiedbb@users.noreply.github.com> Date: Thu, 29 Dec 2022 14:13:11 +0100 Subject: [PATCH 2/9] Update sklearn/tests/test_public_functions.py --- sklearn/tests/test_public_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 96de09cb3301c..6584632cf0779 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -109,7 +109,7 @@ def _check_function_param_validation( "sklearn.metrics.confusion_matrix", "sklearn.metrics.mean_absolute_error", "sklearn.metrics.mean_tweedie_deviance", - "sklearn.metrics.multilabel_confusion_matrix" + "sklearn.metrics.multilabel_confusion_matrix", "sklearn.metrics.mutual_info_score", "sklearn.metrics.r2_score", "sklearn.metrics.roc_curve", From 482434bdd128fb56e09cbe6478ed506fcb64a799 Mon Sep 17 00:00:00 2001 From: dlitsidis Date: Fri, 30 Dec 2022 12:39:14 +0200 Subject: [PATCH 3/9] param validation for det_curve --- sklearn/metrics/_ranking.py | 12 ++++++++++-- sklearn/tests/test_public_functions.py | 1 + 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 94f3b8b4bd3a0..44bbb0c79aa28 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -21,7 +21,7 @@ import warnings from functools import partial -from numbers import Real +from numbers import Real, Integral import numpy as np from scipy.sparse import csr_matrix, issparse @@ -34,7 +34,7 @@ from ..utils.multiclass import type_of_target from ..utils.extmath import stable_cumsum from ..utils.sparsefuncs import count_nonzero -from ..utils._param_validation import validate_params +from ..utils._param_validation import Interval, validate_params from ..exceptions import UndefinedMetricWarning from ..preprocessing import label_binarize from ..utils._encode import _encode, _unique @@ -239,6 +239,14 @@ def _binary_uninterpolated_average_precision( ) +@validate_params( + { + "y_true": ["array-like"], + "y_score": ["array-like"], + "pos_label": ["array-like", None], + "sample_weight": [Interval(Integral, 1, None, closed="left"), None], + } +) def det_curve(y_true, y_score, pos_label=None, sample_weight=None): """Compute error rates for different probability thresholds. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 6584632cf0779..84eeb53d3726d 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -107,6 +107,7 @@ def _check_function_param_validation( "sklearn.metrics.auc", "sklearn.metrics.cohen_kappa_score", "sklearn.metrics.confusion_matrix", + "sklearn.metrics.det_curve" "sklearn.metrics.mean_absolute_error", "sklearn.metrics.mean_tweedie_deviance", "sklearn.metrics.multilabel_confusion_matrix", From bd411e9292737b3d43937e6c718ffde91df2f856 Mon Sep 17 00:00:00 2001 From: dlitsidis Date: Fri, 30 Dec 2022 20:57:41 +0200 Subject: [PATCH 4/9] update for det_curve --- sklearn/metrics/_ranking.py | 4 ++-- sklearn/tests/test_public_functions.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 44bbb0c79aa28..f1bd32bc9ed85 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -243,8 +243,8 @@ def _binary_uninterpolated_average_precision( { "y_true": ["array-like"], "y_score": ["array-like"], - "pos_label": ["array-like", None], - "sample_weight": [Interval(Integral, 1, None, closed="left"), None], + "pos_label": [Interval(Integral, str, None, closed="left"), None], + "sample_weight": ["array-like", None], } ) def det_curve(y_true, y_score, pos_label=None, sample_weight=None): diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 84eeb53d3726d..286df8c0cd76b 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -107,7 +107,7 @@ def _check_function_param_validation( "sklearn.metrics.auc", "sklearn.metrics.cohen_kappa_score", "sklearn.metrics.confusion_matrix", - "sklearn.metrics.det_curve" + "sklearn.metrics.det_curve", "sklearn.metrics.mean_absolute_error", "sklearn.metrics.mean_tweedie_deviance", "sklearn.metrics.multilabel_confusion_matrix", From 6361a85524dd3a63fbc37850df216e6fc21e41ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= <34657725+jeremiedbb@users.noreply.github.com> Date: Sat, 31 Dec 2022 15:18:12 +0100 Subject: [PATCH 5/9] Update _ranking.py --- sklearn/metrics/_ranking.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index f1bd32bc9ed85..ebe477a7e2de7 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -34,7 +34,7 @@ from ..utils.multiclass import type_of_target from ..utils.extmath import stable_cumsum from ..utils.sparsefuncs import count_nonzero -from ..utils._param_validation import Interval, validate_params +from ..utils._param_validation import validate_params from ..exceptions import UndefinedMetricWarning from ..preprocessing import label_binarize from ..utils._encode import _encode, _unique @@ -243,7 +243,7 @@ def _binary_uninterpolated_average_precision( { "y_true": ["array-like"], "y_score": ["array-like"], - "pos_label": [Interval(Integral, str, None, closed="left"), None], + "pos_label": [Integral, str, None], "sample_weight": ["array-like", None], } ) From a5d94d234de3bb73ad262ca707a43b4161b7fe3c Mon Sep 17 00:00:00 2001 From: dlitsidis Date: Sun, 1 Jan 2023 15:16:34 +0200 Subject: [PATCH 6/9] param validation for fetch_california_housing --- sklearn/datasets/_california_housing.py | 11 +++++++++++ sklearn/tests/test_public_functions.py | 1 + 2 files changed, 12 insertions(+) diff --git a/sklearn/datasets/_california_housing.py b/sklearn/datasets/_california_housing.py index f3f7d0e57c502..0befba91adb1c 100644 --- a/sklearn/datasets/_california_housing.py +++ b/sklearn/datasets/_california_housing.py @@ -37,10 +37,13 @@ from ._base import RemoteFileMetadata from ._base import load_descr from ..utils import Bunch +from ..utils._param_validation import validate_params # The original data can be found at: # https://www.dcc.fc.up.pt/~ltorgo/Regression/cal_housing.tgz + + ARCHIVE = RemoteFileMetadata( filename="cal_housing.tgz", url="https://ndownloader.figshare.com/files/5976036", @@ -50,6 +53,14 @@ logger = logging.getLogger(__name__) +@validate_params( + { + "data_home": [str, None], + "download_if_missing": ["boolean"], + "return_X_y": ["boolean"], + "as_frame": ["boolean"], + } +) def fetch_california_housing( *, data_home=None, download_if_missing=True, return_X_y=False, as_frame=False ): diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index b5c39f3ecc7a0..57452912cec78 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -99,6 +99,7 @@ def _check_function_param_validation( "sklearn.cluster.estimate_bandwidth", "sklearn.cluster.kmeans_plusplus", "sklearn.covariance.empirical_covariance", + "sklearn.datasets.fetch_california_housing", "sklearn.datasets.make_sparse_coded_signal", "sklearn.decomposition.sparse_encode", "sklearn.feature_extraction.grid_to_graph", From 564cdae490d1a123144052db0bd75fbddd1adcb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= <34657725+jeremiedbb@users.noreply.github.com> Date: Mon, 2 Jan 2023 16:19:51 +0100 Subject: [PATCH 7/9] unrelated --- sklearn/datasets/_california_housing.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sklearn/datasets/_california_housing.py b/sklearn/datasets/_california_housing.py index 0befba91adb1c..98e3bfcb37a98 100644 --- a/sklearn/datasets/_california_housing.py +++ b/sklearn/datasets/_california_housing.py @@ -42,8 +42,6 @@ # The original data can be found at: # https://www.dcc.fc.up.pt/~ltorgo/Regression/cal_housing.tgz - - ARCHIVE = RemoteFileMetadata( filename="cal_housing.tgz", url="https://ndownloader.figshare.com/files/5976036", From 4985c30d4d19a32edd2df30da7473a614d58cf41 Mon Sep 17 00:00:00 2001 From: dlitsidis Date: Fri, 6 Jan 2023 14:01:21 +0200 Subject: [PATCH 8/9] param validation for average_precision_score --- sklearn/metrics/_ranking.py | 11 ++++++++++- sklearn/tests/test_public_functions.py | 1 + 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index ebe477a7e2de7..83f1c19b1ac4b 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -34,7 +34,7 @@ from ..utils.multiclass import type_of_target from ..utils.extmath import stable_cumsum from ..utils.sparsefuncs import count_nonzero -from ..utils._param_validation import validate_params +from ..utils._param_validation import validate_params, StrOptions from ..exceptions import UndefinedMetricWarning from ..preprocessing import label_binarize from ..utils._encode import _encode, _unique @@ -112,6 +112,15 @@ def auc(x, y): return area +@validate_params( + { + "y_true": ["array-like"], + "y_score": ["array-like"], + "average": [StrOptions({"micro", "samples", "weighted", "macro"})], + "pos_label": [Integral, str, 1], + "sample_weight": ["array-like", None], + } +) def average_precision_score( y_true, y_score, *, average="macro", pos_label=1, sample_weight=None ): diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 33760e1644e24..3f07a47fd4d1e 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -108,6 +108,7 @@ def _check_function_param_validation( "sklearn.feature_extraction.image.extract_patches_2d", "sklearn.metrics.accuracy_score", "sklearn.metrics.auc", + "sklearn.metrics.average_precision_score", "sklearn.metrics.cohen_kappa_score", "sklearn.metrics.confusion_matrix", "sklearn.metrics.det_curve", From e400c6ebaf9281ae491aa48741156a3d1ef81dc7 Mon Sep 17 00:00:00 2001 From: dlitsidis Date: Fri, 6 Jan 2023 16:49:43 +0200 Subject: [PATCH 9/9] update for average_precision_score --- sklearn/metrics/_ranking.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 83f1c19b1ac4b..4ca4aa15257c5 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -116,8 +116,8 @@ def auc(x, y): { "y_true": ["array-like"], "y_score": ["array-like"], - "average": [StrOptions({"micro", "samples", "weighted", "macro"})], - "pos_label": [Integral, str, 1], + "average": [StrOptions({"micro", "samples", "weighted", "macro"}), None], + "pos_label": [Real, str, "boolean"], "sample_weight": ["array-like", None], } ) @@ -146,10 +146,10 @@ def average_precision_score( Parameters ---------- - y_true : ndarray of shape (n_samples,) or (n_samples, n_classes) + y_true : array-like of shape (n_samples,) or (n_samples, n_classes) True binary labels or binary label indicators. - y_score : ndarray of shape (n_samples,) or (n_samples, n_classes) + y_score : array-like of shape (n_samples,) or (n_samples, n_classes) Target scores, can either be probability estimates of the positive class, confidence values, or non-thresholded measure of decisions (as returned by :term:`decision_function` on some classifiers).