Thanks to visit codestin.com
Credit goes to github.com

Skip to content

LituRout/ADLM

Repository files navigation

Anchored Diffusion Language Model

Project page badge ArXiv badge GitHub stars NeurIPS 2025 badge

Anchored Diffusion Language Model (ADLM) introduces an alternate noising schedule for masked diffusion language models. Rather than masking tokens uniformly at random, ADLM keeps important anchor tokens visible longer during the forward process, and therefore equivalently unmasks them early during denoising. This anchored two-stage framework yields better likelihood estimates and higher-quality text generations. In practice, ADLM:

  • recovers up to 25.4% perplexity improvements on LM1B and OpenWebText over prior DLMs,
  • narrows the gap to autoregressive baselines and becomes the first to surpass them in terms of MAUVE score,
  • extends gracefully to seven zero-shot benchmarks (LAMBADA, PTB, Wikitext-2/103, LM1B, AG News, PubMed, arXiv).

We also derive the Anchored Negative Evidence Lower Bound (ANELBO), establishing theoretical gains in sample complexity and likelihood modeling.

🔔 News

  • [2025.10.13] ADLM pretraining checkpoints are now live.
  • [2025.10.13] ADLM codebase has been open-sourced.
  • [2025.09.18] ADLM was accepted to NeurIPS 2025 🏆
  • [2025.05.24] Paper posted on arXiv.

Contents


Quickstart ⚡

Create the environment and install the FlashAttention dependency:

conda env create -f requirements.yml
conda activate adlm
pip install flash-attn==2.6.3

Set up output directories for checkpoints and logs:

mkdir -p outputs
mkdir -p watch_folder
mkdir -p ckpts

The project configuration lives under configs/. Zero-shot evaluation datasets can be swapped by selecting the appropriate YAML in configs/data/.


Training 🏋️

Launch ADLM training (Slurm job script):

sbatch scripts/adlm.sh

For quick local debugging, you can launch a single-node run with the following command (feel free to trim overrides as needed):

torchrun \
  adlm_main.py \
  loader.global_batch_size=64 \
  model=small \
  data=openwebtext-split \
  wandb.name=adlm \
  parameterization=subs \
  model.length=1024 \
  eval.compute_generative_perplexity=True \
  sampling.steps=1000 \
  checkpointing.save_dir="outputs/adlm/" \
  trainer.num_nodes=1 \
  trainer.val_check_interval=10000 \
  trainer.log_every_n_steps=1000 \
  trainer.max_steps=1_000_000 \
  checkpointing.resume_from_ckpt=True \
  time_conditioning=False \
  enable_anchor_loss=True \
  base_scaling_factor1=3e-3 \
  base_scaling_factor2=1 \
  threshold=5

Checkpoints 💾


Evaluation 🎯

MAUVE score, Generative Text Perplexity (Gen PPL) and Entropy 📈

Runs sample generation along with MAUVE, entropy, and generative perplexity metrics:

sbatch scripts/adlm_eval.sh

For quick local debugging without slurm, you can reuse the same evaluation pipeline with a smaller number of sampling steps (128) and small sample batches (10):

checkpoint_path=ckpts/adlm-medium.ckpt
sampling_steps=128
SEED=5
NUM_SAMPLE_BATCHES=10
generated_seqs_path=outputs/adlm_samples-${NUM_SAMPLE_BATCHES}_seed-${SEED}_steps-${sampling_steps}.json

python -u -m adlm_main \
  mode=sample_eval \
  data=openwebtext-split \
  model=small \
  parameterization=subs \
  backbone=dit \
  model.length=1024 \
  eval.checkpoint_path="${checkpoint_path}" \
  loader.batch_size=1 \
  loader.eval_batch_size=1 \
  eval.perplexity_batch_size=1 \
  sampling.steps=${sampling_steps} \
  sampling.num_sample_batches=${NUM_SAMPLE_BATCHES} \
  sampling.generated_seqs_path="${generated_seqs_path}" \
  sampling.sampler=remdm-loop \
  sampling.nucleus_p=0.9 \
  sampling.eta=0.02 \
  sampling.t_on=0.55 \
  sampling.t_off=0.05 \
  sampling.alpha_on=0.9 \
  seed=${SEED} \
  T=0 \
  time_conditioning=false \
  +wandb.offline=true \
  hydra.run.dir="${PWD}/outputs/adlm-${NUM_SAMPLE_BATCHES}-${SEED}"

Following standard practice, we evaluate using 5,000 generated samples. If a Slurm time limit interrupts a run with sampling.num_sample_batches=5000, rerun with a new seed with fewer batches.

Zero-shot Perplexity 🎯

Zero-shot perplexity evaluation on Wikitext-2 by default:

sbatch scripts/adlm_zero_shot_eval.sh

Other zero-shot tasks (LAMBADA, PTB, Wikitext-103, LM1B, AG News, PubMed, arXiv) are supported via the dataset configs in configs/data/. Update the data= override in the script or use Hydra overrides, e.g. data=ptb.


Inference notebook 🧪

Generate samples interactively and compute metrics at a smaller scale:

jupyter notebook notebooks/adlm_inference.ipynb

The notebook accepts the same Hydra overrides and checkpoints as the shell scripts.


Baselines 🆚

Submit baseline jobs to reproduce paper results:

# Autoregressive baseline
sbatch scripts/ar.sh

# Masked Diffusion Language Model
sbatch scripts/mdlm.sh

# Remasking Masked Diffusion Language Model
sbatch scripts/remdm-loop.sh

Acknowledgements 🙏

Built on top of ReMDM, which is developed using MDLM and SEDD.


Citation 📝

@article{rout2025anchored,
  title={Anchored Diffusion Language Model},
  author={Rout, Litu and Caramanis, Constantine and Shakkottai, Sanjay},
  journal={Neural Information Processing Systems (NeurIPS)},
  year={2025}
}