Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 44554bd

Browse files
committed
Pass affinity to fix connectivity in linkage tree
Resolves #9308.
1 parent a08555a commit 44554bd

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

sklearn/cluster/hierarchical.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,9 @@ def linkage_tree(X, connectivity=None, n_components=None,
415415
return children_, 1, n_samples, None, distances
416416
return children_, 1, n_samples, None
417417

418-
connectivity, n_components = _fix_connectivity(X, connectivity)
418+
connectivity, n_components = _fix_connectivity(X,
419+
connectivity,
420+
affinity=affinity)
419421

420422
connectivity = connectivity.tocoo()
421423
# Put the diagonal to zero

sklearn/cluster/tests/test_hierarchical.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,3 +518,32 @@ def test_agg_n_clusters():
518518
msg = ("n_clusters should be an integer greater than 0."
519519
" %s was provided." % str(agc.n_clusters))
520520
assert_raise_message(ValueError, msg, agc.fit, X)
521+
522+
523+
def test_affinity_passed_to_fix_connectivity():
524+
# Test that the affinity parameter is actually passed to the pairwise
525+
# function
526+
527+
rng = np.random.RandomState(0)
528+
X = rng.randn(2, 2)
529+
530+
mask = np.zeros((2, 2), dtype=np.bool)
531+
mask[0:1, 0:1] = True
532+
mask[-1:, -1:] = True
533+
mask = mask.reshape(2 ** 2)
534+
connectivity = grid_to_graph(n_x=2, n_y=2, mask=mask,
535+
return_as=np.ndarray)
536+
537+
class FakeAffinity:
538+
def __init__(self):
539+
self.counter = 0
540+
541+
def increment(self, *args, **kwargs):
542+
self.counter += 1
543+
return self.counter
544+
545+
fa = FakeAffinity()
546+
547+
linkage_tree(X, affinity=fa.increment, connectivity=connectivity)
548+
549+
assert_equal(fa.counter, 3)

0 commit comments

Comments
 (0)