diff --git a/sklearn/decomposition/_fastica.py b/sklearn/decomposition/_fastica.py index 92de875f64ea3..bb0c4ccdd78ea 100644 --- a/sklearn/decomposition/_fastica.py +++ b/sklearn/decomposition/_fastica.py @@ -19,7 +19,7 @@ from ..exceptions import ConvergenceWarning from ..utils import check_array, as_float_array, check_random_state from ..utils.validation import check_is_fitted -from ..utils._param_validation import Hidden, Interval, StrOptions +from ..utils._param_validation import Hidden, Interval, StrOptions, validate_params __all__ = ["fastica", "FastICA"] @@ -154,6 +154,14 @@ def _cube(x, fun_args): return x**3, (3 * x**2).mean(axis=-1) +@validate_params( + { + "X": ["array-like"], + "return_X_mean": ["boolean"], + "compute_sources": ["boolean"], + "return_n_iter": ["boolean"], + } +) def fastica( X, n_components=None, @@ -319,6 +327,7 @@ def my_g(x): whiten_solver=whiten_solver, random_state=random_state, ) + est._validate_params() S = est._fit_transform(X, compute_sources=compute_sources) if est._whiten in ["unit-variance", "arbitrary-variance"]: diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 6dad08db1a5c3..1491576b59cf0 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -134,6 +134,7 @@ def test_function_param_validation(func_module): ("sklearn.cluster.affinity_propagation", "sklearn.cluster.AffinityPropagation"), ("sklearn.covariance.ledoit_wolf", "sklearn.covariance.LedoitWolf"), ("sklearn.covariance.oas", "sklearn.covariance.OAS"), + ("sklearn.decomposition.fastica", "sklearn.decomposition.FastICA"), ]