Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@betatim
Copy link
Member

@betatim betatim commented Oct 3, 2025

Todo:

  • clean up commit
  • add a test that fails on main and passes on this PR

Reference Issues/PRs

#32178

What does this implement/fix? Explain your changes.

This is a simple fix, but it took a while to find. The problem occurs only if you first look at a different feature and then consider the feature that has missing values. This explains (I think) why it happens roughly half the time when you run the following snippet:

In [2]: import numpy as np
   ...: from sklearn.tree import DecisionTreeRegressor
   ...:
   ...: X = np.vstack([
   ...:     [0, 0, 0, 0, 1, 2, 3, 4],
   ...:     [1, 2, 1, 2, 1, 2, 1, 2]
   ...: ]).swapaxes(0, 1).astype(float)
   ...: y = [0, 0, 0, 0, 1, 1, 1, 1]
   ...: X[X==0] = np.nan
   ...: for i in range(50):
   ...:   tree = DecisionTreeRegressor(max_depth=1, random_state=i, min_impurity_decrease=0.002).fit(X, y)
   ...:   print(tree.tree_.impurity)

The way to tell that something went wrong is that the impurity is not [0.25 0. 0. ] but [0.25 0.1875 0.1875]. If you uncomment the print statements (in the first commit of this PR) you will see the following printed when running the above snippet for a random state that is broken (i=2 on my computer).

before partitioning best_n_missing=4 partition_end=3 start=0 end=7 best_threshold=inf
X[i, best_feature]=1.0 i=4
X[i, best_feature]=nan i=0
X[i, best_feature]=3.0 i=6
X[i, best_feature]=nan i=2
X[i, best_feature]=nan i=3
X[i, best_feature]=nan i=1
X[i, best_feature]=4.0 i=7
X[i, best_feature]=2.0 i=5
---- p=0 samples[p]=4 current_value=1.0 end=7
X[i, best_feature]=1.0 i=4
X[i, best_feature]=nan i=0
X[i, best_feature]=3.0 i=6
X[i, best_feature]=nan i=2
X[i, best_feature]=nan i=3
X[i, best_feature]=nan i=1
X[i, best_feature]=4.0 i=7
X[i, best_feature]=2.0 i=5
---- p=1 samples[p]=0 current_value=nan end=7
X[i, best_feature]=1.0 i=4
X[i, best_feature]=nan i=0
X[i, best_feature]=3.0 i=6
X[i, best_feature]=nan i=2
X[i, best_feature]=nan i=3
X[i, best_feature]=nan i=1
X[i, best_feature]=4.0 i=7
X[i, best_feature]=2.0 i=5
---- p=2 samples[p]=6 current_value=3.0 end=6
X[i, best_feature]=1.0 i=4
X[i, best_feature]=2.0 i=5
X[i, best_feature]=3.0 i=6
X[i, best_feature]=nan i=2
X[i, best_feature]=nan i=3
X[i, best_feature]=nan i=1
X[i, best_feature]=4.0 i=7
X[i, best_feature]=nan i=0
after partitioning
X[i, best_feature]=1.0 i=4
X[i, best_feature]=2.0 i=5
X[i, best_feature]=3.0 i=6
X[i, best_feature]=nan i=2
X[i, best_feature]=nan i=3
X[i, best_feature]=nan i=1
X[i, best_feature]=4.0 i=7
X[i, best_feature]=nan i=0
[0.25   0.1875 0.1875]

If you look at the last print out of the samples you see that 4.0 is in amongst the missing values. Counting how how many times we "swapped" values/where p ends is what makes me think we need to either use <= for the condition or intp_t partition_end = end - best_n_missing + 1

Any other comments?

This was an interesting puzzle. I used the snippet from the issue that @cakedev0 posted and then asked Cursor to explain to me how it could be that this happens. It fairly quickly figured out that it must be something to do with which feature is considered first (because we permute the feature order). It then hinted at a problem with samples not being sorted properly. But the fixes it suggested were to add a special case to partition_samples_final for when the threshold was np.inf. This kind of fixed the problem but not completely. It then suggested sorting the values before calling the function. At this point I decided that Cursor was doing what it often does: come up with "overfitted" solutions (specific solution for just the one problem you asked it about). Having had it point to sorting being a problem and that it only happens half the time I then went back to the old school "add print statements" approach :D Combined with a hunch that an off by one error is likely given how much gymnastics happens to calculate all the start and end positions.

Overall a fun experience. Thanks for making that initial snippet @cakedev0 it was a good minimal reproducer!

We need to examine one more value to make sure we look at all the non
missing values.
@github-actions
Copy link

github-actions bot commented Oct 3, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: ef7d10c. Link to the linter CI: here

@cakedev0
Copy link
Contributor

cakedev0 commented Oct 3, 2025

Note: My PR #32119 fixes the issue (even if it's not it's primary goal, so it might make sense to do a 1-line fix PR first ^^)

@betatim
Copy link
Member Author

betatim commented Oct 3, 2025

