Thanks to visit codestin.com
Credit goes to github.com

Skip to content
/ RAE Public
forked from bytetriper/RAE

Official PyTorch Implementation of "Diffusion Transformers with Representation Autoencoders"

License

Notifications You must be signed in to change notification settings

0bi0n3/RAE

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Diffusion Transformers with Representation Autoencoders (RAE)
Official PyTorch Implementation

This repository contains PyTorch/GPU and TorchXLA/TPU implementations of our paper: Diffusion Transformers with Representation Autoencoders. For JAX/TPU implementation, please refer to diffuse_nnx

Diffusion Transformers with Representation Autoencoders
Boyang Zheng, Nanye Ma, Shengbang Tong, Saining Xie
New York University

We present Representation Autoencoders (RAE), a class of autoencoders that utilize pretrained, frozen representation encoders such as DINOv2 and SigLIP2 as encoders with trained ViT decoders. RAE can be used in a two-stage training pipeline for high-fidelity image synthesis, where a Stage 2 diffusion model is trained on the latent space of a pretrained RAE to generate images.

This repository contains:

PyTorch/GPU:

  • A PyTorch implementation of RAE and pretrained weights.
  • A PyTorch implementation of LightningDiT, DiTDH and pretrained weights.
  • Training and sampling scripts for the two-stage RAE+DiT pipeline.

TorchXLA/TPU:

  • A TPU implementation of RAE and pretrained weights.
  • Sampling of RAE and DiTDH on TPU.

Environment

Dependency Setup

  1. Create environment and install via uv:
    conda create -n rae python=3.10 -y
    conda activate rae
    pip install uv
    
    # Install PyTorch 2.2.0 with CUDA 12.1
    uv pip install torch==2.2.0 torchvision==0.17.0 torchaudio --index-url https://download.pytorch.org/whl/cu121
    
    # Install other dependencies
    uv pip install timm==0.9.16 accelerate==0.23.0 torchdiffeq==0.2.5 wandb
    uv pip install "numpy<2" transformers einops omegaconf

Data & Model Preparation

Download Pre-trained Models

We release three kind of models: RAE decoders, DiTDH diffusion transformers and stats for latent normalization. To download all models at once:

cd RAE
pip install huggingface_hub
hf download nyu-visionx/RAE-collections \
  --local-dir models 

To download specific models, run:

hf download nyu-visionx/RAE-collections \
  <remote_model_path> \
  --local-dir models 

Prepare Dataset

  1. Download ImageNet-1k.
  2. Point Stage 1 and Stage 2 scripts to the training split via --data-path.

Config-based Initialization

All training and sampling entrypoints are driven by OmegaConf YAML files. A single config describes the Stage 1 autoencoder, the Stage 2 diffusion model, and the solver used during training or inference. A minimal example looks like:

stage_1:
   target: stage1.RAE
   params: { ... }
   ckpt: <path_to_ckpt>  

stage_2:
   target: stage2.models.DDT.DiTwDDTHead
   params: { ... }
   ckpt: <path_to_ckpt>  

transport:
   params:
      path_type: Linear
      prediction: velocity
      ...
sampler:
   mode: ODE
   params:
      num_steps: 50
      ...
guidance:
   method: cfg/autoguidance
   scale: 1.0
   ...
misc:
   latent_size: [768, 16, 16]
   num_classes: 1000
training:
   ...
  • stage_1 instantiates the frozen encoder and trainable decoder. For Stage 1 training you can point to an existing checkpoint via stage_1.ckpt or start from pretrained_decoder_path.
  • stage_2 defines the diffusion transformer. During sampling you must provide ckpt; during training you typically omit it so weights initialise randomly.
  • transport, sampler, and guidance select the forward/backward SDE/ODE integrator and optional classifier-free or autoguidance schedule.
  • misc collects shapes, class counts, and scaling constants used by both stages.
  • training contains defaults that the training scripts consume (epochs, learning rate, EMA decay, gradient accumulation, etc.).

Stage 1 training configs additionally include a top-level gan block that configures the discriminator architecture and the LPIPS/GAN loss schedule.

