|
13 | 13 | import pytest
|
14 | 14 | from scipy.stats import bernoulli, expon, uniform
|
15 | 15 |
|
| 16 | +from sklearn import config_context |
16 | 17 | from sklearn.base import BaseEstimator, ClassifierMixin, is_classifier
|
17 | 18 | from sklearn.cluster import KMeans
|
18 | 19 | from sklearn.datasets import (
|
19 | 20 | make_blobs,
|
20 | 21 | make_classification,
|
21 | 22 | make_multilabel_classification,
|
22 | 23 | )
|
| 24 | +from sklearn.dummy import DummyClassifier |
23 | 25 | from sklearn.ensemble import HistGradientBoostingClassifier
|
24 | 26 | from sklearn.exceptions import FitFailedWarning
|
25 | 27 | from sklearn.experimental import enable_halving_search_cv # noqa
|
26 | 28 | from sklearn.feature_extraction.text import TfidfVectorizer
|
27 | 29 | from sklearn.impute import SimpleImputer
|
28 | 30 | from sklearn.linear_model import (
|
29 | 31 | LinearRegression,
|
| 32 | + LogisticRegression, |
30 | 33 | Ridge,
|
31 | 34 | SGDClassifier,
|
32 | 35 | )
|
|
60 | 63 | from sklearn.naive_bayes import ComplementNB
|
61 | 64 | from sklearn.neighbors import KernelDensity, KNeighborsClassifier, LocalOutlierFactor
|
62 | 65 | from sklearn.pipeline import Pipeline
|
| 66 | +from sklearn.preprocessing import StandardScaler |
63 | 67 | from sklearn.svm import SVC, LinearSVC
|
64 | 68 | from sklearn.tests.metadata_routing_common import (
|
65 | 69 | ConsumingScorer,
|
@@ -2523,6 +2527,34 @@ def test_search_with_2d_array():
|
2523 | 2527 | np.testing.assert_array_equal(result.data, expected_data)
|
2524 | 2528 |
|
2525 | 2529 |
|
| 2530 | +def test_search_html_repr(): |
| 2531 | + """Test different HTML representations for GridSearchCV.""" |
| 2532 | + X, y = make_classification(random_state=42) |
| 2533 | + |
| 2534 | + pipeline = Pipeline([("scale", StandardScaler()), ("clf", DummyClassifier())]) |
| 2535 | + param_grid = {"clf": [DummyClassifier(), LogisticRegression()]} |
| 2536 | + |
| 2537 | + # Unfitted shows the original pipeline |
| 2538 | + search_cv = GridSearchCV(pipeline, param_grid=param_grid, refit=False) |
| 2539 | + with config_context(display="diagram"): |
| 2540 | + repr_html = search_cv._repr_html_() |
| 2541 | + assert "<pre>DummyClassifier()</pre>" in repr_html |
| 2542 | + |
| 2543 | + # Fitted with `refit=False` shows the original pipeline |
| 2544 | + search_cv.fit(X, y) |
| 2545 | + with config_context(display="diagram"): |
| 2546 | + repr_html = search_cv._repr_html_() |
| 2547 | + assert "<pre>DummyClassifier()</pre>" in repr_html |
| 2548 | + |
| 2549 | + # Fitted with `refit=True` shows the best estimator |
| 2550 | + search_cv = GridSearchCV(pipeline, param_grid=param_grid, refit=True) |
| 2551 | + search_cv.fit(X, y) |
| 2552 | + with config_context(display="diagram"): |
| 2553 | + repr_html = search_cv._repr_html_() |
| 2554 | + assert "<pre>DummyClassifier()</pre>" not in repr_html |
| 2555 | + assert "<pre>LogisticRegression()</pre>" in repr_html |
| 2556 | + |
| 2557 | + |
2526 | 2558 | # Metadata Routing Tests
|
2527 | 2559 | # ======================
|
2528 | 2560 |
|
|
0 commit comments