Imagen is a text-to-image generative diffusion model that operates in pixel-space. This repository contains the necessary tools and scripts for performantly training Imagen from base model to its superresolution models in JAX on GPUs.
Prompts:
- A racoon wearning a hat and leather jacketin front of a backyard window. There are raindrops on the window
- A blue colored pizza
- a highly detailed digital painting of a portal in a mystic forest with many beautiful trees. A person is standing in front of the portal.
For maximum flexibility and low disk requirements, this repo supports a distributed architecture for text embedding in diffusion model training. Upon launching training, it will spawn LLM inference servers that will performantly calculate text embeddings online (with no latency hit). It does this by creating several inference clients in the diffusion model trainer's dataloaders, which send embedding requests to the inference servers. These servers are based on NVIDIA PyTriton, so execute all requests batched. Currently, this inference server supports T5x LLMs, but can be changed to be based on anything (doesn't even have to be JAX!) since the diffusion model trainer's client is simply making PyTriton (http) calls.
We provide scripts to run interactively or on SLURM.
We provide a fully built and ready-to-use container here: ghcr.io/nvidia/t5x:imagen-2023-10-02.
We do not currently have custom-built container workflows, but are actively working on supporting this, stay tuned for updates! Imagen will also be available in our T5x container in future releases.
This model accepts webdataset-format datasets for training. For reference, we have an imagenet webdataset example here.(NOTE: imagen is not directly compatible with imagenet). For imagen training with a compatible dataset, you can find or create your own webdataset (with image and text modalities).
Once you have your webdataset, update the dataset configs {base, sr1, sr2} with the paths to your dataset(s) under MIXTURE_OR_TASK_NAME.
The 'img-txt-ds' configs assume a webdataset with a text and image modality. The images are in jpg format and the text is raw text in a '.txt' file. Currently, the configs are set up to do resolution-based filtering, scale-preserved square random cropping, and low-resolution image generation for SR model training. This can be changed (i.e. if you want your text in .json format and want to do additional processing) in the dataset configuration files {base, sr1, sr2}.
You will need to acquire the LLM checkpoint for T5 (for multimodal training) from T5x here. All models use T51.1 format T5-xxl by default. Once you have the checkpoint, place it at rosetta/projects/inference_serving/checkpoints/checkpoint_1000000_t5_1_1_xxl (appending the _{size} to the checkpoint folder). NOTE: We're working on adding TransformerEngine support to the inference server, but for now, please run with the DISABLE_TE=True environment variable (example scripts include this).
Note: this should only be done with singlenode jobs
CONTAINER=ghcr.io/nvidia/t5x:imagen-2023-10-02
docker run --rm --gpus=all -it --net=host --ipc=host -v ${PWD}:/opt/rosetta -v ${DATASET_PATH}:/mnt/datasets --privileged $CONTAINER bashPretraining can be done on multiple gpus within 1 host with scripts/singlenode_inf_train.sh. This will build an Imagen model with the Adam optimizer and relevant parameters. It will also launch the relevant LLM inference servers.
#### Pretraining (interactive: already inside container) with example args
bash rosetta/projects/imagen/scripts/singlenode_inf_train.sh {DATASET NAME} {MODEL NAME} {PRECISION} {NUM GPUS} {BSIZE/GPU} {LOGDIR} {MODEL DIR} {NUM LLM INFERENCE GPUS} {INFERENCE SERVER LLM SIZE}
#### Pretraining (non-interactive)
docker run --rm --gpus=all --net=host --ipc=host -v ${DATASET_PATH}:/mnt/datasets $CONTAINER bash rosetta/projects/imagen/scripts/singlenode_inf_train.sh {args from above}For a SLURM+pyxis cluster, the scripts/example_slurm_inf_train.sub file provides an example slurm submit file (edit with your details), which calls scripts/multinode_train.sh and scripts/specialized_run.py to execute training.
All commands below assume you are in $ROSETTA_DIR=/opt/rosetta and have the scripts and slurm scripts locally.
Arguments are set as such:
sbatch -N {NODE_CT} rosetta/projects/imagen/scripts/example_slurm_inf_train.sub \
{DATASET NAME} {MODEL NAME} {PRECISION} {NUM GPUS / NODE} {BSIZE/GPU} {MODEL DIR} {NUM LLM INFERENCE GPUS} {INFERENCE SERVER LLM SIZE}All parameters can be found in the relevant script.
Assumes 8GPU 80GB A100/H100 Nodes.
sbatch -N 14 rosetta/projects/imagen/scripts/example_slurm_inf_train.sub \
{DATASET} imagen_base_500M bfloat16 8 32 runs/imagen-base 48 xxlsbatch -N 20 rosetta/projects/imagen/scripts/example_slurm_inf_train.sub \
{DATASET} imagen_base_2B bfloat16 8 16 runs/imagen-base 32 xxlsbatch -N 14 rosetta/projects/imagen/scripts/example_slurm_inf_train.sub \
{DATASET} imagen_sr1_efficientunet_600M bfloat16 8 32 runs/imagen-sr1 48 xxlsbatch -N 14 rosetta/projects/imagen/scripts/example_slurm_inf_train.sub \
{DATASET} imagen_sr2_efficientunet_600M bfloat16 8 32 runs/imagen-sr2 48 xxlYou can find example sampling scripts that use the 500M base model and EfficientUnet SR models in scripts. Prompts should be specified as in example
Defaults to imagen_256_sample.gin config (can be adjusted in script)
CUDA_VISIBLE_DEVICES=<DEVICES> CFG=5.0 BASE_PATH=<BASE_CKPT> SR1_PATH=<SR1_CKPT> PROMPT_TEXT_FILES=<FILE> ./rosetta/projects/imagen/scripts/sample_imagen_256.sh
Defaults to imagen_1024_sample.gin config (can be adjusted in script).
CUDA_VISIBLE_DEVICES=<DEVICES> CFG=5.0 BASE_PATH=<BASE_CKPT> SR1_PATH=<SR1_CKPT> SR2_PATH=<SR2_CKPT> PROMPT_TEXT_FILES=<FILE> ./rosetta/projects/imagen/scripts/sample_imagen_1024.sh
Global Batch size = 2048. We assume 2.5B Training examples in these calculations. LLM Inference server nodes are not included in these numbers.
| size | GPU | Precision | #GPUs | BS / GPU | Images/Sec | Im/Sec/GPU | Est. Walltime (hr) | GPU-days | Config |
|---|---|---|---|---|---|---|---|---|---|
| Imagen-base-500M | A100-80G-SXM | BF16 | 8 | 64 | 858 | 107.0 | 809 | 269 | cfg |
| Imagen-base-500M | A100-80G-SXM | BF16 | 32 | 64 | 3056 | 95.5 | 227 | 303 | cfg |
| Imagen-base-2B | A100-80G-SXM | BF16 | 8 | 16 | 219 | 27.4 | 3170 | 1057 | cfg |
| Imagen-base-2B | A100-80G-SXM | BF16 | 32 | 16 | 795 | 24.8 | 873 | 1164 | cfg |
| Imagen-base-2B | A100-80G-SXM | BF16 | 128 | 16 | 2934 | 22.9 | 236 | 1258 | cfg |
| Imagen-SR1-600M-EffUNet | A100-80G-SXM | BF16 | 8 | 64 | 674 | 84.3 | 1030 | 343 | cfg |
| Imagen-SR1-600M-EffUNet | A100-80G-SXM | BF16 | 32 | 64 | 2529 | 79.1 | 274 | 365 | cfg |
| Imagen-SR2-600M-EffUNet | A100-80G-SXM | BF16 | 8 | 64 | 678 | 84.8 | 1024 | 341 | cfg |
| Imagen-SR2-600M-EffUNet | A100-80G-SXM | BF16 | 32 | 64 | 2601 | 81.3 | 267 | 356 | cfg |
| Imagen-SR1-430M-UNet | A100-80G-SXM | BF16 | 8 | 16 | 194 | 24.3 | 3580 | 1193 | cfg |
Imagen-SR1-430M-UNet is not currently supported. You can use the sr1-efficient-unet instead. Coming Soon!
Imagen base 500M + Efficient SR1 (600M):
| cfg | FID-30K (256x256) |
|---|---|
| 2 | 11.30 |
| 3 | 10.23 |
| 4 | 11.33 |
| 6 | 12.34 |
- Currently, the nightly images will not be able to run Imagen since they lack a patch that needs refactoring. This will be released soon!