Provided Configs:

Stage1

We release decoders for DINOv2-B, SigLIP-B, MAE-B, at configs/stage1/pretrained/.

There is also a training script for training a ViT-XL decoder on DINOv2-B: configs/stage1/training/DINOv2-B_decXL.yaml

Stage2

We release our best model, DiTDH-XL and it's guidance model on both $256\times 256$ and $512\times 512$, at configs/stage2/sampling/.

We also provide training configs for DiTDH at configs/stage2/training/.

Stage 1: Representation Autoencoder

Train the decoder

src/train_stage1.py fine-tunes the ViT decoder while keeping the representation encoder frozen. Launch it with PyTorch DDP (single or multi-GPU):

torchrun --standalone --nproc_per_node=N \
  src/train_stage1.py \
  --config <config> \
  --data-path <imagenet_train_split> \
  --results-dir results/stage1 \
  --image-size 256 --precision bf16/fp32 \
  --ckpt <optional_ckpt> \

where N refers to the number of GPU cards available, and --ckpt resumes from an existing checkpoint.

Logging. To enable wandb, firstly set WANDB_KEY, ENTITY, and PROJECT as environment variables:

export WANDB_KEY="key"
export ENTITY="entity name"
export PROJECT="project name"

Then in training command add the --wandb flag

Sampling/Reconstruction

Use src/stage1_sample.py to encode/decode a single image:

python src/stage1_sample.py \
  --config <config> \
  --image assets/pixabay_cat.png \

For batched reconstructions and .npz export, run the DDP variant:

torchrun --standalone --nproc_per_node=N \
  src/stage1_sample_ddp.py \
  --config <config> \
  --data-path <imagenet_val_split> \
  --sample-dir recon_samples \
  --image-size 256

The script writes per-image PNGs as well as a packed .npz suitable for FID.

Stage 2: Latent Diffusion Transformer

Training

src/train.py trains the Stage 2 diffusion transformer using PyTorch DDP. Edit one of the configs under configs/training/ and launch:

torchrun --standalone --nnodes=1 --nproc_per_node=N \
  src/train.py \
  --config <training_config> \
  --data-path <imagenet_train_split> \
  --results-dir results/stage2 \
  --precision bf16

Sampling

src/sample.py uses the same config schema to draw a small batch of images on a single device and saves them to sample.png:

python src/sample.py \
  --config <sample_config> \
  --seed 42

Distributed sampling for evaluation

src/sample_ddp.py parallelises sampling across GPUs, producing PNGs and an FID-ready .npz:

torchrun --standalone --nnodes=1 --nproc_per_node=N \
  src/sample_ddp.py \
  --config <sample_config> \
  --sample-dir samples \
  --precision bf16 \
  --label-sampling equal

--label-sampling {equal,random}: equal uses exactly 50 images per class for FID-50k; random uniformly samples labels. Using equal brings consistently lower FID than random by around 0.1. We use equal by default.

Autoguidance and classifier-free guidance are controlled via the config’s guidance block.

Evaluation

ADM Suite FID setup

Use the ADM evaluation suite to score generated samples:

  1. Clone the repo:

    git clone https://github.com/openai/guided-diffusion.git
    cd guided-diffusion/evaluation
  2. Create an environment and install dependencies:

    conda create -n adm-fid python=3.10
    conda activate adm-fid
    pip install 'tensorflow[and-cuda]'==2.19 scipy requests tqdm
  3. Download ImageNet statistics (256×256 shown here):

    wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
  4. Evaluate:

    python evaluator.py VIRTUAL_imagenet256_labeled.npz /path/to/samples.npz

TorchXLA / TPU support

See XLA branch for TPU support.

Acknowledgement

This code is built upon the following repositories:

  • SiT - for diffusion implementation and training codebase.
  • DDT - for some of the DiTDH implementation.
  • LightningDiT - for the PyTorch Lightning based DiT implementation.
  • MAE - for the ViT decoder architecture.

About

Official PyTorch Implementation of "Diffusion Transformers with Representation Autoencoders"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 98.9%
  • Shell 1.1%