-
Notifications
You must be signed in to change notification settings - Fork 124
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Description
Removing memory buffer in jax for PPO (Cartpole) causes errors of self._current_next_states being None since the variable never updates if self.memory = None.
Presented Fix: Move next state attribute setting regardless of memory state.
Previously:
def record_transition(...):
super().record_transition(
states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps
)
if self.memory is not None:
self._current_next_states = next_states
# reward shaping
if self._rewards_shaper is not None:
rewards = self._rewards_shaper(rewards, timestep, timesteps)
...Fixed:
def record_transition(...):
super().record_transition(
states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps
)
self._current_next_states = next_states
if self.memory is not None:
# reward shaping
if self._rewards_shaper is not None:
rewards = self._rewards_shaper(rewards, timestep, timesteps)
...What skrl version are you using?
1.4.1
What ML framework/library version are you using?
Jax 0.4.13, Flax 0.7.2, Optax 0.1.8
Additional system information
Linux
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working