-
-
Notifications
You must be signed in to change notification settings - Fork 26.4k
VIZ, ENH Adding True/False labels in arrows coming from root in plot_tree to match export_tree
#28552
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Adam Li <[email protected]>
import numpy as np
from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
clf.fit(X_train, y_train)
fig, ax = plt.subplots(figsize=(12, 12))
tree.plot_tree(
clf,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True,
ax=ax,
)
plt.show() |
Signed-off-by: Adam Li <[email protected]>
Signed-off-by: Adam Li <[email protected]>
sklearn/tree/_export.py
Outdated
| text_pos, | ||
| ha="center", | ||
| va="center", | ||
| fontsize=10, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this have to be self.fontsize if it is not None?
How does this annotation scale when the figsize is large, (i.e. figsize=(20,20))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a refactoring of the kwargs, so that way it shouldn't be hardcoded here. I included in the PR description the image from various figure sizes.
Signed-off-by: Adam Li <[email protected]>
Signed-off-by: Adam Li <[email protected]>
Charlie-XIAO
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rendering looks nice 😊 Here are some minor comments that probably makes the code cleaner.
Co-authored-by: Yao Xiao <[email protected]>
Signed-off-by: Adam Li <[email protected]>
|
Thanks for the review @thomasjpfan and @Charlie-XIAO ! I addressed your comments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Here are some additional suggestions regarding the positions of the True/False labels when I read through the diff again. I'm not sure if my suggested solution would work, so you may need to try it out :)
Also it seems that Codecov is complaining because we do not have a test that uses the fontsize parameter. It is not originally caused by this PR but maybe you can add a test to cover that.
Signed-off-by: Adam Li <[email protected]>
Signed-off-by: Adam Li <[email protected]>
Added a unit-test for fantasize |
Co-authored-by: Yao Xiao <[email protected]>
adam2392
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review @Charlie-XIAO ! Now the figsize=(20,8) works well.
Signed-off-by: Adam Li <[email protected]>
Charlie-XIAO
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @adam2392! The plots look good to me now :)
| if node.parent.left() == node: | ||
| label_text, label_ha = ("True ", "right") | ||
| else: | ||
| label_text, label_ha = (" False", "left") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For other reviewers' reference, the spaces before and after the True/False label are to create some offset from the arrow on top of adjusting horizontal alignment. I'm not sure if there is a nicer way to do this; note that we are also creating padding by using spaces in annotation text e.g. this line so maybe this is fine?
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Charlie-XIAO Currently, it looks like displaying the labels in the tree is "hardcoded". What do you think about providing this feature of displaying the labels for the first split in a decision tree with a parameter for plot_tree() in case anyone does not like to display them? Maybe as a param called display_labels like in sklearn.metrics.ConfusionMatrixDisplay
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems reasonable to me, but I'll wait to see what the dev team thinks if it's adding complexity to the API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDYT @thomasjpfan and @Charlie-XIAO ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, I'm okay with the current implementation. Adding true and false labels is already improvement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW resolves #17184
sklearn/tree/_export.py
Outdated
| xycoords="axes fraction", | ||
| bbox=self.bbox_args.copy(), | ||
| arrowprops=self.arrow_args.copy(), | ||
| **non_box_kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If kwargs applies to annotations with a bounding box, then it's strange to have non_box_kwargs here.
If they share some kwargs, then I prefer this naming:
| **non_box_kwargs, | |
| **common_box_kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed to common_kwargs since the annotations true/false technically don't have a bounding box.
| if node.parent.left() == node: | ||
| label_text, label_ha = ("True ", "right") | ||
| else: | ||
| label_text, label_ha = (" False", "left") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, I'm okay with the current implementation. Adding true and false labels is already improvement.
thomasjpfan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Reference Issues/PRs
Fixes: #16153
Fixes: #17184
What does this implement/fix? Explain your changes.
True/Falselabels on top of the arrows that go from root node to left/right childAny other comments?
(figsize=(10,10)):

(20, 20):

(5, 10):
