diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 1a8c50e408a0b..01f0384af5c1d 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -347,6 +347,8 @@ Changelog - |Enhancement| :term:`CV splitters ` that ignores the group parameter now raises a warning when groups are passed in to :term:`split`. :pr:`28210` by + `Thomas Fan`_. + - |Fix| the ``cv_results_`` attribute (of :class:`model_selection.GridSearchCV`) now returns masked arrays of the appropriate NumPy dtype, as opposed to always returning dtype ``object``. :pr:`28352` by :user:`Marco Gorelli`. @@ -354,12 +356,19 @@ Changelog - |Fix| :func:`sklearn.model_selection.train_test_score` works with Array API inputs. Previously indexing was not handled correctly leading to exceptions when using strict implementations of the Array API like CuPY. - :pr:`28407` by `Tim Head `. + :pr:`28407` by :user:`Tim Head `. + +- |Enhancement| The HTML diagram representation of + :class:`~model_selection.GridSearchCV`, + :class:`~model_selection.RandomizedSearchCV`, + :class:`~model_selection.HalvingGridSearchCV`, and + :class:`~model_selection.HalvingRandomSearchCV` will show the best estimator when + `refit=True`. :pr:`28722` by :user:`Yao Xiao ` and `Thomas Fan`_. :mod:`sklearn.multioutput` .......................... -- |Enhancement| `chain_method` parameter added to `:class:`multioutput.ClassifierChain`. +- |Enhancement| `chain_method` parameter added to :class:`multioutput.ClassifierChain`. :pr:`27700` by :user:`Lucy Liu `. :mod:`sklearn.neighbors` diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 9b9072f1491a2..42fde09c16bce 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -33,6 +33,7 @@ get_scorer_names, ) from ..utils import Bunch, check_random_state +from ..utils._estimator_html_repr import _VisualBlock from ..utils._param_validation import HasMethods, Interval, StrOptions from ..utils._tags import _safe_tags from ..utils.metadata_routing import ( @@ -1153,6 +1154,19 @@ def get_metadata_routing(self): ) return router + def _sk_visual_block_(self): + if hasattr(self, "best_estimator_"): + key, estimator = "best_estimator_", self.best_estimator_ + else: + key, estimator = "estimator", self.estimator + + return _VisualBlock( + "parallel", + [estimator], + names=[f"{key}: {estimator.__class__.__name__}"], + name_details=[str(estimator)], + ) + class GridSearchCV(BaseSearchCV): """Exhaustive search over specified parameter values for an estimator. diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 1ff4520034ff0..1a9230259d22e 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -13,6 +13,7 @@ import pytest from scipy.stats import bernoulli, expon, uniform +from sklearn import config_context from sklearn.base import BaseEstimator, ClassifierMixin, is_classifier from sklearn.cluster import KMeans from sklearn.datasets import ( @@ -20,6 +21,7 @@ make_classification, make_multilabel_classification, ) +from sklearn.dummy import DummyClassifier from sklearn.ensemble import HistGradientBoostingClassifier from sklearn.exceptions import FitFailedWarning from sklearn.experimental import enable_halving_search_cv # noqa @@ -27,6 +29,7 @@ from sklearn.impute import SimpleImputer from sklearn.linear_model import ( LinearRegression, + LogisticRegression, Ridge, SGDClassifier, ) @@ -60,6 +63,7 @@ from sklearn.naive_bayes import ComplementNB from sklearn.neighbors import KernelDensity, KNeighborsClassifier, LocalOutlierFactor from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC, LinearSVC from sklearn.tests.metadata_routing_common import ( ConsumingScorer, @@ -2523,6 +2527,34 @@ def test_search_with_2d_array(): np.testing.assert_array_equal(result.data, expected_data) +def test_search_html_repr(): + """Test different HTML representations for GridSearchCV.""" + X, y = make_classification(random_state=42) + + pipeline = Pipeline([("scale", StandardScaler()), ("clf", DummyClassifier())]) + param_grid = {"clf": [DummyClassifier(), LogisticRegression()]} + + # Unfitted shows the original pipeline + search_cv = GridSearchCV(pipeline, param_grid=param_grid, refit=False) + with config_context(display="diagram"): + repr_html = search_cv._repr_html_() + assert "
DummyClassifier()
" in repr_html + + # Fitted with `refit=False` shows the original pipeline + search_cv.fit(X, y) + with config_context(display="diagram"): + repr_html = search_cv._repr_html_() + assert "
DummyClassifier()
" in repr_html + + # Fitted with `refit=True` shows the best estimator + search_cv = GridSearchCV(pipeline, param_grid=param_grid, refit=True) + search_cv.fit(X, y) + with config_context(display="diagram"): + repr_html = search_cv._repr_html_() + assert "
DummyClassifier()
" not in repr_html + assert "
LogisticRegression()
" in repr_html + + # Metadata Routing Tests # ======================