diff --git a/sklearn/cluster/tests/test_birch.py b/sklearn/cluster/tests/test_birch.py index 99f9b1cdbc9fe..c2f3c06d15ba7 100644 --- a/sklearn/cluster/tests/test_birch.py +++ b/sklearn/cluster/tests/test_birch.py @@ -17,9 +17,10 @@ from sklearn.utils._testing import assert_allclose -def test_n_samples_leaves_roots(global_random_seed): +def test_n_samples_leaves_roots(global_random_seed, global_dtype): # Sanity check for the number of samples in leaves and roots X, y = make_blobs(n_samples=10, random_state=global_random_seed) + X = X.astype(global_dtype, copy=False) brc = Birch() brc.fit(X) n_samples_root = sum([sc.n_samples_ for sc in brc.root_.subclusters_]) @@ -30,9 +31,10 @@ def test_n_samples_leaves_roots(global_random_seed): assert n_samples_root == X.shape[0] -def test_partial_fit(global_random_seed): +def test_partial_fit(global_random_seed, global_dtype): # Test that fit is equivalent to calling partial_fit multiple times X, y = make_blobs(n_samples=100, random_state=global_random_seed) + X = X.astype(global_dtype, copy=False) brc = Birch(n_clusters=3) brc.fit(X) brc_partial = Birch(n_clusters=None) @@ -47,10 +49,11 @@ def test_partial_fit(global_random_seed): assert_array_equal(brc_partial.subcluster_labels_, brc.subcluster_labels_) -def test_birch_predict(global_random_seed): +def test_birch_predict(global_random_seed, global_dtype): # Test the predict method predicts the nearest centroid. rng = np.random.RandomState(global_random_seed) X = generate_clustered_data(n_clusters=3, n_features=3, n_samples_per_cluster=10) + X = X.astype(global_dtype, copy=False) # n_samples * n_samples_per_cluster shuffle_indices = np.arange(30) @@ -58,6 +61,10 @@ def test_birch_predict(global_random_seed): X_shuffle = X[shuffle_indices, :] brc = Birch(n_clusters=4, threshold=1.0) brc.fit(X_shuffle) + + # Birch must preserve inputs' dtype + assert brc.subcluster_centers_.dtype == global_dtype + assert_array_equal(brc.labels_, brc.predict(X_shuffle)) centroids = brc.subcluster_centers_ nearest_centroid = brc.subcluster_labels_[ @@ -66,9 +73,10 @@ def test_birch_predict(global_random_seed): assert_allclose(v_measure_score(nearest_centroid, brc.labels_), 1.0) -def test_n_clusters(global_random_seed): +def test_n_clusters(global_random_seed, global_dtype): # Test that n_clusters param works properly X, y = make_blobs(n_samples=100, centers=10, random_state=global_random_seed) + X = X.astype(global_dtype, copy=False) brc1 = Birch(n_clusters=10) brc1.fit(X) assert len(brc1.subcluster_centers_) > 10 @@ -88,9 +96,10 @@ def test_n_clusters(global_random_seed): brc4.fit(X) -def test_sparse_X(global_random_seed): +def test_sparse_X(global_random_seed, global_dtype): # Test that sparse and dense data give same results X, y = make_blobs(n_samples=100, centers=10, random_state=global_random_seed) + X = X.astype(global_dtype, copy=False) brc = Birch(n_clusters=10) brc.fit(X) @@ -98,6 +107,9 @@ def test_sparse_X(global_random_seed): brc_sparse = Birch(n_clusters=10) brc_sparse.fit(csr) + # Birch must preserve inputs' dtype + assert brc_sparse.subcluster_centers_.dtype == global_dtype + assert_array_equal(brc.labels_, brc_sparse.labels_) assert_allclose(brc.subcluster_centers_, brc_sparse.subcluster_centers_) @@ -122,9 +134,10 @@ def check_branching_factor(node, branching_factor): check_branching_factor(cluster.child_, branching_factor) -def test_branching_factor(global_random_seed): +def test_branching_factor(global_random_seed, global_dtype): # Test that nodes have at max branching_factor number of subclusters X, y = make_blobs(random_state=global_random_seed) + X = X.astype(global_dtype, copy=False) branching_factor = 9 # Purposefully set a low threshold to maximize the subclusters. @@ -146,9 +159,10 @@ def check_threshold(birch_instance, threshold): current_leaf = current_leaf.next_leaf_ -def test_threshold(global_random_seed): +def test_threshold(global_random_seed, global_dtype): # Test that the leaf subclusters have a threshold lesser than radius X, y = make_blobs(n_samples=80, centers=4, random_state=global_random_seed) + X = X.astype(global_dtype, copy=False) brc = Birch(threshold=0.5, n_clusters=None) brc.fit(X) check_threshold(brc, 0.5)