diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 96bb2ddfa8f7d..d0174cc8857d6 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -263,6 +263,11 @@ Changelog :pr:`18649` by `Leandro Hermida ` and `Rodion Martynov `. +- |Fix| The `fit` method of the successive halving parameter search + (:class:`model_selection.HalvingGridSearchCV`, and + :class:`model_selection.HalvingRandomSearchCV`) now correctly handles the + `groups` parameter. :pr:`19847` by :user:`Xiaoyu Chai `. + :mod:`sklearn.naive_bayes` .......................... diff --git a/sklearn/model_selection/_search_successive_halving.py b/sklearn/model_selection/_search_successive_halving.py index b522ce7fbda41..d27cd3b1823ca 100644 --- a/sklearn/model_selection/_search_successive_halving.py +++ b/sklearn/model_selection/_search_successive_halving.py @@ -210,7 +210,7 @@ def fit(self, X, y=None, groups=None, **fit_params): self._n_samples_orig = _num_samples(X) - super().fit(X, y=y, groups=None, **fit_params) + super().fit(X, y=y, groups=groups, **fit_params) # Set best_score_: BaseSearchCV does not set it, as refit is a callable self.best_score_ = ( diff --git a/sklearn/model_selection/tests/test_successive_halving.py b/sklearn/model_selection/tests/test_successive_halving.py index 2c55f6aa6cd85..6660b35a934ba 100644 --- a/sklearn/model_selection/tests/test_successive_halving.py +++ b/sklearn/model_selection/tests/test_successive_halving.py @@ -7,9 +7,16 @@ from sklearn.datasets import make_classification from sklearn.dummy import DummyClassifier from sklearn.experimental import enable_halving_search_cv # noqa +from sklearn.model_selection import StratifiedKFold +from sklearn.model_selection import StratifiedShuffleSplit +from sklearn.model_selection import LeaveOneGroupOut +from sklearn.model_selection import LeavePGroupsOut +from sklearn.model_selection import GroupKFold +from sklearn.model_selection import GroupShuffleSplit from sklearn.model_selection import HalvingGridSearchCV from sklearn.model_selection import HalvingRandomSearchCV from sklearn.model_selection import KFold, ShuffleSplit +from sklearn.svm import LinearSVC from sklearn.model_selection._search_successive_halving import ( _SubsampleMetaSplitter, _top_k, _refit_callable) @@ -562,3 +569,32 @@ def set_params(self, **params): assert (cv_results_df['params'] == passed_params).all() assert (cv_results_df['n_resources'] == passed_n_samples).all() + + +@pytest.mark.parametrize('Est', (HalvingGridSearchCV, HalvingRandomSearchCV)) +def test_groups_support(Est): + # Check if ValueError (when groups is None) propagates to + # HalvingGridSearchCV and HalvingRandomSearchCV + # And also check if groups is correctly passed to the cv object + rng = np.random.RandomState(0) + + X, y = make_classification(n_samples=50, n_classes=2, random_state=0) + groups = rng.randint(0, 3, 50) + + clf = LinearSVC(random_state=0) + grid = {'C': [1]} + + group_cvs = [LeaveOneGroupOut(), LeavePGroupsOut(2), + GroupKFold(n_splits=3), GroupShuffleSplit(random_state=0)] + error_msg = "The 'groups' parameter should not be None." + for cv in group_cvs: + gs = Est(clf, grid, cv=cv) + with pytest.raises(ValueError, match=error_msg): + gs.fit(X, y) + gs.fit(X, y, groups=groups) + + non_group_cvs = [StratifiedKFold(), StratifiedShuffleSplit(random_state=0)] + for cv in non_group_cvs: + gs = Est(clf, grid, cv=cv) + # Should not raise an error + gs.fit(X, y)