diff --git a/sklearn/covariance/_empirical_covariance.py b/sklearn/covariance/_empirical_covariance.py index e3dd51bb74eb9..2af0aadfb890d 100644 --- a/sklearn/covariance/_empirical_covariance.py +++ b/sklearn/covariance/_empirical_covariance.py @@ -17,6 +17,7 @@ from .. import config_context from ..base import BaseEstimator from ..utils import check_array +from ..utils._param_validation import validate_params from ..utils.extmath import fast_logdet from ..metrics.pairwise import pairwise_distances @@ -48,6 +49,12 @@ def log_likelihood(emp_cov, precision): return log_likelihood_ +@validate_params( + { + "X": ["array-like"], + "assume_centered": ["boolean"], + } +) def empirical_covariance(X, *, assume_centered=False): """Compute the Maximum likelihood covariance estimator. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index ff42011427b83..b4d3d38746dd7 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -96,6 +96,7 @@ def _check_function_param_validation( PARAM_VALIDATION_FUNCTION_LIST = [ "sklearn.cluster.estimate_bandwidth", "sklearn.cluster.kmeans_plusplus", + "sklearn.covariance.empirical_covariance", "sklearn.feature_extraction.grid_to_graph", "sklearn.feature_extraction.img_to_graph", "sklearn.metrics.accuracy_score",