Thanks to visit codestin.com
Credit goes to github.com

Skip to content

ENH Added dtype preservation to Birch #22968

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
May 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8d2ece3
Added dtype preservation to `Birch`
Micky774 Mar 27, 2022
7ec1bbb
Improved dtype preservation strategy per review comment
Micky774 Mar 27, 2022
ced8575
Modified `_CFNode` to preserve `dtype` via optional param
Micky774 Mar 27, 2022
7b683f2
Updated `_CFNode` dtype preservation and included `_split_node`
Micky774 Mar 27, 2022
2322e60
Improved `dtype` specification
Micky774 Apr 1, 2022
3817855
Merge branch 'main' into pdtype_birch
Micky774 Apr 1, 2022
7f9477c
Apply suggestions from code review
Micky774 Apr 2, 2022
5d3df25
Added changelog entry
Micky774 Apr 2, 2022
5057740
Merge branch 'main' into pdtype_birch
Micky774 Apr 2, 2022
6affe7c
Removed old argument in call to `_split_node`
Micky774 Apr 2, 2022
037c5cf
Removed `_CFNode.dtype`
Micky774 Apr 3, 2022
d238286
Moved `_more_tags` method
Micky774 Apr 7, 2022
62b4668
Merge branch 'main' into pdtype_birch
Micky774 Apr 8, 2022
1f50edb
Added additional tests
Micky774 Apr 9, 2022
17c0c7d
Adjusted test absolute tolerance
Micky774 Apr 9, 2022
a1e3400
Improved tests
Micky774 Apr 14, 2022
ca165cd
Merge branch 'main' into pdtype_birch
Micky774 Apr 14, 2022
794e3c9
Merge branch 'main' into pdtype_birch
Micky774 Apr 16, 2022
a2bfbd0
Merge branch 'main' into pdtype_birch
Micky774 Apr 22, 2022
02ff1a3
Update sklearn/cluster/tests/test_birch.py
Micky774 Apr 29, 2022
199a3f2
Merge branch 'main' into pdtype_birch
Micky774 Apr 29, 2022
9acacb7
Merge branch 'pdtype_birch' of https://github.com/Micky774/scikit-lea…
Micky774 Apr 29, 2022
d45bb08
Linting
Micky774 Apr 29, 2022
227211b
Merge remote-tracking branch 'upstream/main' into pr/Micky774/22968
jeremiedbb May 2, 2022
81028c9
fix position in changelog
jeremiedbb May 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats_new/v1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ Changelog
:pr:`20802` by :user:`Brandon Pokorny <Clickedbigfoot>`,
and :pr:`22965` by :user:`Meekail Zain <micky774>`.

- |Enhancement| :class:`cluster.Birch` now preserves dtype for `numpy.float32`
inputs. :pr:`22968` by `Meekail Zain <micky774>`.

Code and Documentation Contributors
-----------------------------------

Expand Down
26 changes: 20 additions & 6 deletions sklearn/cluster/_birch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ def _split_node(node, threshold, branching_factor):
branching_factor=branching_factor,
is_leaf=node.is_leaf,
n_features=node.n_features,
dtype=node.init_centroids_.dtype,
)
new_node2 = _CFNode(
threshold=threshold,
branching_factor=branching_factor,
is_leaf=node.is_leaf,
n_features=node.n_features,
dtype=node.init_centroids_.dtype,
)
new_subcluster1.child_ = new_node1
new_subcluster2.child_ = new_node2
Expand Down Expand Up @@ -147,7 +149,7 @@ class _CFNode:

"""

def __init__(self, *, threshold, branching_factor, is_leaf, n_features):
def __init__(self, *, threshold, branching_factor, is_leaf, n_features, dtype):
self.threshold = threshold
self.branching_factor = branching_factor
self.is_leaf = is_leaf
Expand All @@ -156,8 +158,8 @@ def __init__(self, *, threshold, branching_factor, is_leaf, n_features):
# The list of subclusters, centroids and squared norms
# to manipulate throughout.
self.subclusters_ = []
self.init_centroids_ = np.zeros((branching_factor + 1, n_features))
self.init_sq_norm_ = np.zeros((branching_factor + 1))
self.init_centroids_ = np.zeros((branching_factor + 1, n_features), dtype=dtype)
self.init_sq_norm_ = np.zeros((branching_factor + 1), dtype)
self.squared_norm_ = []
self.prev_leaf_ = None
self.next_leaf_ = None
Expand Down Expand Up @@ -221,7 +223,9 @@ def insert_cf_subcluster(self, subcluster):
# subcluster to accommodate the new child.
else:
new_subcluster1, new_subcluster2 = _split_node(
closest_subcluster.child_, threshold, branching_factor
closest_subcluster.child_,
threshold,
branching_factor,
)
self.update_split_subclusters(
closest_subcluster, new_subcluster1, new_subcluster2
Expand Down Expand Up @@ -552,7 +556,11 @@ def _fit(self, X, partial):
first_call = not (partial and has_root)

X = self._validate_data(
X, accept_sparse="csr", copy=self.copy, reset=first_call
X,
accept_sparse="csr",
copy=self.copy,
reset=first_call,
dtype=[np.float64, np.float32],
)
threshold = self.threshold
branching_factor = self.branching_factor
Expand All @@ -568,6 +576,7 @@ def _fit(self, X, partial):
branching_factor=branching_factor,
is_leaf=True,
n_features=n_features,
dtype=X.dtype,
)

# To enable getting back subclusters.
Expand All @@ -576,6 +585,7 @@ def _fit(self, X, partial):
branching_factor=branching_factor,
is_leaf=True,
n_features=n_features,
dtype=X.dtype,
)
self.dummy_leaf_.next_leaf_ = self.root_
self.root_.prev_leaf_ = self.dummy_leaf_
Expand All @@ -600,6 +610,7 @@ def _fit(self, X, partial):
branching_factor=branching_factor,
is_leaf=False,
n_features=n_features,
dtype=X.dtype,
)
self.root_.append_subcluster(new_subcluster1)
self.root_.append_subcluster(new_subcluster2)
Expand Down Expand Up @@ -714,7 +725,7 @@ def transform(self, X):
Transformed data.
"""
check_is_fitted(self)
self._validate_data(X, accept_sparse="csr", reset=False)
X = self._validate_data(X, accept_sparse="csr", reset=False)
with config_context(assume_finite=True):
return euclidean_distances(X, self.subcluster_centers_)

Expand Down Expand Up @@ -758,3 +769,6 @@ def _global_clustering(self, X=None):

if compute_labels:
self.labels_ = self._predict(X)

def _more_tags(self):
return {"preserves_dtype": [np.float64, np.float32]}
18 changes: 18 additions & 0 deletions sklearn/cluster/tests/test_birch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sklearn.utils._testing import assert_almost_equal
from sklearn.utils._testing import assert_array_equal
from sklearn.utils._testing import assert_array_almost_equal
from sklearn.utils._testing import assert_allclose


def test_n_samples_leaves_roots():
Expand Down Expand Up @@ -228,3 +229,20 @@ def test_feature_names_out():

names_out = brc.get_feature_names_out()
assert_array_equal([f"birch{i}" for i in range(n_clusters)], names_out)


def test_transform_match_across_dtypes():
X, _ = make_blobs(n_samples=80, n_features=4, random_state=0)
brc = Birch(n_clusters=4)
Y_64 = brc.fit_transform(X)
Y_32 = brc.fit_transform(X.astype(np.float32))

assert_allclose(Y_64, Y_32, atol=1e-6)


def test_subcluster_dtype(global_dtype):
X = make_blobs(n_samples=80, n_features=4, random_state=0)[0].astype(
global_dtype, copy=False
)
brc = Birch(n_clusters=4)
assert brc.fit(X).subcluster_centers_.dtype == global_dtype