From 0a15321f98a10cbd792a002bbb31b59b4ea10528 Mon Sep 17 00:00:00 2001 From: Shivachauhan17 Date: Tue, 28 Feb 2023 11:41:53 +0530 Subject: [PATCH 1/4] add parameter validation to dump_svmlight_file --- sklearn/datasets/_svmlight_format_io.py | 11 ++++++++++- sklearn/tests/test_public_functions.py | 1 + 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sklearn/datasets/_svmlight_format_io.py b/sklearn/datasets/_svmlight_format_io.py index 2a141e1732ff7..efdbc416922bd 100644 --- a/sklearn/datasets/_svmlight_format_io.py +++ b/sklearn/datasets/_svmlight_format_io.py @@ -25,6 +25,7 @@ from .. import __version__ from ..utils import check_array, IS_PYPY +from ..utils._param_validation import validate_params,StrOptions if not IS_PYPY: from ._svmlight_format_fast import ( @@ -403,7 +404,15 @@ def _dump_svmlight(X, y, f, multilabel, one_based, comment, query_id): y_is_sp, ) - +@validate_params({ + "X":["array-like","sparse matrix"], + "y":["array-like","sparse matrix"], + "f":[str,StrOptions({"file"})], + "zero_based":[bool,True], + "comment":[str,None], + "query_id":["array-like",None], + "multilabel":[bool,False], +}) def dump_svmlight_file( X, y, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index dae1fdb2e6164..d2e60f37f2e57 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -107,6 +107,7 @@ def _check_function_param_validation( "sklearn.datasets.make_classification", "sklearn.datasets.make_friedman1", "sklearn.datasets.make_sparse_coded_signal", + "sklearn.datasets.dump_svmlight_file", "sklearn.decomposition.sparse_encode", "sklearn.feature_extraction.grid_to_graph", "sklearn.feature_extraction.img_to_graph", From a3c81991ed48cba7dcb4f80e14a98c16d4f22d8b Mon Sep 17 00:00:00 2001 From: jeremiedbb Date: Tue, 28 Feb 2023 15:45:28 +0100 Subject: [PATCH 2/4] fix constraints --- sklearn/datasets/_svmlight_format_io.py | 29 ++++++++++++++----------- sklearn/tests/test_public_functions.py | 2 +- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/sklearn/datasets/_svmlight_format_io.py b/sklearn/datasets/_svmlight_format_io.py index efdbc416922bd..b8d9dde2fa6c3 100644 --- a/sklearn/datasets/_svmlight_format_io.py +++ b/sklearn/datasets/_svmlight_format_io.py @@ -25,7 +25,7 @@ from .. import __version__ from ..utils import check_array, IS_PYPY -from ..utils._param_validation import validate_params,StrOptions +from ..utils._param_validation import validate_params, StrOptions, HasMethods if not IS_PYPY: from ._svmlight_format_fast import ( @@ -404,15 +404,18 @@ def _dump_svmlight(X, y, f, multilabel, one_based, comment, query_id): y_is_sp, ) -@validate_params({ - "X":["array-like","sparse matrix"], - "y":["array-like","sparse matrix"], - "f":[str,StrOptions({"file"})], - "zero_based":[bool,True], - "comment":[str,None], - "query_id":["array-like",None], - "multilabel":[bool,False], -}) + +@validate_params( + { + "X": ["array-like", "sparse matrix"], + "y": ["array-like", "sparse matrix"], + "f": [str, HasMethods(["write"])], + "zero_based": [bool], + "comment": [str, bytes, None], + "query_id": ["array-like", None], + "multilabel": [bool], + } +) def dump_svmlight_file( X, y, @@ -437,7 +440,7 @@ def dump_svmlight_file( Training vectors, where `n_samples` is the number of samples and `n_features` is the number of features. - y : {array-like, sparse matrix}, shape = [n_samples (, n_labels)] + y : {array-like, sparse matrix}, shape = (n_samples,) or (n_samples, n_labels) Target values. Class labels must be an integer or float, or array-like objects of integer or float for multilabel classifications. @@ -451,7 +454,7 @@ def dump_svmlight_file( Whether column indices should be written zero-based (True) or one-based (False). - comment : str, default=None + comment : str or bytes, default=None Comment to insert at the top of the file. This should be either a Unicode string, which will be encoded as UTF-8, or an ASCII byte string. @@ -468,7 +471,7 @@ def dump_svmlight_file( https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel.html). .. versionadded:: 0.17 - parameter *multilabel* to support multilabel datasets. + parameter `multilabel` to support multilabel datasets. """ if comment is not None: # Convert comment string to list of lines in UTF-8. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index d2e60f37f2e57..ed288b33319bd 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -102,12 +102,12 @@ def _check_function_param_validation( "sklearn.cluster.ward_tree", "sklearn.covariance.empirical_covariance", "sklearn.covariance.shrunk_covariance", + "sklearn.datasets.dump_svmlight_file", "sklearn.datasets.fetch_california_housing", "sklearn.datasets.fetch_kddcup99", "sklearn.datasets.make_classification", "sklearn.datasets.make_friedman1", "sklearn.datasets.make_sparse_coded_signal", - "sklearn.datasets.dump_svmlight_file", "sklearn.decomposition.sparse_encode", "sklearn.feature_extraction.grid_to_graph", "sklearn.feature_extraction.img_to_graph", From c471fe720704594c82f5aa81ba7f76949608f550 Mon Sep 17 00:00:00 2001 From: jeremiedbb Date: Tue, 28 Feb 2023 15:49:10 +0100 Subject: [PATCH 3/4] boolean --- sklearn/datasets/_svmlight_format_io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/datasets/_svmlight_format_io.py b/sklearn/datasets/_svmlight_format_io.py index b8d9dde2fa6c3..d40eea042d1fc 100644 --- a/sklearn/datasets/_svmlight_format_io.py +++ b/sklearn/datasets/_svmlight_format_io.py @@ -410,10 +410,10 @@ def _dump_svmlight(X, y, f, multilabel, one_based, comment, query_id): "X": ["array-like", "sparse matrix"], "y": ["array-like", "sparse matrix"], "f": [str, HasMethods(["write"])], - "zero_based": [bool], + "zero_based": ["boolean"], "comment": [str, bytes, None], "query_id": ["array-like", None], - "multilabel": [bool], + "multilabel": ["boolean"], } ) def dump_svmlight_file( From 220d684295ea29ecbd09d19a8e8c124ce2894bc0 Mon Sep 17 00:00:00 2001 From: jeremiedbb Date: Tue, 28 Feb 2023 15:51:38 +0100 Subject: [PATCH 4/4] lint --- sklearn/datasets/_svmlight_format_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/datasets/_svmlight_format_io.py b/sklearn/datasets/_svmlight_format_io.py index d40eea042d1fc..991832c23c389 100644 --- a/sklearn/datasets/_svmlight_format_io.py +++ b/sklearn/datasets/_svmlight_format_io.py @@ -25,7 +25,7 @@ from .. import __version__ from ..utils import check_array, IS_PYPY -from ..utils._param_validation import validate_params, StrOptions, HasMethods +from ..utils._param_validation import validate_params, HasMethods if not IS_PYPY: from ._svmlight_format_fast import (