Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit af5b6a1

Browse files
Micky774huntzhanelefantinerrrrsjeremiedbbthomasjpfan
authored
ENH Add sparse input support to OPTICS (#22965)
Co-authored-by: huntzhan <[email protected]> Co-authored-by: Clickedbigfoot <[email protected]> Co-authored-by: Jérémie du Boisberranger <[email protected]> Co-authored-by: Thomas J. Fan <[email protected]>
1 parent b1cc92c commit af5b6a1

File tree

3 files changed

+61
-15
lines changed

3 files changed

+61
-15
lines changed

doc/whats_new/v1.2.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ Changelog
3333
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
3434
where 123456 is the *pull request* number, not the issue number.
3535
36+
:mod:`sklearn.cluster`
37+
......................
38+
39+
- |Enhancement| The `predict` and `fit_predict` methods of :class:`cluster.OPTICS` now
40+
accept sparse data type for input data. :pr:`14736` by :user:`Hunt Zhan <huntzhan>`,
41+
:pr:`20802` by :user:`Brandon Pokorny <Clickedbigfoot>`,
42+
and :pr:`22965` by :user:`Meekail Zain <micky774>`.
43+
3644
Code and Documentation Contributors
3745
-----------------------------------
3846

sklearn/cluster/_optics.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ..neighbors import NearestNeighbors
2121
from ..base import BaseEstimator, ClusterMixin
2222
from ..metrics import pairwise_distances
23+
from scipy.sparse import issparse, SparseEfficiencyWarning
2324

2425

2526
class OPTICS(ClusterMixin, BaseEstimator):
@@ -81,6 +82,7 @@ class OPTICS(ClusterMixin, BaseEstimator):
8182
'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean',
8283
'yule']
8384
85+
Sparse matrices are only supported by scikit-learn metrics.
8486
See the documentation for scipy.spatial.distance for details on these
8587
metrics.
8688
@@ -263,10 +265,11 @@ def fit(self, X, y=None):
263265
264266
Parameters
265267
----------
266-
X : ndarray of shape (n_samples, n_features), or \
268+
X : {ndarray, sparse matrix} of shape (n_samples, n_features), or \
267269
(n_samples, n_samples) if metric=’precomputed’
268270
A feature array, or array of distances between samples if
269-
metric='precomputed'.
271+
metric='precomputed'. If a sparse matrix is provided, it will be
272+
converted into CSR format.
270273
271274
y : Ignored
272275
Not used, present for API consistency by convention.
@@ -285,7 +288,13 @@ def fit(self, X, y=None):
285288
)
286289
warnings.warn(msg, DataConversionWarning)
287290

288-
X = self._validate_data(X, dtype=dtype)
291+
X = self._validate_data(X, dtype=dtype, accept_sparse="csr")
292+
if self.metric == "precomputed" and issparse(X):
293+
with warnings.catch_warnings():
294+
warnings.simplefilter("ignore", SparseEfficiencyWarning)
295+
# Set each diagonal to an explicit value so each point is its
296+
# own neighbor
297+
X.setdiag(X.diagonal())
289298
memory = check_memory(self.memory)
290299

