File tree 3 files changed +18
-13
lines changed
3 files changed +18
-13
lines changed Original file line number Diff line number Diff line change @@ -158,16 +158,3 @@ def test_threshold():
158
158
brc = Birch (threshold = 5.0 , n_clusters = None )
159
159
brc .fit (X )
160
160
check_threshold (brc , 5. )
161
-
162
-
163
- def test_compute_label_predict ():
164
- """Test predict is invariant of the param 'compute_labels'"""
165
- X , y = make_blobs (n_samples = 80 , centers = 4 )
166
- brc1 = Birch (threshold = 0.5 , n_clusters = None , compute_labels = True )
167
- brc1 .fit (X )
168
- brc1_labels = brc1 .predict (X )
169
-
170
- brc2 = Birch (threshold = 0.5 , n_clusters = None , compute_labels = False )
171
- brc2 .fit (X )
172
- brc2_labels = brc2 .predict (X )
173
- assert_almost_equal (v_measure_score (brc1_labels , brc2_labels ), 1.0 )
Original file line number Diff line number Diff line change 32
32
check_regressors_classifiers_sparse_data ,
33
33
check_transformer ,
34
34
check_clustering ,
35
+ check_clusterer_compute_labels_predict ,
35
36
check_regressors_int ,
36
37
check_regressors_train ,
37
38
check_regressors_pickle ,
@@ -123,6 +124,7 @@ def test_clustering():
123
124
for name , Alg in clustering :
124
125
# test whether any classifier overwrites his init parameters during fit
125
126
yield check_cluster_overwrite_params , name , Alg
127
+ yield check_clusterer_compute_labels_predict , name , Alg
126
128
if name not in ('WardAgglomeration' , "FeatureAgglomeration" ):
127
129
# this is clustering on the features
128
130
# let's not test that here.
Original file line number Diff line number Diff line change @@ -420,6 +420,22 @@ def check_clustering(name, Alg):
420
420
assert_array_equal (pred , pred2 )
421
421
422
422
423
+ def check_clusterer_compute_labels_predict (name , Clusterer ):
424
+ """Check that predict is invariant of compute_labels"""
425
+ X , y = make_blobs (n_samples = 20 , random_state = 0 )
426
+ clusterer = Clusterer ()
427
+
428
+ if hasattr (clusterer , "compute_labels" ):
429
+ # MiniBatchKMeans
430
+ if hasattr (clusterer , "random_state" ):
431
+ clusterer .set_params (random_state = 0 )
432
+
433
+ X_pred1 = clusterer .fit (X ).predict (X )
434
+ clusterer .set_params (compute_labels = False )
435
+ X_pred2 = clusterer .fit (X ).predict (X )
436
+ assert_array_equal (X_pred1 , X_pred2 )
437
+
438
+
423
439
def check_classifiers_one_label (name , Classifier ):
424
440
error_string_fit = "Classifier can't train when only one class is present."
425
441
error_string_predict = ("Classifier can't predict when only one class is "
You can’t perform that action at this time.
0 commit comments