From 4a3dd0ff61c424b0eede2024aedb1aa8e5a54fbb Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 24 Jan 2022 10:02:12 +0100 Subject: [PATCH] MAINT Do not compute distances for uniform weighting --- sklearn/neighbors/_classification.py | 17 +++++++++++++++-- sklearn/neighbors/_regression.py | 8 +++++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/sklearn/neighbors/_classification.py b/sklearn/neighbors/_classification.py index 4f84a16211dbd..bcad8c71aee07 100644 --- a/sklearn/neighbors/_classification.py +++ b/sklearn/neighbors/_classification.py @@ -213,7 +213,14 @@ def predict(self, X): y : ndarray of shape (n_queries,) or (n_queries, n_outputs) Class labels for each data sample. """ - neigh_dist, neigh_ind = self.kneighbors(X) + if self.weights == "uniform": + # In that case, we do not need the distances to perform + # the weighting so we do not compute them. + neigh_ind = self.kneighbors(X, return_distance=False) + neigh_dist = None + else: + neigh_dist, neigh_ind = self.kneighbors(X) + classes_ = self.classes_ _y = self._y if not self.outputs_2d_: @@ -255,7 +262,13 @@ def predict_proba(self, X): The class probabilities of the input samples. Classes are ordered by lexicographic order. """ - neigh_dist, neigh_ind = self.kneighbors(X) + if self.weights == "uniform": + # In that case, we do not need the distances to perform + # the weighting so we do not compute them. + neigh_ind = self.kneighbors(X, return_distance=False) + neigh_dist = None + else: + neigh_dist, neigh_ind = self.kneighbors(X) classes_ = self.classes_ _y = self._y diff --git a/sklearn/neighbors/_regression.py b/sklearn/neighbors/_regression.py index 74cecede2efa4..423d01612b514 100644 --- a/sklearn/neighbors/_regression.py +++ b/sklearn/neighbors/_regression.py @@ -228,7 +228,13 @@ def predict(self, X): y : ndarray of shape (n_queries,) or (n_queries, n_outputs), dtype=int Target values. """ - neigh_dist, neigh_ind = self.kneighbors(X) + if self.weights == "uniform": + # In that case, we do not need the distances to perform + # the weighting so we do not compute them. + neigh_ind = self.kneighbors(X, return_distance=False) + neigh_dist = None + else: + neigh_dist, neigh_ind = self.kneighbors(X) weights = _get_weights(neigh_dist, self.weights)