diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 3d01d5eeaf12d..ef31523a478f4 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -29,6 +29,7 @@ from ..utils._mask import _get_mask from ..utils.parallel import delayed, Parallel from ..utils.fixes import sp_version, parse_version +from ..utils._param_validation import validate_params from ._pairwise_distances_reduction import ArgKmin from ._pairwise_fast import _chi2_kernel_fast, _sparse_manhattan @@ -1403,6 +1404,7 @@ def cosine_similarity(X, Y=None, dense_output=True): return K +@validate_params({"X": ["array-like"], "Y": ["array-like", None]}) def additive_chi2_kernel(X, Y=None): """Compute the additive chi-squared kernel between observations in X and Y. @@ -1423,7 +1425,7 @@ def additive_chi2_kernel(X, Y=None): X : array-like of shape (n_samples_X, n_features) A feature array. - Y : ndarray of shape (n_samples_Y, n_features), default=None + Y : array-like of shape (n_samples_Y, n_features), default=None An optional second feature array. If `None`, uses `Y=X`. Returns @@ -1451,8 +1453,6 @@ def additive_chi2_kernel(X, Y=None): International Journal of Computer Vision 2007 https://hal.archives-ouvertes.fr/hal-00171412/document """ - if issparse(X) or issparse(Y): - raise ValueError("additive_chi2 does not support sparse matrices.") X, Y = check_pairwise_arrays(X, Y) if (X < 0).any(): raise ValueError("X contains negative values.") diff --git a/sklearn/metrics/tests/test_pairwise.py b/sklearn/metrics/tests/test_pairwise.py index 3624983c4c481..c1ee728f6e71e 100644 --- a/sklearn/metrics/tests/test_pairwise.py +++ b/sklearn/metrics/tests/test_pairwise.py @@ -390,8 +390,6 @@ def test_pairwise_kernels(metric): Y_sparse = csr_matrix(Y) if metric in ["chi2", "additive_chi2"]: # these don't support sparse matrices yet - with pytest.raises(ValueError): - pairwise_kernels(X_sparse, Y=Y_sparse, metric=metric) return K1 = pairwise_kernels(X_sparse, Y=Y_sparse, metric=metric) assert_allclose(K1, K2) @@ -1231,12 +1229,6 @@ def test_chi_square_kernel(): with pytest.raises(ValueError): chi2_kernel([[0, 1]], [[0.2, 0.2, 0.6]]) - # sparse matrices - with pytest.raises(ValueError): - chi2_kernel(csr_matrix(X), csr_matrix(Y)) - with pytest.raises(ValueError): - additive_chi2_kernel(csr_matrix(X), csr_matrix(Y)) - @pytest.mark.parametrize( "kernel", diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index acc44ce60c755..a70ca2fa1c046 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -122,6 +122,7 @@ def _check_function_param_validation( "sklearn.metrics.median_absolute_error", "sklearn.metrics.multilabel_confusion_matrix", "sklearn.metrics.mutual_info_score", + "sklearn.metrics.pairwise.additive_chi2_kernel", "sklearn.metrics.r2_score", "sklearn.metrics.roc_curve", "sklearn.metrics.zero_one_loss",