Official Implementation of Denoising Diffusion Bridge Models.
To install all packages in this codebase along with their dependencies, run
pip install torch==2.1.0+cu121 torchvision==0.16.0+cu121 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
pip install packaging ninja
conda install -c conda-forge mpi4py openmpi
pip install -e .We provide pretrained checkpoints via Huggingface repo here. It includes models trained on two image-to-image datasets using Variance-Preserving (VP) schedules:
- DDBM on Edges2Handbags (VP): ddbm_e2h_vp_ema.pt
- DDBM on DIODE (VP): ddbm_diode_vp_ema.pt
For Edges2Handbags, please follow instructions from here. For DIODE, please download appropriate datasets from here.
We provide bash files train_ddbm.sh and sample_ddbm.sh for model training and sampling.
Simply set variables DATASET_NAME and SCHEDULE_TYPE:
DATASET_NAMEspecifies which dataset to use. We only supporte2hfor Edges2Handbags anddiodefor DIODE. For each dataset, make sure to set the respectiveDATA_DIRvariable inargs.shto your dataset path.SCHEDULE_TYPEdenotes the noise schedule type. Onlyveandvpare recommended.ve_simpleandvp_simpleare their naive baselines.
To train, run
bash train_ddbm.sh $DATASET_NAME $SCHEDULE_TYPE
# to resume, set CKPT to your checkpoint, or it will automatically resume from your last checkpoint based on your experiment name.
bash train_ddbm.sh $DATASET_NAME $SCHEDULE_TYPE $CKPT
For inference, additional variables need to be set:
MODEL_PATHis your checkpoint to be evaluated.CHURN_STEP_RATIOis the ratio of step that's used for stochastic Euler step (see paper for details). Default recommendation is0.33. Lower value generally degrades performance. For better value setting please refer to the paper.GUIDANCEis thewparameter specified in the paper. Default recommendation is1for VP schedules and anything less than1produces significantly worse results. However, for VE schedules, this value (ranging from0to1) does not affect generation too much. . For better value setting please refer to the paper.SPLITdenotes which split you use for testing. Onlytrainandtestare supported. To sample, run
bash sample_ddbm.sh $DATASET_NAME $SCHEDULE_TYPE $MODEL_PATH $CHURN_STEP_RATIO $GUIDANCE $SPLIT
This script will aggregate all samples into .npz file into your experiment folder ready for quantitative evaluation.
One can evaluate samples with evaluations/evaluator.py. We also provide the reference statistics in our Huggingface repo:
- Reference stats for Edge2Handbags: e2h_ref_stats.npz.
- Reference stats for DIODE: diode_ref_stats.npz.
To evaluate, set REF_PATH to path of your reference stats and SAMPLE_PATH to your generated .npz path. You can additionally specify the metrics to use via --metric. We only support fid and lpips.
python $REF_PATH $SAMPLE_PATH --metric $YOUR_METRIC
We noticed that on some machines mpiexec errors out with
--------------------------------------------------------------------------
MPI_INIT has failed because at least one MPI process is unreachable
from another. This *usually* means that an underlying communication
plugin -- such as a BTL or an MTL -- has either not loaded or not
allowed itself to be used. Your MPI job will now abort.
You may wish to try to narrow down the problem;
* Check the output of ompi_info to see which BTL/MTL plugins are
available.
* Run your application with MPI_THREAD_SINGLE.
* Set the MCA parameter btl_base_verbose to 100 (or mtl_base_verbose,
if using MTL-based communications) to see exactly which
communication plugins were considered and/or discarded.
--------------------------------------------------------------------------
In this case, you can try adding --mca btl vader,self to mpiexec command before python run.
During evaluation, if you see significantly high LPIPS or MSE scores, this is likely due to mismatch in order between your generation and the reference stats. This may be due to the multiprocess gathering of results returning the incorrect order. Please make sure the order is correct for your generation, or regenerate the reference stats by yourself.
If you find this method and/or code useful, please consider citing
@article{zhou2023denoising,
title={Denoising diffusion bridge models},
author={Zhou, Linqi and Lou, Aaron and Khanna, Samar and Ermon, Stefano},
journal={arXiv preprint arXiv:2309.16948},
year={2023}
}