Thanks to visit codestin.com
Credit goes to github.com

Skip to content

hz-nm/jasmine

 
 

Repository files navigation

🧞‍♀️ Jasmine: A simple, performant and scalable JAX-based world modeling codebase 🧞‍♀️

Jasmine is a production-ready JAX-based world modeling codebase. It currently implements the high-level architecture of Genie: Generative Interactive Environments (Bruce et al., 2024) with MaskGIT (Chang et al., 2022), as well as an autoregressive (causal) baseline. A diffusion baseline is coming soon.

Jasmine scales from single hosts to hundreds of xPUs thanks to XLA and strives to be an easily hackable, batteries-included foundation for world modeling research.

Overview

  • Asynchronous & distributed checkpointing thanks to orbax.checkpoint
    • Jasmine also supports mixing and matching hardware topologies (e.g. train on four nodes, load the checkpoint on a single node)
  • Optimized dataloading thanks to Grain
    • Dataloading scales with the number of processes (i.e. nodes/xPUs)
  • Checkpointing of model weights, optimizer and dataloader states
  • Full reproducibility with identical training curves (thanks to seeded dataloading and training, and JAX' approach to pseudo random numbers)
  • Automatic checkpoint deletion/retention according to specified retention policy thanks to orbax.checkpoint.CheckpointManager
  • Mixed precision training using bfloat16
    • int8 training is on the roadmap via aqt
  • FlashAttention thanks to cuDNN SDPA
  • Frame-level KV cache resets for accelerated spatiotemporal attention in causal baseline (still in PR)
  • Activation checkpointing (even onto host memory if desired)
  • DDP (changing to FSDP requires changing a single line of code)
  • WSD learning rate schedule
    • No need to retrain from scratch if you want to train for longer
  • Index-shuffling during dataloading
  • Google-native stack
  • Easy model inspection thanks to treescope
  • Modularized training script for easy inspection using notebooks (demo notebook)
  • Easy model surgery thanks to the new flax.nnx API
  • Shape suffixes throughout the repository

Setup 🧗

Jasmine requires python 3.10, jax 0.6.2, and flax 0.10.7. To install the requirements, run:

pip install -r requirements.txt
pre-commit install

Dataset 📂

You can either download our preprocessed dataset from Hugging Face or preprocess OpenAI's VPT dataset manually.

Option 1: Use Preprocessed Dataset (Recommended)

The easiest way to get started is to download our preprocessed dataset from Hugging Face. This script will handle downloading and extracting it:

bash input_pipeline/download/download_array_records.sh

Option 2: Manual Download & Preprocessing of OpenAI's VPT Dataset

If you prefer to use the raw VPT dataset from OpenAI and preprocess it yourself, follow these steps:

  1. Download index files: This will download the initial index file:

    bash input_pipeline/download/openai/download_index_files.sh
  2. Download from all index files: This may take a long time depending on your bandwidth:

    python input_pipeline/download/openai/download_videos.py --index_file_path data/open_ai_index_files/all_7xx_Apr_6.json
    python input_pipeline/download/openai/download_videos.py --index_file_path data/open_ai_index_files/all_8xx_Jun_29.json
    python input_pipeline/download/openai/download_videos.py --index_file_path data/open_ai_index_files/all_9xx_Jun_29.json
    python input_pipeline/download/openai/download_videos.py --index_file_path data/open_ai_index_files/all_10xx_Jun_29.json
  3. Preprocess videos into ArrayRecords: For efficient distributed training, convert the raw videos into the arrayrecord format (make sure to have ffmpeg installed on your machine):

    python input_pipeline/preprocess/video_to_array_records.py

Note: This is a large dataset and may take considerable time and storage to download and process.

Quick Start 🚀

Genie has three components: a video tokenizer, a latent action model, and a dynamics model. Each of these components are trained separately, however, the dynamics model requires a pre-trained video tokenizer (and latent action model).

To train the video tokenizer, run:

python train_tokenizer.py --ckpt_dir <path>

To train the latent action model, run:

python train_lam.py --ckpt_dir <path>

Once the tokenizer and LAM are trained, the dynamics model can be trained with:

python train_dynamics.py --tokenizer_checkpoint <path> --lam_checkpoint <path>

Logging with wandb is supported. To enable logging, set the WANDB_API_KEY environment variable or run:

wandb login

Training can then be logged by setting the --log flag:

python train_tokenizer.py --log --entity <wandb-entity> --project <wandb-project>

Citing 📜

Jasmine was built by Mihir Mahajan, Alfred Nguyen and Franz Srambical, but started as a fork of Jafar, built by Matthew Jackson and Timon Willi.

If you use Jasmine in your work, please cite us, Jafar, and the original Genie paper as follows:

@article{
    mahajan2025jasmine,
    title={Jasmine: A simple, performant and scalable JAX-based world modeling codebase},
    author={Mihir Mahajan and Alfred Nguyen and Franz Srambical and Stefan Bauer},
    journal = {p(doom) blog},
    year={2025},
    url={https://pdoom.org/jasmine.html},
    note = {https://pdoom.org/blog.html}
}
@inproceedings{
    willi2024jafar,
    title={Jafar: An Open-Source Genie Reimplemention in Jax},
    author={Timon Willi and Matthew Thomas Jackson and Jakob Nicolaus Foerster},
    booktitle={First Workshop on Controllable Video Generation @ ICML 2024},
    year={2024},
    url={https://openreview.net/forum?id=ZZGaQHs9Jb}
}
@inproceedings{
    bruce2024genie,
    title={Genie: Generative Interactive Environments},
    author={Jake Bruce and Michael D Dennis and Ashley Edwards and Jack Parker-Holder and Yuge Shi and Edward Hughes and Matthew Lai and Aditi Mavalankar and Richie Steigerwald and Chris Apps and Yusuf Aytar and Sarah Maria Elisabeth Bechtle and Feryal Behbahani and Stephanie C.Y. Chan and Nicolas Heess and Lucy Gonzalez and Simon Osindero and Sherjil Ozair and Scott Reed and Jingwei Zhang and Konrad Zolna and Jeff Clune and Nando de Freitas and Satinder Singh and Tim Rockt{\"a}schel},
    booktitle={Forty-first International Conference on Machine Learning},
    year={2024},
    url={https://openreview.net/forum?id=bJbSbJskOS}
}

About

A simple, performant and scalable JAX-based world modeling codebase

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 98.0%
  • Shell 2.0%