Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@riga
Copy link
Collaborator

@riga riga commented May 10, 2025

This PR improves the search for inverse cdf values, mainly in PoissonDiscrete.inv_cdf.

Background

While implementing the Poisson-modelled stat errors I noticed that the conversion of nuisance (parameter) values to the pdf space via inv_cdf took far too long (in the order of 30s for a single histogram with 53 bins). After digging deeper I could identify two separate reasons for this:

  1. The cond_fn for jax.lax.while_loop is meant to return a global stopping decision, rather than per element in the input arrays (thus the jnp.any, but thinking about the XLA lowering, this is the only way to go). However, this causes the body_fn to be processed for each array element, even though this element might already be solved. In a setup with O(100) elements where the number of loop iterations is mainly driven by a single element (say, one element needs 1000k iterations while all others could be done in O(10)), this causes an overhead of 98%.
  2. The current implementation of PoissonDiscrete.inv_cdf performs a one-sided search starting at 0. However, for large $\lambda$ values, and thus wide distributions, the algorithm could profit from a sensible starting value to avoid expensive cdf evaluations in regions far off the target value.

Changes

One a high-level, I added a generalized discrete_inv_cdf_search helper that is used now by PoissonDiscrete.inv_cdf, plus doc strings and test cases.

  • Issue 1 is solved by flattening the input, passing it through a vmapped version of the search function based on jax.lax.while_loop (which then performs only as many iterations as needed per element), and then reshaping the resulting k values to the desired output shape.
  • The algorithm itself now accepts starting values, which in the case of Poisson distributions are taken from a normal approximation ($\lambda + \text{normal ppf}(x) * \sqrt{\lambda}$). The larger $\lambda$, the more important is a good starting value, but fortunately also the better the normal approximation 🙂 The search is also capable of striding to smaller values. With that, the amount of necessary Poisson cdf evaluations should never exceed 2.

So far, performance looks good, but I'm planning more large-scale tests with the full stats error handling.

@riga riga requested a review from pfackeldey May 10, 2025 15:19
@riga riga added the enhancement New feature or request label May 10, 2025
@pfackeldey
Copy link
Owner

This is amazing!
Do you have numbers huch much speedup this gives for low, medium, and high lambda? My hope initially was that in Staterrors we only use PoissonDiscrete when n_eff (lambda) is ~smallish.
Anyway, I'm excited to get this in asap!

@riga
Copy link
Collaborator Author

riga commented May 10, 2025

Do you have numbers huch much speedup this gives for low, medium, and high lambda?

I'll do some more tests in a larger example 👍

While doing that, I found a small detail regarding the shape handling which should be solved I think. I'll remove draft mode and provide some performance numbers once finished.

@riga riga marked this pull request as draft May 10, 2025 17:34
@pfackeldey
Copy link
Owner

@riga I solved the shape "issue" and did some smaller polishing. If the CI is green, I'll merge this 👍

Next step is then to make use of it for the StatErrors, but this will be a different PR.

Thanks again! I could confirm locally that this drastically improves performance for many bins 👍

@pfackeldey pfackeldey marked this pull request as ready for review August 14, 2025 14:28
@pfackeldey pfackeldey merged commit e0497d5 into pfackeldey:main Aug 14, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants