diff --git a/sklearn/preprocessing/_label.py b/sklearn/preprocessing/_label.py index b9da2254ad60f..3d1bf8f3064ac 100644 --- a/sklearn/preprocessing/_label.py +++ b/sklearn/preprocessing/_label.py @@ -18,6 +18,7 @@ from ..base import BaseEstimator, TransformerMixin from ..utils.sparsefuncs import min_max_axis +from ..utils._param_validation import Interval, validate_params from ..utils import column_or_1d from ..utils.validation import _num_samples, check_array, check_is_fitted from ..utils.multiclass import unique_labels @@ -422,6 +423,15 @@ def _more_tags(self): return {"X_types": ["1dlabels"]} +@validate_params( + { + "y": ["array-like"], + "classes": ["array-like"], + "neg_label": [Interval(Integral, None, None, closed="neither")], + "pos_label": [Interval(Integral, None, None, closed="neither")], + "sparse_output": ["boolean"], + } +) def label_binarize(y, *, classes, neg_label=0, pos_label=1, sparse_output=False): """Binarize labels in a one-vs-all fashion. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 96490d738e69d..b39ea6a0e41d8 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -224,6 +224,7 @@ def _check_function_param_validation( "sklearn.model_selection.train_test_split", "sklearn.preprocessing.add_dummy_feature", "sklearn.preprocessing.binarize", + "sklearn.preprocessing.label_binarize", "sklearn.preprocessing.maxabs_scale", "sklearn.preprocessing.scale", "sklearn.random_projection.johnson_lindenstrauss_min_dim",