Thanks to visit codestin.com
Credit goes to github.com

Skip to content

ENH HTML repr show best estimator in *SearchCV when refit=True #28722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 11, 2024
13 changes: 11 additions & 2 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -347,19 +347,28 @@ Changelog

- |Enhancement| :term:`CV splitters <CV splitter>` that ignores the group parameter now
raises a warning when groups are passed in to :term:`split`. :pr:`28210` by
`Thomas Fan`_.

- |Fix| the ``cv_results_`` attribute (of :class:`model_selection.GridSearchCV`) now
returns masked arrays of the appropriate NumPy dtype, as opposed to always returning
dtype ``object``. :pr:`28352` by :user:`Marco Gorelli<MarcoGorelli>`.

- |Fix| :func:`sklearn.model_selection.train_test_score` works with Array API inputs.
Previously indexing was not handled correctly leading to exceptions when using strict
implementations of the Array API like CuPY.
:pr:`28407` by `Tim Head <betatim>`.
:pr:`28407` by :user:`Tim Head <betatim>`.

- |Enhancement| The HTML diagram representation of
:class:`~model_selection.GridSearchCV`,
:class:`~model_selection.RandomizedSearchCV`,
:class:`~model_selection.HalvingGridSearchCV`, and
:class:`~model_selection.HalvingRandomSearchCV` will show the best estimator when
`refit=True`. :pr:`28722` by :user:`Yao Xiao <Charlie-XIAO>` and `Thomas Fan`_.

:mod:`sklearn.multioutput`
..........................

- |Enhancement| `chain_method` parameter added to `:class:`multioutput.ClassifierChain`.
- |Enhancement| `chain_method` parameter added to :class:`multioutput.ClassifierChain`.
:pr:`27700` by :user:`Lucy Liu <lucyleeow>`.

:mod:`sklearn.neighbors`
Expand Down
14 changes: 14 additions & 0 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
get_scorer_names,
)
from ..utils import Bunch, check_random_state
from ..utils._estimator_html_repr import _VisualBlock
from ..utils._param_validation import HasMethods, Interval, StrOptions
from ..utils._tags import _safe_tags
from ..utils.metadata_routing import (
Expand Down Expand Up @@ -1153,6 +1154,19 @@ def get_metadata_routing(self):
)
return router

def _sk_visual_block_(self):
if hasattr(self, "best_estimator_"):
key, estimator = "best_estimator_", self.best_estimator_
else:
key, estimator = "estimator", self.estimator

return _VisualBlock(
"parallel",
[estimator],
names=[f"{key}: {estimator.__class__.__name__}"],
name_details=[str(estimator)],
)


class GridSearchCV(BaseSearchCV):
"""Exhaustive search over specified parameter values for an estimator.
Expand Down
32 changes: 32 additions & 0 deletions sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,23 @@
import pytest
from scipy.stats import bernoulli, expon, uniform

from sklearn import config_context
from sklearn.base import BaseEstimator, ClassifierMixin, is_classifier
from sklearn.cluster import KMeans
from sklearn.datasets import (
make_blobs,
make_classification,
make_multilabel_classification,
)
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.exceptions import FitFailedWarning
from sklearn.experimental import enable_halving_search_cv # noqa
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.impute import SimpleImputer
from sklearn.linear_model import (
LinearRegression,
LogisticRegression,
Ridge,
SGDClassifier,
)
Expand Down Expand Up @@ -60,6 +63,7 @@
from sklearn.naive_bayes import ComplementNB
from sklearn.neighbors import KernelDensity, KNeighborsClassifier, LocalOutlierFactor
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC, LinearSVC
from sklearn.tests.metadata_routing_common import (
ConsumingScorer,
Expand Down Expand Up @@ -2523,6 +2527,34 @@ def test_search_with_2d_array():
np.testing.assert_array_equal(result.data, expected_data)


def test_search_html_repr():
"""Test different HTML representations for GridSearchCV."""
X, y = make_classification(random_state=42)

pipeline = Pipeline([("scale", StandardScaler()), ("clf", DummyClassifier())])
param_grid = {"clf": [DummyClassifier(), LogisticRegression()]}

# Unfitted shows the original pipeline
search_cv = GridSearchCV(pipeline, param_grid=param_grid, refit=False)
with config_context(display="diagram"):
repr_html = search_cv._repr_html_()
assert "<pre>DummyClassifier()</pre>" in repr_html

# Fitted with `refit=False` shows the original pipeline
search_cv.fit(X, y)
with config_context(display="diagram"):
repr_html = search_cv._repr_html_()
assert "<pre>DummyClassifier()</pre>" in repr_html

# Fitted with `refit=True` shows the best estimator
search_cv = GridSearchCV(pipeline, param_grid=param_grid, refit=True)
search_cv.fit(X, y)
with config_context(display="diagram"):
repr_html = search_cv._repr_html_()
assert "<pre>DummyClassifier()</pre>" not in repr_html
assert "<pre>LogisticRegression()</pre>" in repr_html


# Metadata Routing Tests
# ======================

Expand Down
Loading