Ah ok. I had a quick look at the linked PRs from the issue and didn't realise that there was already a fix :-/

@betatim
Copy link
Member Author

betatim commented Oct 3, 2025

Do you have a nice unittest for this issue? Fitting a bunch of trees and checking that all of them get the right impurity/have enough splits seems a bit of a round about way of testing this :-/

@cakedev0
Copy link
Contributor

cakedev0 commented Oct 3, 2025

Do you have a nice unittest for this issue?

See PR #32193 😄

Fitting a bunch of trees and checking that all of them get the right impurity/have enough splits seems a bit of a round about way of testing this

I personally love tests with many random inputs. Toy examples should only be used for debugging IMO. It's still good to have them as tests, as it makes it easy to debug regressions. But it's really not enough to build confidence.

@betatim betatim marked this pull request as ready for review October 3, 2025 12:19
@betatim betatim requested a review from thomasjpfan October 3, 2025 13:35
@ogrisel
Copy link
Member

ogrisel commented Oct 7, 2025

I personally love tests with many random inputs. Toy examples should only be used for debugging IMO. It's still good to have them as tests, as it makes it easy to debug regressions. But it's really not enough to build confidence.

Unfortunately, such randomized tests are often costly to run so we need to strike a balance. Ideally, individual tests should last no longer than 1 s on a usual CI runner.

Slow test suites makes it painful to run the tests locally or wait for the CI to complete on a PR. It would further make the release process more horrible than it is, given the number of Python versions and architecture combinations that we support.

@betatim
Copy link
Member Author

betatim commented Oct 14, 2025

What about adding a function to _partitioner.pyx that instantiates a DensePartitioner so that we can access it directly, instead of via a decision tree? We already have _py_sort to expose sort for testing purposes.

Being able to directly instantiate the partitioner would have been useful during debugging as well. And we could more easily test it. What are the downsides? I think having the class itself as cdef has performance benefits(?), but I can't think of a downside in terms of performance. The only downside I can think of is more code :-/ WDYT @ogrisel ?

Asking cursor for a sketch of how to do this gives the following:

def _py_test_partition_samples_final(
    const float32_t[:, :] X,
    intp_t[::1] samples,
    intp_t feature,
    float64_t threshold,
    intp_t n_missing,
    const uint8_t[::1] missing_values_in_feature_mask,
):
    """Test helper to directly test partition_samples_final.
    
    This function is used for testing the partitioning logic when missing
    values are present. It creates a DensePartitioner and calls
    partition_samples_final.
    
    Parameters
    ----------
    X : array-like of shape (n_samples, n_features)
        The input data
    samples : array-like of shape (n_samples,)
        Sample indices to partition (modified in-place)
    feature : int
        The feature index to partition on
    threshold : float
        The threshold value for partitioning
    n_missing : int
        The number of missing values in this feature
    missing_values_in_feature_mask : array-like of shape (n_features,)
        Mask indicating which features have missing values
        
    Returns
    -------
    None
        The samples array is modified in-place
    """
    cdef float32_t[::1] feature_values = np.empty(len(samples), dtype=np.float32)
    cdef DensePartitioner partitioner = DensePartitioner(
        X, samples, feature_values, missing_values_in_feature_mask
    )
    
    partitioner.init_node_split(0, len(samples))
    partitioner.partition_samples_final(
        best_pos=0,  # Not used in the actual partitioning logic
        best_threshold=threshold,
        best_feature=feature,
        best_n_missing=n_missing,
    )

@cakedev0
Copy link
Contributor

I have one drawback in mind, but that depends how you use it.

Assuming you're using this wrapper to write the non-regression test for your fix, here is what could happen: someone refactors the inner working of trees, and can't really translate your test to the new design, so they remove it (or rewrite a weaker/incorrect version), and introduce the regression. Not super likely, but it still make me prefer testing the top level API when I can.
And here I think it's not too hard. As I already said in #32351 (comment), I think the test you've wrote would be a pretty good test with just a few simple changes.

By repeating the fit and checking the computed impurity we should catch
the bug. The bug should happen in ~50% of fits, so ten fits are likely
to trigger it.
@betatim
Copy link
Member Author

betatim commented Oct 28, 2025

I went with trying ten different trees. I think this is a good compromise between catching the bug and not spending too much compute time on this bug (e.g. by using different seeds).

@betatim betatim requested a review from ogrisel October 28, 2025 12:53
Copy link
Contributor

@cakedev0 cakedev0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Nit: I would rather go with 20 trees as 1/2^10 is still 0.1% but maybe I'm too cautious 😅

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@ogrisel ogrisel enabled auto-merge (squash) November 6, 2025 10:17
@ogrisel
Copy link
Member

ogrisel commented Nov 6, 2025

I updated with main and marked as auto merge.

@ogrisel ogrisel merged commit 8aceace into scikit-learn:main Nov 6, 2025
38 checks passed
@betatim betatim deleted the fix-impurity-calculation-with-missing branch November 7, 2025 18:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants