[ English | 中文 ]
OpenRLHF is a high-performance RLHF framework built on Ray, DeepSpeed and HF Transformers:
- Simple and easy to use: OpenRLHF is one of the simplest high-performance RLHF libraries currently available, and seamlessly compatible with Huggingface models and datasets.
- High performance: RLHF training spends 80% of the time on the sample generation stage. Thanks to the ability to use a large inference batch size with Ray and Packing Samples and vLLM generation acceleration, the performance of OpenRLHF 3~4x+ that of Optimized DeepSpeedChat with Hybrid Engine.
- Distributed RLHF: OpenRLHF distribute the Actor, Reward, Reference, and Critic models onto separate GPUs using Ray, while placing the Adam optimizer on the CPU. This enables full-scale fine-tuning of 70B+ models with multiple A100 80G GPUs and vLLM and 7B models across multiple 24GB RTX 4090 GPUs.
- PPO Implementation Optimization: We integrated the implementation tricks for PPO to improve the training stability, referencing Zhihu and the Notion blog.
More details are in Slides | Technical Report | Documents
- [2024/12] We analyzed the PPO, REINFORCE, GRPO and RLOO in the Notion Blogpost.
- Distributed PPO and REINFORCE/RLOO implementations based on Ray.
- Full RLHF fine-tuning support for models with over 70 billion parameters.
- Integration with vLLM for accelerated generation in RLHF tasks (--vllm_num_engines).
- Support for multiple reward models (--reward_pretrain model1,model2...) and remote reward models (--remote_rm_url).
- Implementation of DPO (Direct Preference Optimization)/IPO/cDPO and Kahneman-Tversky Optimization (KTO).
- Support for Iterative DPO (GitHub: Online-RLHF).
- Support for Rejection Sampling.
- Implementation of Conditional SFT (arXiv:2308.12050).
- Support for Knowledge Distillation (Microsoft: minillm).
- Integration of Process Reward Model (PRM).
- Packing of training samples for SFT, DPO, RM, PRM, and PPO (--packing_samples).
- Implementation of RingAttention (--ring_attn_size,--ring_head_stride).
- Support for Mixture of Experts (MoE) (--aux_loss_coef).
- Integration of FlashAttention2 (--flash_attn).
- Support for QLoRA (--load_in_4bit) and LoRA (--lora_rank,--target_modules).
- Compatibility with HuggingFace's tokenizer.apply_chat_templatefor datasets (--apply_chat_templateand--input_key).
- Logging support with Wandb (--use_wandb) and TensorBoard (--use_tensorboard).
- Checkpoint recovery functionality (--load_checkpointand--save_steps).
- Provided multi-node training scripts, such as DPO and Ray PPO.
| Feature | OpenRLHF | DSChat | CAIChat | TRL | 
|---|---|---|---|---|
| 70B+ Full Tuning with 16 A100-80GB | ✅ | ❌ | ❌ | ❌ | 
| 7B Full Tuning with 4 RTX4090 | ✅ | ❌ | ❌ | ❌ | 
| 34B DPO Full Tuning with 8 A100-80GB | ✅ | ❌ | ❌ | ❌ | 
| Inference Engine in PPO | ✅ | ✅ | ❌ | ❌ | 
| PPO Implementation Tricks | ✅ | ❌ | ❌ | ✅ | 
| Support QLoRA | ✅ | ❌ | ❌ | ✅ | 
| Support Mixtral 8*7b | ✅ | ❌ | ❌ | ❌ | 
| Support Unmerged Actor-Critic | ✅ | ✅ | ✅ | ❌ | 
| Support Multiple Reward Models | ✅ | ❌ | ❌ | ❌ | 
| Support Huggingface Models | ✅ | ✅ | ✅ | ✅ | 
| Easy-to-use | ✅ | ❌ (HybridEngine bugs) | ✅ | ✅ | 
To use OpenRLHF, first launch the docker container (Recommended) and pip install openrlhf inside the docker container:
# Launch the docker container
docker run --runtime=nvidia -it --rm --shm-size="10g" --cap-add=SYS_ADMIN -v $PWD:/openrlhf nvcr.io/nvidia/pytorch:24.07-py3 bash
sudo pip uninstall xgboost transformer_engine flash_attn -y
# pip install
pip install openrlhf
# If you want to use vLLM acceleration (Install vLLM 0.6.4.post1)
pip install openrlhf[vllm]
# latest vLLM is also supported
pip install openrlhf[vllm_latest]
# pip install the latest version
pip install git+https://github.com/OpenRLHF/OpenRLHF.git
# Or git clone
git clone https://github.com/OpenRLHF/OpenRLHF.git
cd OpenRLHF
pip install -e .Note
We recommend using vLLM 0.6.4+ (Only multi-nodes support NCCL weight synchronization) or vLLM 0.4.2 (--vllm_sync_backend nccl), as other versions currently require synchronizing weights via Gloo (--vllm_sync_backend gloo).
We also provided the Dockerfiles for vLLM and One-Click Installation Script of Nvidia-Docker.
OpenRLHF provides multiple data processing methods in our dataset classes. Such as in the Prompt Dataset:
def preprocess_data(data, input_template=None, input_key="input", apply_chat_template=None) -> str:
    if apply_chat_template:
        prompt = apply_chat_template(data[input_key], tokenize=False, add_generation_prompt=True)
    else:
        prompt = data[input_key]
        if input_template:
            prompt = input_template.format(prompt)
    return prompt- We can use --input_keyto specify theJSON key nameof the input datasets--prompt_data {name or path}(PPO) or--dataset {name or path}, and use--apply_chat_templateto utilize thechat_templatefrom the Huggingface Tokenizer.
- If you don't want to use --apply_chat_template, you can use--input_templateinstead, or preprocess the datasets offline in advance.
- OpenRLHF also support mixing multiple datasets using --prompt_data_probs 0.1,0.4,0.5(PPO) or--dataset_probs 0.1,0.4,0.5.
How Chat Templating Works:
dataset = [{"input_key": [
  {"role": "user", "content": "Hello, how are you?"},
  {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
  {"role": "user", "content": "I'd like to show off how chat templating works!"},
]}]
tokenizer.apply_chat_template(dataset[0]["input_key"], tokenize=False)
"<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]"How to specify training and test datasets ?
You can specify it using the data_type@data_dir format. For example, the dataset can be set as --dataset json@./data.
data
├── test.jsonl
└── train.jsonl
Note
By default, we use train and test as splits to distinguish training and testing datasets from Huggingface.
The JSON key options depends on the specific datasets. See Reward Dataset and SFT Dataset
OpenRLHF's model checkpoint is fully compatible with HuggingFace models. You can specify the model name or path using --pretrain  {name or path}, --reward_pretrain  {name or path} and --critic_pretrain  {name or path}. We have provided some pre-trained checkpoints and datasets on HuggingFace OpenRLHF.
Then you can use the startup scripts we provide in the examples/scripts directory, or start the training using the following commands.
deepspeed --module openrlhf.cli.train_sft \
   --max_len 4096 \
   --dataset Open-Orca/OpenOrca \
   --input_key question \
   --output_key response \
   --input_template $'User: {}\nAssistant: ' \
   --train_batch_size 256 \
   --micro_train_batch_size 2 \
   --max_samples 500000 \
   --pretrain meta-llama/Meta-Llama-3-8B \
   --save_path ./checkpoint/llama3-8b-sft \
   --save_steps -1 \
   --logging_steps 1 \
   --eval_steps -1 \
   --zero_stage 2 \
   --max_epochs 1 \
   --packing_samples \
   --bf16 \
   --flash_attn \
   --learning_rate 5e-6 \
   --gradient_checkpointing \
   --use_wandb {wandb_token}
# Support HF tokenizer.apply_chat_template
# --apply_chat_template 
# --input_key {JSON Key}
# --tokenizer_chat_template {HF Chat Template}
# Can also be used for continued pre-training
# --pretrain_modeNote
OpenRLHF SFT/DPO/RewardModel/PPO trainers support --packing_samples based on --flash_attn
deepspeed --module openrlhf.cli.train_rm \
   --save_path ./checkpoint/llama3-8b-rm \
   --save_steps -1 \
   --logging_steps 1 \
   --eval_steps -1 \
   --train_batch_size 256 \
   --micro_train_batch_size 1 \
   --pretrain OpenRLHF/Llama-3-8b-sft-mixture \
   --bf16 \
   --max_epochs 1 \
   --max_len 8192 \
   --zero_stage 3 \
   --learning_rate 9e-6 \
   --dataset OpenRLHF/preference_dataset_mixture2_and_safe_pku \
   --apply_chat_template \
   --chosen_key chosen \
   --rejected_key rejected \
   --flash_attn \
   --packing_samples \
   --gradient_checkpointing \
   --use_wandb {wandb_token}
It is recommended to set the --value_prefix_head option of the Reward Model to score, so that we can load the model using AutoModelForSequenceClassification:
reward_model = AutoModelForSequenceClassification.from_pretrained(
              reward_model_path,
              num_labels=1,
              torch_dtype=torch.bfloat16,
              attn_implementation="flash_attention_2",
              use_cache=False,
          )
inputs = xxxx (Left Padding Input Tokens)
reward = reward_model.model(*inputs)
reward = reward_model.score(reward)[:, -1]deepspeed --module openrlhf.cli.train_ppo \
  --pretrain OpenRLHF/Llama-3-8b-sft-mixture \
  --reward_pretrain OpenRLHF/Llama-3-8b-rm-mixture \
  --save_path ./checkpoint/llama-3-8b-rlhf \
  --save_steps -1 \
  --logging_steps 1 \
  --eval_steps -1 \
  --micro_train_batch_size 2 \
  --train_batch_size 128 \
  --micro_rollout_batch_size 4 \
  --rollout_batch_size 1024 \
  --max_epochs 1 \
  --prompt_max_len 1024 \
  --generate_max_len 1024 \
  --zero_stage 2 \
  --bf16 \
  --actor_learning_rate 5e-7 \
  --critic_learning_rate 9e-6 \
  --init_kl_coef 0.01 \
  --prompt_data OpenRLHF/prompt-collection-v0.1 \
  --input_key context_messages \
  --apply_chat_template \
  --max_samples 100000 \
  --normalize_reward \
  --adam_offload \
  --flash_attn \
  --gradient_checkpointing \
  --use_wandb {wandb_token}
# Support remote reward model (HTTP)
# --remote_rm_url http://localhost:5000/get_rewardTo improve RLHF training speed or support 70B models, we can use the PPO with Ray and vLLM acceleration
# launch the master node of ray in container
ray start --head --node-ip-address 0.0.0.0 --num-gpus 8
# if you want to launch ray on more nodes, use
ray start --address {MASTER-NODE-ADDRESS}:6379  --num-gpus 8
ray job submit --address="http://127.0.0.1:8265" \
  --runtime-env-json='{"working_dir": "/openrlhf"}' \
  -- python3 -m openrlhf.cli.train_ppo_ray \
  --ref_num_nodes 1 \
  --ref_num_gpus_per_node 2 \
  --reward_num_nodes 1 \
  --reward_num_gpus_per_node 2 \
  --critic_num_nodes 1 \
  --critic_num_gpus_per_node 2 \
  --actor_num_nodes 1 \
  --actor_num_gpus_per_node 2 \
  --vllm_num_engines 2 \
  --vllm_tensor_parallel_size 2 \
  --colocate_critic_reward \
  --colocate_actor_ref \
  --pretrain OpenRLHF/Llama-3-8b-sft-mixture \
  --reward_pretrain OpenRLHF/Llama-3-8b-rm-mixture \
  --save_path /openrlhf/examples/checkpoint/llama3-8b-rlhf \
  --micro_train_batch_size 8 \
  --train_batch_size 128 \
  --micro_rollout_batch_size 16 \
  --rollout_batch_size 1024 \
  --max_samples 100000 \
  --max_epochs 1 \
  --prompt_max_len 1024 \
  --generate_max_len 1024 \
  --zero_stage 3 \
  --bf16 \
  --actor_learning_rate 5e-7 \
  --critic_learning_rate 9e-6 \
  --init_kl_coef 0.01 \
  --prompt_data OpenRLHF/prompt-collection-v0.1 \
  --input_key context_messages \
  --apply_chat_template \
  --normalize_reward \
  --packing_samples \
  --adam_offload \
  --flash_attn \
  --gradient_checkpointing \
  --use_wandb {wandb_token}
# --vllm_sync_backend nccl (Only for multi-nodes with vLLM 0.6.4+ or vLLM 0.4.2)
# Support remote reward model (HTTP)
# --remote_rm_url http://localhost:5000/get_reward
# Support REINFORCE | RLOO
# --advantage_estimator reinforce | rloo
# Support N samples
# --n_samples_per_prompt 4Note
Do not set --vllm_num_engines means not using the vLLM engine.
You can also use setup_commands to let Ray automatically deploy the environment, such as --runtime-env-json='{"setup_commands": ["pip install openrlhf[vllm]"]}'.
Note
If you you encounter an error related to index out of range when deepspeed sets up the GPU devices, you can try to set the environment variable RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES as a workaround.
# For NVIDIA GPUs:
export RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1The launch scripts and documents for supported algorithms are in example/scripts and Documents - Usage
We optimized DSChat's performance to the greatest extent possible by employing techniques such as enabling Adam offload, along with reward model (RM) and reference model (Ref) offload to increase the micro-batch size during the inference stage and avoid out-of-memory issues. We even fixed some bugs in DSChat to enable the Hybrid Engine (HE) for LLaMA2. The average time (seconds) it took to train 1024 prompts with 1 PPO epoch using the Optimized DSChat and OpenRLHF:
| Size | NVIDIA A800-80GB GPUs | Optimized DSChat (with Hybrid Engine) | OpenRLHF | Speedup | 
|---|---|---|---|---|
| 7B | 16 | 855.09 | 471.11 | 1.82x | 
| 13B | 32 | 1528.93 | 608.93 | 2.5x | 
| 34B | 32 | 3634.98 | 1526.4 | 2.4x | 
| 70B | 32 | 10407.0 | 4488.53 | 2.3x | 
Note
The data is outdated; please refer to the performance tuning section for re-testing.
To achieve optimal performance, we recommend allocating more nodes to the vLLM Engine. For example, for a 70B model with 32 A100 GPUs, it is advised to allocate 16 A100 GPUs to the vLLM Engine, 8 GPUs to the Actor model, and the remaining 8 GPUs to the Critic model. Additionally, enable the --colocate_critic_reward, --colocate_actor_ref options to merge nodes. Finally, you should increase the rollout_micro_batch_size (and minimize the TP size of vLLM engine) as much as possible. During the training phase, a larger --micro_train_batch_size is better and enable --packing_samples. When there are enough GPUs, please disable --adam_offload. For multi-nodes RLHF, please use --vllm_sync_backend nccl with vLLM 0.6.4+.
- ByteDance
- Tencent
- Alibaba
- Baidu
- China Telecom
- Vivo
- Allen AI
- NexusFlow
- Jülich Supercomputing Centre (JSC)
- Berkeley Starling Team
- M-A-P
- ...
How to Join?
- Email us at [email protected] or join GitHub Organization. Please include the following details:
- Your name
- Your GitHub username
- Your areas of interest
- Your skills and experience related to NLP and/or AI
 
- You can also join us through the official GitHub OpenRLHF ↗ project page. Just create an issue about your interest to contribute and we will get back to you.
What can you do?
- Join the team and participate in the development of the OpenRLHF project.
- Contribute to the project by submitting pull requests.
- Help improve documentation, fix bugs, or create new features.
- Share the project and help us grow the community.
Your sponsorship can help us maintain and improve OpenRLHF. If you find this project useful, please consider sponsoring us. You can sponsor us on Open Collective ↗.
A big thank you to all our contributors! If you want to contribute, feel free to make a pull request or create an issue.
We would like to express our gratitude to the following projects and organizations for their contributions to the field of AI and NLP:
Our project would also like to thank ColossalChat and DeepSpeedChat. In the early stages of the project, we referred to their code design.
(2024/7) Our GitHub organization has changed from OpenLLMAI to OpenRLHF.
@article{hu2024openrlhf,
  title={OpenRLHF: An Easy-to-use, Scalable and High-performance RLHF Framework},
  author={Jian Hu and Xibin Wu and Zilin Zhu and Xianyu and Weixun Wang and Dehao Zhang and Yu Cao},
  journal={arXiv preprint arXiv:2405.11143},
  year={2024}
}
OpenRLHF © 2024 OpenRLHF. All Rights Reserved.