I'm trying to get the logprob of a expit-transformed logistic-distributed variable but it always returns zero:
import jax.random
from oryx.core.ppl import random_variable, log_prob
from jax.scipy.special import expit
import oryx.distributions as tfd
def simple_sample(key):
a = random_variable(tfd.Logistic(0., 1.))(key)
return expit(a)
x = simple_sample(jax.random.PRNGKey(0))
print(x) # 0.41845703
print(log_prob(simple_sample)(0.5)) # 0.0
print(log_prob(simple_sample)(x)) # 0.0
Versions:
jax-0.4.25
oryx-0.2.6
(Both a exp transformed logistic variable and a expit transformed normal variable seems to work, so there is something special about this combination)
I'm trying to get the logprob of a expit-transformed logistic-distributed variable but it always returns zero:
Versions:
jax-0.4.25
oryx-0.2.6
(Both a exp transformed logistic variable and a expit transformed normal variable seems to work, so there is something special about this combination)