Thanks to visit codestin.com
Credit goes to github.com

Skip to content

ENH add subsample to HGBT #28063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

lorentzenchr
Copy link
Member

Reference Issues/PRs

Fixes #16062 (#27139 is already merged).

What does this implement/fix? Explain your changes.

Add subsample to HistGradientBoostingClassifier and HistGradientBoostingRegressor. Similar to subsample in the old GradientBoostingClassifier.

Any other comments?

While the implementation is rather easy, suggestions for good tests are welcome.

Copy link

github-actions bot commented Jan 4, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 7fb2f72. Link to the linter CI: here

@ogrisel
Copy link
Member

ogrisel commented Jan 8, 2024

Can you adapt or extend the existing stochastic gradient boosting example:

https://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_regularization.html

to check that similar results can be obtained with the HGBDT counter part?

While the implementation is rather easy, suggestions for good tests are welcome.

If the above works as expected, we can probably turn a simplified version into a non-regression test (by making assertions on the test loss values, without the plots) while keeping a ref from the test to the example to "explain" what is tested in the test (and also referencing the Friedman 2002 paper).

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to why I suggested above for testing, maybe you can parametrize existing HGBDT tests whose result should be invariant as to whether subsampling is enabled or not by adding:

@pytest.mark.parametrize("subsample", [0.5, 1.0])

whenever appropriate as is done in sklearn/ensemble/tests/test_gradient_boosting.py.

# Do out of bag if required
if do_oob:
self._bagging_subsample_rng.shuffle(sample_mask)
sample_weight_train = sample_weight_train_original * sample_mask
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using null sample_weight to mathematically simulate subsampling it's simple to implement and does not need allocating any extra memory but it does not produce the expected 2x computational speed-up of when using a typical subsample=0.5.

Have considered row-wise fancy indexing of the training data instead?

@ogrisel
Copy link
Member

ogrisel commented Jan 9, 2024

The comments of my first pass of review were marked as resolved without any pushed code change to address them nor any discussion.

@lorentzenchr
Copy link
Member Author

lorentzenchr commented Jan 9, 2024

The comments of my first pass of review were marked as resolved without any pushed code change to address them nor any discussion.

Sorry, I forgot to push. Will do as soon as time permits.

@lorentzenchr
Copy link
Member Author

One problem with the current way of just setting sample weight to zero is that the sample counts in histograms is wrong and min_samples_leaf does not work as it should.

Another solution is to pass a sample mask everywhere which is quite a massive change that I'm hesitant to implement.

@ogrisel
Copy link
Member

ogrisel commented Jan 25, 2024

Another solution is to pass a sample mask everywhere which is quite a massive change that I'm hesitant to implement.

Why not just fancy index to physically resample the training set? To avoid the memory copy?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add subsample and max_features parameters to HistGradientBoostingRegressor
2 participants