-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG] Allow nan/inf in feature selection #11635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
03eb769
03042a2
d6ffba7
da4c5f4
d3601f4
27e1399
2b06f2a
fb7d7d6
4d4662c
eadb804
da866eb
9f2ff90
f7c2fb8
288cecc
67fa9e7
9be3870
c9e8fcc
34e2710
816b0a3
c4619eb
195ed13
87c771a
a79155d
620726a
fe60c94
cca7acc
017a0db
0efc7f0
35878e8
4119fcd
9f20985
fc1a975
f710408
2b55934
1adc65c
055d4ea
61d8277
4499960
92bc14f
e1add7f
4921954
6783182
8a9164f
8167671
6176538
e43d8f8
6950892
20f2426
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,10 +10,28 @@ | |
from sklearn.linear_model import LogisticRegression, SGDClassifier, Lasso | ||
from sklearn.svm import LinearSVC | ||
from sklearn.feature_selection import SelectFromModel | ||
from sklearn.ensemble import RandomForestClassifier | ||
from sklearn.experimental import enable_hist_gradient_boosting # noqa | ||
from sklearn.ensemble import (RandomForestClassifier, | ||
HistGradientBoostingClassifier) | ||
from sklearn.linear_model import PassiveAggressiveClassifier | ||
from sklearn.base import BaseEstimator | ||
|
||
|
||
class NaNTag(BaseEstimator): | ||
def _more_tags(self): | ||
return {'allow_nan': True} | ||
|
||
|
||
class NoNaNTag(BaseEstimator): | ||
def _more_tags(self): | ||
return {'allow_nan': False} | ||
|
||
|
||
class NaNTagRandomForest(RandomForestClassifier): | ||
def _more_tags(self): | ||
return {'allow_nan': True} | ||
|
||
|
||
iris = datasets.load_iris() | ||
data, y = iris.data, iris.target | ||
rng = np.random.RandomState(0) | ||
|
@@ -320,3 +338,40 @@ def test_threshold_without_refitting(): | |
# Set a higher threshold to filter out more features. | ||
model.threshold = "1.0 * mean" | ||
assert X_transform.shape[1] > model.transform(data).shape[1] | ||
|
||
|
||
def test_fit_accepts_nan_inf(): | ||
# Test that fit doesn't check for np.inf and np.nan values. | ||
clf = HistGradientBoostingClassifier(random_state=0) | ||
|
||
model = SelectFromModel(estimator=clf) | ||
|
||
nan_data = data.copy() | ||
nan_data[0] = np.NaN | ||
nan_data[1] = np.Inf | ||
|
||
model.fit(data, y) | ||
|
||
|
||
def test_transform_accepts_nan_inf(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you also check for |
||
# Test that transform doesn't check for np.inf and np.nan values. | ||
clf = NaNTagRandomForest(n_estimators=100, random_state=0) | ||
nan_data = data.copy() | ||
|
||
model = SelectFromModel(estimator=clf) | ||
model.fit(nan_data, y) | ||
|
||
nan_data[0] = np.NaN | ||
nan_data[1] = np.Inf | ||
|
||
model.transform(nan_data) | ||
|
||
|
||
def test_allow_nan_tag_comes_from_estimator(): | ||
allow_nan_est = NaNTag() | ||
model = SelectFromModel(estimator=allow_nan_est) | ||
assert model._get_tags()['allow_nan'] is True | ||
|
||
no_nan_est = NoNaNTag() | ||
model = SelectFromModel(estimator=no_nan_est) | ||
assert model._get_tags()['allow_nan'] is False |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,3 +46,15 @@ def test_zero_variance_floating_point_error(): | |
msg = "No feature in X meets the variance threshold 0.00000" | ||
with pytest.raises(ValueError, match=msg): | ||
VarianceThreshold().fit(X) | ||
|
||
|
||
def test_variance_nan(): | ||
arr = np.array(data, dtype=np.float64) | ||
# add single NaN and feature should still be included | ||
arr[0, 0] = np.NaN | ||
# make all values in feature NaN and feature should be rejected | ||
arr[:, 1] = np.NaN | ||
|
||
for X in [arr, csr_matrix(arr), csc_matrix(arr), bsr_matrix(arr)]: | ||
sel = VarianceThreshold().fit(X) | ||
assert_array_equal([0, 3, 4], sel.get_support(indices=True)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should ideally also check that an all-nan feature is treated as 0 variance... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't mind which as long as the feature is rejected. It is probably simpler to set it to 0 at fit time. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay I left it as NaN since that doesn't require any changes to fit and would also allow you to differentiate between an all-NaN feature and one that truly has 0 variance, if desired. But I updated the test to ensure that the all-NaN feature is rejected. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this logic now won't work for univariate :(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yeah I removed the test for the univariate but I wasn't sure what to do for that case. I guess the only way you could really use it would be to subclass and override the tags.