-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
BUG ensure that parallel/sequential give the same permutation importances #15933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a47ebe2
ab95fd5
d58272d
7211a53
70e1ef4
a6909ca
a24b9d5
7ad0b26
fb69870
f80dc69
775e986
0631299
1a21a98
023eca2
e213236
be8f1c1
910ef4f
723bf03
f5bda8c
e9770cf
9cdc7b8
03ab3a1
0c25e61
fe4cac6
7bdb93a
d399d96
bdaffb5
42e8cb5
1834fca
74a1c54
5cf37f6
51f7467
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,41 +4,36 @@ | |
from joblib import delayed | ||
|
||
from ..metrics import check_scoring | ||
from ..utils import Bunch | ||
from ..utils import check_random_state | ||
from ..utils import check_array | ||
from ..utils import Bunch | ||
|
||
|
||
def _safe_column_setting(X, col_idx, values): | ||
"""Set column on X using `col_idx`""" | ||
if hasattr(X, "iloc"): | ||
X.iloc[:, col_idx] = values | ||
else: | ||
X[:, col_idx] = values | ||
|
||
|
||
def _safe_column_indexing(X, col_idx): | ||
"""Return column from X using `col_idx`""" | ||
if hasattr(X, "iloc"): | ||
return X.iloc[:, col_idx].values | ||
else: | ||
return X[:, col_idx] | ||
|
||
|
||
def _calculate_permutation_scores(estimator, X, y, col_idx, random_state, | ||
n_repeats, scorer): | ||
"""Calculate score when `col_idx` is permuted.""" | ||
original_feature = _safe_column_indexing(X, col_idx).copy() | ||
temp = original_feature.copy() | ||
random_state = check_random_state(random_state) | ||
|
||
# Work on a copy of X to to ensure thread-safety in case of threading based | ||
# parallelism. Furthermore, making a copy is also useful when the joblib | ||
# backend is 'loky' (default) or the old 'multiprocessing': in those cases, | ||
# if X is large it will be automatically be backed by a readonly memory map | ||
# (memmap). X.copy() on the other hand is always guaranteed to return a | ||
# writable data-structure whose columns can be shuffled inplace. | ||
X_permuted = X.copy() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note for the reviewers: the fact that we always make a copy here also fixes the issue with read-only memmaps as reported in #15810. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just by chance, you are not aware of a way to avoid to re-allocating the full dataframe, and instead make a view for most columns except for the one we want to change inplace? For instance of one makes a slice of a dataframe, and then tries to modify to a column, pandas will raise a warning about a view being modified, but I'm not sure if it will actually change the original dataframe inplace or not in this case... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It depends on the internal block structure of the dataframe but this is considered private API and is likely to change in future versions of pandas. I would rather stay safe for now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the future, I can see this being possible when pandas switches to using a columnar data structure to hold its data as detailed in the pandas roadmap. |
||
scores = np.zeros(n_repeats) | ||
shuffling_idx = np.arange(X.shape[0]) | ||
for n_round in range(n_repeats): | ||
random_state.shuffle(temp) | ||
_safe_column_setting(X, col_idx, temp) | ||
feature_score = scorer(estimator, X, y) | ||
random_state.shuffle(shuffling_idx) | ||
if hasattr(X_permuted, "iloc"): | ||
col = X_permuted.iloc[shuffling_idx, col_idx] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Speaking with @jorisvandenbossche, this is the case where one should use
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No it will keep the dtype info of the In [6]: s = pd.Series(pd.Categorical(['a', 'b']))
In [7]: s
Out[7]:
0 a
1 b
dtype: category
Categories (2, object): [a, b]
In [8]: s.values
Out[8]:
[a, b]
Categories (2, object): [a, b] That's why you need to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alright, good point. Feel free to submit a new PR to simplify the code then :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The To solve this mix, in more recent versions of pandas there is |
||
col.index = X_permuted.index | ||
thomasjpfan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
X_permuted.iloc[:, col_idx] = col | ||
else: | ||
X_permuted[:, col_idx] = X_permuted[shuffling_idx, col_idx] | ||
feature_score = scorer(estimator, X_permuted, y) | ||
scores[n_round] = feature_score | ||
|
||
_safe_column_setting(X, col_idx, original_feature) | ||
return scores | ||
|
||
|
||
|
@@ -104,20 +99,22 @@ def permutation_importance(estimator, X, y, scoring=None, n_repeats=5, | |
.. [BRE] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, | ||
2001. https://doi.org/10.1023/A:1010933404324 | ||
""" | ||
if hasattr(X, "iloc"): | ||
X = X.copy() # Dataframe | ||
else: | ||
X = check_array(X, force_all_finite='allow-nan', dtype=np.object, | ||
copy=True) | ||
|
||
if not hasattr(X, "iloc"): | ||
X = check_array(X, force_all_finite='allow-nan', dtype=None) | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Precompute random seed from the random state to be used | ||
# to get a fresh independent RandomState instance for each | ||
# parallel call to _calculate_permutation_scores, irrespective of | ||
# the fact that variables are shared or not depending on the active | ||
# joblib backend (sequential, thread-based or process-based). | ||
random_state = check_random_state(random_state) | ||
scorer = check_scoring(estimator, scoring=scoring) | ||
random_seed = random_state.randint(np.iinfo(np.int32).max + 1) | ||
|
||
scorer = check_scoring(estimator, scoring=scoring) | ||
baseline_score = scorer(estimator, X, y) | ||
scores = np.zeros((X.shape[1], n_repeats)) | ||
|
||
scores = Parallel(n_jobs=n_jobs)(delayed(_calculate_permutation_scores)( | ||
estimator, X, y, col_idx, random_state, n_repeats, scorer | ||
estimator, X, y, col_idx, random_seed, n_repeats, scorer | ||
) for col_idx in range(X.shape[1])) | ||
|
||
importances = baseline_score - np.array(scores) | ||
|
Uh oh!
There was an error while loading. Please reload this page.