Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 157 additions & 24 deletions src/evermore/pdf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import abc
from typing import Generic, Protocol, runtime_checkable
from collections.abc import Callable
from typing import Generic, Literal, Protocol, runtime_checkable

import equinox as eqx
import jax
Expand Down Expand Up @@ -80,45 +81,47 @@ class PoissonBase(AbstractPDF[V]):

class PoissonDiscrete(PoissonBase[V]):
"""
Poisson distribution with discrete support. Float inputs are floored to the nearest integer.
See https://root.cern.ch/doc/master/RooPoisson_8cxx_source.html#l00057 for reference.
Poisson distribution with discrete support. Float inputs are floored to the nearest integer,
similar to the behavior implemented in other libraries like SciPy or RooFit.
"""

def log_prob(
self,
x: V,
normalize: bool = True,
) -> V:
x = jnp.floor(x)
k = jnp.floor(x)

unnormalized = jax.scipy.stats.poisson.logpmf(x, self.lamb)
# plain evaluation of the pmf
unnormalized = jax.scipy.stats.poisson.logpmf(k, self.lamb)
if not normalize:
return unnormalized

logpdf_max = jax.scipy.stats.poisson.logpmf(x, x)
# when normalizing, divide (subtract in log space) by maximum over k range
logpdf_max = jax.scipy.stats.poisson.logpmf(k, k)
return unnormalized - logpdf_max

def cdf(self, x: V) -> V:
# no need to round x to k, already done by cdf library function
return jax.scipy.stats.poisson.cdf(x, self.lamb)

def inv_cdf(self, x: V) -> V:
# perform an iterative search
# see: https://num.pyro.ai/en/stable/tutorials/truncated_distributions.html?highlight=poisson%20inverse#5.3-Example:-Left-truncated-Poisson
def cond_fn(val):
_, cdf = val
return jnp.any(cdf < x)

def body_fn(val):
n, cdf = val
n_new = jnp.where(cdf < x, n + 1, n)
return n_new, jax.scipy.stats.poisson.cdf(n_new, self.lamb)

start_n = jnp.zeros_like(x, dtype=jnp.result_type(int))
start_cdf = jnp.zeros_like(x, dtype=jnp.result_type(float))
n, _ = jax.lax.while_loop(cond_fn, body_fn, (start_n, start_cdf))

# since we check for cdf < value, n will always refer to the next value
return jnp.clip(n - 1, min=0)
def inv_cdf(self, x: V, rounding: DiscreteRounding = "floor") -> V:
# define starting point for search from normal approximation
def start_fn(x: V) -> V:
return jnp.floor(
self.lamb + jax.scipy.stats.norm.ppf(x) * jnp.sqrt(self.lamb)
)

# define the cdf function
def cdf_fn(k: V) -> V:
return jax.scipy.stats.poisson.cdf(k, self.lamb)

return discrete_inv_cdf_search(
x,
cdf_fn=cdf_fn,
start_fn=start_fn,
rounding=rounding,
)

def sample(self, key: PRNGKeyArray, shape: Shape) -> Float[Array, ...]:
return jax.random.poisson(key, self.lamb, shape=shape)
Expand All @@ -138,10 +141,13 @@ def _log_prob(x, lamb):
x = jnp.array(x, jnp.result_type(float))
return xlogy(x, lamb) - lamb - gammaln(x + 1)

# plain evaluation of the pdf
unnormalized = _log_prob(x, lamb)
if not normalize:
return unnormalized

# when normalizing, divide (subtract in log space) by maximum over a range
# that depends on whether the mode is shifted
args = (self.lamb, lamb) if shift_mode else (x, x)
logpdf_max = _log_prob(*args)
return unnormalized - logpdf_max
Expand All @@ -159,3 +165,130 @@ def sample(
) -> Float[Array, ...]:
msg = f"{self.__class__.__name__} does not support sampling, use PoissonDiscrete instead"
raise Exception(msg)


# alias for rounding literals
DiscreteRounding = Literal["floor", "ceil", "closest"]
known_roundings = frozenset(DiscreteRounding.__args__) # type: ignore[attr-defined]


def discrete_inv_cdf_search(
x: V,
cdf_fn: Callable[[V], V],
start_fn: Callable[[V], V],
rounding: DiscreteRounding,
) -> V:
"""
Computes the inverse CDF (percent point function) at integral values *x* for a discrete CDF
distribution *cdf* using an iterative search strategy. The search starts at values provided by
*start_fn* and progresses in integer steps towards the target values.

.. code-block:: python

# this example mimics the PoissonDiscrete.inv_cdf implementation

import jax
import jax.numpy as jnp
import evermore as evm

# parameter of the poisson distribution
lamb = 5.0

# the normal approximation is a good starting point
def start_fn(x):
return jnp.floor(lamb + jax.scipy.stats.norm.ppf(x) * jnp.sqrt(lamb))


# define the cdf function
def cdf_fn(k):
return jax.scipy.stats.poisson.cdf(k, lamb)


k = discrete_inv_cdf_search(jnp.array(0.9), cdf_fn, start_fn, "floor")
# -> 7.0

Args:
x (V): Integral values to compute the inverse CDF for.
cdf_fn (Callable): A callable representing the discrete CDF function. It is called with a
single argument and supposed to return the CDF value for that argument.
start_fn (Callable): A callable that provides a starting point for the search. It is called
with a reshaped representation of *x*.
rounding (DiscreteRounding): One of "floor", "ceil" or "closest".

Returns:
V: The computed inverse CDF values in the same shape as *x*.
"""
# store masks for injecting exact values for known edge cases later on
# inject 0 for x == 0
zero_mask = x == 0.0
# inject inf for x == 1
inf_mask = x == 1.0
# inject nan for ~(0 < x < 1) or non-finite values
nan_mask = (x < 0.0) | (x > 1.0) | ~jnp.isfinite(x)

