This is a work-in-progress rewrite of Andrej Karpathy's nanoGPT in Jax/Flax.
One of the goals of this project is to try out jax.experimental.pjit. I'm curious about the performance differences for model size and distribution configurations.
Currently the GPT2 124M parameter model reaches 2.906 validation loss after training on OpenWebText with a TPU V3-32 pod slice for 150K steps (about 20 hours).
Clone gpt-jax
git clone https://github.com/jenkspt/gpt-jax.git
cd gpt-jaxInstall python dependencies
pip install -U pip
pip install tqdm
pip install numpy
pip install tiktoken
pip install datasets
pip install tensorflowPrepare data
python data/openwebtext/prepare.pyThis will generate the following files:
train_0.tfrecord, train_1.tfrecord ... train_{num_shards}
val_0.tfrecord, val_1.tfrecord ... val_{num_shards}
If you're training on a TPU, you should copy these files to a GCS bucket.
- Create TPU v3-32
ZONE=europe-west4-a
TPU_TYPE=v3-32
VM_NAME=jax-gpt-v3-32
gcloud alpha compute tpus tpu-vm create $VM_NAME \
--zone=$ZONE \
--accelerator-type=$TPU_TYPE \
--version=v2-tf-stable \
--preemptible- Clone repo and install dependencies
gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone=$ZONE \
--worker=all --command="
git clone https://github.com/jenkspt/gpt-jax.git
cd gpt-jax
pip install -U pip
pip install tyro
pip install wandb
pip install -U tensorflow
pip install -U \"jax[tpu]\" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install flax
"- Launch training job
EXPERIMENT=gpt2-124m/run_$(date +%Y-%m-%d_%H-%M-%S)
echo $EXPERIMENT
BRANCH=main
gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone=$ZONE \
--worker=all --command="
export WANDB_API_KEY=$WANDB_API_KEY
export TF_CPP_MIN_LOG_LEVEL=3 # silence annoying TF warnings
export GPT_CONFIG=config/gpt2.yaml # this is the default GPT config for this run
cd gpt-jax
git fetch --all
git checkout $BRANCH
python3 train.py \
--out_dir=gs://{your-bucket}/$EXPERIMENT \
--train_pattern=gs://{your-bucket}/openwebtext/train_??.tfrecord \
--val_pattern=gs://{your-bucket}/openwebtext/val_??.tfrecord \
--wandb.name=$EXPERIMENT \
--wandb.notes=''
"Don't forget to delete the TPU instance when you're done
gcloud alpha compute tpus tpu-vm delete $VM_NAME --zone=$ZONE