Thanks to visit codestin.com
Credit goes to github.com

Skip to content

ENH reuse parent histograms as one of the child's histogram #27865

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

lorentzenchr
Copy link
Member

Reference Issues/PRs

None

What does this implement/fix? Explain your changes.

This PR reuses the parent node's histogram in the histogram subtraction trick in HGBT (as does LightGBM). This saves new memory allocation for one of the child nodes and also makes the histogram subtraction a tiny bit faster. (But the hist subtraction is only a fraction of the overall fit time, so basically no effect on fit.)

Any other comments?

Copy link

github-actions bot commented Nov 28, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 6ac7413. Link to the linter CI: here

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! How much runtime improve do you get with this PR?

@@ -618,9 +618,8 @@ def split_next(self):
if child.is_leaf:
del child.histograms

# Release memory used by histograms as they are no longer needed for
# internal nodes once children histograms have been computed.
del node.histograms
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was included in e325f16 because of a memory issue. To be safe, can you rerun the benchmark in #18334 (comment) to make sure there are no regressions?

@lorentzenchr
Copy link
Member Author

lorentzenchr commented Nov 28, 2023

Main
Run 1 | Run 2 | Run 3 in same ipython instance

Fit 100 trees in 7.730 s, (12700 total leaves) | 7.693 s           | 7.636 s
Time spent computing histograms: 3.950s        | 3.928s            | 3.818s
Time spent finding best splits:  2.086s        | 2.061s            | 2.164s
Time spent applying splits:      0.233s        | 0.236s            | 0.246s
Time spent predicting:           0.008s        | 0.008s            | 0.008s
281.06, 104.00 MB                              | 295.38, 117.08 MB | 325.07, 141.31 MB

This PR

Fit 100 trees in 7.917 s, (12700 total leaves) | 7.764 s
Time spent computing histograms: 3.963s        | 3.872s
Time spent finding best splits:  2.123s        | 2.064s
Time spent applying splits:      0.245s        | 0.239s
Time spent predicting:           0.008s        | 0.007s
336.08, 159.11 MB                              | 349.80, 171.91 MB

Wow, this seems to bring back the cyclic memory references. So, current state of PR is worse than main. But note the large variation even for main branch.

Taken from #18334 (comment)

from sklearn.datasets import make_classification
from sklearn.experimental import enable_hist_gradient_boosting
from sklearn.ensemble import HistGradientBoostingClassifier
from memory_profiler import memory_usage

X, y = make_classification(n_classes=2,
                           n_samples=10_000,
                           n_features=400,
                           random_state=0)

hgb = HistGradientBoostingClassifier(
    max_iter=100,
    max_leaf_nodes=127,
    learning_rate=.1,
    random_state=0,
    verbose=1,
)

mems = memory_usage((hgb.fit, (X, y)))
print(f"{max(mems):.2f}, {max(mems) - min(mems):.2f} MB")

@lorentzenchr
Copy link
Member Author

lorentzenchr commented Nov 28, 2023

I fixed the cyclic memory references again in d242a6d. Now, I get:
Run 1 | 2 | 3 in same ipython instance

Fit 100 trees in 6.721 s, (12700 total leaves) | 6.909 s            | 6.917 s
Time spent computing histograms: 3.092s        | 3.213s             | 3.230s
Time spent finding best splits:  2.041s        | 2.121s             | 2.147s
Time spent applying splits:      0.259s        | 0.234s             | 0.236s
Time spent predicting:           0.007s        | 0.008s             | 0.008s
286.75, 110.00 MB                              | 315.38, 137.90 MB  | 323.48, 145.31 MB

Results show a large variation. Runtime seems improved by roughly 10%, but memory usage seems, on average, a bit worse than main.

@lorentzenchr
Copy link
Member Author

Interesting: If only the lines

mems = memory_usage((hgb.fit, (X, y)))
print(f"{max(mems):.2f}, {max(mems) - min(mems):.2f} MB")

are run again in the same ipython instance, I get (Run 1 full, run 2... only the 2 lines):

Main

Run total time [s] time histograms [s] max memory [MB] max - min memory [MB]
1 7.788 3.911 295.03 118.02
2 7.702 3.906 271.74 100.22
3 7.753 3.940 283.8 109.86
4 7.707 3.904 276.03 103.54

PR

Run total time [s] time histograms [s] max memory [MB] max - min memory [MB]
1 7.263 3.426 276.61 99.85
2 7.067 3.358 286.92 115.39
3 7.234 3.416 278.11 105.56
4 6.997 3.337 291.90 119.30

Conclusion: This PR is a clear improvement. It would be nice to better understand some gc behavior.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adds a little bit of complexity, but it still looks manageable. LGTM!

@thomasjpfan thomasjpfan added the Waiting for Second Reviewer First reviewer is done, need a second one! label Nov 30, 2023
Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

I think most (all?) implementations of malloc (called by np.empty) now reuse blocks of memory between allocations and deallocation, so the overhead might only be numpy's wrappers'.

In dilettante, I just have one comment regarding the potential extension of some context that might now qualify for nogil.

@lorentzenchr lorentzenchr added this to the 1.4 milestone Dec 3, 2023
@jjerphan jjerphan merged commit 7b9f794 into scikit-learn:main Dec 3, 2023
@lorentzenchr lorentzenchr deleted the hgbt_reuse_parent_hist_in_subtract_histogram branch December 4, 2023 09:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cython module:ensemble Waiting for Second Reviewer First reviewer is done, need a second one!
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants