Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Gradient of CARMA log probability wrt kernel parameters produces NaNs #228

@davecwright3

Description

@davecwright3

I originally found this issue when trying to use a CARMA kernel in NumPyro HMC. I receive only NaNs when evaluating the gradient of the log probability wrt the kernel parameters.

Things I've tried to resolve/narrow down the issue

  • adding increasingly larger diag values to the GP
  • many different values for the CARMA parameters
  • double vs single precision
  • different CARMA(p, q) models other than (1,0)---they still produce NaNs
  • other quasiseparable kernels (these work as they should)

Below is a minimal reproducible example that doesn't involve Numpyro.

import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from tinygp import GaussianProcess, kernels


# CARMA(1,0)
def build_gp_drw(params, x):
    kernel = kernels.quasisep.CARMA(params["alpha"], params["beta"])
    gp = GaussianProcess(kernel, x)
    return gp


x = jnp.linspace(1, 100)
y = jnp.sin(x) + 1e-2*jax.random.normal(jax.random.key(5), x.shape)

params = {"alpha": jnp.array([0.01]), "beta": jnp.array([0.1])}
drw_gp = build_gp_drw(params, x)


@jax.jit
def loss(params):
    gp = build_gp_drw(params, x)
    return -gp.log_probability(y)
>>> loss(params)
Array(194.17045899, dtype=float64)
>>> jax.grad(loss)(params)
{'alpha': Array([nan], dtype=float64), 'beta': Array([nan], dtype=float64)}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions