-
-
Notifications
You must be signed in to change notification settings - Fork 26.6k
[MRG+1] fix plot_partial_dependence not taking target into account when multiclass #14393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG+1] fix plot_partial_dependence not taking target into account when multiclass #14393
Conversation
|
cc @NicolasHug |
glemaitre
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise LGTM
If |
|
You can add a test for it in fact (in another PR). Quickly the test should be something like: from sklearn.datasets import fetch_openml
from sklearn.tree import DecisionTreeClassifier
from sklearn.inspection import plot_partial_depedence
iris = fetch_openml('iris', as_frame=True, version=1)
df, y = iris.data, iris.target.to_numpy()
clf = DecisionTreeClassifier().fit(df, y)
assert dtype(clf.classes_) == 'object'
# check that the pdp with str and int give the same results
# pick-up the last class
# implement the assert as in this PR
plot_partial_dependence(clf, df, [0], target='Iris-viriginica')
plot_partial_dependence(clf, df, [0], target=2) |
Actually I asked because there is already a test about that
|
Oh perfect then, so no need for an additional test ;) |
NicolasHug
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small comments but LGTM anywway.
Thanks for the fix @GuillemGSubies !
| # check that the pd plots are the same for 0 and "setosa" | ||
| assert all(axs[0].lines[0]._y == axs2[0].lines[0]._y) | ||
| # check that the pd plots are different for another target | ||
| clf = GradientBoostingClassifier(n_estimators=10, random_state=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can remove a few lines, namely the clf definition and fitting, as well as the grid_resolution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I will. I did not change those because I did not know if it had to be with some standard you use when testing.
NicolasHug
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small comments but LGTM anywway.
Thanks for the fix @GuillemGSubies !
Co-Authored-By: Nicolas Hug <[email protected]>
…ithub.com/GuillemGSubies/scikit-learn into bugfix_partial_dependence_plot_multiclass
| # check that the pd plots are the same for 0 and "setosa" | ||
| assert all(axs[0].lines[0]._y == axs2[0].lines[0]._y) | ||
| # check that the pd plots are different for another target | ||
| clf.fit(iris.data, iris.target) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can remove this line too ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like I shouldn't have removed it. That means that if I train using the targets as strings, I cannot pass an int to plot_partial_dependence
Don't know if that is the expected behavior or not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh OK, my bad, I didn't realize it was fit on something different before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted
|
Thanks @GuillemGSubies |
Reference Issues/PRs
Fixes #14301
What does this implement/fix? Explain your changes.
I just took out an else so
target_idxdoes not get overwritten.Any other comments?
I didn't know what was the optimal way to test it. Right now I check the y axis and make sure that they are not the same (the bug made them equals all the time).
Also, I have a question: Here it should be int or str, shouldn't it?
scikit-learn/sklearn/inspection/partial_dependence.py
Line 404 in c0c5313