diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 154c32617c4ba..78dc95026c45e 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -70,6 +70,9 @@ Changelog in multicore settings. :pr:`19052` by :user:`Yusuke Nagasaka `. +- |Fix| Fixes incorrect multiple data-conversion warnings when clustering + boolean data. :pr:`19046` by :user:`Surya Prakash `. + :mod:`sklearn.linear_model` ........................... diff --git a/sklearn/cluster/_optics.py b/sklearn/cluster/_optics.py index 8998963704562..11893dbd70520 100755 --- a/sklearn/cluster/_optics.py +++ b/sklearn/cluster/_optics.py @@ -14,6 +14,8 @@ import warnings import numpy as np +from ..exceptions import DataConversionWarning +from ..metrics.pairwise import PAIRWISE_BOOLEAN_FUNCTIONS from ..utils import gen_batches, get_chunk_n_rows from ..utils.validation import _deprecate_positional_args from ..neighbors import NearestNeighbors @@ -243,7 +245,15 @@ def fit(self, X, y=None): self : instance of OPTICS The instance. """ - X = self._validate_data(X, dtype=float) + + dtype = bool if self.metric in PAIRWISE_BOOLEAN_FUNCTIONS else float + if dtype == bool and X.dtype != bool: + msg = (f"Data will be converted to boolean for" + f" metric {self.metric}, to avoid this warning," + f" you may convert the data prior to calling fit.") + warnings.warn(msg, DataConversionWarning) + + X = self._validate_data(X, dtype=dtype) if self.cluster_method not in ['dbscan', 'xi']: raise ValueError("cluster_method should be one of" diff --git a/sklearn/cluster/tests/test_optics.py b/sklearn/cluster/tests/test_optics.py index 4428b6c00d7eb..d5b30256d4943 100644 --- a/sklearn/cluster/tests/test_optics.py +++ b/sklearn/cluster/tests/test_optics.py @@ -10,6 +10,7 @@ from sklearn.datasets import make_blobs from sklearn.cluster import OPTICS from sklearn.cluster._optics import _extend_region, _extract_xi_labels +from sklearn.exceptions import DataConversionWarning from sklearn.metrics.cluster import contingency_matrix from sklearn.metrics.pairwise import pairwise_distances from sklearn.cluster import DBSCAN @@ -213,6 +214,49 @@ def test_bad_reachability(): clust.fit(X) +def test_nowarn_if_metric_bool_data_bool(): + # make sure no warning is raised if metric and data are both boolean + # non-regression test for + # https://github.com/scikit-learn/scikit-learn/issues/18996 + + pairwise_metric = 'rogerstanimoto' + X = np.random.randint(2, size=(5, 2), dtype=np.bool) + + with pytest.warns(None) as warn_record: + OPTICS(metric=pairwise_metric).fit(X) + assert len(warn_record) == 0 + + +def test_warn_if_metric_bool_data_no_bool(): + # make sure a *single* conversion warning is raised if metric is boolean + # but data isn't + # non-regression test for + # https://github.com/scikit-learn/scikit-learn/issues/18996 + + pairwise_metric = 'rogerstanimoto' + X = np.random.randint(2, size=(5, 2), dtype=np.int) + msg = f"Data will be converted to boolean for metric {pairwise_metric}" + + with pytest.warns(DataConversionWarning, match=msg) as warn_record: + OPTICS(metric=pairwise_metric).fit(X) + assert len(warn_record) == 1 + + +def test_nowarn_if_metric_no_bool(): + # make sure no conversion warning is raised if + # metric isn't boolean, no matter what the data type is + pairwise_metric = 'minkowski' + X_bool = np.random.randint(2, size=(5, 2), dtype=np.bool) + X_num = np.random.randint(2, size=(5, 2), dtype=np.int) + + with pytest.warns(None) as warn_record: + # fit boolean data + OPTICS(metric=pairwise_metric).fit(X_bool) + # fit numeric data + OPTICS(metric=pairwise_metric).fit(X_num) + assert len(warn_record) == 0 + + def test_close_extract(): # Test extract where extraction eps is close to scaled max_eps