diff --git a/doc/modules/clustering.rst b/doc/modules/clustering.rst index 9e970478dcf71..c992e95ebdaba 100644 --- a/doc/modules/clustering.rst +++ b/doc/modules/clustering.rst @@ -567,30 +567,24 @@ considers at each step all the possible merges. number of features. It is a dimensionality reduction tool, see :ref:`data_reduction`. -Different linkage type: Ward, complete and average linkage ------------------------------------------------------------ +Different linkage type: Ward, complete, average, and single linkage +-----------------------------------------------------------------)- -:class:`AgglomerativeClustering` supports Ward, average, and complete +:class:`AgglomerativeClustering` supports Ward, single, average, and complete linkage strategies. -.. image:: ../auto_examples/cluster/images/sphx_glr_plot_digits_linkage_001.png - :target: ../auto_examples/cluster/plot_digits_linkage.html +.. image:: ../auto_examples/cluster/images/sphx_glr_plot_linkage_comparison_001.png + :target: ../auto_examples/cluster/plot_linkage_comparison.html :scale: 43 -.. image:: ../auto_examples/cluster/images/sphx_glr_plot_digits_linkage_002.png - :target: ../auto_examples/cluster/plot_digits_linkage.html - :scale: 43 - -.. image:: ../auto_examples/cluster/images/sphx_glr_plot_digits_linkage_003.png - :target: ../auto_examples/cluster/plot_digits_linkage.html - :scale: 43 - - Agglomerative cluster has a "rich get richer" behavior that leads to -uneven cluster sizes. In this regard, complete linkage is the worst +uneven cluster sizes. In this regard, single linkage is the worst strategy, and Ward gives the most regular sizes. However, the affinity (or distance used in clustering) cannot be varied with Ward, thus for non -Euclidean metrics, average linkage is a good alternative. +Euclidean metrics, average linkage is a good alternative. Single linkage, +while not robust to noisy data, can be computed very efficiently and can +therefore be useful to provide hierarchical clustering of larger datasets. +Single linkage can also perform well on non-globular data. .. topic:: Examples: @@ -652,15 +646,16 @@ enable only merging of neighboring pixels on an image, as in the * :ref:`sphx_glr_auto_examples_cluster_plot_agglomerative_clustering.py` -.. warning:: **Connectivity constraints with average and complete linkage** +.. warning:: **Connectivity constraints with single, average and complete linkage** - Connectivity constraints and complete or average linkage can enhance + Connectivity constraints and single, complete or average linkage can enhance the 'rich getting richer' aspect of agglomerative clustering, particularly so if they are built with :func:`sklearn.neighbors.kneighbors_graph`. In the limit of a small number of clusters, they tend to give a few macroscopically occupied clusters and almost empty ones. (see the discussion in :ref:`sphx_glr_auto_examples_cluster_plot_agglomerative_clustering.py`). + Single linkage is the most brittle linkage option with regard to this issue. .. image:: ../auto_examples/cluster/images/sphx_glr_plot_agglomerative_clustering_001.png :target: ../auto_examples/cluster/plot_agglomerative_clustering.html @@ -682,7 +677,7 @@ enable only merging of neighboring pixels on an image, as in the Varying the metric ------------------- -Average and complete linkage can be used with a variety of distances (or +Single, average and complete linkage can be used with a variety of distances (or affinities), in particular Euclidean distance (*l2*), Manhattan distance (or Cityblock, or *l1*), cosine distance, or any precomputed affinity matrix. diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 7e7d39dbf1759..caec9335d75c6 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -78,14 +78,15 @@ Model evaluation - Added the :func:`metrics.balanced_accuracy_score` metric and a corresponding ``'balanced_accuracy'`` scorer for binary classification. :issue:`8066` by :user:`xyguo` and :user:`Aman Dalmia `. - - Added :class:`multioutput.RegressorChain` for multi-target regression. :issue:`9257` by :user:`Kumar Ashutosh `. -- Added the :class:`preprocessing.TransformedTargetRegressor` which transforms - the target y before fitting a regression model. The predictions are mapped - back to the original space via an inverse transform. :issue:`9041` by - `Andreas Müller`_ and :user:`Guillaume Lemaitre `. +Clustering + +- :class:`cluster.AgglomerativeClustering` now supports Single Linkage + clustering via ``linkage='single'``. :issue:`9372` by + :user:`Leland McInnes ` and :user:`Steve Astels `. + Enhancements ............ diff --git a/examples/cluster/plot_agglomerative_clustering.py b/examples/cluster/plot_agglomerative_clustering.py index dfb27d17d1a89..1a9cf22c1e5f7 100644 --- a/examples/cluster/plot_agglomerative_clustering.py +++ b/examples/cluster/plot_agglomerative_clustering.py @@ -9,17 +9,18 @@ Two consequences of imposing a connectivity can be seen. First clustering with a connectivity matrix is much faster. -Second, when using a connectivity matrix, average and complete linkage are -unstable and tend to create a few clusters that grow very quickly. Indeed, -average and complete linkage fight this percolation behavior by considering all -the distances between two clusters when merging them. The connectivity -graph breaks this mechanism. This effect is more pronounced for very -sparse graphs (try decreasing the number of neighbors in -kneighbors_graph) and with complete linkage. In particular, having a very -small number of neighbors in the graph, imposes a geometry that is -close to that of single linkage, which is well known to have this -percolation instability. -""" +Second, when using a connectivity matrix, single, average and complete +linkage are unstable and tend to create a few clusters that grow very +quickly. Indeed, average and complete linkage fight this percolation behavior +by considering all the distances between two clusters when merging them ( +while single linkage exaggerates the behaviour by considering only the +shortest distance between clusters). The connectivity graph breaks this +mechanism for average and complete linkage, making them resemble the more +brittle single linkage. This effect is more pronounced for very sparse graphs +(try decreasing the number of neighbors in kneighbors_graph) and with +complete linkage. In particular, having a very small number of neighbors in +the graph, imposes a geometry that is close to that of single linkage, +which is well known to have this percolation instability. """ # Authors: Gael Varoquaux, Nelle Varoquaux # License: BSD 3 clause @@ -52,8 +53,11 @@ for connectivity in (None, knn_graph): for n_clusters in (30, 3): plt.figure(figsize=(10, 4)) - for index, linkage in enumerate(('average', 'complete', 'ward')): - plt.subplot(1, 3, index + 1) + for index, linkage in enumerate(('average', + 'complete', + 'ward', + 'single')): + plt.subplot(1, 4, index + 1) model = AgglomerativeClustering(linkage=linkage, connectivity=connectivity, n_clusters=n_clusters) @@ -62,7 +66,7 @@ elapsed_time = time.time() - t0 plt.scatter(X[:, 0], X[:, 1], c=model.labels_, cmap=plt.cm.spectral) - plt.title('linkage=%s (time %.2fs)' % (linkage, elapsed_time), + plt.title('linkage=%s\n(time %.2fs)' % (linkage, elapsed_time), fontdict=dict(verticalalignment='top')) plt.axis('equal') plt.axis('off') diff --git a/examples/cluster/plot_digits_linkage.py b/examples/cluster/plot_digits_linkage.py index f1fe1783c10e5..ba69d04eb4957 100644 --- a/examples/cluster/plot_digits_linkage.py +++ b/examples/cluster/plot_digits_linkage.py @@ -12,8 +12,10 @@ What this example shows us is the behavior "rich getting richer" of agglomerative clustering that tends to create uneven cluster sizes. -This behavior is especially pronounced for the average linkage strategy, -that ends up with a couple of singleton clusters. +This behavior is pronounced for the average linkage strategy, +that ends up with a couple of singleton clusters, while in the case +of single linkage we get a single central cluster with all other clusters +being drawn from noise points around the fringes. """ # Authors: Gael Varoquaux @@ -69,7 +71,7 @@ def plot_clustering(X_red, X, labels, title=None): if title is not None: plt.title(title, size=17) plt.axis('off') - plt.tight_layout() + plt.tight_layout(rect=[0, 0.03, 1, 0.95]) #---------------------------------------------------------------------- # 2D embedding of the digits dataset @@ -79,11 +81,11 @@ def plot_clustering(X_red, X, labels, title=None): from sklearn.cluster import AgglomerativeClustering -for linkage in ('ward', 'average', 'complete'): +for linkage in ('ward', 'average', 'complete', 'single'): clustering = AgglomerativeClustering(linkage=linkage, n_clusters=10) t0 = time() clustering.fit(X_red) - print("%s : %.2fs" % (linkage, time() - t0)) + print("%s :\t%.2fs" % (linkage, time() - t0)) plot_clustering(X_red, X, clustering.labels_, "%s linkage" % linkage) diff --git a/examples/cluster/plot_linkage_comparison.py b/examples/cluster/plot_linkage_comparison.py new file mode 100644 index 0000000000000..471132a0f222f --- /dev/null +++ b/examples/cluster/plot_linkage_comparison.py @@ -0,0 +1,149 @@ +""" +================================================================ +Comparing different hierarchical linkage methods on toy datasets +================================================================ + +This example shows characteristics of different linkage +methods for hierarchical clustering on datasets that are +"interesting" but still in 2D. + +The main observations to make are: + +- single linkage is fast, and can perform well on + non-globular data, but it performs poorly in the + presence of noise. +- average and complete linkage perform well on + cleanly separated globular clusters, but have mixed + results otherwise. +- Ward is the most effective method for noisy data. + +While these examples give some intuition about the +algorithms, this intuition might not apply to very high +dimensional data. +""" +print(__doc__) + +import time +import warnings + +import numpy as np +import matplotlib.pyplot as plt + +from sklearn import cluster, datasets +from sklearn.preprocessing import StandardScaler +from itertools import cycle, islice + +np.random.seed(0) + +###################################################################### +# Generate datasets. We choose the size big enough to see the scalability +# of the algorithms, but not too big to avoid too long running times + +n_samples = 1500 +noisy_circles = datasets.make_circles(n_samples=n_samples, factor=.5, + noise=.05) +noisy_moons = datasets.make_moons(n_samples=n_samples, noise=.05) +blobs = datasets.make_blobs(n_samples=n_samples, random_state=8) +no_structure = np.random.rand(n_samples, 2), None + +# Anisotropicly distributed data +random_state = 170 +X, y = datasets.make_blobs(n_samples=n_samples, random_state=random_state) +transformation = [[0.6, -0.6], [-0.4, 0.8]] +X_aniso = np.dot(X, transformation) +aniso = (X_aniso, y) + +# blobs with varied variances +varied = datasets.make_blobs(n_samples=n_samples, + cluster_std=[1.0, 2.5, 0.5], + random_state=random_state) + +###################################################################### +# Run the clustering and plot + +# Set up cluster parameters +plt.figure(figsize=(9 * 1.3 + 2, 14.5)) +plt.subplots_adjust(left=.02, right=.98, bottom=.001, top=.96, wspace=.05, + hspace=.01) + +plot_num = 1 + +default_base = {'n_neighbors': 10, + 'n_clusters': 3} + +datasets = [ + (noisy_circles, {'n_clusters': 2}), + (noisy_moons, {'n_clusters': 2}), + (varied, {'n_neighbors': 2}), + (aniso, {'n_neighbors': 2}), + (blobs, {}), + (no_structure, {})] + +for i_dataset, (dataset, algo_params) in enumerate(datasets): + # update parameters with dataset-specific values + params = default_base.copy() + params.update(algo_params) + + X, y = dataset + + # normalize dataset for easier parameter selection + X = StandardScaler().fit_transform(X) + + # ============ + # Create cluster objects + # ============ + ward = cluster.AgglomerativeClustering( + n_clusters=params['n_clusters'], linkage='ward') + complete = cluster.AgglomerativeClustering( + n_clusters=params['n_clusters'], linkage='complete') + average = cluster.AgglomerativeClustering( + n_clusters=params['n_clusters'], linkage='average') + single = cluster.AgglomerativeClustering( + n_clusters=params['n_clusters'], linkage='single') + + clustering_algorithms = ( + ('Single Linkage', single), + ('Average Linkage', average), + ('Complete Linkage', complete), + ('Ward Linkage', ward), + ) + + for name, algorithm in clustering_algorithms: + t0 = time.time() + + # catch warnings related to kneighbors_graph + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="the number of connected components of the " + + "connectivity matrix is [0-9]{1,2}" + + " > 1. Completing it to avoid stopping the tree early.", + category=UserWarning) + algorithm.fit(X) + + t1 = time.time() + if hasattr(algorithm, 'labels_'): + y_pred = algorithm.labels_.astype(np.int) + else: + y_pred = algorithm.predict(X) + + plt.subplot(len(datasets), len(clustering_algorithms), plot_num) + if i_dataset == 0: + plt.title(name, size=18) + + colors = np.array(list(islice(cycle(['#377eb8', '#ff7f00', '#4daf4a', + '#f781bf', '#a65628', '#984ea3', + '#999999', '#e41a1c', '#dede00']), + int(max(y_pred) + 1)))) + plt.scatter(X[:, 0], X[:, 1], s=10, color=colors[y_pred]) + + plt.xlim(-2.5, 2.5) + plt.ylim(-2.5, 2.5) + plt.xticks(()) + plt.yticks(()) + plt.text(.99, .01, ('%.2fs' % (t1 - t0)).lstrip('0'), + transform=plt.gca().transAxes, size=15, + horizontalalignment='right') + plot_num += 1 + +plt.show() diff --git a/sklearn/cluster/_hierarchical.pyx b/sklearn/cluster/_hierarchical.pyx index 2c04e5921679c..5d42d84944956 100644 --- a/sklearn/cluster/_hierarchical.pyx +++ b/sklearn/cluster/_hierarchical.pyx @@ -332,3 +332,120 @@ cdef class WeightedEdge: self.weight, self.a, self.b) + +################################################################################ +# Efficient labelling/conversion of MSTs to single linkage hierarchies + +cdef class UnionFind(object): + + cdef ITYPE_t next_label + cdef ITYPE_t[:] parent + cdef ITYPE_t[:] size + + def __init__(self, N): + self.parent = -1 * np.ones(2 * N - 1, dtype=ITYPE, order='C') + self.next_label = N + self.size = np.hstack((np.ones(N, dtype=ITYPE), + np.zeros(N - 1, dtype=ITYPE))) + + @cython.boundscheck(False) + @cython.nonecheck(False) + cdef void union(self, ITYPE_t m, ITYPE_t n): + self.parent[m] = self.next_label + self.parent[n] = self.next_label + self.size[self.next_label] = self.size[m] + self.size[n] + self.next_label += 1 + + return + + @cython.boundscheck(False) + @cython.nonecheck(False) + cdef ITYPE_t fast_find(self, ITYPE_t n): + cdef ITYPE_t p + p = n + # find the highest node in the linkage graph so far + while self.parent[n] != -1: + n = self.parent[n] + # provide a shortcut up to the highest node + while self.parent[p] != n: + p, self.parent[p] = self.parent[p], n + return n + +@cython.boundscheck(False) +@cython.nonecheck(False) +cpdef np.ndarray[DTYPE_t, ndim=2] _single_linkage_label( + np.ndarray[DTYPE_t, ndim=2] L): + """ + Convert an linkage array or MST to a tree by labelling clusters at merges. + This is done by using a Union find structure to keep track of merges + efficiently. This is the private version of the function that assumes that + ``L`` has been properly validated. See ``single_linkage_label`` for the + user facing version of this function. + + Parameters + ---------- + L: array of shape (n_samples - 1, 3) + The linkage array or MST where each row specifies two samples + to be merged and a distance or weight at which the merge occurs. This + array is assumed to be sorted by the distance/weight. + + Returns + ------- + A tree in the format used by scipy.cluster.hierarchy. + """ + + cdef np.ndarray[DTYPE_t, ndim=2] result_arr + cdef DTYPE_t[:, ::1] result + + cdef ITYPE_t left, left_cluster, right, right_cluster, index + cdef DTYPE_t delta + + result_arr = np.zeros((L.shape[0], 4), dtype=DTYPE) + result = result_arr + U = UnionFind(L.shape[0] + 1) + + for index in range(L.shape[0]): + + left = L[index, 0] + right = L[index, 1] + delta = L[index, 2] + + left_cluster = U.fast_find(left) + right_cluster = U.fast_find(right) + + result[index][0] = left_cluster + result[index][1] = right_cluster + result[index][2] = delta + result[index][3] = U.size[left_cluster] + U.size[right_cluster] + + U.union(left_cluster, right_cluster) + + return result_arr + + +def single_linkage_label(L): + """ + Convert an linkage array or MST to a tree by labelling clusters at merges. + This is done by using a Union find structure to keep track of merges + efficiently. + + Parameters + ---------- + L: array of shape (n_samples - 1, 3) + The linkage array or MST where each row specifies two samples + to be merged and a distance or weight at which the merge occurs. This + array is assumed to be sorted by the distance/weight. + + Returns + ------- + A tree in the format used by scipy.cluster.hierarchy. + """ + # Validate L + if L[:, :2].min() < 0 or L[:, :2].max() >= 2 * L.shape[0] + 1: + raise ValueError("Input MST array is not a validly formatted MST array") + + is_sorted = lambda x: np.all(x[:-1] <= x[1:]) + if not is_sorted(L[:, 2]): + raise ValueError("Input MST array must be sorted by weight") + + return _single_linkage_label(L) \ No newline at end of file diff --git a/sklearn/cluster/hierarchical.py b/sklearn/cluster/hierarchical.py index cb901dd19d4f3..c462f2f2cda2e 100644 --- a/sklearn/cluster/hierarchical.py +++ b/sklearn/cluster/hierarchical.py @@ -80,6 +80,56 @@ def _fix_connectivity(X, connectivity, affinity): return connectivity, n_components +def _single_linkage_tree(connectivity, n_samples, n_nodes, n_clusters, + n_components, return_distance): + """ + Perform single linkage clustering on sparse data via the minimum + spanning tree from scipy.sparse.csgraph, then using union-find to label. + The parent array is then generated by walking through the tree. + """ + from scipy.sparse.csgraph import minimum_spanning_tree + + # explicitly cast connectivity to ensure safety + connectivity = connectivity.astype('float64') + + # Ensure zero distances aren't ignored by setting them to "epsilon" + epsilon_value = np.nextafter(0, 1, dtype=connectivity.data.dtype) + connectivity.data[connectivity.data == 0] = epsilon_value + + # Use scipy.sparse.csgraph to generate a minimum spanning tree + mst = minimum_spanning_tree(connectivity.tocsr()) + + # Convert the graph to scipy.cluster.hierarchy array format + mst = mst.tocoo() + + # Undo the epsilon values + mst.data[mst.data == epsilon_value] = 0 + + mst_array = np.vstack([mst.row, mst.col, mst.data]).T + + # Sort edges of the min_spanning_tree by weight + mst_array = mst_array[np.argsort(mst_array.T[2]), :] + + # Convert edge list into standard hierarchical clustering format + single_linkage_tree = _hierarchical._single_linkage_label(mst_array) + children_ = single_linkage_tree[:, :2].astype(np.int) + + # Compute parents + parent = np.arange(n_nodes, dtype=np.intp) + for i, (left, right) in enumerate(children_, n_samples): + if n_clusters is not None and i >= n_nodes: + break + if left < n_nodes: + parent[left] = i + if right < n_nodes: + parent[right] = i + + if return_distance: + distances = single_linkage_tree[:, 2] + return children_, n_components, n_samples, parent, distances + return children_, n_components, n_samples, parent + + ############################################################################### # Hierarchical tree building functions @@ -288,7 +338,7 @@ def ward_tree(X, connectivity=None, n_clusters=None, return_distance=False): return children, n_components, n_leaves, parent -# average and complete linkage +# single average and complete linkage def linkage_tree(X, connectivity=None, n_components='deprecated', n_clusters=None, linkage='complete', affinity="euclidean", return_distance=False): @@ -323,13 +373,15 @@ def linkage_tree(X, connectivity=None, n_components='deprecated', limited use, and the 'parents' output should rather be used. This option is valid only when specifying a connectivity matrix. - linkage : {"average", "complete"}, optional, default: "complete" + linkage : {"average", "complete", "single"}, optional, default: "complete" Which linkage criteria to use. The linkage criterion determines which distance to use between sets of observation. - average uses the average of the distances of each observation of the two sets - complete or maximum linkage uses the maximum distances between all observations of the two sets. + - single uses the minimum of the distances between all observations + of the two sets. affinity : string or callable, optional, default: "euclidean". which metric to use. Can be "euclidean", "manhattan", or any @@ -378,7 +430,8 @@ def linkage_tree(X, connectivity=None, n_components='deprecated', n_samples, n_features = X.shape linkage_choices = {'complete': _hierarchical.max_merge, - 'average': _hierarchical.average_merge} + 'average': _hierarchical.average_merge, + 'single': None} # Single linkage is handled differently try: join_func = linkage_choices[linkage] except KeyError: @@ -434,7 +487,7 @@ def linkage_tree(X, connectivity=None, n_components='deprecated', del diag_mask if affinity == 'precomputed': - distances = X[connectivity.row, connectivity.col] + distances = X[connectivity.row, connectivity.col].astype('float64') else: # FIXME We compute all the distances, while we could have only computed # the "interesting" distances @@ -449,6 +502,10 @@ def linkage_tree(X, connectivity=None, n_components='deprecated', assert n_clusters <= n_samples n_nodes = 2 * n_samples - n_clusters + if linkage == 'single': + return _single_linkage_tree(connectivity, n_samples, n_nodes, + n_clusters, n_components, return_distance) + if return_distance: distances = np.empty(n_nodes - n_samples) # create inertia heap and connection matrix @@ -532,10 +589,16 @@ def _average_linkage(*args, **kwargs): return linkage_tree(*args, **kwargs) +def _single_linkage(*args, **kwargs): + kwargs['linkage'] = 'single' + return linkage_tree(*args, **kwargs) + + _TREE_BUILDERS = dict( ward=ward_tree, complete=_complete_linkage, - average=_average_linkage) + average=_average_linkage, + single=_single_linkage) ############################################################################### @@ -630,7 +693,8 @@ class AgglomerativeClustering(BaseEstimator, ClusterMixin): when varying the number of clusters and using caching, it may be advantageous to compute the full tree. - linkage : {"ward", "complete", "average"}, optional, default: "ward" + linkage : {"ward", "complete", "average", "single"}, optional \ + (default="ward") Which linkage criterion to use. The linkage criterion determines which distance to use between sets of observation. The algorithm will merge the pairs of cluster that minimize this criterion. @@ -640,6 +704,8 @@ class AgglomerativeClustering(BaseEstimator, ClusterMixin): the two sets. - complete or maximum linkage uses the maximum distances between all observations of the two sets. + - single uses the minimum of the distances between all observations + of the two sets. pooling_func : callable, default='deprecated' Ignored. @@ -713,7 +779,7 @@ def fit(self, X, y=None): (self.affinity, )) if self.linkage not in _TREE_BUILDERS: - raise ValueError("Unknown linkage type %s." + raise ValueError("Unknown linkage type %s. " "Valid options are %s" % (self.linkage, _TREE_BUILDERS.keys())) tree_builder = _TREE_BUILDERS[self.linkage] @@ -799,7 +865,8 @@ class FeatureAgglomeration(AgglomerativeClustering, AgglomerationTransform): when varying the number of clusters and using caching, it may be advantageous to compute the full tree. - linkage : {"ward", "complete", "average"}, optional, default "ward" + linkage : {"ward", "complete", "average", "single"}, optional\ + (default="ward") Which linkage criterion to use. The linkage criterion determines which distance to use between sets of features. The algorithm will merge the pairs of cluster that minimize this criterion. @@ -809,6 +876,8 @@ class FeatureAgglomeration(AgglomerativeClustering, AgglomerationTransform): the two sets. - complete or maximum linkage uses the maximum distances between all features of the two sets. + - single uses the minimum of the distances between all observations + of the two sets. pooling_func : callable, default np.mean This combines the values of agglomerated features into a single diff --git a/sklearn/cluster/tests/test_hierarchical.py b/sklearn/cluster/tests/test_hierarchical.py index c4534663236b0..3dcc415424cb9 100644 --- a/sklearn/cluster/tests/test_hierarchical.py +++ b/sklearn/cluster/tests/test_hierarchical.py @@ -24,7 +24,7 @@ from sklearn.cluster import ward_tree from sklearn.cluster import AgglomerativeClustering, FeatureAgglomeration from sklearn.cluster.hierarchical import (_hc_cut, _TREE_BUILDERS, - linkage_tree) + linkage_tree, _fix_connectivity) from sklearn.feature_extraction.image import grid_to_graph from sklearn.metrics.pairwise import PAIRED_DISTANCES, cosine_distances,\ manhattan_distances, pairwise_distances @@ -34,6 +34,7 @@ from sklearn.utils.fast_dict import IntFloatDict from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_warns +from sklearn.datasets import make_moons, make_circles def test_deprecation_of_n_components_in_linkage_tree(): @@ -148,7 +149,7 @@ def test_agglomerative_clustering(): n_samples = 100 X = rng.randn(n_samples, 50) connectivity = grid_to_graph(*mask.shape) - for linkage in ("ward", "complete", "average"): + for linkage in ("ward", "complete", "average", "single"): clustering = AgglomerativeClustering(n_clusters=10, connectivity=connectivity, linkage=linkage) @@ -248,6 +249,22 @@ def test_ward_agglomeration(): assert_raises(ValueError, agglo.fit, X[:0]) +def test_single_linkage_clustering(): + # Check that we get the correct result in two emblematic cases + moons, moon_labels = make_moons(noise=0.05, random_state=42) + clustering = AgglomerativeClustering(n_clusters=2, linkage='single') + clustering.fit(moons) + assert_almost_equal(normalized_mutual_info_score(clustering.labels_, + moon_labels), 1) + + circles, circle_labels = make_circles(factor=0.5, noise=0.025, + random_state=42) + clustering = AgglomerativeClustering(n_clusters=2, linkage='single') + clustering.fit(circles) + assert_almost_equal(normalized_mutual_info_score(clustering.labels_, + circle_labels), 1) + + def assess_same_labelling(cut1, cut2): """Util for comparison with scipy""" co_clust = [] @@ -279,6 +296,12 @@ def test_scikit_vs_scipy(): children_ = out[:, :2].astype(np.int) children, _, n_leaves, _ = _TREE_BUILDERS[linkage](X, connectivity) + # Sort the order of of child nodes per row for consistency + children.sort(axis=1) + assert_array_equal(children, children_, 'linkage tree differs' + ' from scipy impl for' + ' linkage: ' + linkage) + cut = _hc_cut(k, children, n_leaves) cut_ = _hc_cut(k, children_, n_leaves) assess_same_labelling(cut, cut_) @@ -287,6 +310,29 @@ def test_scikit_vs_scipy(): assert_raises(ValueError, _hc_cut, n_leaves + 1, children, n_leaves) +def test_identical_points(): + # Ensure identical points are handled correctly when using mst with + # a sparse connectivity matrix + X = np.array([[0, 0, 0], [0, 0, 0], + [1, 1, 1], [1, 1, 1], + [2, 2, 2], [2, 2, 2]]) + true_labels = np.array([0, 0, 1, 1, 2, 2]) + connectivity = kneighbors_graph(X, n_neighbors=3, include_self=False) + connectivity = 0.5 * (connectivity + connectivity.T) + connectivity, n_components = _fix_connectivity(X, + connectivity, + 'euclidean') + + for linkage in ('single', 'average', 'average', 'ward'): + clustering = AgglomerativeClustering(n_clusters=3, + linkage=linkage, + connectivity=connectivity) + clustering.fit(X) + + assert_almost_equal(normalized_mutual_info_score(clustering.labels_, + true_labels), 1) + + def test_connectivity_propagation(): # Check that connectivity in the ward tree is propagated correctly during # merging. @@ -354,7 +400,7 @@ def test_ward_linkage_tree_return_distance(): assert_array_almost_equal(dist_unstructured, dist_structured) - for linkage in ['average', 'complete']: + for linkage in ['average', 'complete', 'single']: structured_items = linkage_tree( X, connectivity=connectivity, linkage=linkage, return_distance=True)[-1] @@ -412,7 +458,7 @@ def test_ward_linkage_tree_return_distance(): assert_array_almost_equal(linkage_X_ward[:, 2], out_X_unstructured[4]) assert_array_almost_equal(linkage_X_ward[:, 2], out_X_structured[4]) - linkage_options = ['complete', 'average'] + linkage_options = ['complete', 'average', 'single'] X_linkage_truth = [linkage_X_complete, linkage_X_average] for (linkage, X_truth) in zip(linkage_options, X_linkage_truth): out_X_unstructured = linkage_tree(