Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Samples with nan distance are included in the computation of mean in KNNImputer for uniform weights #29079

Closed
@xuefeng-xu

Description

@xuefeng-xu

Describe the bug

The toy dataset and the distance computed by nan_euclidean_distances are as follows.

import numpy as np
from sklearn.metrics.pairwise import nan_euclidean_distances
X_train = [[1, 1], [np.nan, 2]]
X_test = [[0, np.nan]]
print(nan_euclidean_distances(X_test, X_train)) # [[1.41421356, nan]]

When weights is set to 'uniform', the second sample in X_train is included. See the code below.
However, when weights is set to 'distance', the second sample in X_train is excluded.

This is because weight_matrix where samples with nan distance are set to 0 when weights is set to 'distance'.

weight_matrix = _get_weights(donors_dist, self.weights)
# fill nans with zeros
if weight_matrix is not None:
weight_matrix[np.isnan(weight_matrix)] = 0.0

To takle this, we could also fill the nans with 0 when weights is set to 'uniform'.

if weight_matrix is None:
    weight_matrix = np.ones_like(donors_dist)
    weight_matrix[np.isnan(donors_dist)] = 0.0

Steps/Code to Reproduce

import numpy as np
from sklearn.impute import KNNImputer
X_train = [[1, 1], [np.nan, 2]]
X_test = [[0, np.nan]]

knn_uniform = KNNImputer(n_neighbors=2, weights='uniform').fit(X_train)
print(knn_uniform.transform(X_test))

knn_distance = KNNImputer(n_neighbors=2, weights='distance').fit(X_train)
print(knn_distance.transform(X_test))

Expected Results

[[0, 1]] # uniform
[[0, 1]] # distance

Actual Results

[[0, 1.5]] # uniform
[[0, 1]]   # distance

Versions

System:
    python: 3.9.16 | packaged by conda-forge | (main, Feb  1 2023, 21:38:11)  [Clang 14.0.6 ]
executable: /Users/xxf/miniconda3/envs/sklearn-env/bin/python
   machine: macOS-14.5-arm64-arm-64bit

Python dependencies:
      sklearn: 1.6.dev0
          pip: 23.2.1
   setuptools: 68.0.0
        numpy: 1.26.4
        scipy: 1.13.0
       Cython: 3.0.8
       pandas: 2.1.0
   matplotlib: 3.7.2
       joblib: 1.3.0
threadpoolctl: 3.5.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 8
         prefix: libopenblas
       filepath: /Users/xxf/miniconda3/envs/sklearn-env/lib/libopenblas.0.dylib
        version: 0.3.23
threading_layer: openmp
   architecture: VORTEX

       user_api: openmp
   internal_api: openmp
    num_threads: 8
         prefix: libomp
       filepath: /Users/xxf/miniconda3/envs/sklearn-env/lib/libomp.dylib
        version: None

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions