-
Notifications
You must be signed in to change notification settings - Fork 32
Open
Description
Hi @dfm,
I am building a Gaussian process model that contains discrete variables. Here is a simplified version:
import matplotlib.pyplot as plt
import jax
from jax import random, lax
import jax.numpy as jnp
import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, HMC
import numpy as onp
import tinygp
from tinygp import GaussianProcess, kernels, transforms
def model(X_train=None, y_train=None, X_test=None):
with numpyro.plate("dimensions", X_train.shape[1]) as d:
ls = numpyro.sample("ls", dist.Gamma(3, 0.5))
with numpyro.plate("dimensions_comb", 2**X_train.shape[1]-1) as m:
pis = numpyro.sample("pi", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'}) # Induces sparsity on the kernel structure
# pis = jnp.array([0,1,1.0])
kernel1 = transforms.Subspace(0, kernels.ExpSquared(ls[0]))
kernel2 = transforms.Subspace(1, kernels.ExpSquared(ls[1]))
kernel = pis[0]*kernel1 + pis[1]*kernel2 + pis[2]*kernel1*kernel2
gp = GaussianProcess(kernel, X_train, diag=0.1)
with jax.ensure_compile_time_eval():
with numpyro.plate("data", X_train.shape[0]):
numpyro.sample("gp", gp.numpyro_dist(), obs=y_train)
if y_train is not None:
with numpyro.plate("data", X_test.shape[0]):
numpyro.deterministic("f", gp.condition(y_train, X_test).gp.loc)
# Data creation
onp.random.seed(0)
N = 100
X_train = (jnp.array(onp.random.uniform(size=(N,2)))-0.5)*10
y_train = 0.5*X_train[:,0] + 1*X_train[:,1] + 2*X_train[:,0]*X_train[:,1]
numpyro.render_model(model, model_args=(X_train,y_train,X_train), render_distributions=True, render_params=True)
# GP fitting
rng_key = random.PRNGKey(0)
num_chains = 1
hmc = MCMC(NUTS(model), num_samples=1000, num_warmup=1000, num_chains=num_chains)
hmc.run(rng_key, X_train, y_train, X_train)
hmc.print_summary(exclude_deterministic=True)
hmc_samples = hmc.get_samples()
plt.errorbar(y_train, jnp.mean(hmc_samples["f"],axis=0), jnp.std(hmc_samples["f"],axis=0), color="k", fmt = '.',)
Running the code gives me the following error, as enumeration adds array dimensions to the kernel hyperparameters.
ValueError: The value of a constant kernel must be a scalar
Is there a workaround that would allow TinyGP kernels to be compatible with Numpyro enumeration? Thank you very much!
Metadata
Metadata
Assignees
Labels
No labels