diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 6c0e7f559429e..cb9246f11c4fc 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -456,6 +456,10 @@ Changelog :pr:`10468` by :user:`Ruben ` and :pr:`22993` by :user:`Jovan Stojanovic `. +- |Efficiency| :class:`neighbors.NearestCentroid` is faster and requires + less memory as it better leverages CPUs' caches to compute predictions. + :pr:`24645` by :user:`Olivier Grisel `. + - |Feature| Adds new function :func:`neighbors.sort_graph_by_row_values` to sort a CSR sparse graph such that each row is stored with increasing values. This is useful to improve efficiency when using precomputed sparse distance diff --git a/sklearn/neighbors/_nearest_centroid.py b/sklearn/neighbors/_nearest_centroid.py index 653662350b38f..4e5ce354cc257 100644 --- a/sklearn/neighbors/_nearest_centroid.py +++ b/sklearn/neighbors/_nearest_centroid.py @@ -13,7 +13,7 @@ from scipy import sparse as sp from ..base import BaseEstimator, ClassifierMixin -from ..metrics.pairwise import pairwise_distances +from ..metrics.pairwise import pairwise_distances_argmin from ..preprocessing import LabelEncoder from ..utils.validation import check_is_fitted from ..utils.sparsefuncs import csc_median_axis_0 @@ -234,5 +234,5 @@ def predict(self, X): X = self._validate_data(X, accept_sparse="csr", reset=False) return self.classes_[ - pairwise_distances(X, self.centroids_, metric=self.metric).argmin(axis=1) + pairwise_distances_argmin(X, self.centroids_, metric=self.metric) ]