-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
FIX Fix DecisionTree* partitioning with missing values present
#32351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX Fix DecisionTree* partitioning with missing values present
#32351
Conversation
We need to examine one more value to make sure we look at all the non missing values.
|
Note: My PR #32119 fixes the issue (even if it's not it's primary goal, so it might make sense to do a 1-line fix PR first ^^) |
|
Ah ok. I had a quick look at the linked PRs from the issue and didn't realise that there was already a fix :-/ |
|
Do you have a nice unittest for this issue? Fitting a bunch of trees and checking that all of them get the right |
See PR #32193 😄
I personally love tests with many random inputs. Toy examples should only be used for debugging IMO. It's still good to have them as tests, as it makes it easy to debug regressions. But it's really not enough to build confidence. |
Unfortunately, such randomized tests are often costly to run so we need to strike a balance. Ideally, individual tests should last no longer than 1 s on a usual CI runner. Slow test suites makes it painful to run the tests locally or wait for the CI to complete on a PR. It would further make the release process more horrible than it is, given the number of Python versions and architecture combinations that we support. |
2655a4e to
dac3345
Compare
|
What about adding a function to Being able to directly instantiate the partitioner would have been useful during debugging as well. And we could more easily test it. What are the downsides? I think having the class itself as Asking cursor for a sketch of how to do this gives the following: def _py_test_partition_samples_final(
const float32_t[:, :] X,
intp_t[::1] samples,
intp_t feature,
float64_t threshold,
intp_t n_missing,
const uint8_t[::1] missing_values_in_feature_mask,
):
"""Test helper to directly test partition_samples_final.
This function is used for testing the partitioning logic when missing
values are present. It creates a DensePartitioner and calls
partition_samples_final.
Parameters
----------
X : array-like of shape (n_samples, n_features)
The input data
samples : array-like of shape (n_samples,)
Sample indices to partition (modified in-place)
feature : int
The feature index to partition on
threshold : float
The threshold value for partitioning
n_missing : int
The number of missing values in this feature
missing_values_in_feature_mask : array-like of shape (n_features,)
Mask indicating which features have missing values
Returns
-------
None
The samples array is modified in-place
"""
cdef float32_t[::1] feature_values = np.empty(len(samples), dtype=np.float32)
cdef DensePartitioner partitioner = DensePartitioner(
X, samples, feature_values, missing_values_in_feature_mask
)
partitioner.init_node_split(0, len(samples))
partitioner.partition_samples_final(
best_pos=0, # Not used in the actual partitioning logic
best_threshold=threshold,
best_feature=feature,
best_n_missing=n_missing,
) |
|
I have one drawback in mind, but that depends how you use it. Assuming you're using this wrapper to write the non-regression test for your fix, here is what could happen: someone refactors the inner working of trees, and can't really translate your test to the new design, so they remove it (or rewrite a weaker/incorrect version), and introduce the regression. Not super likely, but it still make me prefer testing the top level API when I can. |
By repeating the fit and checking the computed impurity we should catch the bug. The bug should happen in ~50% of fits, so ten fits are likely to trigger it.
|
I went with trying ten different trees. I think this is a good compromise between catching the bug and not spending too much compute time on this bug (e.g. by using different seeds). |
cakedev0
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Nit: I would rather go with 20 trees as 1/2^10 is still 0.1% but maybe I'm too cautious 😅
ogrisel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
|
I updated with |
Todo:
mainand passes on this PRReference Issues/PRs
#32178
What does this implement/fix? Explain your changes.
This is a simple fix, but it took a while to find. The problem occurs only if you first look at a different feature and then consider the feature that has missing values. This explains (I think) why it happens roughly half the time when you run the following snippet:
The way to tell that something went wrong is that the impurity is not
[0.25 0. 0. ]but[0.25 0.1875 0.1875]. If you uncomment theprintstatements (in the first commit of this PR) you will see the following printed when running the above snippet for a random state that is broken (i=2on my computer).If you look at the last print out of the samples you see that
4.0is in amongst the missing values. Counting how how many times we "swapped" values/wherepends is what makes me think we need to either use<=for the condition orintp_t partition_end = end - best_n_missing + 1Any other comments?
This was an interesting puzzle. I used the snippet from the issue that @cakedev0 posted and then asked Cursor to explain to me how it could be that this happens. It fairly quickly figured out that it must be something to do with which feature is considered first (because we permute the feature order). It then hinted at a problem with samples not being sorted properly. But the fixes it suggested were to add a special case to
partition_samples_finalfor when the threshold wasnp.inf. This kind of fixed the problem but not completely. It then suggested sorting the values before calling the function. At this point I decided that Cursor was doing what it often does: come up with "overfitted" solutions (specific solution for just the one problem you asked it about). Having had it point to sorting being a problem and that it only happens half the time I then went back to the old school "add print statements" approach :D Combined with a hunch that an off by one error is likely given how much gymnastics happens to calculate all the start and end positions.Overall a fun experience. Thanks for making that initial snippet @cakedev0 it was a good minimal reproducer!