Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Problem conditioning on vmap Poisson random variables #96

@SamuelBrand1

Description

@SamuelBrand1

Hi everyone,

I'm finding oryx a really clean approach to implementing a PPL. However, I'm confused about conditional sampling.

Poisson process with exp random walk intensity

As an attempt to get into the structure of oryx I'm trying to sample from a probabilistic program which represents:

  1. A hierarchical random walk (that is a random walk where the parameters are themselves random variables)
  2. A further Exp transform on the random walk represents the intensity of a Poisson process. This is observed.
  3. Inference done with NUTS from blackjax

Code

Dependencies

import jax
# jax.config.update("jax_enable_x64", True)

import oryx.core.ppl as ppl
import oryx.bijectors as bijectors
import oryx.distributions as tfd
import blackjax


import jax.numpy as jnp
import jax.random as random

from functools import partial
import matplotlib.pyplot as plt

Prob Program
Note that I've implemented the link as a vmap over intensity representing conditional independence of observations.

@partial(jax.jit, static_argnames=["n"])
def hierarchical_random_walk_dist(n, init, step_scale):
    rw_transformation = bijectors.Chain([bijectors.Shift(init), bijectors.Scale(step_scale), bijectors.Cumsum()])
    return tfd.TransformedDistribution(tfd.MultivariateNormalDiag(jnp.zeros(n), jnp.ones(n)), rw_transformation)

def poisson_process(key, n, init_prior_loc, init_prior_scale, step_scale_prior):
    key_poi, key_intensity, key_init, key_step = random.split(key, 4)
    init = ppl.random_variable(tfd.Normal(init_prior_loc, init_prior_scale), name = "init")(key_init)
    step_scale = ppl.random_variable(tfd.HalfNormal(step_scale_prior), name = "step_scale")(key_step)
    intensity = ppl.random_variable(tfd.TransformedDistribution(
        hierarchical_random_walk_dist(n, init, step_scale),
        bijectors.Exp()),
        name = "intensity")(key_intensity)
    poi_keys = random.split(key_poi, n)
    poi = jax.vmap(lambda ky, x: ppl.random_variable(tfd.Poisson(x), name = "poi")(ky))(poi_keys, intensity)
    return poi

Sample some data from model

sampler = ppl.joint_sample(poisson_process)
key_rn = random.PRNGKey(1234)
true_params = sampler(key_rn, n=100, init_prior_loc=0.0, init_prior_scale=0.05, step_scale_prior=0.25)

plt.plot(true_params['intensity'])
plt.scatter(range(len(true_params['poi'])), true_params['poi'], color='red')
plt.xlabel('time')
plt.ylabel('Intensity')
plt.title('Intensity Random Variable')
plt.show()

Image

This looks reasonable.

Inference

I split out the observed data from the rest of the parameters

true_data = true_params.pop('poi')
true_params

{'init': Array(0.01620068, dtype=float32),
'intensity': Array([2.21524286e+00, 1.02979910e+00, 3.47778827e-01, 1.59668833e-01,
1.43431634e-01, 1.89372957e-01, 1.48013055e-01, 8.88996869e-02,
1.40973523e-01, 1.33820206e-01, 6.49942383e-02, 4.85427566e-02,
1.31719252e-02, 1.61807686e-02, 1.93237811e-02, 7.30752666e-03,
3.75548634e-03, 3.23897717e-03, 4.46824962e-03, 3.59713589e-03,
3.48433293e-03, 4.54167370e-03, 8.35305359e-03, 7.45324651e-03,
1.69865470e-02, 2.82925181e-03, 4.80814092e-03, 5.73506765e-03,
1.29247606e-02, 2.23501474e-02, 2.60949116e-02, 2.78504174e-02,
2.95239929e-02, 3.01535334e-02, 3.17600109e-02, 5.96645549e-02,
1.58876508e-01, 4.15319920e-01, 2.83102959e-01, 3.94434422e-01,
6.21528685e-01, 9.56910610e-01, 4.71480668e-01, 3.51778269e-01,
3.12051624e-01, 3.87135684e-01, 4.41913813e-01, 8.34466696e-01,
1.12293482e+00, 9.62291718e-01, 6.46639347e-01, 1.22468376e+00,
1.33461881e+00, 9.76860523e-01, 1.60133433e+00, 4.31086159e+00,
3.78359699e+00, 4.50091076e+00, 7.61642456e+00, 9.94997692e+00,
1.83034401e+01, 1.86841354e+01, 1.93865471e+01, 4.10644569e+01,
5.11959839e+01, 3.74023285e+01, 1.46664228e+01, 2.73789101e+01,
5.09101982e+01, 2.61694183e+01, 2.90790100e+01, 1.15916996e+01,
1.44228182e+01, 8.16150761e+00, 1.21826038e+01, 8.52718925e+00,
8.82525539e+00, 1.47077036e+01, 1.31940975e+01, 8.21146393e+00,
5.06118011e+00, 3.73972368e+00, 8.66150951e+00, 8.86765766e+00,
1.82184372e+01, 2.03960953e+01, 1.50705147e+01, 3.58565903e+01,
3.94253273e+01, 1.51045656e+01, 1.46200066e+01, 1.30218935e+01,
8.60846615e+00, 6.86474085e+00, 5.52572966e+00, 6.91005898e+00,
4.49717140e+00, 2.01037908e+00, 2.75382376e+00, 3.28753996e+00], dtype=float32),
'step_scale': Array(0.5200044, dtype=float32)}

Then do the usual blackjax approach to sampling (based on their example of using oryx)

def logdensity_fn(params):
    theta = dict(params, poi = true_data)
    return ppl.joint_log_prob(poisson_process)(theta, n=100, init_prior_loc=0.0, init_prior_scale=0.05, step_scale_prior=0.25)

ll = logdensity_fn(true_params)
# Array(-50.80758, dtype=float32)

# Warmup
inference_key = jax.random.PRNGKey(12)
rng_key, warmup_key = jax.random.split(inference_key)
adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
(last_state, parameters), _ = adapt.run(warmup_key, true_params, 1000)
kernel = blackjax.nuts(logdensity_fn, **parameters).step

# Sampling

def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return states, infos

rng_key, sample_key = jax.random.split(rng_key)
states, infos = inference_loop(sample_key, kernel, last_state, 2000)

Issues

  1. The main issue is that this silent fails to sample from the posterior (or I'm not understanding the sample structure):
plt.figure(figsize=(12, 6))
plt.plot(true_params['intensity'], label='True Intensity', color='blue')
for i in range(100):  # Plotting the first 10 sampled intensities for clarity
    plt.plot(states.position['intensity'][i], alpha=0.5)
plt.xlabel('Time')
plt.ylabel('Intensity')
plt.title('True Intensity vs Sampled Intensities')
plt.legend()
plt.show()

Image

  1. Warnings about f32 conversion e.g.

UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
minval = minval + np.zeros([1] * final_rank, dtype=dtype)

Which suggests that the underlying Poisson distribution is struggling but...

  1. Errors if enabling f64 conversion, e.g. if the conversion to double precision is allow then the model fails at joint_sample with error:

TypeError: Tensor conversion requested dtype <class 'numpy.float32'> for array with dtype float64: Traced<ShapedArray(float64[100])>with

Steps forward

I don't have a huge amount of JAX/oryx experience, therefore, it would be great if someone could point out if I've made a glaring error. Or if there is some kind of issue with joint_log_prob in combination with Poisson or the way I've implemented the poisson link.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions