diff --git a/doc/whats_new/v0.23.rst b/doc/whats_new/v0.23.rst index 750336a81e109..32716cb9eb694 100644 --- a/doc/whats_new/v0.23.rst +++ b/doc/whats_new/v0.23.rst @@ -124,6 +124,11 @@ Changelog for datasets with large vocabularies combined with ``min_df`` or ``max_df``. :pr:`15834` by :user:`Santiago M. Mola `. + +- |Enhancement| Added support for multioutput data in + :class:`feature_selection.RFE` and :class:`feature_selection.RFECV`. + :pr:`16103` by :user:`Divyaprabha M `. + :mod:`sklearn.gaussian_process` ............................... diff --git a/sklearn/feature_selection/_rfe.py b/sklearn/feature_selection/_rfe.py index 12e99175c9d61..91312c7dc80f9 100644 --- a/sklearn/feature_selection/_rfe.py +++ b/sklearn/feature_selection/_rfe.py @@ -156,7 +156,8 @@ def _fit(self, X, y, step_score=None): tags = self._get_tags() X, y = check_X_y(X, y, "csc", ensure_min_features=2, - force_all_finite=not tags.get('allow_nan', True)) + force_all_finite=not tags.get('allow_nan', True), + multi_output=True) # Initialization n_features = X.shape[1] if self.n_features_to_select is None: @@ -489,8 +490,10 @@ def fit(self, X, y, groups=None): train/test set. Only used in conjunction with a "Group" :term:`cv` instance (e.g., :class:`~sklearn.model_selection.GroupKFold`). """ - X, y = check_X_y(X, y, "csr", ensure_min_features=2, - force_all_finite=False) + tags = self._get_tags() + X, y = check_X_y(X, y, "csc", ensure_min_features=2, + force_all_finite=not tags.get('allow_nan', True), + multi_output=True) # Initialization cv = check_cv(self.cv, y, is_classifier(self.estimator)) diff --git a/sklearn/feature_selection/tests/test_rfe.py b/sklearn/feature_selection/tests/test_rfe.py index ccd3c0a1b0e83..654675e677a11 100644 --- a/sklearn/feature_selection/tests/test_rfe.py +++ b/sklearn/feature_selection/tests/test_rfe.py @@ -395,3 +395,15 @@ def test_rfe_allow_nan_inf_in_x(cv): rfe = RFE(estimator=clf) rfe.fit(X, y) rfe.transform(X) + + +@pytest.mark.parametrize('ClsRFE', [ + RFE, + RFECV + ]) +def test_multioutput(ClsRFE): + X = np.random.normal(size=(10, 3)) + y = np.random.randint(2, size=(10, 2)) + clf = RandomForestClassifier(n_estimators=5) + rfe_test = ClsRFE(clf) + rfe_test.fit(X, y)