Jasmine is a production-ready JAX-based world modeling codebase. It currently implements the high-level architecture of Genie: Generative Interactive Environments (Bruce et al., 2024) with MaskGIT (Chang et al., 2022), as well as an autoregressive (causal) baseline. A diffusion baseline is coming soon.
Jasmine scales from single hosts to hundreds of xPUs thanks to XLA and strives to be an easily hackable, batteries-included foundation for world modeling research.
- Asynchronous & distributed checkpointing thanks to orbax.checkpoint
- Jasmine also supports mixing and matching hardware topologies (e.g. train on four nodes, load the checkpoint on a single node)
- Optimized dataloading thanks to Grain
- Dataloading scales with the number of processes (i.e. nodes/xPUs)
- Checkpointing of model weights, optimizer and dataloader states
- Full reproducibility with identical training curves (thanks to seeded dataloading and training, and JAX' approach to pseudo random numbers)
- Automatic checkpoint deletion/retention according to specified retention policy thanks to
orbax.checkpoint.CheckpointManager
- Mixed precision training using
bfloat16
int8
training is on the roadmap via aqt
- FlashAttention thanks to cuDNN SDPA
- Frame-level KV cache resets for accelerated spatiotemporal attention in causal baseline (still in PR)
- Activation checkpointing (even onto host memory if desired)
- DDP (changing to FSDP requires changing a single line of code)
- WSD learning rate schedule
- No need to retrain from scratch if you want to train for longer
- Index-shuffling during dataloading
- Google-native stack
- https://github.com/google/orbax for checkpointing
- https://github.com/google/grain for dataloading
- https://github.com/google-deepmind/dm_pix for image manipulation
- https://github.com/google/array_record as the data format
- Easy model inspection thanks to treescope
- Modularized training script for easy inspection using notebooks (demo notebook)
- Easy model surgery thanks to the new flax.nnx API
- Shape suffixes throughout the repository
Jasmine requires python 3.10
, jax 0.6.2
, and flax 0.10.7
. To install the requirements, run:
pip install -r requirements.txt
pre-commit install
You can either download our preprocessed dataset from Hugging Face or preprocess OpenAI's VPT dataset manually.
The easiest way to get started is to download our preprocessed dataset from Hugging Face. This script will handle downloading and extracting it:
bash input_pipeline/download/download_array_records.sh
If you prefer to use the raw VPT dataset from OpenAI and preprocess it yourself, follow these steps:
-
Download index files: This will download the initial index file:
bash input_pipeline/download/openai/download_index_files.sh
-
Download from all index files: This may take a long time depending on your bandwidth:
python input_pipeline/download/openai/download_videos.py --index_file_path data/open_ai_index_files/all_7xx_Apr_6.json python input_pipeline/download/openai/download_videos.py --index_file_path data/open_ai_index_files/all_8xx_Jun_29.json python input_pipeline/download/openai/download_videos.py --index_file_path data/open_ai_index_files/all_9xx_Jun_29.json python input_pipeline/download/openai/download_videos.py --index_file_path data/open_ai_index_files/all_10xx_Jun_29.json
-
Preprocess videos into ArrayRecords: For efficient distributed training, convert the raw videos into the arrayrecord format (make sure to have ffmpeg installed on your machine):
python input_pipeline/preprocess/video_to_array_records.py
Note: This is a large dataset and may take considerable time and storage to download and process.
Genie has three components: a video tokenizer, a latent action model, and a dynamics model. Each of these components are trained separately, however, the dynamics model requires a pre-trained video tokenizer (and latent action model).
To train the video tokenizer, run:
python train_tokenizer.py --ckpt_dir <path>
To train the latent action model, run:
python train_lam.py --ckpt_dir <path>
Once the tokenizer and LAM are trained, the dynamics model can be trained with:
python train_dynamics.py --tokenizer_checkpoint <path> --lam_checkpoint <path>
Logging with wandb
is supported. To enable logging, set the WANDB_API_KEY
environment variable or run:
wandb login
Training can then be logged by setting the --log
flag:
python train_tokenizer.py --log --entity <wandb-entity> --project <wandb-project>
Jasmine was built by Mihir Mahajan, Alfred Nguyen and Franz Srambical, but started as a fork of Jafar, built by Matthew Jackson and Timon Willi.
If you use Jasmine in your work, please cite us, Jafar, and the original Genie paper as follows:
@article{
mahajan2025jasmine,
title={Jasmine: A simple, performant and scalable JAX-based world modeling codebase},
author={Mihir Mahajan and Alfred Nguyen and Franz Srambical and Stefan Bauer},
journal = {p(doom) blog},
year={2025},
url={https://pdoom.org/jasmine.html},
note = {https://pdoom.org/blog.html}
}
@inproceedings{
willi2024jafar,
title={Jafar: An Open-Source Genie Reimplemention in Jax},
author={Timon Willi and Matthew Thomas Jackson and Jakob Nicolaus Foerster},
booktitle={First Workshop on Controllable Video Generation @ ICML 2024},
year={2024},
url={https://openreview.net/forum?id=ZZGaQHs9Jb}
}
@inproceedings{
bruce2024genie,
title={Genie: Generative Interactive Environments},
author={Jake Bruce and Michael D Dennis and Ashley Edwards and Jack Parker-Holder and Yuge Shi and Edward Hughes and Matthew Lai and Aditi Mavalankar and Richie Steigerwald and Chris Apps and Yusuf Aytar and Sarah Maria Elisabeth Bechtle and Feryal Behbahani and Stephanie C.Y. Chan and Nicolas Heess and Lucy Gonzalez and Simon Osindero and Sherjil Ozair and Scott Reed and Jingwei Zhang and Konrad Zolna and Jeff Clune and Nando de Freitas and Satinder Singh and Tim Rockt{\"a}schel},
booktitle={Forty-first International Conference on Machine Learning},
year={2024},
url={https://openreview.net/forum?id=bJbSbJskOS}
}