- 
                Notifications
    
You must be signed in to change notification settings  - Fork 79
 
Open
Labels
questionUser queriesUser queries
Description
How to specify exact per leaf shape of a pytree, say a dict? And even further, the graph structure of a pytree.
For example, function f takes a dict as input:
@jaxtyped(typechecker=beartype)
def f(state: Dict{'x': Float[Array, 'b 10'], 'y': Float[Array, 'b 1']}):
    ...
# Example valid input
valid_state = {
    'x': jnp.ones((3, 10)), # b=3
    'y': jnp.zeros((3, 1))  # b=3, consistent
}
# Example invalid input (wrong shape for 'x')
invalid_state = {
    'x': jnp.ones((3, 99)), # Shape is not 'b 10'
    'y': jnp.zeros((3, 1))
}
f(valid_state)
try:
    f(invalid_state)
except Exception as e:
    print(f"\nError with invalid_state:\n{e}")(The above snippet is not going to work)
Hope the type checker can check every leaf's shape and the graph structure.
PyTree[Float[Array, 'b ...']] is good but not fine-grained.
I think this feature is quite intuitive, e.g., in RL, jax env's step function takes a complex state. Maybe there is ways or workaround but I failed to find one. Sorry for possible ignorance.
Metadata
Metadata
Assignees
Labels
questionUser queriesUser queries