diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 39235625093bc..c7cf6a144bec3 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -69,7 +69,6 @@ Changelog :pr:`123456` by :user:`Joe Bloggs `. where 123456 is the *pull request* number, not the issue number. - :mod:`sklearn.base` ................... @@ -311,6 +310,25 @@ Changelog :mod:`sklearn.feature_selection` ................................ +- |Enhancement| Updated the following :mod:`feature_selection` estimators to allow + NaN/Inf values in ``transform`` and ``fit``: + :class:`feature_selection.RFE`, :class:`feature_selection.RFECV`, + :class:`feature_selection.SelectFromModel`, + and :class:`feature_selection.VarianceThreshold`. Note that if the underlying + estimator of the feature selector does not allow NaN/Inf then it will still + error, but the feature selectors themselves no longer enforce this + restriction unnecessarily. :issue:`11635` by :user:`Alec Peters `. + +- |Enhancement| Updated univariate :mod:`feature_selection` estimators to allow + NaN/Inf values in ``transform`` and ``fit``. This includes + :class:`feature_selection.GenericUnivariateSelect`, + :class:`feature_selection.SelectFdr`, :class:`feature_selection.SelectFpr`, + :class:`feature_selection.SelectFwe`, :class:`feature_selection.SelectKBest`, + :class:`feature_selection.SelectPercentile`. Note that if the underlying + score function of the feature selector does not allow NaN/Inf then it will still + error, but the feature selectors themselves no longer enforce this + restriction unnecessarily. :pr:`15434` by :user:`Alec Peters `. + - |Fix| Fixed a bug where :class:`feature_selection.VarianceThreshold` with `threshold=0` did not remove constant features due to numerical instability, by using range rather than variance in this case. diff --git a/sklearn/feature_selection/_base.py b/sklearn/feature_selection/_base.py index bcd9834189f60..20a54c41a358b 100644 --- a/sklearn/feature_selection/_base.py +++ b/sklearn/feature_selection/_base.py @@ -71,7 +71,9 @@ def transform(self, X): X_r : array of shape [n_samples, n_selected_features] The input samples with only the selected features. """ - X = check_array(X, dtype=None, accept_sparse='csr') + tags = self._get_tags() + X = check_array(X, dtype=None, accept_sparse='csr', + force_all_finite=not tags.get('allow_nan', True)) mask = self.get_support() if not mask.any(): warn("No features were selected: either the data is" diff --git a/sklearn/feature_selection/_from_model.py b/sklearn/feature_selection/_from_model.py index 3e324fbec5535..674127f06acd7 100644 --- a/sklearn/feature_selection/_from_model.py +++ b/sklearn/feature_selection/_from_model.py @@ -131,6 +131,10 @@ class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator): threshold_ : float The threshold value used for feature selection. + Notes + ----- + Allows NaN/Inf in the input if the underlying estimator does as well. + Examples -------- >>> from sklearn.feature_selection import SelectFromModel @@ -249,3 +253,7 @@ def partial_fit(self, X, y=None, **fit_params): self.estimator_ = clone(self.estimator) self.estimator_.partial_fit(X, y, **fit_params) return self + + def _more_tags(self): + estimator_tags = self.estimator._get_tags() + return {'allow_nan': estimator_tags.get('allow_nan', True)} diff --git a/sklearn/feature_selection/_rfe.py b/sklearn/feature_selection/_rfe.py index 86362f27ef181..a204ac3742ca5 100644 --- a/sklearn/feature_selection/_rfe.py +++ b/sklearn/feature_selection/_rfe.py @@ -103,6 +103,10 @@ class RFE(SelectorMixin, MetaEstimatorMixin, BaseEstimator): >>> selector.ranking_ array([1, 1, 1, 1, 1, 6, 4, 3, 2, 5]) + Notes + ----- + Allows NaN/Inf in the input if the underlying estimator does as well. + See also -------- RFECV : Recursive feature elimination with built-in cross-validated @@ -150,7 +154,8 @@ def _fit(self, X, y, step_score=None): # and is used when implementing RFECV # self.scores_ will not be calculated when calling _fit through fit - X, y = check_X_y(X, y, "csc", ensure_min_features=2) + X, y = check_X_y(X, y, "csc", ensure_min_features=2, + force_all_finite=False) # Initialization n_features = X.shape[1] if self.n_features_to_select is None: @@ -326,7 +331,7 @@ def predict_log_proba(self, X): return self.estimator_.predict_log_proba(self.transform(X)) def _more_tags(self): - return {'poor_score': True} + return {'poor_score': True, 'allow_nan': True} class RFECV(RFE): @@ -421,6 +426,8 @@ class RFECV(RFE): ``ceil((n_features - min_features_to_select) / step) + 1``, where step is the number of features removed at each iteration. + Allows NaN/Inf in the input if the underlying estimator does as well. + Examples -------- The following example shows how to retrieve the a-priori not known 5 @@ -479,7 +486,8 @@ def fit(self, X, y, groups=None): train/test set. Only used in conjunction with a "Group" :term:`cv` instance (e.g., :class:`~sklearn.model_selection.GroupKFold`). """ - X, y = check_X_y(X, y, "csr", ensure_min_features=2) + X, y = check_X_y(X, y, "csr", ensure_min_features=2, + force_all_finite=False) # Initialization cv = check_cv(self.cv, y, is_classifier(self.estimator)) diff --git a/sklearn/feature_selection/_univariate_selection.py b/sklearn/feature_selection/_univariate_selection.py index 21990bb3a8167..78111f5526c61 100644 --- a/sklearn/feature_selection/_univariate_selection.py +++ b/sklearn/feature_selection/_univariate_selection.py @@ -338,7 +338,9 @@ def fit(self, X, y): ------- self : object """ - X, y = check_X_y(X, y, ['csr', 'csc'], multi_output=True) + X, y = check_X_y(X, y, ['csr', 'csc'], + force_all_finite=not tags.get('allow_nan', True), + multi_output=True) if not callable(self.score_func): raise TypeError("The score function should be a callable, %s (%s) " @@ -361,6 +363,10 @@ def fit(self, X, y): def _check_params(self, X, y): pass + #FIXME: how do we determine the tags when it depends on the underlying score_func, which does not have tags? + def _more_tags(self): + return {'allow_nan': False} + ###################################################################### # Specific filters @@ -405,6 +411,8 @@ class SelectPercentile(_BaseFilter): Ties between features with equal scores will be broken in an unspecified way. + Allows NaN/Inf in the input if the underlying score_func does as well. + See also -------- f_classif: ANOVA F-value between label/feature for classification tasks. @@ -489,6 +497,8 @@ class SelectKBest(_BaseFilter): Ties between features with equal scores will be broken in an unspecified way. + Allows NaN/Inf in the input if the underlying score_func does as well. + See also -------- f_classif: ANOVA F-value between label/feature for classification tasks. @@ -568,6 +578,10 @@ class SelectFpr(_BaseFilter): >>> X_new.shape (569, 16) + Notes + ----- + Allows NaN/Inf in the input if the underlying score_func does as well. + See also -------- f_classif: ANOVA F-value between label/feature for classification tasks. @@ -634,6 +648,10 @@ class SelectFdr(_BaseFilter): ---------- https://en.wikipedia.org/wiki/False_discovery_rate + Notes + ----- + Allows NaN/Inf in the input if the underlying score_func does as well. + See also -------- f_classif: ANOVA F-value between label/feature for classification tasks. @@ -699,6 +717,10 @@ class SelectFwe(_BaseFilter): pvalues_ : array-like of shape (n_features,) p-values of feature scores. + Notes + ----- + Allows NaN/Inf in the input if the underlying score_func does as well. + See also -------- f_classif: ANOVA F-value between label/feature for classification tasks. @@ -765,6 +787,10 @@ class GenericUnivariateSelect(_BaseFilter): >>> X_new.shape (569, 20) + Notes + ----- + Allows NaN/Inf in the input if the underlying score_func does as well. + See also -------- f_classif: ANOVA F-value between label/feature for classification tasks. diff --git a/sklearn/feature_selection/_variance_threshold.py b/sklearn/feature_selection/_variance_threshold.py index 15576fe31025c..4f9d720b762b9 100644 --- a/sklearn/feature_selection/_variance_threshold.py +++ b/sklearn/feature_selection/_variance_threshold.py @@ -29,6 +29,10 @@ class VarianceThreshold(SelectorMixin, BaseEstimator): variances_ : array, shape (n_features,) Variances of individual features. + Notes + ----- + Allows NaN in the input. + Examples -------- The following dataset has integer features, two of which are the same @@ -61,7 +65,8 @@ def fit(self, X, y=None): ------- self """ - X = check_array(X, ('csr', 'csc'), dtype=np.float64) + X = check_array(X, ('csr', 'csc'), dtype=np.float64, + force_all_finite='allow-nan') if hasattr(X, "toarray"): # sparse matrix _, self.variances_ = mean_variance_axis(X, axis=0) @@ -69,16 +74,18 @@ def fit(self, X, y=None): mins, maxes = min_max_axis(X, axis=0) peak_to_peaks = maxes - mins else: - self.variances_ = np.var(X, axis=0) + self.variances_ = np.nanvar(X, axis=0) if self.threshold == 0: peak_to_peaks = np.ptp(X, axis=0) if self.threshold == 0: # Use peak-to-peak to avoid numeric precision issues # for constant features - self.variances_ = np.minimum(self.variances_, peak_to_peaks) + compare_arr = np.array([self.variances_, peak_to_peaks]) + self.variances_ = np.nanmin(compare_arr, axis=0) - if np.all(self.variances_ <= self.threshold): + if np.all(~np.isfinite(self.variances_) | + (self.variances_ <= self.threshold)): msg = "No feature in X meets the variance threshold {0:.5f}" if X.shape[0] == 1: msg += " (X contains only one sample)" @@ -90,3 +97,6 @@ def _get_support_mask(self): check_is_fitted(self) return self.variances_ > self.threshold + + def _more_tags(self): + return {'allow_nan': True} diff --git a/sklearn/feature_selection/tests/test_feature_select.py b/sklearn/feature_selection/tests/test_feature_select.py index abb11fdc7b8da..ff24467b389fd 100644 --- a/sklearn/feature_selection/tests/test_feature_select.py +++ b/sklearn/feature_selection/tests/test_feature_select.py @@ -24,6 +24,12 @@ ############################################################################## + +# dummy scorer to test other functionality +def dummy_score(X, y): + return X[0], X[0] + + # Test the score functions def test_f_oneway_vs_scipy_stats(): @@ -471,7 +477,6 @@ def test_selectkbest_tiebreaking(): # Prior to 0.11, SelectKBest would return more features than requested. Xs = [[0, 1, 1], [0, 0, 1], [1, 0, 0], [1, 1, 0]] y = [1] - dummy_score = lambda X, y: (X[0], X[0]) for X in Xs: sel = SelectKBest(dummy_score, k=1) X1 = ignore_warnings(sel.fit_transform)([X], y) @@ -488,7 +493,6 @@ def test_selectpercentile_tiebreaking(): # Test if SelectPercentile selects the right n_features in case of ties. Xs = [[0, 1, 1], [0, 0, 1], [1, 0, 0], [1, 1, 0]] y = [1] - dummy_score = lambda X, y: (X[0], X[0]) for X in Xs: sel = SelectPercentile(dummy_score, percentile=34) X1 = ignore_warnings(sel.fit_transform)([X], y) @@ -667,3 +671,14 @@ def test_mutual_info_regression(): gtruth = np.zeros(10) gtruth[:2] = 1 assert_array_equal(support, gtruth) + + +def test_univariate_nan_inf_allowed_in_fit(): + X, y = make_regression(n_samples=100, n_features=10, n_informative=2, + shuffle=False, random_state=0, noise=10) + + univariate_filter = GenericUnivariateSelect(dummy_score, + mode='percentile') + X[0] = np.NaN + X[1] = np.Inf + univariate_filter.fit(X, y) diff --git a/sklearn/feature_selection/tests/test_from_model.py b/sklearn/feature_selection/tests/test_from_model.py index a1f6a9d970117..40fafe4429896 100644 --- a/sklearn/feature_selection/tests/test_from_model.py +++ b/sklearn/feature_selection/tests/test_from_model.py @@ -14,6 +14,22 @@ from sklearn.linear_model import PassiveAggressiveClassifier from sklearn.base import BaseEstimator + +class NaNTag(BaseEstimator): + def _more_tags(self): + return {'allow_nan': True} + + +class NoNaNTag(BaseEstimator): + def _more_tags(self): + return {'allow_nan': False} + + +class NaNTagRandomForest(RandomForestClassifier): + def _more_tags(self): + return {'allow_nan': True} + + iris = datasets.load_iris() data, y = iris.data, iris.target rng = np.random.RandomState(0) @@ -320,3 +336,25 @@ def test_threshold_without_refitting(): # Set a higher threshold to filter out more features. model.threshold = "1.0 * mean" assert X_transform.shape[1] > model.transform(data).shape[1] + + +def test_transform_accepts_nan_inf(): + # Test that transform doesn't check for np.inf and np.nan values. + clf = NaNTagRandomForest(n_estimators=100, random_state=0) + + model = SelectFromModel(estimator=clf) + model.fit(data, y) + + data[0] = np.NaN + data[1] = np.Inf + model.transform(data) + + +def test_allow_nan_tag_comes_from_estimator(): + allow_nan_est = NaNTag() + model = SelectFromModel(estimator=allow_nan_est) + assert model._get_tags()['allow_nan'] is True + + no_nan_est = NoNaNTag() + model = SelectFromModel(estimator=no_nan_est) + assert model._get_tags()['allow_nan'] is False diff --git a/sklearn/feature_selection/tests/test_rfe.py b/sklearn/feature_selection/tests/test_rfe.py index 724c749ee636b..a86999d2b1e44 100644 --- a/sklearn/feature_selection/tests/test_rfe.py +++ b/sklearn/feature_selection/tests/test_rfe.py @@ -2,6 +2,7 @@ Testing Recursive feature elimination """ +import pytest import numpy as np from numpy.testing import assert_array_almost_equal, assert_array_equal from scipy import sparse @@ -369,3 +370,25 @@ def test_rfe_cv_groups(): ) est_groups.fit(X, y, groups=groups) assert est_groups.n_features_ > 0 + + +@pytest.mark.parametrize("cv", [ + None, + 5 +]) +def test_rfe_allow_nan_inf_in_x(cv): + iris = load_iris() + X = iris.data + y = iris.target + + # add nan and inf value to X + X[0][0] = np.NaN + X[0][1] = np.Inf + + clf = MockClassifier() + if cv is not None: + rfe = RFECV(estimator=clf, cv=cv) + else: + rfe = RFE(estimator=clf) + rfe.fit(X, y) + rfe.transform(X) diff --git a/sklearn/feature_selection/tests/test_variance_threshold.py b/sklearn/feature_selection/tests/test_variance_threshold.py index 9dc7effd3d1a5..77d9c9445bc71 100644 --- a/sklearn/feature_selection/tests/test_variance_threshold.py +++ b/sklearn/feature_selection/tests/test_variance_threshold.py @@ -46,3 +46,15 @@ def test_zero_variance_floating_point_error(): msg = "No feature in X meets the variance threshold 0.00000" with pytest.raises(ValueError, match=msg): VarianceThreshold().fit(X) + + +def test_variance_nan(): + arr = np.array(data, dtype=np.float64) + # add single NaN and feature should still be included + arr[0, 0] = np.NaN + # make all values in feature NaN and feature should be rejected + arr[:, 1] = np.NaN + + for X in [arr, csr_matrix(arr), csc_matrix(arr), bsr_matrix(arr)]: + sel = VarianceThreshold().fit(X) + assert_array_equal([0, 3, 4], sel.get_support(indices=True)) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 30c668237b371..f3e3b997bf678 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -1475,7 +1475,18 @@ def check_estimators_pickle(name, estimator_orig): y = _enforce_estimator_tags_y(estimator, y) set_random_state(estimator) - estimator.fit(X, y) + + try: + estimator.fit(X, y) + except ValueError as e: + if 'inf' not in repr(e) and 'NaN' not in repr(e): + raise e + else: + # Some feature selection estimators don't allow nan/inf with + # their default parameters, even though they are allowed in + # general. Remove the nan in these cases. + X = np.nan_to_num(X) + estimator.fit(X, y) result = dict() for method in check_methods: