From 313a8a39b32109930a86b6adb79445c1ac71ac6a Mon Sep 17 00:00:00 2001 From: ROMEEZHOU Date: Tue, 18 Apr 2023 01:42:44 +0800 Subject: [PATCH] MAINT Parameters validation for sklearn.covariance.ledoit_wolf_shrinkage --- sklearn/covariance/_shrunk_covariance.py | 7 +++++++ sklearn/tests/test_public_functions.py | 1 + 2 files changed, 8 insertions(+) diff --git a/sklearn/covariance/_shrunk_covariance.py b/sklearn/covariance/_shrunk_covariance.py index 5cdc9f3d212ad..4bf3d9a490b6b 100644 --- a/sklearn/covariance/_shrunk_covariance.py +++ b/sklearn/covariance/_shrunk_covariance.py @@ -272,6 +272,13 @@ def fit(self, X, y=None): # Ledoit-Wolf estimator +@validate_params( + { + "X": ["array-like"], + "assume_centered": ["boolean"], + "block_size": [Interval(Integral, 1, None, closed="left")], + } +) def ledoit_wolf_shrinkage(X, assume_centered=False, block_size=1000): """Estimate the shrunk Ledoit-Wolf covariance matrix. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index d46ae07821ac2..af27ad98d47c0 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -117,6 +117,7 @@ def _check_function_param_validation( "sklearn.cluster.cluster_optics_xi", "sklearn.cluster.ward_tree", "sklearn.covariance.empirical_covariance", + "sklearn.covariance.ledoit_wolf_shrinkage", "sklearn.covariance.shrunk_covariance", "sklearn.datasets.dump_svmlight_file", "sklearn.datasets.fetch_20newsgroups",