Thanks to visit codestin.com
Credit goes to github.com

Skip to content

FSDP Transformer in JAX/NNX

License

tingtang2/picodo

 
 

Repository files navigation

Picodo: fast Transformer decoder training in JAX/NNX

  • Picodo has only ~360 SLOC
  • can run on GPUs, TPUs, Google Colab, or even locally on a Mac
  • achieves 39% MFU on TPU v6e-1 when training GPT-2 (124M)
  • supports FSDP (Fully Sharded Data Parallel) training
  • uses TPU flash attention
  • uses the new Flax NNX Api
  • uses Hydra for experiment management
  • uses Weights & Biases for experiment tracking

Training

Open In Colab

Picodo requires a pretokenized dataset for training following the same format as nanoGPT. This speeds up training and simplifies the codebase. FineWeb / FineWeb-Edu can be downloaded in this format using download_fineweb.py.

The simplest way to use this codebase is by using the provided Colab notebook, which automatically installs requirements, downloads the dataset, and starts training a model.

To train a model using bash, simply set the config name and any overrides:

python main.py +model=gpt2s +dataset=fw_gpt2 opt.batch_size=8

You can also run train.py directly, which uses the base.yaml config by default.

Inspiration

This repository was originally a fork of deepmind/NanoDO but it no longer shares any lines of code. Some notable changes:

  • NanoDO has ~1800 SLOC while Picodo only has ~360 SLOC
  • Picodo uses TPU flash attention
  • Picodo doens't rely on grain for data loading so it can run locally on a Mac
  • Picodo uses the new Flax NNX Api
  • Picodo uses Hydra and Weights & Biases instead of Google's ConfigDict / Tensorboard

About

FSDP Transformer in JAX/NNX

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%