-
-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Fix FEATURE_THRESHOLD initialization in trees #32259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix FEATURE_THRESHOLD initialization in trees #32259
Conversation
eb15f6b
to
b3efc12
Compare
b3efc12
to
76e630e
Compare
76e630e
to
2a3f7ec
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. Good catch!
If you have some bandwidth to detail the use-case that relied on this "ignore almost constant features" behavior, I would be happy. But that's just for my curiosity ^^
# Mitigate precision differences between 32 bit and 64 bit | ||
FEATURE_THRESHOLD = 1e-7 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been working on sklearn/tree/*
quite a lot lately, but this comment has remained a mystery to me. It seems you rely on this behavior, so maybe you can detail a bit more what's the purpose of "mitigating precision differences between 32 bit and 64 bit"?
(100% optional though)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @cakedev0 , I am also unsure of the purpose of this threshold. Actually, the test that failed on our side was based on randomly generated fake data. I don't believe we have features with such low min/max values. So, I think we also don't rely on this behavior.
sklearn/tree/tests/test_tree.py
Outdated
|
||
def test_almost_constant_feature(): | ||
random_state = check_random_state(0) | ||
X = random_state.rand(10, 20) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
X = random_state.rand(10, 20) | |
X = random_state.rand(10, 2) |
I think you just need 2 features for this test to work. It would make it clearer IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, let me push a commit for this . Also I will add an assertion that the other feature has an importance higher than 0.
Side note on force pushing: I like doing it as well but it seems to mess with links from notifications. Which means people get a notification, click on the link in it and then end up "in the middle of nowhere". So we recommend that people don't force push. The PR gets merged via squashing, so an "ugly" history doesn't matter so much. |
@betatim understood. Sorry for the noise! Will keep in mind for future contributions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for finding this. I think it would be good to get the eyes of a cython guru on this as well as other reviewers
Thanks a lot @sercant for tracking this and fixing it π! I am definitely not a Cython expert so maybe @adam2392 in case you have some spare bandwith and insights into why a variable initialization in pxd doesn't do anything? Here is what I double-checked:
I tweaked the test to be parametrized on the tree class and add a comment about the origin of the From a quick git grep, we likely use this pattern in other places in .pxd files, so I guess this would need to be looked at in more details π (at least in two places according to the regexp below).
|
def test_almost_constant_feature(): | ||
# Non regression test for | ||
# https://github.com/scikit-learn/scikit-learn/pull/32259 | ||
# Make sure that almost constant features are discarded. | ||
random_state = check_random_state(0) | ||
X = random_state.rand(10, 2) | ||
X[:, 0] *= 1e-7 # almost constant feature | ||
y = random_state.randint(0, 2, (10,)) | ||
for _, TreeEstimator in ALL_TREES.items(): | ||
est = TreeEstimator(random_state=0) | ||
est.fit(X, y) | ||
# the almost constant feature should not be used | ||
assert est.feature_importances_[0] == 0 | ||
# other feature should be used | ||
assert est.feature_importances_[1] > 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sercant for some reason, I can not push to your PR branch (maybe you have unticked the box "allow edits by maintainers"?), so doing this as a suggestion instead that you will need to accept. Those are the changes to the test I had in mind (parametrize + comment to explain where 1e-7
comes from):
def test_almost_constant_feature(): | |
# Non regression test for | |
# https://github.com/scikit-learn/scikit-learn/pull/32259 | |
# Make sure that almost constant features are discarded. | |
random_state = check_random_state(0) | |
X = random_state.rand(10, 2) | |
X[:, 0] *= 1e-7 # almost constant feature | |
y = random_state.randint(0, 2, (10,)) | |
for _, TreeEstimator in ALL_TREES.items(): | |
est = TreeEstimator(random_state=0) | |
est.fit(X, y) | |
# the almost constant feature should not be used | |
assert est.feature_importances_[0] == 0 | |
# other feature should be used | |
assert est.feature_importances_[1] > 0 | |
@pytest.mark.parametrize("tree_cls", ALL_TREES.values()) | |
def test_almost_constant_feature(tree_cls): | |
# Non regression test for | |
# https://github.com/scikit-learn/scikit-learn/pull/32259 | |
# Make sure that almost constant features are discarded. | |
random_state = check_random_state(0) | |
X = random_state.rand(10, 2) | |
# FEATURE_TRESHOLD=1e-7 is defined in sklearn/tree/_partitioner.pyx but not | |
# accessible from Python | |
feature_threshold = 1e-7 | |
X[:, 0] *= feature_threshold # almost constant feature | |
y = random_state.randint(0, 2, (10,)) | |
est = tree_cls(random_state=0) | |
est.fit(X, y) | |
# the almost constant feature should not be used | |
assert est.feature_importances_[0] == 0 | |
# other feature should be used | |
assert est.feature_importances_[1] > 0 |
|
||
# Mitigate precision differences between 32 bit and 64 bit | ||
cdef float32_t FEATURE_THRESHOLD = 1e-7 | ||
# Note: Has to be initialized in pyx file, not in the pxd file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not too sure this comment is really needed. We are probably not going to add a similar comment in each Cython module-level variable declaration, but at the same time the behaviour is very suprising (full disclosure: I am definitely not a Cython expert).
(nice idea the git grep + regexp ^^) I took a quick look: The first one is probably worth openning an issue, the code in The second one is ok, it's in a comment: # safe_realloc(&p, n) resizes the allocation of p to n * sizeof(*p) bytes or
# raises a MemoryError. It never calls free, since that's __dealloc__'s job.
# cdef float32_t *p = NULL
# safe_realloc(&p, n)
# is equivalent to p = malloc(n * sizeof(*p)) with error checking. |
@cakedev0 a PR would be more than welcome to fix the Note there may be a few more slightly different regex
Some of them are inside
The value is set in the |
Reference Issues/PRs
What does this implement/fix? Explain your changes.
I noticed one of our tests failing after upgrading from 1.5 to 1.6 and above. I traced the issue to the tree implementation change in #29458. The initialization of
cdef
constant cannot be made in the pxd file. This resulted inFEATURE_THRESHOLD
to be initialized to0.0
instead of1e-7
. This PR fixes that by moving the initialization to thepyx
file.Any other comments?
It's my first time contributing to scikit-learn, so please let me know if anything is missing.