A reinforcement learning system that improves synthetic data generation by training a generator (Llama-3.1-8B) using PPO, guided by classifier performance on real data and distributional quality metrics.
This project implements a closed-loop system where:
- A generator produces synthetic data using in-context learning
- A classifier is trained on the synthetic data
- The classifier's performance on real data, combined with distributional quality metrics, forms a reward signal
- The generator is optimized using PPO to maximize this reward
The system follows this pipeline:
- Generate synthetic data using the generator with ICL examples
- Train a RoBERTa classifier on the synthetic data
- Evaluate the classifier on golden (real) data to get the golden loss
- Compute inter-class and intra-class distances of the generated data
- Combine these metrics into a reward signal
- Use PPO to align the generator based on the reward
- Repeat until convergence
pip install -r requirements.txtpython scripts/train.py --config config/config.yamlpython scripts/evaluate.py --config config/config.yaml --checkpoint outputs/checkpoint_epoch_10Edit config/config.yaml to customize:
- Model parameters
- Training hyperparameters
- Reward weights
- Convergence criteria
- Data paths
genalign/
├── config/ # Configuration files
├── src/ # Source code modules
│ ├── data/ # Data loading and sampling
│ ├── generator/ # Llama-3.1-8B generator
│ ├── classifier/ # RoBERTa classifier
│ ├── metrics/ # Distance computation
│ ├── reward/ # Reward computation
│ ├── rl/ # PPO training
│ └── utils/ # Utilities
├── scripts/ # Training and evaluation scripts
└── outputs/ # Model checkpoints and logs
- Python 3.8+
- CUDA-compatible GPU (recommended)
- 16GB+ RAM
- 50GB+ disk space for model cache
MIT License