diff --git a/sklearn/cluster/_affinity_propagation.py b/sklearn/cluster/_affinity_propagation.py index 07443d65f0ec4..e660de9f020cb 100644 --- a/sklearn/cluster/_affinity_propagation.py +++ b/sklearn/cluster/_affinity_propagation.py @@ -13,7 +13,9 @@ from ..exceptions import ConvergenceWarning from ..base import BaseEstimator, ClusterMixin from ..utils import as_float_array, check_random_state -from ..utils._param_validation import Interval, StrOptions +from ..utils._param_validation import Interval +from ..utils._param_validation import StrOptions +from ..utils._param_validation import validate_params from ..utils.validation import check_is_fitted from ..metrics import euclidean_distances from ..metrics import pairwise_distances_argmin @@ -34,6 +36,23 @@ def all_equal_similarities(): return all_equal_preferences() and all_equal_similarities() +@validate_params( + { + "S": ["array-like"], + "preference": [ + "array-like", + Interval(Real, None, None, closed="neither"), + None, + ], + "convergence_iter": [Interval(Integral, 1, None, closed="left")], + "max_iter": [Interval(Integral, 1, None, closed="left")], + "damping": [Interval(Real, 0.5, 1.0, closed="both")], + "copy": ["boolean"], + "verbose": ["boolean"], + "return_n_iter": ["boolean"], + "random_state": ["random_state"], + } +) def affinity_propagation( S, *, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index d4e645c052dab..2babf7b81de3c 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -10,6 +10,7 @@ PARAM_VALIDATION_FUNCTION_LIST = [ "sklearn.cluster.kmeans_plusplus", + "sklearn.cluster.affinity_propagation", ]