Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Amineharrabi/RQ-VAE-Unet-Image-Generation-Model

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 

Repository files navigation

RQ-Diffusion - Residual Quantized Autoregressive Image Generation

A high-quality image generation system combining Residual Vector Quantization (RQ-VAE) with a Unet for sharp, detailed image synthesis. This project implements a two-stage approach: coarse generation via RQ-VAE followed by detail enhancement through a residual prediction network.

Watch the Youtube Video :

https://www.youtube.com/watch?v=4PEAPLvfZFM

Project Overview

This project implements a state-of-the-art image generation pipeline that:

  • Encodes images into compact discrete codes using Residual Vector Quantization
  • Generates coarse images from quantized representations
  • Refines outputs with a dedicated detail prediction network
  • Achieves high-quality 256×256 image synthesis

Key Innovation: Two-Stage Generation

  1. Stage 1 (RQ-VAE): Efficient encoding/decoding with residual quantization
  2. Stage 2 (Refiner): Neural network predicts missing high-frequency details
  3. Result: Sharp, high-quality images from compact representations

🔧 Features

  • Residual Vector Quantization: 4-depth quantization for efficient compression
  • Multi-Scale Architecture: 5-level encoder-decoder with residual blocks
  • Perceptual Loss: LPIPS loss for photorealistic outputs
  • Gradient Accumulation: Train with limited VRAM (effective batch size: 8)
  • Automatic Checkpointing: Save to Google Drive for persistence
  • Mixed Precision Training: FP16 for faster training and lower memory usage
  • Progressive Refinement: Separate training stages for optimal quality

Requirements

Hardware

  • Minimum: GPU with 8GB+ VRAM (T4, P100)
  • Recommended: GPU with 16GB+ VRAM (V100, A100)
  • Storage: ~15GB free disk space (dataset + checkpoints)

Software

  • Python 3.8+
  • CUDA 11.0+
  • Google Colab (recommended) or local setup
  • Google Drive (for checkpoint persistence)

Installation

Option 1: Google Colab (Recommended)

  1. Open in Colab:

    • Upload the notebook or create a new one
    • Runtime → Change runtime type → GPU (T4 or better)
  2. Mount Google Drive:

    from google.colab import drive
    drive.mount('/content/drive')
  3. Install Dependencies:

    pip install -q torch torchvision torchaudio
    pip install -q einops timm
    pip install -q accelerate safetensors
    pip install -q pillow tqdm wandb
    pip install -q lpips  # For perceptual loss

Option 2: Local Installation

# Clone repository
git clone <your-repo-url>
cd rq-diffusion

# Create virtual environment
python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate

# Install dependencies
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install einops timm accelerate safetensors pillow tqdm wandb lpips

📊 Dataset Preparation

LAPIS Dataset Acquisition

This project uses the LAPIS (Large-scale AI-generated Products Images for Scene understanding) dataset.

Dataset Information

  • Source: LAPIS GitHub Repository
  • Size: Varies by subset
  • Format: Password-protected ZIP file
  • Content: High-quality product images

Download Instructions

  1. Get Dataset URL and Password:

  2. Configure Download (in the script):

    # Set the download URL and Zip password
    url = "YOUR_DATASET_URL_HERE"  # From LAPIS GitHub
    password = "YOUR_PASSWORD_HERE"  # From LAPIS GitHub
  3. Automatic Download and Extraction:

    # The script will automatically:
    # 1. Download the ZIP file to /content/data/train.zip
    # 2. Extract it to /content/data/train
    # 3. Verify the extraction

Manual Download (Alternative)

If automatic download fails:

# 1. Download manually from LAPIS repository
wget -O /content/data/train.zip "YOUR_DATASET_URL"

# 2. Extract with password
unzip -P "YOUR_PASSWORD" /content/data/train.zip -d /content/data/train

# 3. Verify
ls -lh /content/data/train

Dataset Structure

After extraction, your dataset should look like:

/content/data/train/
├── image_001.jpg
├── image_002.jpg
├── image_003.png
├── ...
└── image_XXXX.jpg

Custom Dataset

To use your own dataset:

  1. Prepare Images:

    your-dataset/
    ├── image1.jpg
    ├── image2.png
    └── ...
    
  2. Update Dataset Path:

    DATASET_PATH = '/path/to/your-dataset'
  3. Requirements:

    • Images in common formats (JPG, PNG, JPEG)
    • Minimum 1000 images recommended
    • Images will be resized to 256×256

Usage

Quick Start (Full Pipeline)

# Run the entire script
python rq_diffusion_3__1__2_.py

Step-by-Step Usage

1. Setup Project Structure

import os
project_root = '/content/drive/MyDrive/rq-ar-project'
os.makedirs(project_root, exist_ok=True)
os.makedirs(f'{project_root}/checkpoints', exist_ok=True)
os.makedirs(f'{project_root}/outputs', exist_ok=True)
os.makedirs('/content/data', exist_ok=True)

2. Download and Prepare Dataset

# Set your credentials
url = "YOUR_LAPIS_URL"
password = "YOUR_PASSWORD"

# Download and extract
download_file(url, '/content/data/train.zip')
with zipfile.ZipFile('/content/data/train.zip') as zf:
    zf.extractall(path='/content/data/train', pwd=bytes(password, 'utf-8'))

3. Train RQ-VAE (Stage 1)

# Initialize model
rqvae = RQVAE(
    codebook_size=16384,    # 2^14 codes per quantizer
    embed_dim=256,          # Embedding dimension
    num_quantizers=4        # Residual depth
).cuda()

# Train
rqvae = train_rqvae_model(
    rqvae,
    dataloader,
    num_epochs=55,
    lr=1e-4,
    checkpoint_dir=f'{project_root}/checkpoints'
)

Training Time: ~6-8 hours on T4 GPU (55 epochs)

4. Train Refiner (Stage 2)

# Initialize refiner
refiner = RefinementNetwork(in_channels=3, out_channels=3).cuda()

# Train
refiner = train_refinement_model(
    refiner,
    dataloader,
    num_epochs=30,
    lr=1e-4,
    perceptual_every=4,
    grad_accum_steps=8,
    checkpoint_dir=f'{project_root}/checkpoints'
)

Training Time: ~3-4 hours on T4 GPU (30 epochs)

5. Generate Sharp Images

# Load models
rqvae.load_state_dict(torch.load('checkpoints/rqvae_epoch_55.pt')['model_state_dict'])
refiner.load_state_dict(torch.load('checkpoints/refiner_latest.pt')['model_state_dict'])

# Generate
with torch.no_grad():
    codes = rqvae.encode(input_image)
    coarse = rqvae.decode_from_codes(codes)
    residual = refiner(coarse)
    refined = (coarse + residual).clamp(-1, 1)

Process Single Image

from PIL import Image

# Process any image
refined_output = process_single_image(
    image_path='/path/to/image.jpg',
    save_path=f'{project_root}/outputs/sharp_output.png'
)

Model Architecture

RQ-VAE (Residual Quantized Variational Autoencoder)

Encoder

  • Input: 256×256×3 RGB image
  • Architecture:
    • Initial conv: 3 → 64 channels
    • 5 downsampling blocks (64 → 64 → 128 → 128 → 256 → 256)
    • Resolution reduction: 256×256 → 8×8
    • Output: 256×8×8 feature maps

Residual Quantizer

  • Codebook Size: 16,384 codes per quantizer
  • Embedding Dim: 256
  • Num Quantizers: 4 (residual depth)
  • Process:
    1. Quantize features to nearest codebook entry
    2. Compute residual error
    3. Quantize residual (repeat 4 times)
  • Output: 4×8×8 discrete code indices

Decoder

  • Input: 256×8×8 quantized features
  • Architecture:
    • 5 upsampling blocks (256 → 256 → 128 → 128 → 64 → 64)
    • Resolution increase: 8×8 → 256×256
    • Final conv: 64 → 3 channels
    • Tanh activation
  • Output: 256×256×3 RGB image

Refinement Network

Architecture

  • Input: 3-channel coarse image from RQ-VAE
  • Structure:
    • 6 residual blocks with group normalization
    • Skip connections for gradient flow
    • Feature dimension: 128
  • Output: 3-channel residual (high-frequency details)
  • Final Output: Coarse + Residual = Sharp Image

Loss Functions

  • L1 Loss: Pixel-wise reconstruction (weight: 1.0)
  • Perceptual Loss (LPIPS): Feature-space similarity (weight: 0.1)
  • Combined: Balances sharpness and perceptual quality

Training Configuration

RQ-VAE Training

CONFIG = {
    # Model
    "codebook_size": 16384,        # Codes per quantizer
    "embed_dim": 256,              # Embedding dimension
    "num_quantizers": 4,           # Residual depth

    # Training
    "num_epochs": 55,              # Total epochs
    "batch_size": 1,               # Per-GPU batch size
    "learning_rate": 1e-4,         # Initial LR
    "warmup_epochs": 5,            # LR warmup period

    # Loss weights
    "recon_weight": 1.0,           # Reconstruction loss
    "commitment_weight": 0.25,     # Commitment loss
    "perceptual_weight": 0.1,      # LPIPS loss

    # Optimization
    "grad_accum_steps": 4,         # Gradient accumulation
    "save_every": 5,               # Checkpoint frequency
}

Refiner Training

CONFIG = {
    # Training
    "num_epochs": 30,              # Total epochs
    "batch_size": 1,               # Per-GPU batch size
    "learning_rate": 1e-4,         # Initial LR

    # Loss weights
    "l1_weight": 1.0,              # L1 reconstruction
    "perceptual_weight": 0.1,      # LPIPS perceptual
    "perceptual_every": 4,         # Compute LPIPS every N steps

    # Optimization
    "grad_accum_steps": 8,         # Effective batch: 8
    "save_every": 1,               # Save every epoch
}

Performance Metrics

Memory Usage

Stage VRAM (FP16) Batch Size Grad Accum
RQ-VAE Training ~12.5 GB 1 4
Refiner Training ~6.2 GB 1 8
Inference ~2.5 GB 1 -

Model Sizes

Component Parameters Size (FP32) Size (FP16)
Encoder ~15M ~60 MB ~30 MB
Decoder ~15M ~60 MB ~30 MB
Quantizer ~67M ~268 MB ~134 MB
Refiner ~8M ~32 MB ~16 MB
Total ~105M ~420 MB ~210 MB

Training Time On 16Gb VRAM

Stage Epochs GPU Time
RQ-VAE 55 T4 ~15-18 hours
Refiner 30 T4 ~6-8 hours
Total - T4 ~21-26 hours

Quality Metrics (Example)

MSE (Coarse):  0.012450
MSE (Refined): 0.008320
Improvement:   33.17%

Troubleshooting

Common Issues

1. CUDA Out of Memory

Symptoms:

RuntimeError: CUDA out of memory

Solutions:

# Reduce batch size (already at 1)
# Increase gradient accumulation
CONFIG["grad_accum_steps"] = 16

# Reduce perceptual loss frequency
perceptual_every = 8

# Use smaller model
CONFIG["embed_dim"] = 128
CONFIG["num_quantizers"] = 2

2. Dataset Download Fails

Solutions:

# Check URL and password
print(f"URL: {url}")
print(f"Password: {password}")

# Try manual download
# Download from browser, then upload to Colab

# Check file integrity
import os
print(f"File size: {os.path.getsize(zip_path) / 1e9:.2f} GB")

3. Google Drive Disconnects

Solutions:

# Re-mount periodically
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Use local checkpoints for speed
checkpoint_dir = '/content/checkpoints'

# Copy to Drive periodically
import shutil
shutil.copy('/content/checkpoints/model.pt',
            '/content/drive/MyDrive/rq-ar-project/checkpoints/')

4. Training Instability

Symptoms:

  • Loss becomes NaN
  • Gradients explode
  • Model outputs artifacts

Solutions:

# Lower learning rate
lr = 5e-5

# Enable gradient clipping (already enabled)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

# Check data normalization
# Images should be in range [-1, 1]

5. Slow Training

Solutions:

# Enable mixed precision (already enabled)
scaler = torch.cuda.amp.GradScaler()

# Reduce perceptual loss frequency
perceptual_every = 8

# Use DataLoader workers
dataloader = DataLoader(dataset, num_workers=2, pin_memory=True)

# Enable cudnn benchmark
torch.backends.cudnn.benchmark = True

Error Messages Reference

Error Cause Fix
ZipFile: Bad password Wrong dataset password Check LAPIS GitHub for correct password
FileNotFoundError: train/ Dataset not extracted Verify extraction path and contents
RuntimeError: Expected all tensors on cuda Model/data mismatch Ensure .cuda() on both model and data
KeyError: 'model_state_dict' Checkpoint corrupted Re-train or use different checkpoint
ValueError: too many values to unpack Wrong tensor shape Check input dimensions (should be 256×256)

Advanced Configuration

Custom Image Size

# Note: Requires architecture changes for other sizes
# Current: 256×256 → 8×8 latent (32x downsampling)

# For 512×512:
# Modify encoder/decoder to add one more downsampling level
# Latent size will be 16×16

Different Codebook Settings

# Smaller memory footprint
CONFIG = {
    "codebook_size": 8192,      # Fewer codes
    "num_quantizers": 2,        # Less depth
    "embed_dim": 128,           # Smaller embeddings
}

# Higher quality
CONFIG = {
    "codebook_size": 32768,     # More codes
    "num_quantizers": 8,        # More depth
    "embed_dim": 512,           # Larger embeddings
}

Training Resumption

# Resume RQ-VAE training
rqvae = train_rqvae_model(
    rqvae,
    dataloader,
    num_epochs=100,
    resume_from='latest'  # or 'epoch_55'
)

# Resume Refiner training
refiner = train_refinement_model(
    refiner,
    dataloader,
    num_epochs=50,
    resume_from='latest'
)

🔬 Model Evaluation

Visualize Results

# Compare original vs coarse vs refined
with torch.no_grad():
    sample_images = next(iter(dataloader))[:8].cuda()

    codes = rqvae.encode(sample_images)
    coarse = rqvae.decode_from_codes(codes)
    residual = refiner(coarse)
    refined = (coarse + residual).clamp(-1, 1)

    # Plot side-by-side
    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(3, 8, figsize=(24, 9))
    # ... plotting code ...

Calculate Metrics

import torch.nn.functional as F

# MSE
mse_coarse = F.mse_loss(coarse, sample_images)
mse_refined = F.mse_loss(refined, sample_images)

# PSNR
psnr_coarse = 10 * torch.log10(4 / mse_coarse)
psnr_refined = 10 * torch.log10(4 / mse_refined)

# Perceptual
import lpips
lpips_fn = lpips.LPIPS(net='alex').cuda()
lpips_score = lpips_fn(refined, sample_images).mean()

print(f"PSNR (Coarse):  {psnr_coarse:.2f} dB")
print(f"PSNR (Refined): {psnr_refined:.2f} dB")
print(f"LPIPS:          {lpips_score:.4f}")

Training Curves

# View saved history
import json
with open(f'{project_root}/checkpoints/training_history.json', 'r') as f:
    history = json.load(f)

# Plot losses
import matplotlib.pyplot as plt
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(history['epochs'], history['recon_loss'])
plt.title('Reconstruction Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.subplot(1, 3, 2)
plt.plot(history['epochs'], history['perceptual_loss'])
plt.title('Perceptual Loss')

plt.subplot(1, 3, 3)
plt.plot(history['epochs'], history['total_loss'])
plt.title('Total Loss')

plt.tight_layout()
plt.show()

Contributing

Contributions welcome! Areas for improvement:

  • Architecture: Support for larger images (512×512, 1024×1024)
  • Quantization: Experiment with different codebook strategies
  • Training: Multi-GPU distributed training
  • Inference: Faster decoding methods
  • Applications: Conditional generation, style transfer
  • Optimization: Model compression, pruning, distillation

License

This project is provided for research and educational purposes. Please cite the original papers:

🙏 Acknowledgments

  • LAPIS Dataset: Anne-Sofie Maerten et al. for the high-quality dataset
  • PyTorch Team: For the deep learning framework
  • Google Colab: For free GPU access
  • LPIPS: For perceptual loss implementation
  • Research Community: For open-source implementations and papers

References

Papers

  1. VQ-VAE: Neural Discrete Representation Learning
  2. VQ-VAE-2: Generating Diverse High-Fidelity Images
  3. RQ: Residual Vector Quantization
  4. LPIPS: Perceptual Similarity Metric

Resources

💡 Tips for Best Results

  1. Dataset Quality: Use diverse, high-quality images (min 1000 images)
  2. Training Duration:
    • RQ-VAE: 50-100 epochs for convergence
    • Refiner: 30-50 epochs sufficient
  3. Learning Rate: Start with 1e-4, reduce if training unstable
  4. Codebook Size: Larger = better quality but more memory (16K sweet spot)
  5. Perceptual Loss: Essential for photorealism, but expensive to compute
  6. Gradient Accumulation: Enables larger effective batch sizes on limited VRAM
  7. Checkpointing: Save frequently, Google Drive can disconnect
  8. Resolution: 256×256 optimal for this architecture

🔍 Project Structure

rq-ar-project/
├── checkpoints/               # Model checkpoints (Google Drive)
│   ├── rqvae_epoch_55.pt
│   ├── rqvae_latest.pt
│   ├── refiner_epoch_30.pt
│   ├── refiner_latest.pt
│   └── training_history.json
├── outputs/                   # Generated images
│   ├── final_comparison.png
│   └── sharp_output.png
└── data/                     # Training data (Colab SSD)
    └── train/
        ├── image_001.jpg
        ├── image_002.jpg
        └── ...

📧 Support

For issues, questions, or suggestions:

  • Open an issue on GitHub
  • Check troubleshooting section
  • Review closed issues for solutions

Happy Generating!

About

A high-quality image generation system combining Residual Vector Quantization (RQ-VAE) with a refinement network (Unet) for image synthesis.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors