From 174dc7b4c160a98920ffab2df7b64300496b676d Mon Sep 17 00:00:00 2001 From: Breno Freitas Date: Thu, 13 Jul 2017 21:33:08 -0400 Subject: [PATCH] Pass affinity to fix connectivity in linkage tree Resolves #9308. --- sklearn/cluster/hierarchical.py | 9 ++++---- sklearn/cluster/tests/test_hierarchical.py | 27 ++++++++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/sklearn/cluster/hierarchical.py b/sklearn/cluster/hierarchical.py index 2195fe8ee3d85..59d8811cc728d 100644 --- a/sklearn/cluster/hierarchical.py +++ b/sklearn/cluster/hierarchical.py @@ -30,8 +30,7 @@ # For non fully-connected graphs -def _fix_connectivity(X, connectivity, n_components=None, - affinity="euclidean"): +def _fix_connectivity(X, connectivity, affinity): """ Fixes the connectivity matrix @@ -190,7 +189,8 @@ def ward_tree(X, connectivity=None, n_clusters=None, return_distance=False): else: return children_, 1, n_samples, None - connectivity, n_components = _fix_connectivity(X, connectivity) + connectivity, n_components = _fix_connectivity(X, connectivity, + affinity='euclidean') if n_clusters is None: n_nodes = 2 * n_samples - 1 else: @@ -415,7 +415,8 @@ def linkage_tree(X, connectivity=None, n_components=None, return children_, 1, n_samples, None, distances return children_, 1, n_samples, None - connectivity, n_components = _fix_connectivity(X, connectivity) + connectivity, n_components = _fix_connectivity(X, connectivity, + affinity=affinity) connectivity = connectivity.tocoo() # Put the diagonal to zero diff --git a/sklearn/cluster/tests/test_hierarchical.py b/sklearn/cluster/tests/test_hierarchical.py index 986b92e0ce9f4..f706966e8e80d 100644 --- a/sklearn/cluster/tests/test_hierarchical.py +++ b/sklearn/cluster/tests/test_hierarchical.py @@ -518,3 +518,30 @@ def test_agg_n_clusters(): msg = ("n_clusters should be an integer greater than 0." " %s was provided." % str(agc.n_clusters)) assert_raise_message(ValueError, msg, agc.fit, X) + + +def test_affinity_passed_to_fix_connectivity(): + # Test that the affinity parameter is actually passed to the pairwise + # function + + size = 2 + rng = np.random.RandomState(0) + X = rng.randn(size, size) + mask = np.array([True, False, False, True]) + + connectivity = grid_to_graph(n_x=size, n_y=size, + mask=mask, return_as=np.ndarray) + + class FakeAffinity: + def __init__(self): + self.counter = 0 + + def increment(self, *args, **kwargs): + self.counter += 1 + return self.counter + + fa = FakeAffinity() + + linkage_tree(X, connectivity=connectivity, affinity=fa.increment) + + assert_equal(fa.counter, 3)