Welcome to the UniRL project, where we supercharge reinforcement learning with joint diffusion model and language model experts with a user-friendly, lightweight codebase! 🎉 Ready to jump in? 🧙♂️
Our mission is to enable seamless joint training of language models and diffusion models for reinforcement learning.
UniRL might supports RL training on a diverse range of powerful pretrained models:
- language models:
Qwen,Qwen-VL - diffusion models:
Flux-dev,Flux-Kontext,Stable Diffusion,SANA - unified models:
MetaQuery,Blip3o. - and more importantly, joint RL training on disjoint language models and diffusion models, such as
QwenandFlux. - reward models available:
aesthetic,deqa,editreward,geneval,hps,image_reward,paddle_ocr,pickscore,unifiedreward, ...
- [2025-10-20]⭐️: Release the initial version code and report of project
UniRL.
Follow these steps to get started:
-
Create the Conda Environment 🌍:
conda env create -f environment.yml
-
Install CLIP 🖼️:
pip install git+https://github.com/openai/CLIP.git
-
Install Diffusers 🎨:
git clone https://github.com/huggingface/diffusers.git cd diffusers pip install -e . cd ..
-
Install Flash Attention ⚡️:
pip install flash-attn==2.7.4.post1 --no-build-isolation
-
Activate the Environment 🚀:
conda activate unirl
-
Install Reward Services 😊:
cd rewards_services/api_services # install the environments of api services independently following the readme files. # take aesthetic score as an example: cd aesthetic_scorer_service conda create -n aes python=3.10 -y conda activate aes pip install -r requirements.txt bash run.sh
After lanuching all the reward services, please modify the
[NODE_ADDR](inunirl/reward_evaluator/reward_evaluator.py,unirl/trainer/grpo_pmatters_trainer.py) as the address of the reward service machine.
We patched the TorchCheckpointEngine in DeepSpeed to address a bug that was preventing training resumption. Our fix? Adding weights_only=False to the torch.load function in the load method. This ensures you can resume training without a hitch! 💪
The modified file is located at:
[PATH TO MINICONDA]/miniconda3/envs/unirl/lib/python3.11/site-packages/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py
Here's the updated code:
class TorchCheckpointEngine(CheckpointEngine):
....
def load(self, path: str, map_location=None):
logger.info(f"[Torch] Loading checkpoint from {path}...")
# partition = torch.load(path, map_location=map_location)
partition = torch.load(path, map_location=map_location, weights_only=False)
logger.info(f"[Torch] Loaded checkpoint from {path}.")
return partition
...The UniRL project encompasses the following key components:
- Pretraining of Unified Understanding and Generation Models: We develop and pretrain models that integrate multimodal understanding and generative capabilities, enabling robust feature representation and high-quality content generation across diverse tasks.🧠
- Reinforcement Learning on Unified Models: We implement reinforcement learning algorithms tailored for unified understanding and generation models, optimizing policies for enhanced decision-making and performance in complex environments.⚙️
- Joint Reinforcement Learning with Pretrained Models: UniRL supports joint reinforcement learning with pretrained large language models (e.g., Qwen, QwenVL) and diffusion models (e.g., Stable Diffusion 3, FLUX, FLUX-Kontext, SANA), facilitating advanced, multimodal policy training and evaluation.✨
-
Set up the environment as outlined above.
-
Activate the
unirlenvironment:conda activate unirl
-
Run your experiments and enjoy seamless training! 😎
For pretraining the base unified understanding and generation model,
bash scripts/train/pretrain/train.sh
which will automatically download the model weights and datasets.
For RL training,
bash scripts/train/rl/train_blip3o_[setting_name].sh
Please correctly set up the setting name and WANDB_KEY in the
train_blip3o_[setting_name].shfile. -
Inference
For image generation inference:
python -m scripts.inference.inference_blip3o_t2i
For image editing inference:
python -m scripts.inference.inference_blip3o_i2i
Got ideas to make UniRL even better? Submit a pull request or open an issue to join the fun! We love community contributions. 🌈
This project is licensed under the Apache-2.0 License. See the LICENSE file for details.
Check out these amazing related works to explore more in reinforcement learning and generative models:
- GRPO: DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models
- ReMax: ReMax: A Simple, Effective, and Efficient Reinforcement Learning Method for Aligning Large Language Models
- RLOO: Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs
- Flow-GRPO: Flow-GRPO: Training Flow Matching Models via Online RL
- DanceGRPO: DanceGRPO: Unleashing GRPO on Visual Generation
We thank the following library for providing a robust foundation for code: