Thanks to visit codestin.com
Credit goes to github.com

Skip to content
/ sbx Public
forked from araffin/sbx

SBX: Stable Baselines Jax (SB3 + Jax)

License

Chaiinik/sbx

 
 

Repository files navigation

CI codestyle

Stable Baselines Jax (SB3 + Jax = SBX)

Proof of concept version of Stable-Baselines3 in Jax.

Implemented algorithms:

Install using pip

For the latest master version:

pip install git+https://github.com/araffin/sbx

or:

pip install sbx-rl

Example

import gymnasium as gym

from sbx import TQC, DroQ, SAC, PPO, DQN, TD3, DDPG

env = gym.make("Pendulum-v1")

model = TQC("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000, progress_bar=True)

vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = vec_env.step(action)
    vec_env.render()

vec_env.close()

Using SBX with the RL Zoo

Since SBX shares the SB3 API, it is compatible with the RL Zoo, you just need to override the algorithm mapping:

import rl_zoo3
import rl_zoo3.train
from rl_zoo3.train import train
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ

rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.train.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS

if __name__ == "__main__":
    train()

Then you can run this script as you would with the RL Zoo:

python train.py --algo sac --env HalfCheetah-v4 -params train_freq:4 gradient_steps:4 -P

The same goes for the enjoy script:

import rl_zoo3
import rl_zoo3.enjoy
from rl_zoo3.enjoy import enjoy
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ

rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.enjoy.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS

if __name__ == "__main__":
    enjoy()

Citing the Project

To cite this repository in publications:

@article{stable-baselines3,
  author  = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
  title   = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
  journal = {Journal of Machine Learning Research},
  year    = {2021},
  volume  = {22},
  number  = {268},
  pages   = {1-8},
  url     = {http://jmlr.org/papers/v22/20-1364.html}
}

Maintainers

Stable-Baselines3 is currently maintained by Ashley Hill (aka @hill-a), Antonin Raffin (aka @araffin), Maximilian Ernestus (aka @ernestum), Adam Gleave (@AdamGleave), Anssi Kanervisto (@Miffyli) and Quentin Gallouédec (@qgallouedec).

Important Note: We do not do technical support, nor consulting and don't answer personal questions per email. Please post your question on the RL Discord, Reddit or Stack Overflow in that case.

How To Contribute

To any interested in making the baselines better, there is still some documentation that needs to be done. If you want to contribute, please read CONTRIBUTING.md guide first.

About

SBX: Stable Baselines Jax (SB3 + Jax)

Resources

License

Code of conduct

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.2%
  • Other 0.8%