diff --git a/doc/whats_new.rst b/doc/whats_new.rst index a4b775ec66d0a..6998ac63cf791 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -19,6 +19,11 @@ New features Enhancements ............ + - Added ``classes_`` attribute to :class:`model_selection.GridSearchCV` + that matches the ``classes_`` attribute of ``best_estimator_``. (`#7661 + `_) by `Alyssa + Batula`_ and `Dylan Werner-Meier`_. + - The ``min_weight_fraction_leaf`` constraint in tree construction is now more efficient, taking a fast path to declare a node a leaf if its weight is less than 2 * the minimum. Note that the constructed tree will be diff --git a/sklearn/grid_search.py b/sklearn/grid_search.py index 0de08ee9e89f0..f49d7e0485fa5 100644 --- a/sklearn/grid_search.py +++ b/sklearn/grid_search.py @@ -387,6 +387,10 @@ def __init__(self, estimator, scoring=None, def _estimator_type(self): return self.estimator._estimator_type + @property + def classes_(self): + return self.best_estimator_.classes_ + def score(self, X, y=None): """Returns the score on the given data, if the estimator has been refit. @@ -688,7 +692,7 @@ class GridSearchCV(BaseSearchCV): - An iterable yielding train/test splits. For integer/None inputs, if the estimator is a classifier and ``y`` is - either binary or multiclass, + either binary or multiclass, :class:`sklearn.model_selection.StratifiedKFold` is used. In all other cases, :class:`sklearn.model_selection.KFold` is used. @@ -900,7 +904,7 @@ class RandomizedSearchCV(BaseSearchCV): - An iterable yielding train/test splits. For integer/None inputs, if the estimator is a classifier and ``y`` is - either binary or multiclass, + either binary or multiclass, :class:`sklearn.model_selection.StratifiedKFold` is used. In all other cases, :class:`sklearn.model_selection.KFold` is used. diff --git a/sklearn/tests/test_grid_search.py b/sklearn/tests/test_grid_search.py index 13b9086310595..e6c2e18538163 100644 --- a/sklearn/tests/test_grid_search.py +++ b/sklearn/tests/test_grid_search.py @@ -42,6 +42,7 @@ from sklearn.metrics import f1_score from sklearn.metrics import make_scorer from sklearn.metrics import roc_auc_score +from sklearn.linear_model import Ridge from sklearn.exceptions import ChangedBehaviorWarning from sklearn.exceptions import FitFailedWarning @@ -785,3 +786,20 @@ def test_parameters_sampler_replacement(): sampler = ParameterSampler(params_distribution, n_iter=7) samples = list(sampler) assert_equal(len(samples), 7) + + +def test_classes__property(): + # Test that classes_ property matches best_esimator_.classes_ + X = np.arange(100).reshape(10, 10) + y = np.array([0] * 5 + [1] * 5) + Cs = [.1, 1, 10] + + grid_search = GridSearchCV(LinearSVC(random_state=0), {'C': Cs}) + grid_search.fit(X, y) + assert_array_equal(grid_search.best_estimator_.classes_, + grid_search.classes_) + + # Test that regressors do not have a classes_ attribute + grid_search = GridSearchCV(Ridge(), {'alpha': [1.0, 2.0]}) + grid_search.fit(X, y) + assert_false(hasattr(grid_search, 'classes_'))