Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@JohnStott
Copy link
Contributor

K is not being updated properly in certain situations where we have non-uniform sample weights. This occurs during a removal/pop or push onto the WeightedMedianCalculator. The proposed fix solves this problem by identifying the exact index a new / old sample is added / removed and applying additional logic to update K correctly.

Reference Issues/PRs

Fixes #10725 (BUG Median not always being calculated correctly for DecisionTrees in the WeightedMedianCalculator)

@JohnStott
Copy link
Contributor Author

More detail...

1) Proof that there is an existing problem with the median calculations.

In addition to my original bug illustration (#10725), I have added a branch to my github page that makes testing for this issue much easier: https://github.com/JohnStott/scikit-learn/tree/median_issue_example

The median_issue_example branch is the existing version of sklearn (without this pull fix)... I have created a new function in class:WeightedMedianCalculator (tree_utils.pyx) called verify_state(). This is called each time a sample is popped, removed or pushed. It simply checks the median, k and sum_w_0_k are correct. It checks these by doing a whole loop through the current node's data (remember, the existing median, k and sum_w_0_k are incrementally calculated from memory state for efficiency). If an issue is found it will throw an exception. This new function is simply for testing and shouldn't be used in production hence my separate branch just for this.

2) Proof the fix solves this problem.
I have another branch that is identical to this pull request but which also contains the verify_state() code: https://github.com/JohnStott/scikit-learn/tree/median_fix_debug. This means we can test that this new version does in fact solve the issue.

3) Test unit(s) to ensure median calculation integrity.
I have created a "brute force / naive" type script that should highlight any issues (see my next post)
Running this on my median_issue_example and median_fix_debug branches should highlight any problems.

For the production code, I have modified sklearn/tree/tests/test_tree.py with an additional test. For the existing implementation, this test data will produce an incorrect median and thus a different tree to the fixed version. I am not sure this test is sufficient, or whether my comments there are ideal, so this process would be good for further discussion / review.

I will also run some benchmark tests to show that the fix doesn't hinder the performance but more importantly that the fix works i.e., the MAE error should <= the older version.

Other considerations - Both the original version of MAE and the fixed version of MAE sometimes SILENTLY fail to calculate correctly when negative sample weights are included (this can be observed by removing my surrounding np.abs(..) call in the brute force script when creating sample_weights). Garbage in garbage out. I believe a fix for this would require more checking logic and could make the process less efficient time wise? Is it worth all the effort/loss of efficiency for such edge case requirements? We could throw an error in such a situation like what is thrown when sample weights all sum to 0? Perhaps this should be raised as a separate issue to encourage further discussion?

@JohnStott
Copy link
Contributor Author

JohnStott commented Jul 21, 2018

Brute force / naive test script. This simply creates lots of random data sets and pushes them through a DecisionTreeRegressor. Note that this test script has an added emphasis on choosing datasets with duplicate values, as per above issue.

import numpy as np
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_absolute_error
import time

mae = 0
maeSeconds = 0

#change these:
trials = 100
x_distincts = 10
y_distincts = 10
sample_weight_distincts = 10

#for n_samples in range(40000,40001):
#for n_samples in range(33,34):
for n_samples in range(2,50):
#for n_samples in range(2, 50000, 1000):
    seed = 0
    medianExceptionRaised = False
    while seed < trials:
        seed += 1
        np.random.seed(seed)

        # Randomly select X data structure:
        rand_X_type = int(np.ceil(np.random.rand(1) * 2)[0])
        # 1 to 50 variables:
        rand_X_vars = int(np.ceil(np.random.rand(1) * 50)[0])
        
        if rand_X_type == 1:
            # random norm:
            train_X = np.random.randn(n_samples, rand_X_vars)
        else:
            # test where x has mostly duplicate values!
            train_X = np.ceil((
                              np.random.rand(n_samples, rand_X_vars) 
                              * x_distincts) - np.floor(x_distincts / 2))

        # Randomly select y data structure:
        rand_y_type = int(np.ceil(np.random.rand(1) * 2)[0])
        # 1 to 10 y outputs per sample:
        rand_y_vars = int(np.ceil(np.random.rand(1) * 10)[0])

        if rand_y_type == 1:
            # random norm:
            train_y = np.random.randn(n_samples, rand_y_vars)
        else:
            # test where y has mostly duplicate values!
            train_y = np.ceil((
                              np.random.rand(n_samples, rand_y_vars) 
                              * y_distincts) - np.floor(y_distincts / 2))

        # Randomly select sample_weight data structure:
        rand_wt_type = int(np.ceil(np.random.rand(1) * 2)[0])

        if rand_wt_type == 1:
            # random norm:
            # Ensure sum of sample_weights are not smaller than or equal to 0:
            sample_weight = 0
            while np.sum(sample_weight) <= 0:
                sample_weight = np.abs(np.ravel(np.random.randn(n_samples, 1)))
        else:
            # 1 to 5 distinct wts:
            # Ensure sum of sample_weights are not smaller than or equal to 0:
            sample_weight = 0
            while np.sum(sample_weight) <= 0:
                sample_weight = np.abs(np.ravel(np.ceil((
                    np.random.rand(n_samples, 1) * sample_weight_distincts) - np.floor(sample_weight_distincts / 2)
                    )))

        start = time.time()
        
        wineTree = DecisionTreeRegressor(
            criterion='mae',
            random_state=seed)
        try:
            wineTree.fit(train_X, train_y, sample_weight=sample_weight)
        except ValueError as err:
            print ("n_samples: " + str(n_samples))
            print ("seed: " + str(seed))
            raise

        prediction = wineTree.predict(train_X)
        mae += mean_absolute_error(train_y, prediction,
                                    sample_weight=sample_weight)

        
        end = time.time()
        maeSeconds += (end - start)

    print ("n_samples: " + str(n_samples))

