Thanks to visit codestin.com
Credit goes to GitHub.com

Skip to content

PPO without Memory Buffer (Jax) #293

@ShimBoi

Description

@ShimBoi

Description

Removing memory buffer in jax for PPO (Cartpole) causes errors of self._current_next_states being None since the variable never updates if self.memory = None.

Presented Fix: Move next state attribute setting regardless of memory state.

Previously:

def record_transition(...):
    super().record_transition(
        states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps
    )
                    
    if self.memory is not None:

          self._current_next_states = next_states
                      
          # reward shaping
          if self._rewards_shaper is not None:
              rewards = self._rewards_shaper(rewards, timestep, timesteps)
...

Fixed:

def record_transition(...):
    super().record_transition(
        states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps
    )

    self._current_next_states = next_states
                    
    if self.memory is not None:
                      
          # reward shaping
          if self._rewards_shaper is not None:
              rewards = self._rewards_shaper(rewards, timestep, timesteps)
...

What skrl version are you using?

1.4.1

What ML framework/library version are you using?

Jax 0.4.13, Flax 0.7.2, Optax 0.1.8

Additional system information

Linux

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions