Closed
Description
Describe the bug
The toy dataset and the distance computed by nan_euclidean_distances
are as follows.
import numpy as np
from sklearn.metrics.pairwise import nan_euclidean_distances
X_train = [[1, 1], [np.nan, 2]]
X_test = [[0, np.nan]]
print(nan_euclidean_distances(X_test, X_train)) # [[1.41421356, nan]]
When weights
is set to 'uniform', the second sample in X_train
is included. See the code below.
However, when weights
is set to 'distance', the second sample in X_train
is excluded.
This is because weight_matrix
where samples with nan distance are set to 0 when weights
is set to 'distance'.
scikit-learn/sklearn/impute/_knn.py
Lines 193 to 197 in 6614f75
To takle this, we could also fill the nans with 0 when weights
is set to 'uniform'.
if weight_matrix is None:
weight_matrix = np.ones_like(donors_dist)
weight_matrix[np.isnan(donors_dist)] = 0.0
Steps/Code to Reproduce
import numpy as np
from sklearn.impute import KNNImputer
X_train = [[1, 1], [np.nan, 2]]
X_test = [[0, np.nan]]
knn_uniform = KNNImputer(n_neighbors=2, weights='uniform').fit(X_train)
print(knn_uniform.transform(X_test))
knn_distance = KNNImputer(n_neighbors=2, weights='distance').fit(X_train)
print(knn_distance.transform(X_test))
Expected Results
[[0, 1]] # uniform
[[0, 1]] # distance
Actual Results
[[0, 1.5]] # uniform
[[0, 1]] # distance
Versions
System:
python: 3.9.16 | packaged by conda-forge | (main, Feb 1 2023, 21:38:11) [Clang 14.0.6 ]
executable: /Users/xxf/miniconda3/envs/sklearn-env/bin/python
machine: macOS-14.5-arm64-arm-64bit
Python dependencies:
sklearn: 1.6.dev0
pip: 23.2.1
setuptools: 68.0.0
numpy: 1.26.4
scipy: 1.13.0
Cython: 3.0.8
pandas: 2.1.0
matplotlib: 3.7.2
joblib: 1.3.0
threadpoolctl: 3.5.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
num_threads: 8
prefix: libopenblas
filepath: /Users/xxf/miniconda3/envs/sklearn-env/lib/libopenblas.0.dylib
version: 0.3.23
threading_layer: openmp
architecture: VORTEX
user_api: openmp
internal_api: openmp
num_threads: 8
prefix: libomp
filepath: /Users/xxf/miniconda3/envs/sklearn-env/lib/libomp.dylib
version: None