This repository contains PyTorch/GPU and TorchXLA/TPU implementations of our paper: Diffusion Transformers with Representation Autoencoders. For JAX/TPU implementation, please refer to diffuse_nnx
Diffusion Transformers with Representation Autoencoders
Boyang Zheng, Nanye Ma, Shengbang Tong, Saining Xie
New York University
We present Representation Autoencoders (RAE), a class of autoencoders that utilize pretrained, frozen representation encoders such as DINOv2 and SigLIP2 as encoders with trained ViT decoders. RAE can be used in a two-stage training pipeline for high-fidelity image synthesis, where a Stage 2 diffusion model is trained on the latent space of a pretrained RAE to generate images.
This repository contains:
PyTorch/GPU:
- A PyTorch implementation of RAE and pretrained weights.
- A PyTorch implementation of LightningDiT, DiTDH and pretrained weights.
- Training and sampling scripts for the two-stage RAE+DiT pipeline.
TorchXLA/TPU:
- A TPU implementation of RAE and pretrained weights.
- Sampling of RAE and DiTDH on TPU.
- Create environment and install via
uv:conda create -n rae python=3.10 -y conda activate rae pip install uv # Install PyTorch 2.2.0 with CUDA 12.1 uv pip install torch==2.2.0 torchvision==0.17.0 torchaudio --index-url https://download.pytorch.org/whl/cu121 # Install other dependencies uv pip install timm==0.9.16 accelerate==0.23.0 torchdiffeq==0.2.5 wandb uv pip install "numpy<2" transformers einops omegaconf
We release three kind of models: RAE decoders, DiTDH diffusion transformers and stats for latent normalization. To download all models at once:
cd RAE
pip install huggingface_hub
hf download nyu-visionx/RAE-collections \
--local-dir models To download specific models, run:
hf download nyu-visionx/RAE-collections \
<remote_model_path> \
--local-dir models - Download ImageNet-1k.
- Point Stage 1 and Stage 2 scripts to the training split via
--data-path.
All training and sampling entrypoints are driven by OmegaConf YAML files. A single config describes the Stage 1 autoencoder, the Stage 2 diffusion model, and the solver used during training or inference. A minimal example looks like:
stage_1:
target: stage1.RAE
params: { ... }
ckpt: <path_to_ckpt>
stage_2:
target: stage2.models.DDT.DiTwDDTHead
params: { ... }
ckpt: <path_to_ckpt>
transport:
params:
path_type: Linear
prediction: velocity
...
sampler:
mode: ODE
params:
num_steps: 50
...
guidance:
method: cfg/autoguidance
scale: 1.0
...
misc:
latent_size: [768, 16, 16]
num_classes: 1000
training:
...stage_1instantiates the frozen encoder and trainable decoder. For Stage 1 training you can point to an existing checkpoint viastage_1.ckptor start frompretrained_decoder_path.stage_2defines the diffusion transformer. During sampling you must provideckpt; during training you typically omit it so weights initialise randomly.transport,sampler, andguidanceselect the forward/backward SDE/ODE integrator and optional classifier-free or autoguidance schedule.misccollects shapes, class counts, and scaling constants used by both stages.trainingcontains defaults that the training scripts consume (epochs, learning rate, EMA decay, gradient accumulation, etc.).
Stage 1 training configs additionally include a top-level gan block that
configures the discriminator architecture and the LPIPS/GAN loss schedule.
We release decoders for DINOv2-B, SigLIP-B, MAE-B, at configs/stage1/pretrained/.
There is also a training script for training a ViT-XL decoder on DINOv2-B: configs/stage1/training/DINOv2-B_decXL.yaml
We release our best model, DiTDH-XL and it's guidance model on both configs/stage2/sampling/.
We also provide training configs for DiTDH at configs/stage2/training/.
src/train_stage1.py fine-tunes the ViT decoder while keeping the
representation encoder frozen. Launch it with PyTorch DDP (single or multi-GPU):
torchrun --standalone --nproc_per_node=N \
src/train_stage1.py \
--config <config> \
--data-path <imagenet_train_split> \
--results-dir results/stage1 \
--image-size 256 --precision bf16/fp32 \
--ckpt <optional_ckpt> \where N refers to the number of GPU cards available, and --ckpt resumes from an existing checkpoint.
Logging. To enable wandb, firstly set WANDB_KEY, ENTITY, and PROJECT as environment variables:
export WANDB_KEY="key"
export ENTITY="entity name"
export PROJECT="project name"Then in training command add the --wandb flag
Use src/stage1_sample.py to encode/decode a single image:
python src/stage1_sample.py \
--config <config> \
--image assets/pixabay_cat.png \For batched reconstructions and .npz export, run the DDP variant:
torchrun --standalone --nproc_per_node=N \
src/stage1_sample_ddp.py \
--config <config> \
--data-path <imagenet_val_split> \
--sample-dir recon_samples \
--image-size 256The script writes per-image PNGs as well as a packed .npz suitable for FID.
src/train.py trains the Stage 2 diffusion transformer using PyTorch DDP. Edit
one of the configs under configs/training/ and launch:
torchrun --standalone --nnodes=1 --nproc_per_node=N \
src/train.py \
--config <training_config> \
--data-path <imagenet_train_split> \
--results-dir results/stage2 \
--precision bf16src/sample.py uses the same config schema to draw a small batch of images on a
single device and saves them to sample.png:
python src/sample.py \
--config <sample_config> \
--seed 42src/sample_ddp.py parallelises sampling across GPUs, producing PNGs and an
FID-ready .npz:
torchrun --standalone --nnodes=1 --nproc_per_node=N \
src/sample_ddp.py \
--config <sample_config> \
--sample-dir samples \
--precision bf16 \
--label-sampling equal--label-sampling {equal,random}: equal uses exactly 50 images per class for FID-50k; random uniformly samples labels. Using equal brings consistently lower FID than random by around 0.1. We use equal by default.
Autoguidance and classifier-free guidance are controlled via the config’s
guidance block.
Use the ADM evaluation suite to score generated samples:
-
Clone the repo:
git clone https://github.com/openai/guided-diffusion.git cd guided-diffusion/evaluation -
Create an environment and install dependencies:
conda create -n adm-fid python=3.10 conda activate adm-fid pip install 'tensorflow[and-cuda]'==2.19 scipy requests tqdm -
Download ImageNet statistics (256×256 shown here):
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
-
Evaluate:
python evaluator.py VIRTUAL_imagenet256_labeled.npz /path/to/samples.npz
See XLA branch for TPU support.
This code is built upon the following repositories:
- SiT - for diffusion implementation and training codebase.
- DDT - for some of the DiTDH implementation.
- LightningDiT - for the PyTorch Lightning based DiT implementation.
- MAE - for the ViT decoder architecture.