From b7f57b9bc0f7d58be5d6e108b100f330c1b46323 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Sat, 10 May 2025 16:56:27 +0200 Subject: [PATCH 1/6] Add discrete_inv_cdf_search, use in PoissonDiscrete. --- src/evermore/pdf.py | 177 ++++++++++++++++++++++++++++++++++++++------ tests/test_pdf.py | 34 ++++++++- 2 files changed, 186 insertions(+), 25 deletions(-) diff --git a/src/evermore/pdf.py b/src/evermore/pdf.py index 938d88b..9251a31 100644 --- a/src/evermore/pdf.py +++ b/src/evermore/pdf.py @@ -8,7 +8,7 @@ import jax.numpy as jnp from jax._src.random import Shape from jax.scipy.special import digamma, gammaln, xlogy -from jaxtyping import Array, PRNGKeyArray +from jaxtyping import Array, ArrayLike, PRNGKeyArray from evermore.util import atleast_1d_float_array from evermore.visualization import SupportsTreescope @@ -22,6 +22,10 @@ ] +# alias for rounding literals +DiscreteRounding = tp.Literal["floor", "ceil", "closest"] + + def __dir__(): return __all__ @@ -82,41 +86,44 @@ class PoissonBase(PDF): class PoissonDiscrete(PoissonBase): """ - Poisson distribution with discrete support. Float inputs are floored to the nearest integer. - See https://root.cern.ch/doc/master/RooPoisson_8cxx_source.html#l00057 for reference. + Poisson distribution with discrete support. Float inputs are floored to the nearest integer, + similar to the behavior implemented in other libraries like SciPy or RooFit. """ def log_prob(self, x: Array, normalize: bool = True) -> Array: - x = jnp.floor(x) + # explicit rounding + k = jnp.floor(x) - unnormalized = jax.scipy.stats.poisson.logpmf(x, self.lamb) + # plain evaluation of the pmf + unnormalized = jax.scipy.stats.poisson.logpmf(k, self.lamb) if not normalize: return unnormalized - logpdf_max = jax.scipy.stats.poisson.logpmf(x, x) + # when normalizing, divide (subtract in log space) by maximum over k range + logpdf_max = jax.scipy.stats.poisson.logpmf(k, k) return unnormalized - logpdf_max def cdf(self, x: Array) -> Array: + # no need to round x to k, already done by cdf library function return jax.scipy.stats.poisson.cdf(x, self.lamb) - def inv_cdf(self, x: Array) -> Array: - # perform an iterative search - # see: https://num.pyro.ai/en/stable/tutorials/truncated_distributions.html?highlight=poisson%20inverse#5.3-Example:-Left-truncated-Poisson - def cond_fn(val): - _, cdf = val - return jnp.any(cdf < x) - - def body_fn(val): - n, cdf = val - n_new = jnp.where(cdf < x, n + 1, n) - return n_new, jax.scipy.stats.poisson.cdf(n_new, self.lamb) - - start_n = jnp.zeros_like(x, dtype=jnp.result_type(int)) - start_cdf = jnp.zeros_like(x, dtype=jnp.result_type(float)) - n, _ = jax.lax.while_loop(cond_fn, body_fn, (start_n, start_cdf)) - - # since we check for cdf < value, n will always refer to the next value - return jnp.clip(n - 1, min=0) + def inv_cdf(self, x: Array, rounding: DiscreteRounding = "floor") -> Array: + # define starting point for search from normal approximation + def start_fn(x): + return jnp.floor( + self.lamb + jax.scipy.stats.norm.ppf(x) * jnp.sqrt(self.lamb) + ) + + # define the cdf function + def cdf_fn(k): + return jax.scipy.stats.poisson.cdf(k, self.lamb) + + return discrete_inv_cdf_search( + x, + cdf_fn=cdf_fn, + start_fn=start_fn, + rounding=rounding, + ) def sample(self, key: PRNGKeyArray, shape: Shape | None = None) -> Array: # jax.random.poisson does not accept empty tuple shape @@ -136,10 +143,13 @@ def _log_prob(x, lamb): x = jnp.array(x, jnp.result_type(float)) return xlogy(x, lamb) - lamb - gammaln(x + 1) + # plain evaluation of the pdf unnormalized = _log_prob(x, lamb) if not normalize: return unnormalized + # when normalizing, divide (subtract in log space) by maximum over a range + # that depends on whether the mode is shifted args = (self.lamb, lamb) if shift_mode else (x, x) logpdf_max = _log_prob(*args) return unnormalized - logpdf_max @@ -155,3 +165,122 @@ def inv_cdf(self, x: Array) -> Array: def sample(self, key: PRNGKeyArray, shape: Shape | None = None) -> Array: msg = f"{self.__class__.__name__} does not support sampling, use PoissonDiscrete instead" raise Exception(msg) + + +def discrete_inv_cdf_search( + x: Array, + cdf_fn: tp.Callable[[ArrayLike], ArrayLike], + start_fn: tp.Callable[[ArrayLike], ArrayLike], + rounding: DiscreteRounding, +) -> Array: + """ + Computes the inverse CDF (percent point function) at integral values *x* for a discrete CDF + distribution *cdf* using an iterative search strategy. The search starts at values provided by + *start_fn* and progresses in integer steps towards the target values. + + .. code-block:: python + + # this example mimics the PoissonDiscrete.inv_cdf implementation + + import jax + import jax.numpy as jnp + import evermore as evm + + # parameter of the poisson distribution + lamb = 5.0 + + # the normal approximation is a good starting point + def start_fn(x): + return jnp.floor(lamb + jax.scipy.stats.norm.ppf(x) * jnp.sqrt(lamb)) + + + # define the cdf function + def cdf_fn(k): + return jax.scipy.stats.poisson.cdf(k, lamb) + + + k = discrete_inv_cdf_search(jnp.array(0.9), cdf_fn, start_fn, "floor") + # -> 7.0 + + Args: + x (Array): Integral values to compute the inverse CDF for. + cdf_fn (Callable): A callable representing the discrete CDF function. It is called with a + single argument and supposed to return the CDF value for that argument. + start_fn (Callable): A callable that provides a starting point for the search. It is called + with a reshaped representation of *x*. + rounding (DiscreteRounding): One of "floor", "ceil" or "closest". + + Returns: + Array: The computed inverse CDF values in the same shape as *x*. + """ + # flatten input + x_shape = x.shape + x = jnp.reshape(x, (-1, 1)) + + # store masks for injecting exact values for known edge cases later one + # inject 0 for x == 0 + zero_mask = x == 0.0 + # inject inf for x == 1 + inf_mask = x == 1.0 + # inject nan for ~(0 < x < 1) or non-finite values + nan_mask = (x < 0.0) | (x > 1.0) | ~jnp.isfinite(x) + + # setup stopping condition and iteration body for the iterative search + # note: functions are defined for scalar values and then vmap'd, with results being reshaped + def cond_fn(val): + stop = val[-1] + return ~jnp.any(stop) + + def body_fn(val): + k, target_itg, prev_itg, stop = val + # compute the current integral + itg = cdf_fn(k) + # special case: itg is the exact solution + stop = jnp.logical_or(stop, itg == target_itg) + # if no previous integral is available or if we have not yet "cornered" the target value + # with the current and previous integrals, make a step in the right direction + make_step = ( + (prev_itg < 0) + | ((prev_itg < itg) & (itg < target_itg)) + | ((target_itg < itg) & (itg < prev_itg)) + ) + step = jnp.where(~stop & make_step, jnp.sign(target_itg - itg), 0) + k += step + # if target_itg is between the computed integrals we can now find the correct k + # note: k might be subject to a shift by +1 or -1, depending on the stride and rounding + k_found = ~stop & ~make_step + if rounding == "floor": + k_shift = jnp.where(k_found & (itg > target_itg), -1, 0) + elif rounding == "ceil": + k_shift = jnp.where(k_found & (prev_itg > target_itg), 1, 0) + else: # "closest" + k_shift = jnp.where( + k_found & (abs(itg - target_itg) > abs(prev_itg - target_itg)), + jnp.sign(prev_itg - itg), + 0, + ) + k += k_shift + # update the stop flag and end + stop = jnp.logical_or(stop, k_found) + return (k, target_itg, itg, stop) + + def search(start_k, target_itg, stop): + prev_itg = -jnp.ones_like(target_itg) + val = (start_k, target_itg, prev_itg, stop) + return jax.lax.while_loop(cond_fn, body_fn, val)[0] + + # vamp + vsearch = jax.vmap(search, in_axes=(0, 0, 0)) + + # define starting point and stop flag (eagerly skipping edge cases), then search + start_k = start_fn(x) + stop = zero_mask | inf_mask | nan_mask + k = vsearch(start_k, x, stop) + + # inject known values for edge cases + k = jnp.where(zero_mask, 0.0, k) + k = jnp.where(inf_mask, jnp.inf, k) + k = jnp.where(nan_mask, jnp.nan, k) + + # reshape to input shape + return jnp.reshape(k, x_shape) diff --git a/tests/test_pdf.py b/tests/test_pdf.py index e025478..a61d4a9 100644 --- a/tests/test_pdf.py +++ b/tests/test_pdf.py @@ -1,9 +1,16 @@ from __future__ import annotations +import jax import jax.numpy as jnp +import numpy as np import pytest -from evermore.pdf import Normal, PoissonContinuous, PoissonDiscrete +from evermore.pdf import ( + Normal, + PoissonContinuous, + PoissonDiscrete, + discrete_inv_cdf_search, +) def test_Normal(): @@ -22,3 +29,28 @@ def test_PoissonContinuous(): pdf = PoissonContinuous(lamb=jnp.array(10)) assert pdf.log_prob(jnp.array(5.0)) == pytest.approx(-1.5342636) + + +def test_discrete_inv_cdf_search(): + lamb = 5.0 + + def start_fn(x): + return jnp.floor(lamb + jax.scipy.stats.norm.ppf(x) * jnp.sqrt(lamb)) + + def cdf_fn(k): + return jax.scipy.stats.poisson.cdf(k, lamb) + + # test correct algorithmic behavior + assert discrete_inv_cdf_search(jnp.array([0.9]), cdf_fn, start_fn, "floor") == 7 + assert discrete_inv_cdf_search(jnp.array([0.9]), cdf_fn, start_fn, "ceil") == 8 + assert discrete_inv_cdf_search(jnp.array([0.9]), cdf_fn, start_fn, "closest") == 8 + + # test individual solutions in vmapped mode plus preservation + k = discrete_inv_cdf_search(jnp.array([0.9, 0.95, 0.99]), cdf_fn, start_fn, "floor") + np.testing.assert_allclose(k, jnp.array([7.0, 8.0, 10.0])) + k = discrete_inv_cdf_search(jnp.array([0.9, 0.95, 0.99]), cdf_fn, start_fn, "ceil") + np.testing.assert_allclose(k, jnp.array([8.0, 9.0, 11.0])) + k = discrete_inv_cdf_search( + jnp.array([0.9, 0.95, 0.99]), cdf_fn, start_fn, "closest" + ) + np.testing.assert_allclose(k, jnp.array([8.0, 8.0, 10.0])) From f9b6f619069b7e2e6cf8791998715dedfd1a26e4 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Sat, 10 May 2025 17:19:14 +0200 Subject: [PATCH 2/6] Typo. --- src/evermore/pdf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/evermore/pdf.py b/src/evermore/pdf.py index 9251a31..c13a99b 100644 --- a/src/evermore/pdf.py +++ b/src/evermore/pdf.py @@ -269,7 +269,7 @@ def search(start_k, target_itg, stop): val = (start_k, target_itg, prev_itg, stop) return jax.lax.while_loop(cond_fn, body_fn, val)[0] - # vamp + # vmap vsearch = jax.vmap(search, in_axes=(0, 0, 0)) # define starting point and stop flag (eagerly skipping edge cases), then search From ec147a4a53178d1cecaf78cfe7d46d0e14776993 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Sat, 10 May 2025 17:29:47 +0200 Subject: [PATCH 3/6] Typos. --- src/evermore/pdf.py | 2 +- tests/test_pdf.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/evermore/pdf.py b/src/evermore/pdf.py index c13a99b..2c71aec 100644 --- a/src/evermore/pdf.py +++ b/src/evermore/pdf.py @@ -217,7 +217,7 @@ def cdf_fn(k): x_shape = x.shape x = jnp.reshape(x, (-1, 1)) - # store masks for injecting exact values for known edge cases later one + # store masks for injecting exact values for known edge cases later on # inject 0 for x == 0 zero_mask = x == 0.0 # inject inf for x == 1 diff --git a/tests/test_pdf.py b/tests/test_pdf.py index a61d4a9..c74d213 100644 --- a/tests/test_pdf.py +++ b/tests/test_pdf.py @@ -45,7 +45,7 @@ def cdf_fn(k): assert discrete_inv_cdf_search(jnp.array([0.9]), cdf_fn, start_fn, "ceil") == 8 assert discrete_inv_cdf_search(jnp.array([0.9]), cdf_fn, start_fn, "closest") == 8 - # test individual solutions in vmapped mode plus preservation + # test individual solutions in vmapped mode plus shape preservation k = discrete_inv_cdf_search(jnp.array([0.9, 0.95, 0.99]), cdf_fn, start_fn, "floor") np.testing.assert_allclose(k, jnp.array([7.0, 8.0, 10.0])) k = discrete_inv_cdf_search(jnp.array([0.9, 0.95, 0.99]), cdf_fn, start_fn, "ceil") From 0dd417456874a79a08a6c0d5f2a944a7e20107bd Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Sat, 10 May 2025 17:36:48 +0200 Subject: [PATCH 4/6] Prefer | over logical_or. --- src/evermore/pdf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/evermore/pdf.py b/src/evermore/pdf.py index 2c71aec..c133871 100644 --- a/src/evermore/pdf.py +++ b/src/evermore/pdf.py @@ -236,7 +236,7 @@ def body_fn(val): # compute the current integral itg = cdf_fn(k) # special case: itg is the exact solution - stop = jnp.logical_or(stop, itg == target_itg) + stop |= itg == target_itg # if no previous integral is available or if we have not yet "cornered" the target value # with the current and previous integrals, make a step in the right direction make_step = ( @@ -261,7 +261,7 @@ def body_fn(val): ) k += k_shift # update the stop flag and end - stop = jnp.logical_or(stop, k_found) + stop |= k_found return (k, target_itg, itg, stop) def search(start_k, target_itg, stop): From a4ab401472cab32731f87b062ee30c59b25ce0d4 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Sat, 10 May 2025 19:30:54 +0200 Subject: [PATCH 5/6] Add rounding choice check. --- src/evermore/pdf.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/evermore/pdf.py b/src/evermore/pdf.py index c133871..0e40887 100644 --- a/src/evermore/pdf.py +++ b/src/evermore/pdf.py @@ -213,6 +213,12 @@ def cdf_fn(k): Returns: Array: The computed inverse CDF values in the same shape as *x*. """ + # check rounding + known_roundings = {"floor", "ceil", "closest"} + if rounding not in known_roundings: + msg = f"unknown rounding '{rounding}', expected one of {', '.join(known_roundings)}" + raise ValueError(msg) + # flatten input x_shape = x.shape x = jnp.reshape(x, (-1, 1)) From 0502617e86bd83b4db1af1b2e015d9f9c8affeb1 Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Thu, 14 Aug 2025 10:26:13 -0400 Subject: [PATCH 6/6] avoid shape manipulation & polish --- src/evermore/pdf.py | 60 ++++++++++++++++++++++----------------------- 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/src/evermore/pdf.py b/src/evermore/pdf.py index 76c63ee..67abd6c 100644 --- a/src/evermore/pdf.py +++ b/src/evermore/pdf.py @@ -24,10 +24,6 @@ ] -# alias for rounding literals -DiscreteRounding = Literal["floor", "ceil", "closest"] - - def __dir__(): return __all__ @@ -171,6 +167,11 @@ def sample( raise Exception(msg) +# alias for rounding literals +DiscreteRounding = Literal["floor", "ceil", "closest"] +known_roundings = frozenset(DiscreteRounding.__args__) # type: ignore[attr-defined] + + def discrete_inv_cdf_search( x: V, cdf_fn: Callable[[V], V], @@ -207,7 +208,7 @@ def cdf_fn(k): # -> 7.0 Args: - x (Array): Integral values to compute the inverse CDF for. + x (V): Integral values to compute the inverse CDF for. cdf_fn (Callable): A callable representing the discrete CDF function. It is called with a single argument and supposed to return the CDF value for that argument. start_fn (Callable): A callable that provides a starting point for the search. It is called @@ -215,18 +216,8 @@ def cdf_fn(k): rounding (DiscreteRounding): One of "floor", "ceil" or "closest". Returns: - Array: The computed inverse CDF values in the same shape as *x*. + V: The computed inverse CDF values in the same shape as *x*. """ - # check rounding - known_roundings = {"floor", "ceil", "closest"} - if rounding not in known_roundings: - msg = f"unknown rounding '{rounding}', expected one of {', '.join(known_roundings)}" - raise ValueError(msg) - - # flatten input - x_shape = x.shape - x = jnp.reshape(x, (-1, 1)) - # store masks for injecting exact values for known edge cases later on # inject 0 for x == 0 zero_mask = x == 0.0 @@ -238,7 +229,7 @@ def cdf_fn(k): # setup stopping condition and iteration body for the iterative search # note: functions are defined for scalar values and then vmap'd, with results being reshaped def cond_fn(val): - stop = val[-1] + *_, stop = val return ~jnp.any(stop) def body_fn(val): @@ -259,16 +250,23 @@ def body_fn(val): # if target_itg is between the computed integrals we can now find the correct k # note: k might be subject to a shift by +1 or -1, depending on the stride and rounding k_found = ~stop & ~make_step - if rounding == "floor": - k_shift = jnp.where(k_found & (itg > target_itg), -1, 0) - elif rounding == "ceil": - k_shift = jnp.where(k_found & (prev_itg > target_itg), 1, 0) - else: # "closest" - k_shift = jnp.where( - k_found & (abs(itg - target_itg) > abs(prev_itg - target_itg)), - jnp.sign(prev_itg - itg), - 0, - ) + + # we're using python >=3.11 :) + match rounding: + case "floor": + k_shift = jnp.where(k_found & (itg > target_itg), -1, 0) + case "ceil": + k_shift = jnp.where(k_found & (prev_itg > target_itg), 1, 0) + case "closest": + k_shift = jnp.where( + k_found & (abs(itg - target_itg) > abs(prev_itg - target_itg)), + jnp.sign(prev_itg - itg), + 0, + ) + case _: + msg = f"unknown rounding '{rounding}' mode, expected one of {', '.join(known_roundings)}" # type: ignore[unreachable] + raise ValueError(msg) + k += k_shift # update the stop flag and end stop |= k_found @@ -279,8 +277,9 @@ def search(start_k, target_itg, stop): val = (start_k, target_itg, prev_itg, stop) return jax.lax.while_loop(cond_fn, body_fn, val)[0] - # vmap - vsearch = jax.vmap(search, in_axes=(0, 0, 0)) + # jnp.vectorize is auto-vmapping over all axes of its arguments, + # see: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.vectorize.html#jax.numpy.vectorize + vsearch = jnp.vectorize(search) # define starting point and stop flag (eagerly skipping edge cases), then search start_k = start_fn(x) @@ -292,5 +291,4 @@ def search(start_k, target_itg, stop): k = jnp.where(inf_mask, jnp.inf, k) k = jnp.where(nan_mask, jnp.nan, k) - # reshape to input shape - return jnp.reshape(k, x_shape) + return k # noqa: RET504