291300
if self.cluster_method not in ["dbscan", "xi"]:
@@ -603,15 +612,16 @@ def _set_reach_dist(
603612
# Only compute distances to unprocessed neighbors:
604613
if metric == "precomputed":
605614
dists = X[point_index, unproc]
615+
if issparse(dists):
616+
dists.sort_indices()
617+
dists = dists.data
606618
else:
607619
_params = dict() if metric_params is None else metric_params.copy()
608620
if metric == "minkowski" and "p" not in _params:
609621
# the same logic as neighbors, p is ignored if explicitly set
610622
# in the dict params
611623
_params["p"] = p
612-
dists = pairwise_distances(
613-
P, np.take(X, unproc, axis=0), metric=metric, n_jobs=None, **_params
614-
).ravel()
624+
dists = pairwise_distances(P, X[unproc], metric, n_jobs=None, **_params).ravel()
615625

616626
rdists = np.maximum(dists, core_distances_[point_index])
617627
np.around(rdists, decimals=np.finfo(rdists.dtype).precision, out=rdists)

sklearn/cluster/tests/test_optics.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# License: BSD 3 clause
44
import numpy as np
55
import pytest
6+
from scipy import sparse
67
import warnings
78

89
from sklearn.datasets import make_blobs
@@ -15,7 +16,7 @@
1516
from sklearn.utils import shuffle
1617
from sklearn.utils._testing import assert_array_equal
1718
from sklearn.utils._testing import assert_allclose
18-
19+
from sklearn.exceptions import EfficiencyWarning
1920
from sklearn.cluster.tests.common import generate_clustered_data
2021

2122

@@ -158,15 +159,19 @@ def test_cluster_hierarchy_(global_dtype):
158159
assert diff / len(X) < 0.05
159160

160161

161-
def test_correct_number_of_clusters():
162+
@pytest.mark.parametrize(
163+
"metric, is_sparse",
164+
[["minkowski", False], ["euclidean", True]],
165+
)
166+
def test_correct_number_of_clusters(metric, is_sparse):
162167
# in 'auto' mode
163168

164169
n_clusters = 3
165170
X = generate_clustered_data(n_clusters=n_clusters)
166171
# Parameters chosen specifically for this task.
167172
# Compute OPTICS
168-
clust = OPTICS(max_eps=5.0 * 6.0, min_samples=4, xi=0.1)
169-
clust.fit(X)
173+
clust = OPTICS(max_eps=5.0 * 6.0, min_samples=4, xi=0.1, metric=metric)
174+
clust.fit(sparse.csr_matrix(X) if is_sparse else X)
170175
# number of clusters, ignoring noise if present
171176
n_clusters_1 = len(set(clust.labels_)) - int(-1 in clust.labels_)
172177
assert n_clusters_1 == n_clusters
@@ -286,18 +291,25 @@ def test_close_extract():
286291

287292
@pytest.mark.parametrize("eps", [0.1, 0.3, 0.5])
288293
@pytest.mark.parametrize("min_samples", [3, 10, 20])
289-
def test_dbscan_optics_parity(eps, min_samples, global_dtype):
294+
@pytest.mark.parametrize(
295+
"metric, is_sparse",
296+
[["minkowski", False], ["euclidean", False], ["euclidean", True]],
297+
)
298+
def test_dbscan_optics_parity(eps, min_samples, metric, is_sparse, global_dtype):
290299
# Test that OPTICS clustering labels are <= 5% difference of DBSCAN
291300

292301
centers = [[1, 1], [-1, -1], [1, -1]]
293302
X, labels_true = make_blobs(
294303
n_samples=750, centers=centers, cluster_std=0.4, random_state=0
295304
)
305+
X = sparse.csr_matrix(X) if is_sparse else X
296306

297307
X = X.astype(global_dtype, copy=False)
298308

299309
# calculate optics with dbscan extract at 0.3 epsilon
300-
op = OPTICS(min_samples=min_samples, cluster_method="dbscan", eps=eps).fit(X)
310+
op = OPTICS(
311+
min_samples=min_samples, cluster_method="dbscan", eps=eps, metric=metric
312+
).fit(X)
301313

302314
# calculate dbscan labels
303315
db = DBSCAN(eps=eps, min_samples=min_samples).fit(X)
@@ -344,7 +356,8 @@ def test_min_cluster_size(min_cluster_size, global_dtype):
344356
assert min(cluster_sizes) >= min_cluster_size
345357
# check behaviour is the same when min_cluster_size is a fraction
346358
clust_frac = OPTICS(
347-
min_samples=9, min_cluster_size=min_cluster_size / redX.shape[0]
359+
min_samples=9,
360+
min_cluster_size=min_cluster_size / redX.shape[0],
348361
)
349362
clust_frac.fit(redX)
350363
assert_array_equal(clust.labels_, clust_frac.labels_)
@@ -356,17 +369,26 @@ def test_min_cluster_size_invalid(min_cluster_size):
356369
with pytest.raises(ValueError, match="must be a positive integer or a "):
357370
clust.fit(X)
358371

372+
clust = OPTICS(min_cluster_size=min_cluster_size, metric="euclidean")
373+
with pytest.raises(ValueError, match="must be a positive integer or a "):
374+
clust.fit(sparse.csr_matrix(X))
375+
359376

360377
def test_min_cluster_size_invalid2():
361378
clust = OPTICS(min_cluster_size=len(X) + 1)
362379
with pytest.raises(ValueError, match="must be no greater than the "):
363380
clust.fit(X)
364381

382+
clust = OPTICS(min_cluster_size=len(X) + 1, metric="euclidean")
383+
with pytest.raises(ValueError, match="must be no greater than the "):
384+
clust.fit(sparse.csr_matrix(X))
385+
365386

366387
def test_processing_order():
367388
# Ensure that we consider all unprocessed points,
368389
# not only direct neighbors. when picking the next point.
369390
Y = [[0], [10], [-10], [25]]
391+
370392
clust = OPTICS(min_samples=3, max_eps=15).fit(Y)
371393
assert_array_equal(clust.reachability_, [np.inf, 10, 10, 15])
372394
assert_array_equal(clust.core_distances_, [10, 15, np.inf, np.inf])
@@ -796,10 +818,16 @@ def test_extract_dbscan(global_dtype):
796818
assert_array_equal(np.sort(np.unique(clust.labels_)), [0, 1, 2, 3])
797819

798820

799-
def test_precomputed_dists(global_dtype):
821+
@pytest.mark.parametrize("is_sparse", [False, True])
822+
def test_precomputed_dists(is_sparse, global_dtype):
800823
redX = X[::2].astype(global_dtype, copy=False)
801824
dists = pairwise_distances(redX, metric="euclidean")
802-
clust1 = OPTICS(min_samples=10, algorithm="brute", metric="precomputed").fit(dists)
825+
dists = sparse.csr_matrix(dists) if is_sparse else dists
826+
with warnings.catch_warnings():
827+
warnings.simplefilter("ignore", EfficiencyWarning)
828+
clust1 = OPTICS(min_samples=10, algorithm="brute", metric="precomputed").fit(
829+
dists
830+
)
803831
clust2 = OPTICS(min_samples=10, algorithm="brute", metric="euclidean").fit(redX)
804832

805833
assert_allclose(clust1.reachability_, clust2.reachability_)

0 commit comments

Comments
 (0)