Thanks to visit codestin.com
Credit goes to github.com

Skip to content

MemoryError in KNNImputer with california housing #15604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
glemaitre opened this issue Nov 12, 2019 · 22 comments · Fixed by #16397
Closed

MemoryError in KNNImputer with california housing #15604

glemaitre opened this issue Nov 12, 2019 · 22 comments · Fixed by #16397
Assignees

Comments

@glemaitre
Copy link
Member

glemaitre commented Nov 12, 2019

I was doing a simple example with california housing and the KNNImputer blow up into my face:

import numpy as np
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.impute import KNNImputer
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import make_pipeline

calhousing = fetch_california_housing()

X = pd.DataFrame(calhousing.data, columns=calhousing.feature_names)
y = pd.Series(calhousing.target, name='house_value')

rng = np.random.RandomState(42)

density = 4  # one in 10 values will be NaN

mask = rng.randint(density, size=X.shape) == 0
X_na = X.copy()
X_na.values[mask] = np.nan
X_na.head()

X_train_na, X_test_na, y_train_na, y_test_na = train_test_split(
    X_na[y<4.9], y[y<4.9], test_size=1000, random_state=0)

model = make_pipeline(
    StandardScaler(),
    KNNImputer(add_indicator=True),
    LinearRegression()
)
model.fit(X_train_na, y_train_na).score(X_test_na, y_test_na)
---------------------------------------------------------------------------
MemoryError                               Traceback (most recent call last)
<ipython-input-71-ad8b65bc77f2> in <module>
      4     LinearRegression()
      5 )
----> 6 model.fit(X_train_na, y_train_na).score(X_test_na, y_test_na)

