Hello,
First, thanks for developing such an amazing package.
I am newbie to oryx and was playing around with its functionalities,
perhaps naively I had been attempting to evalute log_probs of blocker porblems, as below:
from jax.random import split
from oryx.core import ppl
import tensorflow_probability.substrates.jax.distributions as tfd
def latent_normal(key):
z_key,x_key= split(key)
z=ppl.random_variable(tfd.Normal(0,1),name="z")(z_key)
return ppl.random_variable(tfd.Normal(z,1e-1),name="x")(x_key)
blocked=ppl.block(latent_normal,names=["z"])
ppl.joint_log_prob(blocked)({"x":10})
However, it returns:
{
"name": "ValueError",
"message": "Cannot compute log_prob of function.",
"stack": "---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[2], line 12
8 return ppl.random_variable(tfd.Normal(z,1e-1),name="x")(x_key)
11 blocked=ppl.block(latent_normal,names=["z"])
---> 12 ppl.joint_log_prob(blocked)({"x":10})
File /opt/conda/lib/python3.11/site-packages/oryx/core/interpreters/log_prob.py:71, in log_prob..wrapped(sample, *args, **kwargs)
67 flat_incells = [
68 InverseAndILDJ.unknown(trace_util.get_shaped_aval(dummy_seed))
69 ] + [InverseAndILDJ.new(val) for val in flat_inargs]
70 flat_outcells = [InverseAndILDJ.new(a) for a in flat_outargs]
---> 71 return log_prob_jaxpr(jaxpr.jaxpr, constcells, flat_incells, flat_outcells)
File /opt/conda/lib/python3.11/site-packages/oryx/core/interpreters/log_prob.py:128, in log_prob_jaxpr(jaxpr, constcells, flat_incells, flat_outcells)
118 _, final_log_prob = propagate.propagate(
119 InverseAndILDJ,
120 log_prob_rules,
(...)
125 reducer=reducer,
126 initial_state=0.)
127 if final_log_prob is failed_log_prob:
--> 128 raise ValueError('Cannot compute log_prob of function.')
129 return final_log_prob
ValueError: Cannot compute log_prob of function."
}
Am I missing something?
Thanks again
Very Best
Giovanni
Hello,
First, thanks for developing such an amazing package.
I am newbie to oryx and was playing around with its functionalities,
perhaps naively I had been attempting to evalute log_probs of blocker porblems, as below:
However, it returns:
{
"name": "ValueError",
"message": "Cannot compute log_prob of function.",
"stack": "---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[2], line 12
8 return ppl.random_variable(tfd.Normal(z,1e-1),name="x")(x_key)
11 blocked=ppl.block(latent_normal,names=["z"])
---> 12 ppl.joint_log_prob(blocked)({"x":10})
File /opt/conda/lib/python3.11/site-packages/oryx/core/interpreters/log_prob.py:71, in log_prob..wrapped(sample, *args, **kwargs)
67 flat_incells = [
68 InverseAndILDJ.unknown(trace_util.get_shaped_aval(dummy_seed))
69 ] + [InverseAndILDJ.new(val) for val in flat_inargs]
70 flat_outcells = [InverseAndILDJ.new(a) for a in flat_outargs]
---> 71 return log_prob_jaxpr(jaxpr.jaxpr, constcells, flat_incells, flat_outcells)
File /opt/conda/lib/python3.11/site-packages/oryx/core/interpreters/log_prob.py:128, in log_prob_jaxpr(jaxpr, constcells, flat_incells, flat_outcells)
118 _, final_log_prob = propagate.propagate(
119 InverseAndILDJ,
120 log_prob_rules,
(...)
125 reducer=reducer,
126 initial_state=0.)
127 if final_log_prob is failed_log_prob:
--> 128 raise ValueError('Cannot compute log_prob of function.')
129 return final_log_prob
ValueError: Cannot compute log_prob of function."
}
Am I missing something?
Thanks again
Very Best
Giovanni