From 8d2ece3937be4c1acd6c777ce28da7b6ff268c69 Mon Sep 17 00:00:00 2001 From: Micky774 Date: Sun, 27 Mar 2022 16:09:42 -0400 Subject: [PATCH 01/16] Added dtype preservation to `Birch` --- sklearn/cluster/_birch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sklearn/cluster/_birch.py b/sklearn/cluster/_birch.py index cfdfeab27b15c..b86ab3ed0f7d3 100644 --- a/sklearn/cluster/_birch.py +++ b/sklearn/cluster/_birch.py @@ -469,6 +469,9 @@ class Birch( array([0, 0, 0, 1, 1, 1]) """ + def _more_tags(self): + return {"preserves_dtype": [np.float64, np.float32]} + def __init__( self, *, @@ -714,9 +717,9 @@ def transform(self, X): Transformed data. """ check_is_fitted(self) - self._validate_data(X, accept_sparse="csr", reset=False) + X = self._validate_data(X, accept_sparse="csr", reset=False) with config_context(assume_finite=True): - return euclidean_distances(X, self.subcluster_centers_) + return euclidean_distances(X, self.subcluster_centers_).astype(X.dtype) def _global_clustering(self, X=None): """ From 7ec1bbba26411bd02a40ee8698b704cee85a591b Mon Sep 17 00:00:00 2001 From: Micky774 Date: Sun, 27 Mar 2022 17:51:51 -0400 Subject: [PATCH 02/16] Improved dtype preservation strategy per review comment --- sklearn/cluster/_birch.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/sklearn/cluster/_birch.py b/sklearn/cluster/_birch.py index b86ab3ed0f7d3..80f430c8b3c64 100644 --- a/sklearn/cluster/_birch.py +++ b/sklearn/cluster/_birch.py @@ -555,7 +555,11 @@ def _fit(self, X, partial): first_call = not (partial and has_root) X = self._validate_data( - X, accept_sparse="csr", copy=self.copy, reset=first_call + X, + accept_sparse="csr", + copy=self.copy, + reset=first_call, + dtype=[np.float64, np.float32], ) threshold = self.threshold branching_factor = self.branching_factor @@ -607,7 +611,9 @@ def _fit(self, X, partial): self.root_.append_subcluster(new_subcluster1) self.root_.append_subcluster(new_subcluster2) - centroids = np.concatenate([leaf.centroids_ for leaf in self._get_leaves()]) + centroids = np.concatenate( + [leaf.centroids_ for leaf in self._get_leaves()], dtype=X.dtype + ) self.subcluster_centers_ = centroids self._n_features_out = self.subcluster_centers_.shape[0] @@ -719,7 +725,7 @@ def transform(self, X): check_is_fitted(self) X = self._validate_data(X, accept_sparse="csr", reset=False) with config_context(assume_finite=True): - return euclidean_distances(X, self.subcluster_centers_).astype(X.dtype) + return euclidean_distances(X, self.subcluster_centers_) def _global_clustering(self, X=None): """ From ced85757ec276afa5d0b4c73d1c8911701295320 Mon Sep 17 00:00:00 2001 From: Micky774 Date: Sun, 27 Mar 2022 19:05:59 -0400 Subject: [PATCH 03/16] Modified `_CFNode` to preserve `dtype` via optional param --- sklearn/cluster/_birch.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/sklearn/cluster/_birch.py b/sklearn/cluster/_birch.py index 80f430c8b3c64..156d355bdc40a 100644 --- a/sklearn/cluster/_birch.py +++ b/sklearn/cluster/_birch.py @@ -147,7 +147,7 @@ class _CFNode: """ - def __init__(self, *, threshold, branching_factor, is_leaf, n_features): + def __init__(self, *, threshold, branching_factor, is_leaf, n_features, dtype=None): self.threshold = threshold self.branching_factor = branching_factor self.is_leaf = is_leaf @@ -156,8 +156,11 @@ def __init__(self, *, threshold, branching_factor, is_leaf, n_features): # The list of subclusters, centroids and squared norms # to manipulate throughout. self.subclusters_ = [] - self.init_centroids_ = np.zeros((branching_factor + 1, n_features)) - self.init_sq_norm_ = np.zeros((branching_factor + 1)) + self.dtype = dtype or np.float64 + self.init_centroids_ = np.zeros( + (branching_factor + 1, n_features), dtype=self.dtype + ) + self.init_sq_norm_ = np.zeros((branching_factor + 1), self.dtype) self.squared_norm_ = [] self.prev_leaf_ = None self.next_leaf_ = None @@ -575,6 +578,7 @@ def _fit(self, X, partial): branching_factor=branching_factor, is_leaf=True, n_features=n_features, + dtype=X.dtype, ) # To enable getting back subclusters. @@ -583,6 +587,7 @@ def _fit(self, X, partial): branching_factor=branching_factor, is_leaf=True, n_features=n_features, + dtype=X.dtype, ) self.dummy_leaf_.next_leaf_ = self.root_ self.root_.prev_leaf_ = self.dummy_leaf_ @@ -607,13 +612,12 @@ def _fit(self, X, partial): branching_factor=branching_factor, is_leaf=False, n_features=n_features, + dtype=X.dtype, ) self.root_.append_subcluster(new_subcluster1) self.root_.append_subcluster(new_subcluster2) - centroids = np.concatenate( - [leaf.centroids_ for leaf in self._get_leaves()], dtype=X.dtype - ) + centroids = np.concatenate([leaf.centroids_ for leaf in self._get_leaves()]) self.subcluster_centers_ = centroids self._n_features_out = self.subcluster_centers_.shape[0] From 7b683f25c453f38ed7709d62423cb97590bd096e Mon Sep 17 00:00:00 2001 From: Micky774 Date: Sun, 27 Mar 2022 19:10:43 -0400 Subject: [PATCH 04/16] Updated `_CFNode` dtype preservation and included `_split_node` --- sklearn/cluster/_birch.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sklearn/cluster/_birch.py b/sklearn/cluster/_birch.py index 156d355bdc40a..93d639443bcf6 100644 --- a/sklearn/cluster/_birch.py +++ b/sklearn/cluster/_birch.py @@ -43,7 +43,7 @@ def _iterate_sparse_X(X): yield row -def _split_node(node, threshold, branching_factor): +def _split_node(node, threshold, branching_factor, dtype=None): """The node has to be split if there is no place for a new subcluster in the node. 1. Two empty nodes and two empty subclusters are initialized. @@ -60,12 +60,14 @@ def _split_node(node, threshold, branching_factor): branching_factor=branching_factor, is_leaf=node.is_leaf, n_features=node.n_features, + dtype=dtype, ) new_node2 = _CFNode( threshold=threshold, branching_factor=branching_factor, is_leaf=node.is_leaf, n_features=node.n_features, + dtype=dtype, ) new_subcluster1.child_ = new_node1 new_subcluster2.child_ = new_node2 @@ -224,7 +226,10 @@ def insert_cf_subcluster(self, subcluster): # subcluster to accommodate the new child. else: new_subcluster1, new_subcluster2 = _split_node( - closest_subcluster.child_, threshold, branching_factor + closest_subcluster.child_, + threshold, + branching_factor, + dtype=self.dtype, ) self.update_split_subclusters( closest_subcluster, new_subcluster1, new_subcluster2 From 2322e60740710a06dfa873b71f5e9b6437ae89e9 Mon Sep 17 00:00:00 2001 From: Micky774 Date: Fri, 1 Apr 2022 13:11:38 -0400 Subject: [PATCH 05/16] Improved `dtype` specification --- sklearn/cluster/_birch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/cluster/_birch.py b/sklearn/cluster/_birch.py index 93d639443bcf6..eb198918e97e5 100644 --- a/sklearn/cluster/_birch.py +++ b/sklearn/cluster/_birch.py @@ -60,14 +60,14 @@ def _split_node(node, threshold, branching_factor, dtype=None): branching_factor=branching_factor, is_leaf=node.is_leaf, n_features=node.n_features, - dtype=dtype, + dtype=node.dtype, ) new_node2 = _CFNode( threshold=threshold, branching_factor=branching_factor, is_leaf=node.is_leaf, n_features=node.n_features, - dtype=dtype, + dtype=node.dtype, ) new_subcluster1.child_ = new_node1 new_subcluster2.child_ = new_node2 @@ -149,7 +149,7 @@ class _CFNode: """ - def __init__(self, *, threshold, branching_factor, is_leaf, n_features, dtype=None): + def __init__(self, *, threshold, branching_factor, is_leaf, n_features, dtype): self.threshold = threshold self.branching_factor = branching_factor self.is_leaf = is_leaf From 7f9477c07a4782b8b5beb7fdf9199b762e098b3f Mon Sep 17 00:00:00 2001 From: Meekail Zain <34613774+Micky774@users.noreply.github.com> Date: Sat, 2 Apr 2022 13:40:45 -0400 Subject: [PATCH 06/16] Apply suggestions from code review Co-authored-by: Thomas J. Fan --- sklearn/cluster/_birch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/cluster/_birch.py b/sklearn/cluster/_birch.py index eb198918e97e5..422c9501231c3 100644 --- a/sklearn/cluster/_birch.py +++ b/sklearn/cluster/_birch.py @@ -43,7 +43,7 @@ def _iterate_sparse_X(X): yield row -def _split_node(node, threshold, branching_factor, dtype=None): +def _split_node(node, threshold, branching_factor): """The node has to be split if there is no place for a new subcluster in the node. 1. Two empty nodes and two empty subclusters are initialized. @@ -158,7 +158,7 @@ def __init__(self, *, threshold, branching_factor, is_leaf, n_features, dtype): # The list of subclusters, centroids and squared norms # to manipulate throughout. self.subclusters_ = [] - self.dtype = dtype or np.float64 + self.dtype = dtype self.init_centroids_ = np.zeros( (branching_factor + 1, n_features), dtype=self.dtype ) From 5d3df253b2dbbc59d74c9d64376fb4f498834dc2 Mon Sep 17 00:00:00 2001 From: Micky774 Date: Sat, 2 Apr 2022 13:45:14 -0400 Subject: [PATCH 07/16] Added changelog entry --- doc/whats_new/v1.1.rst | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 6c03456bcaf0c..471d74b343ea6 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -186,6 +186,9 @@ Changelog `-1` and the original warning message is shown. :pr:`22217` by :user:`Meekail Zain `. +- |Enhancement| :class:`cluster.Birch` now preserves dtype for `numpy.float32` + inputs. :pr:`22968` by `Meekail Zain `. + :mod:`sklearn.compose` ...................... @@ -284,7 +287,7 @@ Changelog deprecated. - the default value of the `batch_size` parameter of both will change from 3 to 256 in version 1.3. - + :pr:`18975` by :user:`Jérémie du Boisberranger `. - |Enhancement| :func:`decomposition.dict_learning`, :func:`decomposition.dict_learning_online` From 6affe7c8880f7581219b0f8ba7b2190c546dd350 Mon Sep 17 00:00:00 2001 From: Micky774 Date: Sat, 2 Apr 2022 14:42:31 -0400 Subject: [PATCH 08/16] Removed old argument in call to `_split_node` --- sklearn/cluster/_birch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/cluster/_birch.py b/sklearn/cluster/_birch.py index 422c9501231c3..b08264c36319f 100644 --- a/sklearn/cluster/_birch.py +++ b/sklearn/cluster/_birch.py @@ -229,7 +229,6 @@ def insert_cf_subcluster(self, subcluster): closest_subcluster.child_, threshold, branching_factor, - dtype=self.dtype, ) self.update_split_subclusters( closest_subcluster, new_subcluster1, new_subcluster2 From 037c5cf4d25b13650bb28ebc414fc2da2fdcf68d Mon Sep 17 00:00:00 2001 From: Micky774 Date: Sun, 3 Apr 2022 19:31:31 -0400 Subject: [PATCH 09/16] Removed `_CFNode.dtype` - Replaced `_CFNode.dtype` with `_CFNode.init_centroids.dtype` --- sklearn/cluster/_birch.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/sklearn/cluster/_birch.py b/sklearn/cluster/_birch.py index b08264c36319f..77ae75ffedf35 100644 --- a/sklearn/cluster/_birch.py +++ b/sklearn/cluster/_birch.py @@ -60,14 +60,14 @@ def _split_node(node, threshold, branching_factor): branching_factor=branching_factor, is_leaf=node.is_leaf, n_features=node.n_features, - dtype=node.dtype, + dtype=node.init_centroids_.dtype, ) new_node2 = _CFNode( threshold=threshold, branching_factor=branching_factor, is_leaf=node.is_leaf, n_features=node.n_features, - dtype=node.dtype, + dtype=node.init_centroids_.dtype, ) new_subcluster1.child_ = new_node1 new_subcluster2.child_ = new_node2 @@ -158,11 +158,8 @@ def __init__(self, *, threshold, branching_factor, is_leaf, n_features, dtype): # The list of subclusters, centroids and squared norms # to manipulate throughout. self.subclusters_ = [] - self.dtype = dtype - self.init_centroids_ = np.zeros( - (branching_factor + 1, n_features), dtype=self.dtype - ) - self.init_sq_norm_ = np.zeros((branching_factor + 1), self.dtype) + self.init_centroids_ = np.zeros((branching_factor + 1, n_features), dtype=dtype) + self.init_sq_norm_ = np.zeros((branching_factor + 1), dtype) self.squared_norm_ = [] self.prev_leaf_ = None self.next_leaf_ = None From d2382864da96e410ecefea18a135315f1aa4ff3f Mon Sep 17 00:00:00 2001 From: Micky774 Date: Thu, 7 Apr 2022 13:53:29 -0400 Subject: [PATCH 10/16] Moved `_more_tags` method --- sklearn/cluster/_birch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/cluster/_birch.py b/sklearn/cluster/_birch.py index 77ae75ffedf35..2bfdd2971e4d4 100644 --- a/sklearn/cluster/_birch.py +++ b/sklearn/cluster/_birch.py @@ -473,9 +473,6 @@ class Birch( array([0, 0, 0, 1, 1, 1]) """ - def _more_tags(self): - return {"preserves_dtype": [np.float64, np.float32]} - def __init__( self, *, @@ -772,3 +769,6 @@ def _global_clustering(self, X=None): if compute_labels: self.labels_ = self._predict(X) + + def _more_tags(self): + return {"preserves_dtype": [np.float64, np.float32]} From 1f50edbf1f74169cb3f078686b2536fb77d58312 Mon Sep 17 00:00:00 2001 From: Micky774 Date: Fri, 8 Apr 2022 20:00:52 -0400 Subject: [PATCH 11/16] Added additional tests --- sklearn/cluster/tests/test_birch.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sklearn/cluster/tests/test_birch.py b/sklearn/cluster/tests/test_birch.py index e0051704653ae..53e86871b432c 100644 --- a/sklearn/cluster/tests/test_birch.py +++ b/sklearn/cluster/tests/test_birch.py @@ -17,6 +17,7 @@ from sklearn.utils._testing import assert_almost_equal from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import assert_array_almost_equal +from sklearn.utils._testing import np_assert_allclose def test_n_samples_leaves_roots(): @@ -228,3 +229,19 @@ def test_feature_names_out(): names_out = brc.get_feature_names_out() assert_array_equal([f"birch{i}" for i in range(n_clusters)], names_out) + + +def test_transform_match_across_dtypes(): + X, _ = make_blobs(n_samples=80, n_features=4, random_state=0) + brc = Birch(n_clusters=4) + Y_64 = brc.fit_transform(X) + Y_32 = brc.fit_transform(X.astype(np.float32)) + + np_assert_allclose(Y_64, Y_32, rtol=1e-4) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_subcluster_dtype(dtype): + X = make_blobs(n_samples=80, n_features=4, random_state=0)[0].astype(dtype) + brc = Birch(n_clusters=4) + assert brc.fit(X).subcluster_centers_.dtype == dtype From 17c0c7d72d9db8c5230348ea05dcee4d1e852aec Mon Sep 17 00:00:00 2001 From: Micky774 Date: Fri, 8 Apr 2022 20:14:53 -0400 Subject: [PATCH 12/16] Adjusted test absolute tolerance --- sklearn/cluster/tests/test_birch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/cluster/tests/test_birch.py b/sklearn/cluster/tests/test_birch.py index 53e86871b432c..661728b34cbe6 100644 --- a/sklearn/cluster/tests/test_birch.py +++ b/sklearn/cluster/tests/test_birch.py @@ -237,7 +237,7 @@ def test_transform_match_across_dtypes(): Y_64 = brc.fit_transform(X) Y_32 = brc.fit_transform(X.astype(np.float32)) - np_assert_allclose(Y_64, Y_32, rtol=1e-4) + np_assert_allclose(Y_64, Y_32, rtol=1e-4, atol=1e-6) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) From a1e34004d7fde7365ee73592b09b5a248cf0d11e Mon Sep 17 00:00:00 2001 From: Micky774 Date: Thu, 14 Apr 2022 16:52:55 -0400 Subject: [PATCH 13/16] Improved tests --- sklearn/cluster/tests/test_birch.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/sklearn/cluster/tests/test_birch.py b/sklearn/cluster/tests/test_birch.py index 661728b34cbe6..86078bf24a9da 100644 --- a/sklearn/cluster/tests/test_birch.py +++ b/sklearn/cluster/tests/test_birch.py @@ -17,7 +17,7 @@ from sklearn.utils._testing import assert_almost_equal from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import assert_array_almost_equal -from sklearn.utils._testing import np_assert_allclose +from sklearn.utils._testing import assert_allclose def test_n_samples_leaves_roots(): @@ -237,11 +237,10 @@ def test_transform_match_across_dtypes(): Y_64 = brc.fit_transform(X) Y_32 = brc.fit_transform(X.astype(np.float32)) - np_assert_allclose(Y_64, Y_32, rtol=1e-4, atol=1e-6) + assert_allclose(Y_64, Y_32, atol=1e-6) -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_subcluster_dtype(dtype): - X = make_blobs(n_samples=80, n_features=4, random_state=0)[0].astype(dtype) +def test_subcluster_dtype(global_dtype): + X = make_blobs(n_samples=80, n_features=4, random_state=0)[0].astype(global_dtype) brc = Birch(n_clusters=4) - assert brc.fit(X).subcluster_centers_.dtype == dtype + assert brc.fit(X).subcluster_centers_.dtype == global_dtype From 02ff1a3e0c3304be5957f636fef56b67d1ea8d62 Mon Sep 17 00:00:00 2001 From: Meekail Zain <34613774+Micky774@users.noreply.github.com> Date: Fri, 29 Apr 2022 18:01:44 -0400 Subject: [PATCH 14/16] Update sklearn/cluster/tests/test_birch.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com> --- sklearn/cluster/tests/test_birch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/cluster/tests/test_birch.py b/sklearn/cluster/tests/test_birch.py index 86078bf24a9da..6f8da8ef61ad1 100644 --- a/sklearn/cluster/tests/test_birch.py +++ b/sklearn/cluster/tests/test_birch.py @@ -241,6 +241,6 @@ def test_transform_match_across_dtypes(): def test_subcluster_dtype(global_dtype): - X = make_blobs(n_samples=80, n_features=4, random_state=0)[0].astype(global_dtype) + X = make_blobs(n_samples=80, n_features=4, random_state=0)[0].astype(global_dtype, copy=False) brc = Birch(n_clusters=4) assert brc.fit(X).subcluster_centers_.dtype == global_dtype From d45bb0807ae902cb9d304ad8b07d6940ff88d302 Mon Sep 17 00:00:00 2001 From: Micky774 Date: Fri, 29 Apr 2022 18:05:04 -0400 Subject: [PATCH 15/16] Linting --- sklearn/cluster/tests/test_birch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/cluster/tests/test_birch.py b/sklearn/cluster/tests/test_birch.py index 6f8da8ef61ad1..c5d88c2bc6f0e 100644 --- a/sklearn/cluster/tests/test_birch.py +++ b/sklearn/cluster/tests/test_birch.py @@ -241,6 +241,8 @@ def test_transform_match_across_dtypes(): def test_subcluster_dtype(global_dtype): - X = make_blobs(n_samples=80, n_features=4, random_state=0)[0].astype(global_dtype, copy=False) + X = make_blobs(n_samples=80, n_features=4, random_state=0)[0].astype( + global_dtype, copy=False + ) brc = Birch(n_clusters=4) assert brc.fit(X).subcluster_centers_.dtype == global_dtype From 81028c9031fb8b63ec9dcef1d37036ace23f21f2 Mon Sep 17 00:00:00 2001 From: jeremiedbb Date: Mon, 2 May 2022 11:41:23 +0200 Subject: [PATCH 16/16] fix position in changelog --- doc/whats_new/v1.2.rst | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index e5aabf7eb95d0..5f45cdf6a7168 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -41,6 +41,9 @@ Changelog :pr:`20802` by :user:`Brandon Pokorny `, and :pr:`22965` by :user:`Meekail Zain `. +- |Enhancement| :class:`cluster.Birch` now preserves dtype for `numpy.float32` + inputs. :pr:`22968` by `Meekail Zain `. + Code and Documentation Contributors ----------------------------------- @@ -48,8 +51,3 @@ Thanks to everyone who has contributed to the maintenance and improvement of the project since version 1.1, including: TODO: update at the time of the release. - -:mod:`sklearn.cluster` -...................... -- |Enhancement| :class:`cluster.Birch` now preserves dtype for `numpy.float32` - inputs. :pr:`22968` by `Meekail Zain `.