This repository contains the implementation of "Online Reward-Weighted Fine-Tuning of Flow Matching with Wasserstein Regularization," a method for fine-tuning flow-based generative models using reinforcement learning.
ORW-CFM-W2 is a novel reinforcement learning approach for fine-tuning continuous flow-based generative models to align with arbitrary user-defined reward functions. Unlike previous methods that require filtered datasets or gradients of rewards, our method enables optimization with arbitrary reward functions while preventing policy collapse through Wasserstein-2 distance regularization.
Our approach integrates reinforcement learning into the flow matching framework through three key components:
- Online Reward-Weighting: Guides the model to prioritize high-reward regions in the data manifold
- Wasserstein-2 Regularization: Prevents policy collapse and maintains diversity
- Tractable W2 Distance Bound: Enables efficient computation of the W2 distance in flow matching models
The loss function is defined as:
Where:
-
$w(x_1) \propto r(x_1)$ is the weighting function proportional to the reward -
$v_{\theta_{\text{ft}}}$ is the fine-tuned model's vector field -
$v_{\theta_{\text{ref}}}$ is the reference (pre-trained) model's vector field -
$u_t(x|x_1)$ is the true conditional vector field -
$\alpha$ is the regularization coefficient that controls the trade-off between reward and diversity
The core implementation is in the ORWCFMTrainer class, which handles:
- Initialization of models (network model, last policy, reference model)
- Sampling from the current policy
- Computing rewards for samples
- Computing the loss with both FM and W2 components
- Updating the model parameters
- Periodically updating the sampling policy
from torchcfm.conditional_flow_matching import ExactOptimalTransportConditionalFlowMatcher
from torchcfm.models.unet.unet import UNetModelWrapper
from orwcfm import ORWCFMTrainer
# Define configuration
config = {
'learning_rate': 2e-4,
'warmup_steps': 5000,
'w2_coefficient': 1.0, # alpha parameter for W2 regularization, we encourage you to use at least alpha>=1.0
'temperature': 0.5, # tau parameter for reward weighting
'grad_clip': 1.0,
'batch_size': 128,
'text_prompts': ["An image of dog", "Not an image of dog"],
'use_wandb': True,
'wandb_project': 'flow-matching',
'run_name': 'orw-cfm-w2',
'savedir': './results',
'ref_path': './pretrained/model.pt' # Path to pre-trained model
}
# Initialize model
model = UNetModelWrapper(
dim=(3, 32, 32),
num_res_blocks=2,
num_channels=128,
channel_mult=[1, 2, 2, 2],
num_heads=4,
num_head_channels=64,
attention_resolutions="16",
dropout=0.1
)
# Initialize device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize trainer
trainer = ORWCFMTrainer(model, config, device)
# Load pre-trained model
trainer.load_pretrained(config['ref_path'])
# Train model
trainer.train(
num_epochs=1000,
steps_per_epoch=100
)
# Save checkpoint
trainer.save_checkpoint('./checkpoints/orw_cfm_w2.pt')- w2_coefficient (alpha): Controls the strength of the W2 regularization. Higher values prioritize staying close to the reference model, leading to more diverse outputs. Lower values prioritize reward maximization.
- temperature (tau): Controls the sharpness of the reward weighting. Higher values lead to more aggressive focusing on high-reward regions.
Our method provides the following theoretical guarantees:
-
Convergence Behavior: The data distribution after N epochs evolves according to:
$$q^N_{\theta}(x_1) \propto w(x_1) q^{N-1}_{\theta}(x_1) \exp(-\beta D^{N-1}(x_1))$$ Where
$D^{N-1}(x_1)$ measures the discrepancy between the current and reference models. -
Limiting Behavior: Without regularization (α=0), the model converges to a delta distribution centered at the maximum reward point.
-
Reward-Diversity Trade-off: W2 regularization enables a controllable trade-off between reward maximization and diversity preservation.
If you find this code useful for your research, please consider citing our paper:
@inproceedings{
fan2025online,
title={Online Reward-Weighted Fine-Tuning of Flow Matching with Wasserstein Regularization},
author={Jiajun Fan and Shuaike Shen and Chaoran Cheng and Yuxin Chen and Chumeng Liang and Ge Liu},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=2IoFFexvuw}
}- PyTorch
- TorchCFM
- wandb (optional, for logging)
- tqdm
For experiments like CIFAR-10 or MNIST, we recommend using pre-trained flow matching models from the torch-cfm repository. You can train a model using their example script:
git clone https://github.com/atong01/conditional-flow-matching
cd conditional-flow-matching
pip install -e .
python examples/images/cifar10/train_cifar10.py