Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Permutation Importance fails if dataset is large enough #15810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
andersbogsnes opened this issue Dec 6, 2019 · 7 comments · Fixed by #15933
Closed

Permutation Importance fails if dataset is large enough #15810

andersbogsnes opened this issue Dec 6, 2019 · 7 comments · Fixed by #15933
Labels
Milestone

Comments

@andersbogsnes
Copy link

andersbogsnes commented Dec 6, 2019

Description

When using permutation_importance with a large enough pandas DataFrame and n_jobs > 0, joblib switches to read-only memmap mode, which proceeds to raise, as permutation_importance tries to assign to the DataFrame.

The error does not occur when passing a similarly sized Numpy array.

In previous, similar implementations, we fixed the bug by setting max_nbytes to None in the Parallel init, though I don't know what the broader consequences of that are.

Steps/Code to Reproduce

import pandas as pd
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import permutation_importance

data = load_iris()

# Make 150,000 samples
df = pd.DataFrame(data=data.data.repeat(1000, axis=0))
y = data.target.repeat(1000)

clf = RandomForestClassifier()
clf.fit(df, y)

r = permutation_importance(clf, df, y, n_jobs=-1)

Expected Results

We expect no exception to be raised

Actual Results

_RemoteTraceback                          Traceback (most recent call last)
_RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/home/anders/.pyenv/versions/3.7.4/envs/ml_tooling_env/lib/python3.7/site-packages/joblib/externals/loky/process_executor.py", line 418, in _process_worker
    r = call_item()
  File "/home/anders/.pyenv/versions/3.7.4/envs/ml_tooling_env/lib/python3.7/site-packages/joblib/externals/loky/process_executor.py", line 272, in __call__
    return self.fn(*self.args, **self.kwargs)
  File "/home/anders/.pyenv/versions/3.7.4/envs/ml_tooling_env/lib/python3.7/site-packages/joblib/_parallel_backends.py", line 600, in __call__
    return self.func(*args, **kwargs)
  File "/home/anders/.pyenv/versions/3.7.4/envs/ml_tooling_env/lib/python3.7/site-packages/joblib/parallel.py", line 256, in __call__
    for func, args, kwargs in self.items]
  File "/home/anders/.pyenv/versions/3.7.4/envs/ml_tooling_env/lib/python3.7/site-packages/joblib/parallel.py", line 256, in <listcomp>
    for func, args, kwargs in self.items]
  File "/home/anders/.pyenv/versions/3.7.4/envs/ml_tooling_env/lib/python3.7/site-packages/sklearn/inspection/_permutation_importance.py", line 37, in _calculate_permutation_scores
    _safe_column_setting(X, col_idx, temp)
  File "/home/anders/.pyenv/versions/3.7.4/envs/ml_tooling_env/lib/python3.7/site-packages/sklearn/inspection/_permutation_importance.py", line 15, in _safe_column_setting
    X.iloc[:, col_idx] = values
  File "/home/anders/.pyenv/versions/3.7.4/envs/ml_tooling_env/lib/python3.7/site-packages/pandas/core/indexing.py", line 205, in __setitem__
    self._setitem_with_indexer(indexer, value)
  File "/home/anders/.pyenv/versions/3.7.4/envs/ml_tooling_env/lib/python3.7/site-packages/pandas/core/indexing.py", line 576, in _setitem_with_indexer
    self.obj[item_labels[indexer[info_axis]]] = value
  File "/home/anders/.pyenv/versions/3.7.4/envs/ml_tooling_env/lib/python3.7/site-packages/pandas/core/frame.py", line 3487, in __setitem__
    self._set_item(key, value)
  File "/home/anders/.pyenv/versions/3.7.4/envs/ml_tooling_env/lib/python3.7/site-packages/pandas/core/frame.py", line 3565, in _set_item
    NDFrame._set_item(self, key, value)
  File "/home/anders/.pyenv/versions/3.7.4/envs/ml_tooling_env/lib/python3.7/site-packages/pandas/core/generic.py", line 3381, in _set_item
    self._data.set(key, value)
  File "/home/anders/.pyenv/versions/3.7.4/envs/ml_tooling_env/lib/python3.7/site-packages/pandas/core/internals/managers.py", line 1090, in set
    blk.set(blk_locs, value_getitem(val_locs))
  File "/home/anders/.pyenv/versions/3.7.4/envs/ml_tooling_env/lib/python3.7/site-packages/pandas/core/internals/blocks.py", line 380, in set
    self.values[locs] = values
