diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index e3c5d46bc8e2a..5f45cdf6a7168 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -41,6 +41,9 @@ Changelog :pr:`20802` by :user:`Brandon Pokorny `, and :pr:`22965` by :user:`Meekail Zain `. +- |Enhancement| :class:`cluster.Birch` now preserves dtype for `numpy.float32` + inputs. :pr:`22968` by `Meekail Zain `. + Code and Documentation Contributors ----------------------------------- diff --git a/sklearn/cluster/_birch.py b/sklearn/cluster/_birch.py index cfdfeab27b15c..2bfdd2971e4d4 100644 --- a/sklearn/cluster/_birch.py +++ b/sklearn/cluster/_birch.py @@ -60,12 +60,14 @@ def _split_node(node, threshold, branching_factor): branching_factor=branching_factor, is_leaf=node.is_leaf, n_features=node.n_features, + dtype=node.init_centroids_.dtype, ) new_node2 = _CFNode( threshold=threshold, branching_factor=branching_factor, is_leaf=node.is_leaf, n_features=node.n_features, + dtype=node.init_centroids_.dtype, ) new_subcluster1.child_ = new_node1 new_subcluster2.child_ = new_node2 @@ -147,7 +149,7 @@ class _CFNode: """ - def __init__(self, *, threshold, branching_factor, is_leaf, n_features): + def __init__(self, *, threshold, branching_factor, is_leaf, n_features, dtype): self.threshold = threshold self.branching_factor = branching_factor self.is_leaf = is_leaf @@ -156,8 +158,8 @@ def __init__(self, *, threshold, branching_factor, is_leaf, n_features): # The list of subclusters, centroids and squared norms # to manipulate throughout. self.subclusters_ = [] - self.init_centroids_ = np.zeros((branching_factor + 1, n_features)) - self.init_sq_norm_ = np.zeros((branching_factor + 1)) + self.init_centroids_ = np.zeros((branching_factor + 1, n_features), dtype=dtype) + self.init_sq_norm_ = np.zeros((branching_factor + 1), dtype) self.squared_norm_ = [] self.prev_leaf_ = None self.next_leaf_ = None @@ -221,7 +223,9 @@ def insert_cf_subcluster(self, subcluster): # subcluster to accommodate the new child. else: new_subcluster1, new_subcluster2 = _split_node( - closest_subcluster.child_, threshold, branching_factor + closest_subcluster.child_, + threshold, + branching_factor, ) self.update_split_subclusters( closest_subcluster, new_subcluster1, new_subcluster2 @@ -552,7 +556,11 @@ def _fit(self, X, partial): first_call = not (partial and has_root) X = self._validate_data( - X, accept_sparse="csr", copy=self.copy, reset=first_call + X, + accept_sparse="csr", + copy=self.copy, + reset=first_call, + dtype=[np.float64, np.float32], ) threshold = self.threshold branching_factor = self.branching_factor @@ -568,6 +576,7 @@ def _fit(self, X, partial): branching_factor=branching_factor, is_leaf=True, n_features=n_features, + dtype=X.dtype, ) # To enable getting back subclusters. @@ -576,6 +585,7 @@ def _fit(self, X, partial): branching_factor=branching_factor, is_leaf=True, n_features=n_features, + dtype=X.dtype, ) self.dummy_leaf_.next_leaf_ = self.root_ self.root_.prev_leaf_ = self.dummy_leaf_ @@ -600,6 +610,7 @@ def _fit(self, X, partial): branching_factor=branching_factor, is_leaf=False, n_features=n_features, + dtype=X.dtype, ) self.root_.append_subcluster(new_subcluster1) self.root_.append_subcluster(new_subcluster2) @@ -714,7 +725,7 @@ def transform(self, X): Transformed data. """ check_is_fitted(self) - self._validate_data(X, accept_sparse="csr", reset=False) + X = self._validate_data(X, accept_sparse="csr", reset=False) with config_context(assume_finite=True): return euclidean_distances(X, self.subcluster_centers_) @@ -758,3 +769,6 @@ def _global_clustering(self, X=None): if compute_labels: self.labels_ = self._predict(X) + + def _more_tags(self): + return {"preserves_dtype": [np.float64, np.float32]} diff --git a/sklearn/cluster/tests/test_birch.py b/sklearn/cluster/tests/test_birch.py index e0051704653ae..c5d88c2bc6f0e 100644 --- a/sklearn/cluster/tests/test_birch.py +++ b/sklearn/cluster/tests/test_birch.py @@ -17,6 +17,7 @@ from sklearn.utils._testing import assert_almost_equal from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import assert_array_almost_equal +from sklearn.utils._testing import assert_allclose def test_n_samples_leaves_roots(): @@ -228,3 +229,20 @@ def test_feature_names_out(): names_out = brc.get_feature_names_out() assert_array_equal([f"birch{i}" for i in range(n_clusters)], names_out) + + +def test_transform_match_across_dtypes(): + X, _ = make_blobs(n_samples=80, n_features=4, random_state=0) + brc = Birch(n_clusters=4) + Y_64 = brc.fit_transform(X) + Y_32 = brc.fit_transform(X.astype(np.float32)) + + assert_allclose(Y_64, Y_32, atol=1e-6) + + +def test_subcluster_dtype(global_dtype): + X = make_blobs(n_samples=80, n_features=4, random_state=0)[0].astype( + global_dtype, copy=False + ) + brc = Birch(n_clusters=4) + assert brc.fit(X).subcluster_centers_.dtype == global_dtype