Hi everyone,
I'm finding oryx a really clean approach to implementing a PPL. However, I'm confused about conditional sampling.
Poisson process with exp random walk intensity
As an attempt to get into the structure of oryx I'm trying to sample from a probabilistic program which represents:
- A hierarchical random walk (that is a random walk where the parameters are themselves random variables)
- A further
Exp transform on the random walk represents the intensity of a Poisson process. This is observed.
- Inference done with
NUTS from blackjax
Code
Dependencies
import jax
# jax.config.update("jax_enable_x64", True)
import oryx.core.ppl as ppl
import oryx.bijectors as bijectors
import oryx.distributions as tfd
import blackjax
import jax.numpy as jnp
import jax.random as random
from functools import partial
import matplotlib.pyplot as plt
Prob Program
Note that I've implemented the link as a vmap over intensity representing conditional independence of observations.
@partial(jax.jit, static_argnames=["n"])
def hierarchical_random_walk_dist(n, init, step_scale):
rw_transformation = bijectors.Chain([bijectors.Shift(init), bijectors.Scale(step_scale), bijectors.Cumsum()])
return tfd.TransformedDistribution(tfd.MultivariateNormalDiag(jnp.zeros(n), jnp.ones(n)), rw_transformation)
def poisson_process(key, n, init_prior_loc, init_prior_scale, step_scale_prior):
key_poi, key_intensity, key_init, key_step = random.split(key, 4)
init = ppl.random_variable(tfd.Normal(init_prior_loc, init_prior_scale), name = "init")(key_init)
step_scale = ppl.random_variable(tfd.HalfNormal(step_scale_prior), name = "step_scale")(key_step)
intensity = ppl.random_variable(tfd.TransformedDistribution(
hierarchical_random_walk_dist(n, init, step_scale),
bijectors.Exp()),
name = "intensity")(key_intensity)
poi_keys = random.split(key_poi, n)
poi = jax.vmap(lambda ky, x: ppl.random_variable(tfd.Poisson(x), name = "poi")(ky))(poi_keys, intensity)
return poi
Sample some data from model
sampler = ppl.joint_sample(poisson_process)
key_rn = random.PRNGKey(1234)
true_params = sampler(key_rn, n=100, init_prior_loc=0.0, init_prior_scale=0.05, step_scale_prior=0.25)
plt.plot(true_params['intensity'])
plt.scatter(range(len(true_params['poi'])), true_params['poi'], color='red')
plt.xlabel('time')
plt.ylabel('Intensity')
plt.title('Intensity Random Variable')
plt.show()