ValueError: assignment destination is read-only
"""
The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
<ipython-input-2-22d9b46d8800> in <module>
     13 clf.fit(df, y)
     14 
---> 15 r = permutation_importance(clf, df, y, n_jobs=-1)

~/.pyenv/versions/3.7.4/envs/ml_tooling_env/lib/python3.7/site-packages/sklearn/inspection/_permutation_importance.py in permutation_importance(estimator, X, y, scoring, n_repeats, n_jobs, random_state)
    119     scores = Parallel(n_jobs=n_jobs)(delayed(_calculate_permutation_scores)(
    120         estimator, X, y, col_idx, random_state, n_repeats, scorer
--> 121     ) for col_idx in range(X.shape[1]))
    122 
    123     importances = baseline_score - np.array(scores)

~/.pyenv/versions/3.7.4/envs/ml_tooling_env/lib/python3.7/site-packages/joblib/parallel.py in __call__(self, iterable)
   1014 
   1015             with self._backend.retrieval_context():
-> 1016                 self.retrieve()
   1017             # Make sure that we get a last message telling us we are done
   1018             elapsed_time = time.time() - self._start_time

~/.pyenv/versions/3.7.4/envs/ml_tooling_env/lib/python3.7/site-packages/joblib/parallel.py in retrieve(self)
    906             try:
    907                 if getattr(self._backend, 'supports_timeout', False):
--> 908                     self._output.extend(job.get(timeout=self.timeout))
    909                 else:
    910                     self._output.extend(job.get())

~/.pyenv/versions/3.7.4/envs/ml_tooling_env/lib/python3.7/site-packages/joblib/_parallel_backends.py in wrap_future_result(future, timeout)
    552         AsyncResults.get from multiprocessing."""
    553         try:
--> 554             return future.result(timeout=timeout)
    555         except LokyTimeoutError:
    556             raise TimeoutError()

~/.pyenv/versions/3.7.4/lib/python3.7/concurrent/futures/_base.py in result(self, timeout)
    433                 raise CancelledError()
    434             elif self._state == FINISHED:
--> 435                 return self.__get_result()
    436             else:
    437                 raise TimeoutError()

~/.pyenv/versions/3.7.4/lib/python3.7/concurrent/futures/_base.py in __get_result(self)
    382     def __get_result(self):
    383         if self._exception:
--> 384             raise self._exception
    385         else:
    386             return self._result

ValueError: assignment destination is read-only

Versions

System:
python: 3.7.4 (default, Oct 14 2019, 12:42:45) [GCC 7.4.0]
executable: /home/anders/.pyenv/versions/3.7.4/envs/ml_tooling_env/bin/python3.7
machine: Linux-5.0.0-37-generic-x86_64-with-debian-buster-sid

Python dependencies:
pip: 19.3.1
setuptools: 40.8.0
sklearn: 0.22
numpy: 1.17.4
scipy: 1.3.3
Cython: None
pandas: 0.25.3
matplotlib: 3.1.2
joblib: 0.14.0

Built with OpenMP: True

@rth rth added the Bug label Dec 6, 2019
@rth rth added this to the 0.22.1 milestone Dec 6, 2019
@rth
Copy link
Member

rth commented Dec 6, 2019

Thanks for reporting this issue @andersbogsnes !

@Henrilin28
Copy link

do we have a fix for this?

@andersbogsnes
Copy link
Author

As I mentioned in the report, one potential fix is to turn off the memmapping by setting max_ nbytes to None.

This would fix it, but I don’t know what the wider-ranting consequences are.

Alternatively, the implementation would have to assign to a copy of the DataFrame on each pass - which sounds expensive memory-wise...

Not sure why the difference between Numpy and the DataFrame though - that could be a clue.

Finally, if there is some method that allows the user to reach into Parallel and choose to turn off memmapping, that would work too - unfortunately I haven’t found a way to do it in my cursory read through of the joblib docs - the parallel_backend context manager does not let the user control that part

@jnothman
Copy link
Member

Happy to see a patch that turns off max_nbytes for now.

@shivamgargsya
Copy link
Contributor

Would like to take this up.

@andersbogsnes
Copy link
Author

Let me know, otherwise I’d be happy to contribute a patch

@Henrilin28
Copy link

I tried to to turn off the memmapping by setting max_ nbytes to None but it didn't work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment