Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Support for Numpyro Discrete Enumeration #195

@LSZ2001

Description

@LSZ2001

Hi @dfm,

I am building a Gaussian process model that contains discrete variables. Here is a simplified version:

import matplotlib.pyplot as plt
import jax
from jax import random, lax
import jax.numpy as jnp
import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, HMC
import numpy as onp
import tinygp
from tinygp import GaussianProcess, kernels, transforms

def model(X_train=None, y_train=None, X_test=None):
    with numpyro.plate("dimensions", X_train.shape[1]) as d:
        ls = numpyro.sample("ls", dist.Gamma(3, 0.5))
    with numpyro.plate("dimensions_comb", 2**X_train.shape[1]-1) as m:
        pis = numpyro.sample("pi", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'}) # Induces sparsity on the kernel structure
#     pis = jnp.array([0,1,1.0])
    
    kernel1 = transforms.Subspace(0, kernels.ExpSquared(ls[0]))
    kernel2 = transforms.Subspace(1, kernels.ExpSquared(ls[1]))
    kernel = pis[0]*kernel1 + pis[1]*kernel2 + pis[2]*kernel1*kernel2
    gp = GaussianProcess(kernel, X_train, diag=0.1)
    
    with jax.ensure_compile_time_eval():
        with numpyro.plate("data", X_train.shape[0]):
            numpyro.sample("gp", gp.numpyro_dist(), obs=y_train)
    if y_train is not None:
        with numpyro.plate("data", X_test.shape[0]):
            numpyro.deterministic("f", gp.condition(y_train, X_test).gp.loc)
    
# Data creation
onp.random.seed(0)
N = 100
X_train = (jnp.array(onp.random.uniform(size=(N,2)))-0.5)*10
y_train = 0.5*X_train[:,0] + 1*X_train[:,1] + 2*X_train[:,0]*X_train[:,1]
numpyro.render_model(model, model_args=(X_train,y_train,X_train), render_distributions=True, render_params=True)

# GP fitting
rng_key = random.PRNGKey(0)
num_chains = 1
hmc = MCMC(NUTS(model), num_samples=1000, num_warmup=1000, num_chains=num_chains)
hmc.run(rng_key, X_train, y_train, X_train)
hmc.print_summary(exclude_deterministic=True)  
hmc_samples = hmc.get_samples()
plt.errorbar(y_train, jnp.mean(hmc_samples["f"],axis=0), jnp.std(hmc_samples["f"],axis=0), color="k", fmt = '.',)

Running the code gives me the following error, as enumeration adds array dimensions to the kernel hyperparameters.

ValueError: The value of a constant kernel must be a scalar

Is there a workaround that would allow TinyGP kernels to be compatible with Numpyro enumeration? Thank you very much!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions