Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit bf61674

Browse files
ENH HTML repr show best estimator in *SearchCV when refit=True (#28722)
Co-authored-by: Thomas J. Fan <[email protected]>
1 parent c799133 commit bf61674

File tree

3 files changed

+57
-2
lines changed

3 files changed

+57
-2
lines changed

doc/whats_new/v1.5.rst

+11-2
Original file line numberDiff line numberDiff line change
@@ -347,19 +347,28 @@ Changelog
347347

348348
- |Enhancement| :term:`CV splitters <CV splitter>` that ignores the group parameter now
349349
raises a warning when groups are passed in to :term:`split`. :pr:`28210` by
350+
`Thomas Fan`_.
351+
350352
- |Fix| the ``cv_results_`` attribute (of :class:`model_selection.GridSearchCV`) now
351353
returns masked arrays of the appropriate NumPy dtype, as opposed to always returning
352354
dtype ``object``. :pr:`28352` by :user:`Marco Gorelli<MarcoGorelli>`.
353355

354356
- |Fix| :func:`sklearn.model_selection.train_test_score` works with Array API inputs.
355357
Previously indexing was not handled correctly leading to exceptions when using strict
356358
implementations of the Array API like CuPY.
357-
:pr:`28407` by `Tim Head <betatim>`.
359+
:pr:`28407` by :user:`Tim Head <betatim>`.
360+
361+
- |Enhancement| The HTML diagram representation of
362+
:class:`~model_selection.GridSearchCV`,
363+
:class:`~model_selection.RandomizedSearchCV`,
364+
:class:`~model_selection.HalvingGridSearchCV`, and
365+
:class:`~model_selection.HalvingRandomSearchCV` will show the best estimator when
366+
`refit=True`. :pr:`28722` by :user:`Yao Xiao <Charlie-XIAO>` and `Thomas Fan`_.
358367

359368
:mod:`sklearn.multioutput`
360369
..........................
361370

362-
- |Enhancement| `chain_method` parameter added to `:class:`multioutput.ClassifierChain`.
371+
- |Enhancement| `chain_method` parameter added to :class:`multioutput.ClassifierChain`.
363372
:pr:`27700` by :user:`Lucy Liu <lucyleeow>`.
364373

365374
:mod:`sklearn.neighbors`

sklearn/model_selection/_search.py

+14
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
get_scorer_names,
3434
)
3535
from ..utils import Bunch, check_random_state
36+
from ..utils._estimator_html_repr import _VisualBlock
3637
from ..utils._param_validation import HasMethods, Interval, StrOptions
3738
from ..utils._tags import _safe_tags
3839
from ..utils.metadata_routing import (
@@ -1153,6 +1154,19 @@ def get_metadata_routing(self):
11531154
)
11541155
return router
11551156

1157+
def _sk_visual_block_(self):
1158+
if hasattr(self, "best_estimator_"):
1159+
key, estimator = "best_estimator_", self.best_estimator_
1160+
else:
1161+
key, estimator = "estimator", self.estimator
1162+
1163+
return _VisualBlock(
1164+
"parallel",
1165+
[estimator],
1166+
names=[f"{key}: {estimator.__class__.__name__}"],
1167+
name_details=[str(estimator)],
1168+
)
1169+
11561170

11571171
class GridSearchCV(BaseSearchCV):
11581172
"""Exhaustive search over specified parameter values for an estimator.

sklearn/model_selection/tests/test_search.py

+32
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,23 @@
1313
import pytest
1414
from scipy.stats import bernoulli, expon, uniform
1515

16+
from sklearn import config_context
1617
from sklearn.base import BaseEstimator, ClassifierMixin, is_classifier
1718
from sklearn.cluster import KMeans
1819
from sklearn.datasets import (
1920
make_blobs,
2021
make_classification,
2122
make_multilabel_classification,
2223
)
24+
from sklearn.dummy import DummyClassifier
2325
from sklearn.ensemble import HistGradientBoostingClassifier
2426
from sklearn.exceptions import FitFailedWarning
2527
from sklearn.experimental import enable_halving_search_cv # noqa
2628
from sklearn.feature_extraction.text import TfidfVectorizer
2729
from sklearn.impute import SimpleImputer
2830
from sklearn.linear_model import (
2931
LinearRegression,
32+
LogisticRegression,
3033
Ridge,
3134
SGDClassifier,
3235
)
@@ -60,6 +63,7 @@
6063
from sklearn.naive_bayes import ComplementNB
6164
from sklearn.neighbors import KernelDensity, KNeighborsClassifier, LocalOutlierFactor
6265
from sklearn.pipeline import Pipeline
66+
from sklearn.preprocessing import StandardScaler
6367
from sklearn.svm import SVC, LinearSVC
6468
from sklearn.tests.metadata_routing_common import (
6569
ConsumingScorer,
@@ -2523,6 +2527,34 @@ def test_search_with_2d_array():
25232527
np.testing.assert_array_equal(result.data, expected_data)
25242528

25252529

2530+
def test_search_html_repr():
2531+
"""Test different HTML representations for GridSearchCV."""
2532+
X, y = make_classification(random_state=42)
2533+
2534+
pipeline = Pipeline([("scale", StandardScaler()), ("clf", DummyClassifier())])
2535+
param_grid = {"clf": [DummyClassifier(), LogisticRegression()]}
2536+
2537+
# Unfitted shows the original pipeline
2538+
search_cv = GridSearchCV(pipeline, param_grid=param_grid, refit=False)
2539+
with config_context(display="diagram"):
2540+
repr_html = search_cv._repr_html_()
2541+
assert "<pre>DummyClassifier()</pre>" in repr_html
2542+
2543+
# Fitted with `refit=False` shows the original pipeline
2544+
search_cv.fit(X, y)
2545+
with config_context(display="diagram"):
2546+
repr_html = search_cv._repr_html_()
2547+
assert "<pre>DummyClassifier()</pre>" in repr_html
2548+
2549+
# Fitted with `refit=True` shows the best estimator
2550+
search_cv = GridSearchCV(pipeline, param_grid=param_grid, refit=True)
2551+
search_cv.fit(X, y)
2552+
with config_context(display="diagram"):
2553+
repr_html = search_cv._repr_html_()
2554+
assert "<pre>DummyClassifier()</pre>" not in repr_html
2555+
assert "<pre>LogisticRegression()</pre>" in repr_html
2556+
2557+
25262558
# Metadata Routing Tests
25272559
# ======================
25282560

0 commit comments

Comments
 (0)