Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 36151f6

Browse files
committed
MAINT: Moved comput_label test to common tests
1 parent b9474d0 commit 36151f6

File tree

3 files changed

+18
-13
lines changed

3 files changed

+18
-13
lines changed

sklearn/cluster/tests/test_birch.py

-13
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,3 @@ def test_threshold():
158158
brc = Birch(threshold=5.0, n_clusters=None)
159159
brc.fit(X)
160160
check_threshold(brc, 5.)
161-
162-
163-
def test_compute_label_predict():
164-
"""Test predict is invariant of the param 'compute_labels'"""
165-
X, y = make_blobs(n_samples=80, centers=4)
166-
brc1 = Birch(threshold=0.5, n_clusters=None, compute_labels=True)
167-
brc1.fit(X)
168-
brc1_labels = brc1.predict(X)
169-
170-
brc2 = Birch(threshold=0.5, n_clusters=None, compute_labels=False)
171-
brc2.fit(X)
172-
brc2_labels = brc2.predict(X)
173-
assert_almost_equal(v_measure_score(brc1_labels, brc2_labels), 1.0)

sklearn/tests/test_common.py

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
check_regressors_classifiers_sparse_data,
3333
check_transformer,
3434
check_clustering,
35+
check_clusterer_compute_labels_predict,
3536
check_regressors_int,
3637
check_regressors_train,
3738
check_regressors_pickle,
@@ -123,6 +124,7 @@ def test_clustering():
123124
for name, Alg in clustering:
124125
# test whether any classifier overwrites his init parameters during fit
125126
yield check_cluster_overwrite_params, name, Alg
127+
yield check_clusterer_compute_labels_predict, name, Alg
126128
if name not in ('WardAgglomeration', "FeatureAgglomeration"):
127129
# this is clustering on the features
128130
# let's not test that here.

sklearn/utils/estimator_checks.py

+16
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,22 @@ def check_clustering(name, Alg):
420420
assert_array_equal(pred, pred2)
421421

422422

423+
def check_clusterer_compute_labels_predict(name, Clusterer):
424+
"""Check that predict is invariant of compute_labels"""
425+
X, y = make_blobs(n_samples=20, random_state=0)
426+
clusterer = Clusterer()
427+
428+
if hasattr(clusterer, "compute_labels"):
429+
# MiniBatchKMeans
430+
if hasattr(clusterer, "random_state"):
431+
clusterer.set_params(random_state=0)
432+
433+
X_pred1 = clusterer.fit(X).predict(X)
434+
clusterer.set_params(compute_labels=False)
435+
X_pred2 = clusterer.fit(X).predict(X)
436+
assert_array_equal(X_pred1, X_pred2)
437+
438+
423439
def check_classifiers_one_label(name, Classifier):
424440
error_string_fit = "Classifier can't train when only one class is present."
425441
error_string_predict = ("Classifier can't predict when only one class is "

0 commit comments

Comments
 (0)