~/Documents/packages/scikit-learn/sklearn/pipeline.py in fit(self, X, y, **fit_params)
    346             This estimator
    347         """
--> 348         Xt, fit_params = self._fit(X, y, **fit_params)
    349         with _print_elapsed_time('Pipeline',
    350                                  self._log_message(len(self.steps) - 1)):

~/Documents/packages/scikit-learn/sklearn/pipeline.py in _fit(self, X, y, **fit_params)
    311                 message_clsname='Pipeline',
    312                 message=self._log_message(step_idx),
--> 313                 **fit_params_steps[name])
    314             # Replace the transformer of the step with the fitted
    315             # transformer. This is necessary when loading the transformer

~/miniconda3/envs/dev/lib/python3.7/site-packages/joblib/memory.py in __call__(self, *args, **kwargs)
    353 
    354     def __call__(self, *args, **kwargs):
--> 355         return self.func(*args, **kwargs)
    356 
    357     def call_and_shelve(self, *args, **kwargs):

~/Documents/packages/scikit-learn/sklearn/pipeline.py in _fit_transform_one(transformer, X, y, weight, message_clsname, message, **fit_params)
    724     with _print_elapsed_time(message_clsname, message):
    725         if hasattr(transformer, 'fit_transform'):
--> 726             res = transformer.fit_transform(X, y, **fit_params)
    727         else:
    728             res = transformer.fit(X, y, **fit_params).transform(X)

~/Documents/packages/scikit-learn/sklearn/base.py in fit_transform(self, X, y, **fit_params)
    566         else:
    567             # fit method of arity 2 (supervised transformation)
--> 568             return self.fit(X, y, **fit_params).transform(X)
    569 
    570 

~/Documents/packages/scikit-learn/sklearn/impute/_knn.py in transform(self, X)
    230                                   metric=self.metric,
    231                                   missing_values=self.missing_values,
--> 232                                   force_all_finite=force_all_finite)
    233 
    234         # Maps from indices from X to indices in dist matrix

~/Documents/packages/scikit-learn/sklearn/metrics/pairwise.py in pairwise_distances(X, Y, metric, n_jobs, force_all_finite, **kwds)
   1742         func = partial(distance.cdist, metric=metric, **kwds)
   1743 
-> 1744     return _parallel_pairwise(X, Y, func, n_jobs, **kwds)
   1745 
   1746 

~/Documents/packages/scikit-learn/sklearn/metrics/pairwise.py in _parallel_pairwise(X, Y, func, n_jobs, **kwds)
   1341 
   1342     if effective_n_jobs(n_jobs) == 1:
-> 1343         return func(X, Y, **kwds)
   1344 
   1345     # enforce a threading backend to prevent data communication overhead

~/Documents/packages/scikit-learn/sklearn/metrics/pairwise.py in nan_euclidean_distances(X, Y, squared, missing_values, copy)
    409     present_coords_cnt = np.dot(1 - missing_X, 1 - missing_Y.T)
    410     present_mask = (present_coords_cnt != 0)
--> 411     distances[present_mask] *= (X.shape[1] / present_coords_cnt[present_mask])
    412 
    413     if X is Y:

MemoryError: Unable to allocate array with shape (311408488,) and data type float64
@glemaitre
Copy link
Member Author

ping @thomasjpfan any idea (it is a bit late to debug now here)

@wangyexiang
Copy link

Maybe, your dataset is too big.

@glemaitre
Copy link
Member Author

#15615 will save some memory. However, I was wondering if we could have a memory-efficient KNNImputer. It is the first time that I'm looking at the code so maybe I did not get everything straight away.

It seems that we compute the pairwise distance matrix between X and X[missing_row]:

# Pairwise distances between receivers and fitted samples
dist = pairwise_distances(X[row_missing_idx, :], self._fit_X,
metric=self.metric,
missing_values=self.missing_values,
force_all_finite=force_all_finite)

If X is large, then we will get a large dist matrix. I was thinking that we could maybe optimized this issue by:

  • Allow to use a subsampled X;
  • To compute the distance only between X and one row (and by chunk if required). It will require to write some Cython since we will get some nested for loop.

ping @thomasjpfan @jnothman

@jnothman
Copy link
Member

I'm not sure what your proposals mean.

we only calc distances for rows with at least one value missing.

yes we can probably improve memory efficiency by considering one target sample at a time, rather than the column-at-a-time approach here.

but I think we are best off seeing how this goes in practice before optimising it too much. I think the problem can be trivially chunked by receiving row.

are you still having trouble with MemoryError?

@glemaitre
Copy link
Member Author

are you still having trouble with MemoryError?

Nop. I only have a pick of memory at 4 GB for the above example. However X is only taking < 2MB originally.

@glemaitre
Copy link
Member Author

Let's close this issue.

As @jnothman mentioned, we can wait for feedback to see if we need more optimization using chunking.

@zoj613
Copy link
Contributor

zoj613 commented Jan 23, 2020

I am getting the same issue on a real life dataset that is fairly medium in size (i.e 100k rows and 100 columns). Some of the columns are sparse. So when trying to impute about 20-30 of them the KNNImputer consumes about 150GB of memory on an AWS instance and just runs forever without finishing. After sometime the memory usage drops without anything happening afterwards. Is there a remedy for this? Both SimpeImputer and IterativeImputer finish very quickly on the same dataset.

@glemaitre
Copy link
Member Author

Could you open a new issue mentioning this information? We could then think about improving the imputer.

@zoj613
Copy link
Contributor

zoj613 commented Jan 23, 2020

Could you open a new issue mentioning this information? We could then think about improving the imputer.

Is that good to do given the fact that I cant share the data so I dont have an example to reproduce the issue?

@rth
Copy link
Member

rth commented Jan 23, 2020

Yes, please. Memory usage shouldn't be too dependent on actual data. If you can produce a smaller synthetic dataset with a comparable amount of sparsity and NA to impute that would be helpful. Likely a smaller dataset for which KNNimputer would take a few GB as opposed 100s of GB in your example that we can run it locally on a laptop would be best..

@mosari
Copy link

mosari commented Feb 3, 2020

Same problem here with 100k rows and 100 features. The problem arises from metrics\pairwise.py.
distances = - 2 * safe_sparse_dot(X, Y.T, dense_output=True)
MemoryError: Unable to allocate 38.7 GiB for an array with shape (51896, 100000) and data type float64

@rth rth reopened this Feb 3, 2020
@rth
Copy link
Member

rth commented Feb 3, 2020

Thanks, the fact that we can't apply KNNImputer on 100k samples is indeed problematic.

@rth rth added the Bug label Feb 3, 2020
@jnothman
Copy link
Member

jnothman commented Feb 3, 2020 via email

@rth rth added Performance and removed Bug labels Feb 3, 2020
@B-Yassine
Copy link

Same problem here with a (59972, 11) shape dataset

@glemaitre
Copy link
Member Author

glemaitre commented Feb 3, 2020 via email

@jnothman
Copy link
Member

jnothman commented Feb 3, 2020 via email

@B-Yassine
Copy link

But the problem is how can we define these chunks as the final result for a specific value depends on data points which may not be in the defined chunk.
Maybe one can perform a simple clustering and work on these clusters ("chunks") so we will have to push in memory cluster by cluster and perform the KNN imputation.

@jnothman
Copy link
Member

jnothman commented Feb 4, 2020 via email

@felipeeeantunes
Copy link

Same problem here.
MemoryError: Unable to allocate 1.17 TiB for an array with shape (382559, 420370) and data type float64

@thomasjpfan thomasjpfan self-assigned this Feb 5, 2020
@ajing
Copy link

ajing commented Feb 6, 2020

I have the same problem.
Unable to allocate 62.7 GiB for an array with shape (91686, 91713) and data type float64

@jnothman
Copy link
Member

jnothman commented Feb 6, 2020

@thomasjpfan I didn't see you self-assign this. Maybe even core devs should make use of "take". I think I have a patch.

@jnothman jnothman reopened this Feb 6, 2020
jnothman added a commit to jnothman/scikit-learn that referenced this issue Feb 6, 2020
Fixes scikit-learn#15604

This is more computationally expensive than the previous implementation,
but should reduce memory costs substantially in common use cases.
@thomasjpfan thomasjpfan removed their assignment Feb 6, 2020
@jnothman jnothman self-assigned this Feb 7, 2020
thomasjpfan pushed a commit to thomasjpfan/scikit-learn that referenced this issue Feb 22, 2020
Fixes scikit-learn#15604

This is more computationally expensive than the previous implementation,
but should reduce memory costs substantially in common use cases.
panpiort8 pushed a commit to panpiort8/scikit-learn that referenced this issue Mar 3, 2020
Fixes scikit-learn#15604

This is more computationally expensive than the previous implementation,
but should reduce memory costs substantially in common use cases.
@IslamAlam
Copy link

same issue here as well for MDS

from sklearn.manifold import MDS

MemoryError: Unable to allocate 7.46 TiB for an array with shape (1012455, 1012455) and data type float64

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.