How to replicate:
import oryx
import jax.numpy as jnp
oryx.bijectors.IteratedSigmoidCentered().forward(jnp.array([0., 0., 0.]))
Error message:
TypeError: abstract_eval_fun() missing 1 required keyword-only argument: 'debug_info'
I was wondering if I'm doing something wrong here, or this is just a bug with the bijector.
Thanks!
How to replicate:
Error message:
I was wondering if I'm doing something wrong here, or this is just a bug with the bijector.
Thanks!