Thanks to visit codestin.com
Credit goes to github.com

Skip to content

FIX Add NaN handling to selection of best parameters for HalvingGridSearchCV #24539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Nov 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion doc/whats_new/v1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ Changelog
- |Enhancement| :func:`datasets.dump_svmlight_file` is now accelerated with a
Cython implementation, providing 2-4x speedups.
:pr:`23127` by :user:`Meekail Zain <micky774>`

- |Enhancement| Path-like objects, such as those created with pathlib are now
allowed as paths in :func:`datasets.load_svmlight_file` and
:func:`datasets.load_svmlight_files`.
Expand Down Expand Up @@ -476,6 +476,11 @@ Changelog
nan score is correctly set to the maximum possible rank, rather than
`np.iinfo(np.int32).min`. :pr:`24141` by :user:`Loïc Estève <lesteve>`.

- |Fix| In both :class:`model_selection.HalvingGridSearchCV` and
:class:`model_selection.HalvingRandomSearchCV` parameter
combinations with a NaN score now share the lowest rank.
:pr:`24539` by :user:`Tim Head <betatim>`.

- |Fix| For :class:`model_selection.GridSearchCV` and
:class:`model_selection.RandomizedSearchCV` ranks corresponding to nan
scores will all be set to the maximum possible rank.
Expand Down
20 changes: 18 additions & 2 deletions sklearn/model_selection/_search_successive_halving.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ def _top_k(results, k, itr):
for a in (results["iter"], results["mean_test_score"], results["params"])
)
iter_indices = np.flatnonzero(iteration == itr)
sorted_indices = np.argsort(mean_test_score[iter_indices])
scores = mean_test_score[iter_indices]
# argsort() places NaNs at the end of the array so we move NaNs to the
# front of the array so the last `k` items are the those with the
# highest scores.
sorted_indices = np.roll(np.argsort(scores), np.count_nonzero(np.isnan(scores)))
return np.array(params[iter_indices][sorted_indices[-k:]])


Expand Down Expand Up @@ -216,7 +220,15 @@ def _select_best_index(refit, refit_metric, results):
"""
last_iter = np.max(results["iter"])
last_iter_indices = np.flatnonzero(results["iter"] == last_iter)
best_idx = np.argmax(results["mean_test_score"][last_iter_indices])

test_scores = results["mean_test_score"][last_iter_indices]
# If all scores are NaNs there is no way to pick between them,
# so we (arbitrarily) declare the zero'th entry the best one
if np.isnan(test_scores).all():
best_idx = 0
else:
best_idx = np.nanargmax(test_scores)

return last_iter_indices[best_idx]

def fit(self, X, y=None, groups=None, **fit_params):
Expand Down Expand Up @@ -655,6 +667,8 @@ class HalvingGridSearchCV(BaseSuccessiveHalving):
The parameters selected are those that maximize the score of the held-out
data, according to the scoring parameter.

All parameter combinations scored with a NaN will share the lowest rank.

Examples
--------

Expand Down Expand Up @@ -992,6 +1006,8 @@ class HalvingRandomSearchCV(BaseSuccessiveHalving):
The parameters selected are those that maximize the score of the held-out
data, according to the scoring parameter.

All parameter combinations scored with a NaN will share the lowest rank.

Examples
--------

Expand Down
69 changes: 69 additions & 0 deletions sklearn/model_selection/tests/test_successive_halving.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,75 @@ def get_params(self, deep=False):
return params


class SometimesFailClassifier(DummyClassifier):
def __init__(
self,
strategy="stratified",
random_state=None,
constant=None,
n_estimators=10,
fail_fit=False,
fail_predict=False,
a=0,
):
self.fail_fit = fail_fit
self.fail_predict = fail_predict
self.n_estimators = n_estimators
self.a = a

super().__init__(
strategy=strategy, random_state=random_state, constant=constant
)

def fit(self, X, y):
if self.fail_fit:
raise Exception("fitting failed")
return super().fit(X, y)

def predict(self, X):
if self.fail_predict:
raise Exception("predict failed")
return super().predict(X)


@pytest.mark.filterwarnings("ignore::sklearn.exceptions.FitFailedWarning")
@pytest.mark.filterwarnings("ignore:Scoring failed:UserWarning")
@pytest.mark.filterwarnings("ignore:One or more of the:UserWarning")
@pytest.mark.parametrize("HalvingSearch", (HalvingGridSearchCV, HalvingRandomSearchCV))
@pytest.mark.parametrize("fail_at", ("fit", "predict"))
def test_nan_handling(HalvingSearch, fail_at):
"""Check the selection of the best scores in presence of failure represented by
NaN values."""
n_samples = 1_000
X, y = make_classification(n_samples=n_samples, random_state=0)

search = HalvingSearch(
SometimesFailClassifier(),
{f"fail_{fail_at}": [False, True], "a": range(3)},
resource="n_estimators",
max_resources=6,
min_resources=1,
factor=2,
)

search.fit(X, y)

# estimators that failed during fit/predict should always rank lower
# than ones where the fit/predict succeeded
assert not search.best_params_[f"fail_{fail_at}"]
scores = search.cv_results_["mean_test_score"]
ranks = search.cv_results_["rank_test_score"]

# some scores should be NaN
assert np.isnan(scores).any()

unique_nan_ranks = np.unique(ranks[np.isnan(scores)])
# all NaN scores should have the same rank
assert unique_nan_ranks.shape[0] == 1
# NaNs should have the lowest rank
assert (unique_nan_ranks[0] >= ranks).all()


@pytest.mark.parametrize("Est", (HalvingGridSearchCV, HalvingRandomSearchCV))
@pytest.mark.parametrize(
"aggressive_elimination,"
Expand Down