# setup stopping condition and iteration body for the iterative search
# note: functions are defined for scalar values and then vmap'd, with results being reshaped
def cond_fn(val):
*_, stop = val
return ~jnp.any(stop)

def body_fn(val):
k, target_itg, prev_itg, stop = val
# compute the current integral
itg = cdf_fn(k)
# special case: itg is the exact solution
stop |= itg == target_itg
# if no previous integral is available or if we have not yet "cornered" the target value
# with the current and previous integrals, make a step in the right direction
make_step = (
(prev_itg < 0)
| ((prev_itg < itg) & (itg < target_itg))
| ((target_itg < itg) & (itg < prev_itg))
)
step = jnp.where(~stop & make_step, jnp.sign(target_itg - itg), 0)
k += step
# if target_itg is between the computed integrals we can now find the correct k
# note: k might be subject to a shift by +1 or -1, depending on the stride and rounding
k_found = ~stop & ~make_step

# we're using python >=3.11 :)
match rounding:
case "floor":
k_shift = jnp.where(k_found & (itg > target_itg), -1, 0)
case "ceil":
k_shift = jnp.where(k_found & (prev_itg > target_itg), 1, 0)
case "closest":
k_shift = jnp.where(
k_found & (abs(itg - target_itg) > abs(prev_itg - target_itg)),
jnp.sign(prev_itg - itg),
0,
)
case _:
msg = f"unknown rounding '{rounding}' mode, expected one of {', '.join(known_roundings)}" # type: ignore[unreachable]
raise ValueError(msg)

k += k_shift
# update the stop flag and end
stop |= k_found
return (k, target_itg, itg, stop)

def search(start_k, target_itg, stop):
prev_itg = -jnp.ones_like(target_itg)
val = (start_k, target_itg, prev_itg, stop)
return jax.lax.while_loop(cond_fn, body_fn, val)[0]

# jnp.vectorize is auto-vmapping over all axes of its arguments,
# see: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.vectorize.html#jax.numpy.vectorize
vsearch = jnp.vectorize(search)

# define starting point and stop flag (eagerly skipping edge cases), then search
start_k = start_fn(x)
stop = zero_mask | inf_mask | nan_mask
k = vsearch(start_k, x, stop)

# inject known values for edge cases
k = jnp.where(zero_mask, 0.0, k)
k = jnp.where(inf_mask, jnp.inf, k)
k = jnp.where(nan_mask, jnp.nan, k)

return k # noqa: RET504
34 changes: 33 additions & 1 deletion tests/test_pdf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from __future__ import annotations

import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jaxtyping import Float, Scalar

from evermore.pdf import Normal, PoissonContinuous, PoissonDiscrete
from evermore.pdf import (
Normal,
PoissonContinuous,
PoissonDiscrete,
discrete_inv_cdf_search,
)


def test_Normal():
Expand All @@ -23,3 +30,28 @@ def test_PoissonContinuous():
pdf: PoissonContinuous[Float[Scalar, ""]] = PoissonContinuous(lamb=jnp.array(10))

assert pdf.log_prob(jnp.array(5.0)) == pytest.approx(-1.5342636)


def test_discrete_inv_cdf_search():
lamb = 5.0

def start_fn(x):
return jnp.floor(lamb + jax.scipy.stats.norm.ppf(x) * jnp.sqrt(lamb))

def cdf_fn(k):
return jax.scipy.stats.poisson.cdf(k, lamb)

# test correct algorithmic behavior
assert discrete_inv_cdf_search(jnp.array([0.9]), cdf_fn, start_fn, "floor") == 7
assert discrete_inv_cdf_search(jnp.array([0.9]), cdf_fn, start_fn, "ceil") == 8
assert discrete_inv_cdf_search(jnp.array([0.9]), cdf_fn, start_fn, "closest") == 8

# test individual solutions in vmapped mode plus shape preservation
k = discrete_inv_cdf_search(jnp.array([0.9, 0.95, 0.99]), cdf_fn, start_fn, "floor")
np.testing.assert_allclose(k, jnp.array([7.0, 8.0, 10.0]))
k = discrete_inv_cdf_search(jnp.array([0.9, 0.95, 0.99]), cdf_fn, start_fn, "ceil")
np.testing.assert_allclose(k, jnp.array([8.0, 9.0, 11.0]))
k = discrete_inv_cdf_search(
jnp.array([0.9, 0.95, 0.99]), cdf_fn, start_fn, "closest"
)
np.testing.assert_allclose(k, jnp.array([8.0, 8.0, 10.0]))