diff --git a/sklearn/cluster/_affinity_propagation.py b/sklearn/cluster/_affinity_propagation.py index 180e37996aa07..586b6c2c905a4 100644 --- a/sklearn/cluster/_affinity_propagation.py +++ b/sklearn/cluster/_affinity_propagation.py @@ -12,8 +12,8 @@ from ..exceptions import ConvergenceWarning from ..base import BaseEstimator, ClusterMixin -from ..utils import as_float_array, check_random_state -from ..utils._param_validation import Interval, StrOptions +from ..utils import check_random_state +from ..utils._param_validation import Interval, StrOptions, validate_params from ..utils.validation import check_is_fitted from ..metrics import euclidean_distances from ..metrics import pairwise_distances_argmin @@ -178,6 +178,12 @@ def _affinity_propagation( # Public API +@validate_params( + { + "S": ["array-like"], + "return_n_iter": ["boolean"], + } +) def affinity_propagation( S, *, @@ -269,13 +275,11 @@ def affinity_propagation( Brendan J. Frey and Delbert Dueck, "Clustering by Passing Messages Between Data Points", Science Feb. 2007 """ - S = as_float_array(S, copy=copy) - estimator = AffinityPropagation( damping=damping, max_iter=max_iter, convergence_iter=convergence_iter, - copy=False, + copy=copy, preference=preference, affinity="precomputed", verbose=verbose, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index cef75b9be9d4b..6dad08db1a5c3 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -131,6 +131,7 @@ def test_function_param_validation(func_module): PARAM_VALIDATION_CLASS_WRAPPER_LIST = [ ("sklearn.decomposition.non_negative_factorization", "sklearn.decomposition.NMF"), + ("sklearn.cluster.affinity_propagation", "sklearn.cluster.AffinityPropagation"), ("sklearn.covariance.ledoit_wolf", "sklearn.covariance.LedoitWolf"), ("sklearn.covariance.oas", "sklearn.covariance.OAS"), ]