-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] Standardize sample weights validation in DummyClassifier #15510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Looks good pending CI evaluation |
@@ -141,7 +142,10 @@ def fit(self, X, y, sample_weight=None): | |||
|
|||
self.n_outputs_ = y.shape[1] | |||
|
|||
check_consistent_length(X, y, sample_weight) | |||
check_consistent_length(X, y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we might still want it here, not strong opinion, though...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
even though it's checked as part of the added validation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to avoid redundant checks (even if it doesn't cost much). _check_sample_weight
should yield better error messages
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comment below, otherwise LGTM. Thanks @fbchow !
@@ -141,7 +142,10 @@ def fit(self, X, y, sample_weight=None): | |||
|
|||
self.n_outputs_ = y.shape[1] | |||
|
|||
check_consistent_length(X, y, sample_weight) | |||
check_consistent_length(X, y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to avoid redundant checks (even if it doesn't cost much). _check_sample_weight
should yield better error messages
sklearn/tests/test_dummy.py
Outdated
clf = DummyClassifier().fit(X, y, sample_weight) | ||
assert_array_almost_equal(clf.class_prior_, [0.2 / 1.2, 1. / 1.2]) | ||
|
||
sample_weight = np.random.rand(3, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For anything below this line, I'm not sure there is much sense in adding sample_weight tests for each estimator (unless they have some special behavior there) as they would be quite redundant, instead they should be enforced by common tests in sklearn/utils/estimator_checks.py
(and for a number of below point they may already be).
So I would remove the below test. The check specific to DummyClassifier
above is quite nice though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review!
We removed the test below.
Something went wrong with the merge of master here. The diff should only show the changed code... |
(cherry picked from commit 95929b9)
I'm Fanny's pair from the workshop. I'm not sure what happened with the merge from master. I have just branched off of master again and cherry-picked the 2 commits for ease of fixing everything. I can either push to a new branch and create a new PR or force push my current branch to overwrite the changes. Do the maintainers have a preference? |
Co-authored-by: Sallie Walecka <[email protected]> (cherry picked from commit e6bced8)
in this case I think force-pushing is fine if all the comments are addressed. |
dfa0c46
to
caa7871
Compare
Looks like CI is failing due to some warning coming from the test. However, the test is identical to the test case above it (apart from stratified sampling), so I just ended up deleting it. Pushing changes now. |
@amueller everything should be good now |
check_consistent_length(X, y) | ||
|
||
if sample_weight is not None: | ||
sample_weight = _check_sample_weight(sample_weight, X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually it looks like several PRs were done for this estimator as this was added on the line below in #15505. Please remove the above 3 lines.
Otherwise LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth merging at this point? Looks like the only changes left would be one linting issue and removing sample weight from check_consistent_length. Does it make more sense to close the PR instead?
@cmarmo wondering if there still needs to be work done on this one? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merging, thanks for contributing! This indeed leave only minor changes, and most were merged in another PR.
…t-learn#15510) Co-authored-by: Sallie Walecka <[email protected]>
Reference Issues/PRs
Partially addresses #15358 for DummyClassifier
What does this implement/fix? Explain your changes.
Replaces custom validation logic with standardized method
utils.validation._check_sample_weight
(relatively newly introduced).Any other comments?