-
Notifications
You must be signed in to change notification settings - Fork 32
Open
Description
I originally found this issue when trying to use a CARMA kernel in NumPyro HMC. I receive only NaNs when evaluating the gradient of the log probability wrt the kernel parameters.
Things I've tried to resolve/narrow down the issue
- adding increasingly larger
diagvalues to the GP - many different values for the CARMA parameters
- double vs single precision
- different CARMA(p, q) models other than (1,0)---they still produce NaNs
- other quasiseparable kernels (these work as they should)
Below is a minimal reproducible example that doesn't involve Numpyro.
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from tinygp import GaussianProcess, kernels
# CARMA(1,0)
def build_gp_drw(params, x):
kernel = kernels.quasisep.CARMA(params["alpha"], params["beta"])
gp = GaussianProcess(kernel, x)
return gp
x = jnp.linspace(1, 100)
y = jnp.sin(x) + 1e-2*jax.random.normal(jax.random.key(5), x.shape)
params = {"alpha": jnp.array([0.01]), "beta": jnp.array([0.1])}
drw_gp = build_gp_drw(params, x)
@jax.jit
def loss(params):
gp = build_gp_drw(params, x)
return -gp.log_probability(y)>>> loss(params)
Array(194.17045899, dtype=float64)>>> jax.grad(loss)(params)
{'alpha': Array([nan], dtype=float64), 'beta': Array([nan], dtype=float64)}Metadata
Metadata
Assignees
Labels
No labels