-
Couldn't load subscription status.
- Fork 6
feat: Accelerate discrete inv_cdf search #51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
This is amazing! |
I'll do some more tests in a larger example 👍 While doing that, I found a small detail regarding the shape handling which should be solved I think. I'll remove draft mode and provide some performance numbers once finished. |
|
@riga I solved the shape "issue" and did some smaller polishing. If the CI is green, I'll merge this 👍 Next step is then to make use of it for the StatErrors, but this will be a different PR. Thanks again! I could confirm locally that this drastically improves performance for many bins 👍 |
This PR improves the search for inverse cdf values, mainly in
PoissonDiscrete.inv_cdf.Background
While implementing the Poisson-modelled stat errors I noticed that the conversion of nuisance (parameter) values to the pdf space via
inv_cdftook far too long (in the order of 30s for a single histogram with 53 bins). After digging deeper I could identify two separate reasons for this:cond_fnforjax.lax.while_loopis meant to return a global stopping decision, rather than per element in the input arrays (thus thejnp.any, but thinking about the XLA lowering, this is the only way to go). However, this causes thebody_fnto be processed for each array element, even though this element might already be solved. In a setup with O(100) elements where the number of loop iterations is mainly driven by a single element (say, one element needs 1000k iterations while all others could be done in O(10)), this causes an overhead of 98%.PoissonDiscrete.inv_cdfperforms a one-sided search starting at 0. However, for largeChanges
One a high-level, I added a generalized
discrete_inv_cdf_searchhelper that is used now byPoissonDiscrete.inv_cdf, plus doc strings and test cases.jax.lax.while_loop(which then performs only as many iterations as needed per element), and then reshaping the resulting k values to the desired output shape.So far, performance looks good, but I'm planning more large-scale tests with the full stats error handling.