Thanks to visit codestin.com
Credit goes to github.com

Skip to content
forked from bytetriper/RAE

Official PyTorch Implementation of "Diffusion Transformers with Representation Autoencoders"

License

Notifications You must be signed in to change notification settings

vukrosic/RAE-learn

 
 

Repository files navigation

Diffusion Transformers with Representation Autoencoders (RAE)
Official PyTorch Implementation

This repository contains PyTorch/GPU and TorchXLA/TPU implementations of our paper: Diffusion Transformers with Representation Autoencoders. For JAX/TPU implementation, please refer to diffuse_nnx

Diffusion Transformers with Representation Autoencoders
Boyang Zheng, Nanye Ma, Shengbang Tong, Saining Xie
New York University

We present Representation Autoencoders (RAE), a class of autoencoders that utilize pretrained, frozen representation encoders such as DINOv2 and SigLIP2 as encoders with trained ViT decoders. RAE can be used in a two-stage training pipeline for high-fidelity image synthesis, where a Stage 2 diffusion model is trained on the latent space of a pretrained RAE to generate images.

This repository contains:

PyTorch/GPU:

  • A PyTorch implementation of RAE and pretrained weights.
  • A PyTorch implementation of LightningDiT, DiTDH and pretrained weights.
  • Training and sampling scripts for the two-stage RAE+DiT pipeline.

TorchXLA/TPU:

  • A TPU implementation of RAE and pretrained weights.
  • Sampling of RAE and DiTDH on TPU.

Environment

Dependency Setup

  1. Create environment and install via uv:
    conda create -n rae python=3.10 -y
    conda activate rae
    pip install uv
    
    # Install PyTorch 2.2.0 with CUDA 12.1
    uv pip install torch==2.2.0 torchvision==0.17.0 torchaudio --index-url https://download.pytorch.org/whl/cu121
    
    # Install other dependencies
    uv pip install timm==0.9.16 accelerate==0.23.0 torchdiffeq==0.2.5 wandb
    uv pip install "numpy<2" transformers einops omegaconf

Data & Model Preparation

Download Pre-trained Models

We release three kind of models: RAE decoders, DiTDH diffusion transformers and stats for latent normalization. To download all models at once:

cd RAE
pip install huggingface_hub
hf download nyu-visionx/RAE-collections \
  --local-dir models 

To download specific models, run:

hf download nyu-visionx/RAE-collections \
  <remote_model_path> \
  --local-dir models 

Prepare Dataset

  1. Download ImageNet-1k.
  2. Point Stage 1 and Stage 2 scripts to the training split via --data-path.

Config-based Initialization

All training and sampling entrypoints are driven by OmegaConf YAML files. A single config describes the Stage 1 autoencoder, the Stage 2 diffusion model, and the solver used during training or inference. A minimal example looks like:

stage_1:
   target: stage1.RAE
   params: { ... }
   ckpt: <path_to_ckpt>  

stage_2:
   target: stage2.models.DDT.DiTwDDTHead
   params: { ... }
   ckpt: <path_to_ckpt>  

transport:
   params:
      path_type: Linear
      prediction: velocity
      ...
sampler:
   mode: ODE
   params:
      num_steps: 50
      ...
guidance:
   method: cfg/autoguidance
   scale: 1.0
   ...
misc:
   latent_size: [768, 16, 16]
   num_classes: 1000
training:
   ...
  • stage_1 instantiates the frozen encoder and trainable decoder. For Stage 1 training you can point to an existing checkpoint via stage_1.ckpt or start from pretrained_decoder_path.
  • stage_2 defines the diffusion transformer. During sampling you must provide ckpt; during training you typically omit it so weights initialise randomly.
  • transport, sampler, and guidance select the forward/backward SDE/ODE integrator and optional classifier-free or autoguidance schedule.
  • misc collects shapes, class counts, and scaling constants used by both stages.
  • training contains defaults that the training scripts consume (epochs, learning rate, EMA decay, gradient accumulation, etc.).

Stage 1 training configs additionally include a top-level gan block that configures the discriminator architecture and the LPIPS/GAN loss schedule.

Provided Configs:

Stage1

We release decoders for DINOv2-B, SigLIP-B, MAE-B, at configs/stage1/pretrained/.

There is also a training script for training a ViT-XL decoder on DINOv2-B: configs/stage1/training/DINOv2-B_decXL.yaml

Stage2

We release our best model, DiTDH-XL and it's guidance model on both $256\times 256$ and $512\times 512$, at configs/stage2/sampling/.

We also provide training configs for DiTDH at configs/stage2/training/.

Stage 1: Representation Autoencoder

Train the decoder

src/train_stage1.py fine-tunes the ViT decoder while keeping the representation encoder frozen. Launch it with PyTorch DDP (single or multi-GPU):

torchrun --standalone --nproc_per_node=N \
  src/train_stage1.py \
  --config <config> \
  --data-path <imagenet_train_split> \
  --results-dir results/stage1 \
  --image-size 256 --precision bf16/fp32 \
  --ckpt <optional_ckpt> \

where N refers to the number of GPU cards available, and --ckpt resumes from an existing checkpoint.

Logging. To enable wandb, firstly set WANDB_KEY, ENTITY, and PROJECT as environment variables:

export WANDB_KEY="key"
export ENTITY="entity name"
export PROJECT="project name"

Then in training command add the --wandb flag

Sampling/Reconstruction

Use src/stage1_sample.py to encode/decode a single image:

python src/stage1_sample.py \
  --config <config> \
  --image assets/pixabay_cat.png \

For batched reconstructions and .npz export, run the DDP variant:

torchrun --standalone --nproc_per_node=N \
  src/stage1_sample_ddp.py \
  --config <config> \
  --data-path <imagenet_val_split> \
  --sample-dir recon_samples \
  --image-size 256

The script writes per-image PNGs as well as a packed .npz suitable for FID.

Stage 2: Latent Diffusion Transformer

Training

src/train.py trains the Stage 2 diffusion transformer using PyTorch DDP. Edit one of the configs under configs/training/ and launch:

torchrun --standalone --nnodes=1 --nproc_per_node=N \
  src/train.py \
  --config <training_config> \
  --data-path <imagenet_train_split> \
  --results-dir results/stage2 \
  --precision bf16

Sampling

src/sample.py uses the same config schema to draw a small batch of images on a single device and saves them to sample.png:

python src/sample.py \
  --config <sample_config> \
  --seed 42

Distributed sampling for evaluation

src/sample_ddp.py parallelises sampling across GPUs, producing PNGs and an FID-ready .npz:

torchrun --standalone --nnodes=1 --nproc_per_node=N \
  src/sample_ddp.py \
  --config <sample_config> \
  --sample-dir samples \
  --precision bf16 \
  --label-sampling equal

--label-sampling {equal,random}: equal uses exactly 50 images per class for FID-50k; random uniformly samples labels. Using equal brings consistently lower FID than random by around 0.1. We use equal by default.

Autoguidance and classifier-free guidance are controlled via the config’s guidance block.

Evaluation

ADM Suite FID setup

Use the ADM evaluation suite to score generated samples:

  1. Clone the repo:

    git clone https://github.com/openai/guided-diffusion.git
    cd guided-diffusion/evaluation
  2. Create an environment and install dependencies:

    conda create -n adm-fid python=3.10
    conda activate adm-fid
    pip install 'tensorflow[and-cuda]'==2.19 scipy requests tqdm
  3. Download ImageNet statistics (256×256 shown here):

    wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
  4. Evaluate:

    python evaluator.py VIRTUAL_imagenet256_labeled.npz /path/to/samples.npz

TorchXLA / TPU support

See XLA branch for TPU support.

Acknowledgement

This code is built upon the following repositories:

  • SiT - for diffusion implementation and training codebase.
  • DDT - for some of the DiTDH implementation.
  • LightningDiT - for the PyTorch Lightning based DiT implementation.
  • MAE - for the ViT decoder architecture.

About

Official PyTorch Implementation of "Diffusion Transformers with Representation Autoencoders"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%