By Subham Sekhar Sahoo, Justin Deschenaux, Aaron Gokaslan, Guanghan Wang, Justin Chiu, Volodymyr Kuleshov
We unlock few-step generation in discrete diffusion language models via the underlying Guassian diffusion.
In this repo, we release:
- The DUO framework
- Baseline implementations [Examples]:
- Autoregressive Model.
- MDLM: Sahoo et al., "Simple and Effective Masked Diffusion Language Model", NeurIPS 2024.
- SEDD (absorb): Lou et al., "Score Entropy Based Discrete Diffusion", ICML 2024.
- D3PM (absorb) Austin et al., "Structured Denoising Diffusion Models in Discrete State-Spaces", NeurIPS 2021.
To get started, create a conda environment containing the required dependencies.
conda create -n duo python=3.12
conda activate duo
conda install nvidia/label/cuda-12.4.0::cuda-toolkit
pip install -r requirements.txt
pip install flash_attn==2.7.4.post1Curriculum Learning (Sec. 4.1) and Discrete Consistency Distillation (Sec. 4.2) require mapping Gaussian to discrete diffusion parameters via the Diffusion Transformation operator (Sec. 3), which involves computing an integral (dependent only on the tokenizer’s vocabulary size). To avoid slowing down training, we pre-compute and cache this integral. Cached operators for bert-base-uncased (LM1B) and gpt2 (OWT) are in integral/. For other tokenizers, run:
python utils.py --vocab_size=N
where N is the vocabulary size of the tokenizer.
The checkpoints for the DUO models (distilled/undistilled) trained on OpenWebText for 1M training steps are available on:
- Huggingface🤗.
- Google Drive folder as the HF checkpoints can't be finetuned.
Run mkdir watch_folder to create a directory to store slurm logs
and then run any script in scripts/ as a slurm job:
sbatch scripts/ABC_XYZ.shTo train DUO use the following scripts:
-
LM1B
- w/ sentencepacking (same as in D3PM)
- Training script:
scripts/train_lm1b_duo_sentencepacking.sh - Wandb run
- Training script:
- w/o sentencepacking (same as in MDLM, SEDD)
- Training script:
scripts/train_lm1b_duo.sh - Wandb run
- Training script:
- w/ sentencepacking (same as in D3PM)
-
OWT:
scripts/train_owt_duo.sh.
Curriculum Learning increases memory consumption. For faster training on OWT, one may consider a two-stage approach:
Stage 1: Curriculum Learning for500Ksteps- Use
scripts/train_owt_duo.shwith the following modifications:- Reduced batch size (
loader.batch_size=32on an80 GBGPU) trainer.max_steps=500000
- Reduced batch size (
- Use
Stage 2: Finetuning the checkpoint fromstage 1for500Kmore steps- Training script:
scripts/train_owt_duo_finetune.sh - Features a larger batch size (
loader.batch_size=64on an80 GB) thanstage 1. - Wandb run: This run resumes training a
stage 1checkpoint. Although trained for1Msteps, the results reported in the paper correspond to the checkpoint at500Ksteps.
- Training script:
Control the batch size per GPU using the argument loader.batch_size. If loader.batch_size * num_gpus < loader.global_batch_size, PyTorch Lightning resorts to gradient accumulation.
To distil a model using the Discrete Consisitency Distillation (Alg. 1 in the paper), use scripts/distil_owt.sh
To compute test perplexity on the validtion set of OWT use scripts/eval_owt_duo.sh and for zero shot perplexities use scripts/zero_shot_duo.sh.
To generate samples from a pre-trained model use one of the following command. Set
sampling.noise_removal=greedyto use the "Greedy-tail sampler" (equivalent to nucleus sampling in AR models; seeSec. 4.2in the paper).sampling.noise_removal=ancestralfor the standard ancestral sampling. This produces more diverse samples (higher entropy) but with worse generative perplexity.
We have realease the distilled model s-sahoo/duo-distilled and the un-distilled model s-sahoo/duo on Huggingface🤗. To sample from a HF model, run the following command:
python main.py \
mode=sample_eval \
loader.batch_size=2 \
loader.eval_batch_size=8 \
data=openwebtext-split \
algo=duo_base \
algo.backbone=hf_dit \
eval.checkpoint_path=s-sahoo/duo-distilled \
sampling.steps=8 \
sampling.num_sample_batches=1 \
sampling.noise_removal=greedy \
+wandb.offline=true We’ve also released checkpoints for the distilled duo-distilled.ckpt and the un-distilled model duo.ckpt trained on OWT in this Google Drive folder. Download them and use the command in scripts/gen_ppl_owt_duo.sh while specifying the paths correctly.
We release the checkpoints for the baselines: SEDD, MDLM and AR trained on OpenWebText in this Google Drive folder. Download the checkpoints: ar.ckpt, mdlm.ckpt, sedd.ckpt and specify the paths appropriately in the respective shell scripts:
scripts/eval_owt_*.shfor computing validation perplexity on OWT.scripts/gen_ppl_*.shfor generating text samples and evaluating them.scripts/zero_shot_*.shfor computing zero shot perplexities.scripts/train_*.shfor training the models.
This repository was built off of MDLM's Github repository. Cite our paper using:
@inproceedings{
sahoo2025the,
title={The Diffusion Duality},
author={Subham Sekhar Sahoo and Justin Deschenaux and Aaron Gokaslan and Guanghan Wang and Justin T Chiu and Volodymyr Kuleshov},
booktitle={Forty-second International Conference on Machine Learning},
year={2025},
url={https://openreview.net/forum?id=9P9Y8FOSOk}
}