-
Notifications
You must be signed in to change notification settings - Fork 19
Description
Hi @keraJLi,
I'm experimenting with PPO+AMP algorithms by extending your PPO class. One thing I'm currently stuck at is the management of two different sets of observations, for PPO and AMP, respectively.
My idea was to return from my custom environment a pytree observation, i.e. basically a dict like:
obs = {
"ppo": jax.Array
"amp": jax.Array
}The observation space will hence be of type spaces.Dict. Of course I need to override a big part of the library since the gymnax environment interface defines observations as plain jax.Arrays.
Another thing I tried was to put the AMP observation (inappropriately) into the additional info dict returned from the env_step function, but, differently from the gymnasium API, during the auto-reset this dict is not updated, causing a divergence between these two observation types.
How would you approach this use case?