Paper | Project Page | Notebook Demo | Models
This repo contains the official PyTorch implementation of Navigation World Models- the Conditional Diffusion Transformer (CDiT) model training code. See the project page for additional results.
Navigation World Models
Amir Bar, Gaoyue "Kathy" Zhou, Danny Tran, Trevor Darrell, Yann LeCun
AI at Meta, UC Berkeley, New York University
First, download and set up the repo:
git clone https://github.com/facebookresearch/nwm
cd nwmTo download and preprocess data, please follow the steps from NoMaD, specifically:
- Download the datasets
- Change the preprocessing resolution from (160, 120) to (320, 240) for higher resolution
- run
process_bags.pyandprocess_recon.pyto save each processed dataset topath/to/nwm_repo/data/<dataset_name>.
For SACSon/HuRoN, we used a private version which contains higher resolution images. Please contact the dataset authors for access (we're unable to distribute).
Finally, you should have the following structure:
nwm/data
├── <dataset_name>
│ ├── <name_of_traj1>
│ │ ├── 0.jpg
│ │ ├── 1.jpg
│ │ ├── ...
│ │ ├── T_1.jpg
│ │ └── traj_data.pkl
│ ├── <name_of_traj2>
│ │ ├── 0.jpg
│ │ ├── 1.jpg
│ │ ├── ...
│ │ ├── T_2.jpg
│ │ └── traj_data.pkl
│ ...
└── └── <name_of_trajN>
├── 0.jpg
├── 1.jpg
├── ...
├── T_N.jpg
└── traj_data.pkl
mamba create -n nwm python=3.10
mamba activate nwm
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126
mamba install ffmpeg
pip3 install decord einops evo transformers diffusers tqdm timm notebook dreamsim torcheval lpips ipywidgetsUsing torchrun:
export NUM_NODES=8
export HOST_NODE_ADDR=<HOST_ADDR>
export CURR_NODE_RANK=<NODE_RANK>
torchrun \
--nnodes=${NUM_NODES} \
--nproc-per-node=8 \
--node-rank=${CURR_NODE_RANK} \
--rdzv-backend=c10d \
--rdzv-endpoint=${HOST_NODE_ADDR}:29500 \
train.py --config config/nwm_cdit_xl.yaml --ckpt-every 2000 --eval-every 10000 --bfloat16 1 --epochs 300 --torch-compile 0Or using submitit and slurm (8 machines of 8 gpus):
python submitit_train_cw.py --nodes 8 --partition <partition_name> --qos <qos> --config config/nwm_cdit_xl.yaml --ckpt-every 2000 --eval-every 10000 --bfloat16 1 --epochs 300 --torch-compile 0Or locally on one GPU for debug:
python train.py --config config/nwm_cdit_xl.yaml --ckpt-every 2000 --eval-every 10000 --bfloat16 1 --epochs 300 --torch-compile 0Note: torch compile can lead to ~40% faster training speed. However, it might lead to instabilities and inconsistent behvaior across different pytorch versions. Use carefuly.
To use a pretrained CDiT/XL model:
- Download a pretrained model from Hugging Face
- Place the checkpoint in ./logs/nwm_cdit_xl/checkpoints
directory to save evaluation results:
export RESULTS_FOLDER=/path/to/res_folder/
python isolated_nwm_infer.py \
--exp config/nwm_cdit_xl.yaml \
--datasets recon,scand,sacson,tartan_drive \
--batch_size 96 \
--num_workers 12 \
--eval_type time \
--output_dir ${RESULTS_FOLDER} \
--gt 1python isolated_nwm_infer.py \
--exp config/nwm_cdit_xl.yaml \
--ckp 0100000 \
--datasets <dataset_name> \
--batch_size 64 \
--num_workers 12 \
--eval_type time \
--output_dir ${RESULTS_FOLDER}python isolated_nwm_eval.py \
--datasets <dataset_name> \
--gt_dir ${RESULTS_FOLDER}/gt \
--exp_dir ${RESULTS_FOLDER}/nwm_cdit_xl \
--eval_types timeResults are saved in ${RESULTS_FOLDER}/nwm_cdit_xl/<dataset_name>
python isolated_nwm_infer.py \
--exp config/nwm_cdit_xl.yaml \
--datasets recon,scand,sacson,tartan_drive \
--batch_size 96 \
--num_workers 12 \
--eval_type rollout \
--output_dir ${RESULTS_FOLDER} \
--gt 1 \
--rollout_fps_values 1,4python isolated_nwm_infer.py \
--exp config/nwm_cdit_xl.yaml \
--ckp 0100000 \
--datasets <dataset_name> \
--batch_size 64 \
--num_workers 12 \
--eval_type rollout \
--output_dir ${RESULTS_FOLDER} \
--rollout_fps_values 1,4 python isolated_nwm_eval.py \
--datasets recon \
--gt_dir ${RESULTS_FOLDER}/gt \
--exp_dir ${RESULTS_FOLDER}/nwm_cdit_xl \
--eval_types rolloutResults are saved in ${RESULTS_FOLDER}/nwm_cdit_xl/<dataset_name>
Using 1-step Cross Entropy Method planning on 8 gpus (sampling 120 trajectories):
torchrun --nproc-per-node=8 planning_eval.py \
--exp config/nwm_cdit_xl.yaml \
--datasets recon \
--rollout_stride 1 \
--batch_size 1 \
--num_samples 120 \
--topk 5 \
--num_workers 12 \
--output_dir ${RESULTS_FOLDER} \
--save_preds \
--ckp 0100000 \
--opt_steps 1 \
--num_repeat_eval 3Results are saved in ${RESULTS_FOLDER}/nwm_cdit_xl/<dataset_name>
@article{bar2024navigation,
title={Navigation world models},
author={Bar, Amir and Zhou, Gaoyue and Tran, Danny and Darrell, Trevor and LeCun, Yann},
journal={arXiv preprint arXiv:2412.03572},
year={2024}
}We thank Noriaki Hirose for his help with the HuRoN dataset and for sharing his insights, and to Manan Tomar, David Fan, Sonia Joseph, Angjoo Kanazawa, Ethan Weber, Nicolas Ballas, and the anonymous reviewers for their helpful discussions and feedback.
The code and model weights are licensed under Creative Commons Attribution-NonCommercial 4.0 International. See LICENSE.txt for details.