-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG+1] Add classes_ parameter to hyperparameter CV classes #8295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG+1] Add classes_ parameter to hyperparameter CV classes #8295
Conversation
In ``BaseSearchCV`` (superclass of ``GridSearchCV`` and ``RandomizedSearchCV``), add a ``clases_`` parameter which surfaces the ``classes_`` parameter of the ``best_estimator_``. Other parts of the scikit-learn code (e.g. ``cross_val_predict``) as well as users expect this property to be present on fitted classifiers.
I don't know why Travis claims to have failed. It looks like all of the tests succeeded. |
Travis is broken atm: travis-ci/travis-ci#7264 |
Whoops that we failed to port across #7661 to |
Thanks! |
@@ -914,7 +915,7 @@ def test_cross_val_predict_sparse_prediction(): | |||
assert_array_almost_equal(preds_sparse, preds) | |||
|
|||
|
|||
def test_cross_val_predict_with_method(): | |||
def run_cross_val_predict_with_method(est): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we often call helpers in tests check_*
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}) | ||
|
||
grid_search.fit(X, y) | ||
assert_array_equal(grid_search.classes_, np.unique(y)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a test to check that hasattr(gscv, 'classes_')
is false given a regressor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks. And thanks for catching our omission.
Codecov Report@@ Coverage Diff @@
## master #8295 +/- ##
==========================================
+ Coverage 94.74% 94.74% +<.01%
==========================================
Files 342 342
Lines 60711 60735 +24
==========================================
+ Hits 57519 57543 +24
Misses 3192 3192
Continue to review full report at Codecov.
|
@jnothman , I merged master, since GitHub reported a conflict in |
Usually I'd wait for another reviewer, just to be sure I didn't miss anything. Could you add a note in what's new to say this was a bug fix because we failed to implement the change properly for 0.18? |
Okay. The change in #7661 was for v0.19; it hasn't been in a release yet. The text in the existing What's New entry describes what this PR does, so I added the PR number rather than making a new entry. |
doc/whats_new.rst
Outdated
@@ -68,7 +68,8 @@ Enhancements | |||
|
|||
- Added ``classes_`` attribute to :class:`model_selection.GridSearchCV` | |||
that matches the ``classes_`` attribute of ``best_estimator_``. :issue:`7661` | |||
by :user:`Alyssa Batula <abatula>` and :user:`Dylan Werner-Meier <unautre>`. | |||
and :issue:`8295` by :user:`Alyssa Batula <abatula>`, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add the class grid_search.GridSearchCV
, grid_search.RandomizedSearchCV
and model_selection.RandomizedSearchCV
to make this entry more exact.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
LGTM, thanks a lot, merging! |
Thank you! |
Reference Issue
Closes #8290 .
What does this implement/fix? Explain your changes.
In
BaseSearchCV
(superclass ofGridSearchCV
andRandomizedSearchCV
), add aclasses_
parameter which surfaces theclasses_
parameter of thebest_estimator_
. Other parts of the scikit-learn code (e.g.sklearn.model_selection.cross_val_predict
) as well as users expect this property to be present on fitted classifiers.Any other comments?
When making this PR, I realized that there had already been a nearly identical PR, #7661 . That PR added the
classes_
property in the deprecatedsklearn.grid_search
module. This PR adds it to thesklearn.model_selection
module. The only difference is that I've added acall, with the intent of raising an error which is clearer than an
AttributeError
referencingbest_estimator_
.The What's New file claims that this change has already happened, so there doesn't seem to be anything to change there.