Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MRG] Allow nan/inf in feature selection #11635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

adpeters
Copy link
Contributor

@adpeters adpeters commented Jul 19, 2018

Reference Issues/PRs

Closes #10821
Closes #10985

What does this implement/fix? Explain your changes.

Allows for NaN/Inf values in RFE/RFECV.fit method as well as SelectorMixin.transform. This affects all feature selection estimators that inherit from SelectorMixin (except for univariate selectors), which includes those in sklearn.feature_selection.variance_threshold and sklearn.linear_model.randomized_l1.

The RFE/RFECV.fit method does not need to check y, as any checks should be done by the estimator when it runs its own fit, so I changed check_X_y to check_array with just X, and allowed NaN/Inf values.

For SelectorMixin.transform, the method itself does not require no NaN/Inf, so we should let any inheritors to do that check themselves if they need it.

Any other comments?

@jnothman
Copy link
Member

You have test failures. Also, we are working towards a release, which this will not be included in, so please ping for review after 0.20 is released.

@georgewambold
Copy link

georgewambold commented Jan 7, 2019

Is anyone working on this issue? I'm also running into problems with check_X_y in RFE and I think this fix would be great.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking pretty good. Please add a test that transforming for univariate selection with NaN at transform time is acceptable. But we could (and should) also use force_all_finite=False in univariate_selection._BaseFilter.fit since the underlying scoring functions check it. SelectFromModel should already be lenient in fit.

You could consider creating a feature_selection/tests/test_common.py file that checks this for all the feature selection estimators (although we would have to use nanvar; mean_variance_axis should already exclude NaNs).

It would be really wonderful to have all feature selectors be insensitive to missing values and this is only a few lines of code away.

rfe = RFE(estimator=clf)
rfe.fit(X, y)
rfe.transform(X)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two blank lines

rfecv = RFECV(estimator=clf)
rfecv.fit(X, y)
rfecv.transform(X)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No blank lines

rfe.fit(X, y)
rfe.transform(X)

def test_rfecv_allow_nan_inf_in_x():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use a loop or pytest.mark.parametrize rather than duplicating code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added pytest.mark.parametrize.

@adpeters
Copy link
Contributor Author

adpeters commented Jan 8, 2019

Okay I made most of the changes mentioned above: fixed pep8 errors, added tests for univariate fit and transform and SelectFromModel transform, and updated univariate fit to allow nan/inf as well. I also updated estimator_checks to not perform the nan/inf check on SelectorMixin objects.

There's probably a more logical/efficient way to do all the tests, but this is what I've got for now.

@jnothman
Copy link
Member

jnothman commented Jan 8, 2019

please merge in an updated master. Circle CI should not be running on Python 2 anymore.

@@ -103,7 +104,7 @@ def _yield_non_meta_checks(name, estimator):
# cross-decomposition's "transform" returns X and Y
yield check_pipeline_consistency

if name not in ALLOW_NAN:
if name not in ALLOW_NAN and not isinstance(estimator, SelectorMixin):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... Not sure if we'd rather this. There still may be selectors (not in our library) that should error on NaN/inf, and many of these selectors with their default parameters still should/will error on NaN/inf.