print ("Seconds Elapsed: " + str(maeSeconds))
print ("Total MAE: " + str(mae))   



@JohnStott
Copy link
Contributor Author

JohnStott commented Jul 22, 2018

Running the above on the various branches produces the following results:

Branch Seconds Elapsed: Total MAE:
master (Original) 28.5206606388092 495.060674631426
median_issue_example Fails! Fails! 
median_fix 28.547548532486 10.3636367106389
median_fix_debug 234.1753885746 10.3636367106389

(Note, I only ran the tests once as looking to quickly check everything is as expected.)

It is encouraging to see that the additional logic hasn't affected the efficiency of the calculations timewise (i.e., 28.52 seconds versus 28.55). It is also good to see that the MAE for median_fix versus median_fix_debug is the same (i.e., I haven't accidently introduced error in the 2 slight different implementations).

It is interesting to see the difference in MAE between fixed and original. The above script loops around 5000 times, so some exaggeration can be expected(?).

The median_fix_debug is expected to take a lot longer since it is calculating the naive version of the median, k, and sum_w_0_k too (nb I could make this much more efficient by just having one loop in verify_state() but this is just for demonstration purposes).

@JohnStott
Copy link
Contributor Author

As a final test, I have used the Boston dataset to compare the existing implementation ("master (Origin)") versus the newly fixed median version ("median_fix"). Note... here I am using a RandomForestRegressor, this is a good test of the 'sample_weight' and duplication of y since the bootstrapping mechanism works by simply adding or removing from the sample weight in order to simulate sampling with replacement.

Here is the test script:

import time
from sklearn.datasets import load_boston
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error

dataset = load_boston()
X_full, y_full = dataset.data, dataset.target

x = 1
total_mae = 0
total_seconds = 0

noOfTrees = 1000
while x <= 10: 
    estimator = RandomForestRegressor(random_state=x, n_estimators=noOfTrees,
                                      criterion="mae")

    start = time.time()
    estimator.fit(X_full, y_full)
    seconds = time.time() - start
    
    prediction = estimator.predict(X_full)
    mae = mean_absolute_error(y_full, prediction)

    print ("Loop #", str(x))
    print ("Seconds elapsed:", str(seconds))
    print ("MAE: ", str(mae))
    print ("")

    total_mae += mae
    total_seconds += seconds
    x += 1

print ("")
print ("Total Seconds elapsed:", str(total_seconds))
print ("Total MAE: ", str(total_mae))

@JohnStott
Copy link
Contributor Author

The results:

master (Original)       median_fix    
             
Test MAE: Seconds elapsed:   Test MAE: Seconds elapsed:
1 0.82701 22.50109   1 0.81779 22.69898
2 0.82477 22.48553   2 0.81234 22.45428
3 0.82858 22.48545   3 0.81862 22.43860
4 0.82266 22.43859   4 0.81243 22.40735
5 0.81545 22.46982   5 0.80810 22.39163
6 0.82332 22.42289   6 0.81008 22.36039
7 0.82873 22.46974   7 0.81871 22.39168
8 0.82093 22.45420   8 0.81791 22.36042
9 0.82839 22.39163   9 0.81739 22.31357
10 0.81844 22.43855   10 0.81175 22.32927
             
Sum 8.23828 224.55748   Sum 8.14513 224.14616

@JohnStott
Copy link
Contributor Author

It's nice to see (with this dataset at least) that the fixed median model trains better to the data.

original_median)
return return_value
push_index = self.samples.push(data, weight)
if push_index == -1:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if push_index == -1:
return -1

I added this to replicate what was previously being returned. Though we should only get a -1 when an exception occurs i.e., MemoryError. So in hindsight I think I can remove this check since an exception in self.samples.push should terminate immediately...? I am not 100% sure though with being new to Cyphon?

@jnothman
Copy link
Member

jnothman commented Jul 22, 2018 via email

@JohnStott
Copy link
Contributor Author

Thanks @jnothman, no problem. Have to cut off somewhere. Will keep an eye on the release and remind later. Cheers.

Base automatically changed from master to main January 22, 2021 10:50
@jjerphan
Copy link
Member

Hi @JohnStott, this looks like an interesting contribution!

Are you still interested in working on this PR? 🙂

@JohnStott
Copy link
Contributor Author

JohnStott commented Jun 10, 2021 via email

@cakedev0
Copy link
Contributor

This PR should probably be closed since it was superseded by #32100 which has been merged.

@lesteve
Copy link
Member

lesteve commented Nov 17, 2025

Thanks, closing this one then!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BUG Median not always being calculated correctly for DecisionTrees in the WeightedMedianCalculator

8 participants