Thanks to visit codestin.com
Credit goes to github.com

Skip to content

test-time-training/e2e

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

End-to-End Test-Time Training for Long Context

Paper | Setup | Replicating Experiments

Abstract

We formulate long-context language modeling as a problem in continual learning rather than architecture design. Under this formulation, we only use a standard architecture – a Transformer with sliding-window attention. However, our model continues learning at test time via next-token prediction on the given context, compressing the context it reads into its weights. In addition, we improve the model's initialization for learning at test time via meta-learning at training time. verall, our method, a form of Test-Time Training (TTT), is End-to-End (E2E) both at test time (via next-token prediction) and training time (via meta-learning), in contrast to previous forms.

Setup

This codebase is implemented in JAX and has been tested on GPUs.

Environment setup

We recommend the following system GPU library versions:

  • CUDA Toolkit 12.8.1
  • cuDNN 9.8.0
  • NCCL 2.26.2 (built for CUDA 12.8)

We use uv for Python package management. Install uv with:

curl -LsSf https://astral.sh/uv/install.sh | sh

Dataset Download

Our Llama-3 tokenized datasets are available for download from Google Cloud Storage buckets:

gcloud storage cp -r gs://llama3-dclm-filter-8k/ llama3-dclm-filter-8k
gcloud storage cp -r gs://llama3-books3/ llama3-books3

Note (Requester Pays): These buckets may have Requester Pays enabled. If you encounter a billing/permissions error, follow Google Cloud’s docs.

Once downloaded, fill in the deploy_paths in configs/deploy/interactive.yaml (or configs/deploy/submitit.yaml). This will allow the dataloader to find the correct path.

Replicating Experiments

We use Hydra for configuration management. Configs for each experiment in the paper live under configs/experiment/.

Required Weights & Biases settings

Logging is done with Weights & Biases, and the following fields are required for launches below:

  • training.wandb_entity
  • training.wandb_project
  • training.wandb_key

Run on an interactive node

You can launch an experiment on an interactive node with:

uv run --exact train \
  +deploy=interactive \
  +experiment=125m/pretrain/pretrain-125m-e2e \
  training.wandb_entity=my-entity \
  training.wandb_project=my-project \
  training.wandb_key=my-key

Run multi-node on Slurm (Submitit)

For multi-node jobs on a Slurm cluster, we use Hydra’s Submitit launcher. For example:

uv run --exact train \
  +deploy=submitit \
  hydra.launcher.nodes=4 \
  +experiment=125m/pretrain/pretrain-125m-e2e \
  training.wandb_entity=my-entity \
  training.wandb_project=my-project \
  training.wandb_key=my-key

To configure additional Slurm settings (partition, account, GPUs per node, time limits, etc.), see configs/deploy/submitit.yaml and the Hydra Submitit Launcher docs.

Loading a model for extension

To initialize an extension run from a previous experiment, set:

  • training.resume_exp_name=<experiment_name> to point to the experiment you want to resume from, and
  • training.load_part=params to load model parameters from the most recent checkpoint.

On startup, the trainer will automatically locate the latest checkpoint in the experiment directory and restore it before beginning training.

About

Official JAX implementation of End-to-End Test-Time Training for Long Context

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages