-
-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
Description
Describe the bug
MeanShift.predict should work with a sparse X according to the documentation, but the code throws an exception when a sparse matrix is passed. Apologies if this is a non-issue or the issue has been fixed.
This is the same issue as in AffinityPropagation.predict (issue #20049) that was fixed by PR #20117.
Steps/Code to Reproduce
MeanShift.predict with a sparse X throws an exception:
>>> from sklearn.cluster import MeanShift
>>> import numpy as np
>>> X = np.array([[1, 1], [2, 1], [1, 0], [4, 7], [3, 5], [3, 6]])
>>> clustering = MeanShift(bandwidth=2).fit(X)
>>> clustering.labels_
array([1, 1, 1, 0, 0, 0])
>>> import scipy.sparse
>>> clustering.predict(scipy.sparse.csr_matrix([[0, 0], [5, 5]]))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/cluster/_mean_shift.py", line 466, in predict
X = self._validate_data(X, reset=False) #,accept_sparse=['csr', 'csc', 'coo']) #ANA
File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/base.py", line 421, in _validate_data
X = check_array(X, **check_params)
File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/utils/validation.py", line 63, in inner_f
return f(*args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/utils/validation.py", line 593, in check_array
array = _ensure_sparse_format(array, accept_sparse=accept_sparse,
File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/utils/validation.py", line 360, in _ensure_sparse_format
raise TypeError('A sparse matrix was passed, but dense '
TypeError: A sparse matrix was passed, but dense data is required. Use X.toarray() to convert to a dense numpy array.
Expected Results
If we change sklearn/cluster/_mean_shift.py#L466 into
X = self._validate_data(X, accept_sparse = 'csr', reset=False)
or
X = self._validate_data(X, accept_sparse = ['csr', 'csc', 'coo'], reset=False)
then it works:
>>> from sklearn.cluster import MeanShift
>>> import numpy as np
>>> X = np.array([[1, 1], [2, 1], [1, 0], [4, 7], [3, 5], [3, 6]])
>>> clustering = MeanShift(bandwidth=2).fit(X)
>>> clustering.labels_
array([1, 1, 1, 0, 0, 0])
>>> import scipy.sparse
>>> clustering.predict(scipy.sparse.csr_matrix([[0, 0], [5, 5]]))
array([1, 0])
Actual Results
The exception shown above.
Versions
System:
python: 3.9.2 (v3.9.2:1a79785e3e, Feb 19 2021, 09:06:10) [Clang 6.0 (clang-600.0.57)]
executable: /Library/Frameworks/Python.framework/Versions/3.9/bin/python3.9
machine: macOS-10.11.6-x86_64-i386-64bit
Python dependencies:
pip: 21.2.3
setuptools: 49.2.1
sklearn: 0.24.1
numpy: 1.19.5
scipy: 1.6.1
Cython: None
pandas: 1.2.4
matplotlib: None
joblib: 1.0.1
threadpoolctl: 2.1.0
Built with OpenMP: False