In the (near) future, estimator tags (#8022) will solve this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh okay. I'm running into an issue with adding the univariate estimators to the ALLOW_NAN list because they fail check_estimators_pickle. The issue is that while _BaseFilter allows NaN/inf in fit and transform, the default scorer (and all scorers thus far) don't allow them, so when fit is called it errors on NaN. In my tests I use a dummy scorer that avoids this, but check_estimators_pickle is a generic check that applies to all estimators so I'm not sure what the best way to avoid this is without changing the default scorer for _BaseFilter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do they fail the pickle test with default parameters? I think for default parameters all the feature selectors should still raise error on nan/inf except for VarianceThreshold so you should not otherwise need to modify estimator_checks should you?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you add them to ALLOW_NAN then they fail the pickle test because they don’t allow nan/inf with the default scorer. However if you don’t add them to ALLOW_NAN then check_estimators_nan_inf errors because the transform method does not check for nan/inf.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining.
I'm trying to work out if there's a better solution than creating ALLOW_NAN_TRANSFORM. Basically, at the moment we need to classify an estimator (and its parameter setting) as allowing or disallowing NaNs. Alternatively we could weaken the pickle test to fit on non-NaN data in the case that fitting with NaNs raises an error. (The ability to fit on NaN data there is not the point of the check. It's the interpretation of NaNs in transform/predict if NaNs are pickled with the model parameters/attributes that is the concern of that check.) Then we would be redefining ALLOW_NAN to mean "allows NaN in transform/predict, and may also allow NaN in fit". Your thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry this is harder than I thought!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for helping work through it. I think relaxing the pickle test makes sense because it's not supposed to be directly checking NaN/Inf handling in the estimators. There are really two issues here:

  1. a discrepancy in whether an estimator allows NaN/Inf between its different methods (e.g. VarianceThreshold currently allows in transform but not fit), and
  2. the default parameter instance of an estimator not exhibiting its allowance of NaN/Inf.

Your suggestion of ALLOW_NAN_TRANSFORM would solve 1 and I could try to implement it if we decide it's the right way to go. But 2 would still not be solved since these estimators would belong in the ALLOW_NAN_FIT category but still error in the pickle check with NaN. So I think weakening the pickle check is the best option without changing the default scorer. I've updated my code to include this new version of check_estimators_pickle.

@georgewambold
Copy link

Thanks @jnothman for the quick comments and @adpeters for the changes, and thank you both for your time!

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think VarianceThreshold should be updated as well. All it should require is use of nanvar as far as I can see (and testing in both the sparse and dense cases).

I'm still not certain that the change to the pickle test is the best thing to do, but it's okay.

Are there other changes you hope to make before making this MRG?

When the scope of the PR is finalised, please add an entry to the change log at doc/whats_new/v0.21.rst. Like the other entries there, please reference this pull request with :issue: and credit yourself (and other contributors if applicable) with :user:

@jnothman
Copy link
Member

jnothman commented Jan 10, 2019 via email

@jnothman jnothman mentioned this pull request Jan 17, 2019
4 tasks
@adpeters
Copy link
Contributor Author

Okay I updated VarianceThreshold, added tests, and added to Notes documentation and whats_new.
I don't really like the change to the pickle test either, but I'm not sure of a better way right now. The other option would be relaxing the default scorer, f_classif, to allow NaN/Inf just by removing rows that contain them, but I don't think that's how we want functions to handle those cases.

I think this is all I can think of for this PR right now. Should I change it to MRG?

@jnothman
Copy link
Member

jnothman commented Jan 18, 2019 via email

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for continuing on this. I wonder if we should pull univariate into a separate PR and merge the rest :|

X = check_array(X, dtype=None, accept_sparse='csr')
tags = self._get_tags()
X = check_array(X, dtype=None, accept_sparse='csr',
force_all_finite=not tags.get('allow_nan', True))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this logic now won't work for univariate :(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah I removed the test for the univariate but I wasn't sure what to do for that case. I guess the only way you could really use it would be to subclass and override the tags.

@@ -361,6 +362,9 @@ def fit(self, X, y):
def _check_params(self, X, y):
pass

def _more_tags(self):
return {'allow_nan': False}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably have a FIXME or similar to say this should depend on the score func

@adpeters
Copy link
Contributor Author

Yeah if we're not quite sure how we want it to work, we could leave that part out of it and just leave Univariate as not allowing nan/inf until it gets figured out.

@jnothman
Copy link
Member

We'd like to release very soon. If you could strip this back to leave out univariate, and then propose that in another PR with a question on what to do about tags, I think that would be the most pragmatic solution.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. Can we get another reviewer before release??

@amueller
Copy link
Member

amueller commented Nov 2, 2019

I'm slightly confused (but also severely jet-lagged so that might be it).
The transform method of a feature selector always supports NaN, right? It's only fit that depends on the base estimator, right?

@jnothman
Copy link
Member

jnothman commented Nov 2, 2019 via email

@amueller
Copy link
Member

amueller commented Nov 2, 2019

hm... but do we want to restrict this because of estimator checks? I'm not sure if we check nan in transform. Would the "right" fix be to add separate tags for fit and transform?
I was just very confused because the test is for NaN in transform, not in fit, which seemed very strange to me.

@@ -320,3 +336,25 @@ def test_threshold_without_refitting():
# Set a higher threshold to filter out more features.
model.threshold = "1.0 * mean"
assert X_transform.shape[1] > model.transform(data).shape[1]


def test_transform_accepts_nan_inf():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also check for fit, maybe use HistGradientBoostingClassifier?

Copy link
Member

@amueller amueller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

happy to go with this solution for the release, adding a test for SelectFromModel with NaN in fit would be nice, though.

@adpeters
Copy link
Contributor Author

adpeters commented Nov 4, 2019

Thanks for the reviews. I added a test for NaN and Inf in SelectFromModel.fit using HistGradientBoostingClassifier as suggested. We still have to use a different classifier for the transform test since HistGradientBoostingClassifier does not generate feature importance metrics that SelectFromModel relies on to transform (coef_ or feature_importances_), but I think it's still an effective test of NaN/Inf in fit.

@jnothman jnothman merged commit 70b0dde into scikit-learn:master Nov 5, 2019
@jnothman
Copy link
Member

jnothman commented Nov 5, 2019

Thanks @adpeters! It's been very nice working with you.

@adpeters
Copy link
Contributor Author

adpeters commented Nov 5, 2019

@jnothman thanks for all your help on this! Glad I could contribute a little.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Is there any reason for SelectFromModel.transform to use force_all_finite=True in check_array? Unnecessary call to check_X_y in RFE
8 participants