This looks reasonable.
Inference
I split out the observed data from the rest of the parameters
true_data = true_params.pop('poi')
true_params
{'init': Array(0.01620068, dtype=float32),
'intensity': Array([2.21524286e+00, 1.02979910e+00, 3.47778827e-01, 1.59668833e-01,
1.43431634e-01, 1.89372957e-01, 1.48013055e-01, 8.88996869e-02,
1.40973523e-01, 1.33820206e-01, 6.49942383e-02, 4.85427566e-02,
1.31719252e-02, 1.61807686e-02, 1.93237811e-02, 7.30752666e-03,
3.75548634e-03, 3.23897717e-03, 4.46824962e-03, 3.59713589e-03,
3.48433293e-03, 4.54167370e-03, 8.35305359e-03, 7.45324651e-03,
1.69865470e-02, 2.82925181e-03, 4.80814092e-03, 5.73506765e-03,
1.29247606e-02, 2.23501474e-02, 2.60949116e-02, 2.78504174e-02,
2.95239929e-02, 3.01535334e-02, 3.17600109e-02, 5.96645549e-02,
1.58876508e-01, 4.15319920e-01, 2.83102959e-01, 3.94434422e-01,
6.21528685e-01, 9.56910610e-01, 4.71480668e-01, 3.51778269e-01,
3.12051624e-01, 3.87135684e-01, 4.41913813e-01, 8.34466696e-01,
1.12293482e+00, 9.62291718e-01, 6.46639347e-01, 1.22468376e+00,
1.33461881e+00, 9.76860523e-01, 1.60133433e+00, 4.31086159e+00,
3.78359699e+00, 4.50091076e+00, 7.61642456e+00, 9.94997692e+00,
1.83034401e+01, 1.86841354e+01, 1.93865471e+01, 4.10644569e+01,
5.11959839e+01, 3.74023285e+01, 1.46664228e+01, 2.73789101e+01,
5.09101982e+01, 2.61694183e+01, 2.90790100e+01, 1.15916996e+01,
1.44228182e+01, 8.16150761e+00, 1.21826038e+01, 8.52718925e+00,
8.82525539e+00, 1.47077036e+01, 1.31940975e+01, 8.21146393e+00,
5.06118011e+00, 3.73972368e+00, 8.66150951e+00, 8.86765766e+00,
1.82184372e+01, 2.03960953e+01, 1.50705147e+01, 3.58565903e+01,
3.94253273e+01, 1.51045656e+01, 1.46200066e+01, 1.30218935e+01,
8.60846615e+00, 6.86474085e+00, 5.52572966e+00, 6.91005898e+00,
4.49717140e+00, 2.01037908e+00, 2.75382376e+00, 3.28753996e+00], dtype=float32),
'step_scale': Array(0.5200044, dtype=float32)}
Then do the usual blackjax approach to sampling (based on their example of using oryx)
def logdensity_fn(params):
theta = dict(params, poi = true_data)
return ppl.joint_log_prob(poisson_process)(theta, n=100, init_prior_loc=0.0, init_prior_scale=0.05, step_scale_prior=0.25)
ll = logdensity_fn(true_params)
# Array(-50.80758, dtype=float32)
# Warmup
inference_key = jax.random.PRNGKey(12)
rng_key, warmup_key = jax.random.split(inference_key)
adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
(last_state, parameters), _ = adapt.run(warmup_key, true_params, 1000)
kernel = blackjax.nuts(logdensity_fn, **parameters).step
# Sampling
def inference_loop(rng_key, kernel, initial_state, num_samples):
def one_step(state, rng_key):
state, info = kernel(rng_key, state)
return state, (state, info)
keys = jax.random.split(rng_key, num_samples)
_, (states, infos) = jax.lax.scan(one_step, initial_state, keys)
return states, infos
rng_key, sample_key = jax.random.split(rng_key)
states, infos = inference_loop(sample_key, kernel, last_state, 2000)
Issues
- The main issue is that this silent fails to sample from the posterior (or I'm not understanding the sample structure):
plt.figure(figsize=(12, 6))
plt.plot(true_params['intensity'], label='True Intensity', color='blue')
for i in range(100): # Plotting the first 10 sampled intensities for clarity
plt.plot(states.position['intensity'][i], alpha=0.5)
plt.xlabel('Time')
plt.ylabel('Intensity')
plt.title('True Intensity vs Sampled Intensities')
plt.legend()
plt.show()

- Warnings about
f32 conversion e.g.
UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
minval = minval + np.zeros([1] * final_rank, dtype=dtype)
Which suggests that the underlying Poisson distribution is struggling but...
- Errors if enabling
f64 conversion, e.g. if the conversion to double precision is allow then the model fails at joint_sample with error:
TypeError: Tensor conversion requested dtype <class 'numpy.float32'> for array with dtype float64: Traced<ShapedArray(float64[100])>with
Steps forward
I don't have a huge amount of JAX/oryx experience, therefore, it would be great if someone could point out if I've made a glaring error. Or if there is some kind of issue with joint_log_prob in combination with Poisson or the way I've implemented the poisson link.
Hi everyone,
I'm finding
oryxa really clean approach to implementing a PPL. However, I'm confused about conditional sampling.Poisson process with exp random walk intensity
As an attempt to get into the structure of
oryxI'm trying to sample from a probabilistic program which represents:Exptransform on the random walk represents the intensity of a Poisson process. This is observed.NUTSfromblackjaxCode
Dependencies
Prob Program
Note that I've implemented the link as a
vmapoverintensityrepresenting conditional independence of observations.Sample some data from model
This looks reasonable.
Inference
I split out the observed data from the rest of the parameters
Then do the usual
blackjaxapproach to sampling (based on their example of usingoryx)Issues
f32conversion e.g.Which suggests that the underlying Poisson distribution is struggling but...
f64conversion, e.g. if the conversion to double precision is allow then the model fails atjoint_samplewith error:Steps forward
I don't have a huge amount of JAX/oryx experience, therefore, it would be great if someone could point out if I've made a glaring error. Or if there is some kind of issue with
joint_log_probin combination withPoissonor the way I've implemented the poisson link.