Thanks to visit codestin.com
Credit goes to github.com

Skip to content

cghawthorne/perceiver-ar

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Perceiver AR

Perceiver AR is an autoregressive, modality-agnostic architecture which uses cross-attention to map long-range inputs to a small number of latents while also maintaining end-to-end causal masking. Perceiver AR can directly attend to over a hundred thousand tokens, enabling practical long-context density estimation without the need for hand-crafted sparsity patterns or memory mechanisms.

For more details, see our ICML 2022 paper: https://arxiv.org/abs/2202.07765

An implementation is also available for T5X/Flaxformer.

Setup

First, install dependencies following these instructions:

  1. Create a virtual env: python3 -m venv ~/.venv/perceiver-ar
  2. Switch to the virtual env: source ~/.venv/perceiver-ar/bin/activate
  3. Follow instructions for installing JAX on your platform: https://github.com/google/jax#installation
  4. Install other dependencies: pip install -r requirements.txt

Training

As an example of the model, a 32-position version of the Copy Task from our paper can be trained using only a local CPU.

PYTHONPATH=.::$PYTHONPATH python perceiver_ar/experiment.py \
  --config=perceiver_ar/experiment.py:random_mirrored_32

By default, checkpoints and events will be saved to /tmp/perceiver_ar.

Training metrics will be periodically written to Tensorboard event files which can be viewed using:

tensorboard --logdir /tmp/perceiver_ar/

During training, use Ctrl+C to save a checkpoint and Ctrl+\ to save a checkpoint and exit.

Evaluation

To evaluate the latest saved checkpoint:

CHECKPOINTS="/tmp/perceiver_ar"
LATEST_CHECKPOINT="${CHECKPOINTS}/models/latest/$(ls -tr ${CHECKPOINTS}/models/latest/ | tail -n 1)"
echo "Evaluating ${LATEST_CHECKPOINT}"
PYTHONPATH=.::$PYTHONPATH python perceiver_ar/experiment.py \
  --config=perceiver_ar/experiment.py:random_mirrored_32 \
  --jaxline_mode=eval \
  --config.one_off_evaluate=True \
  --config.restore_path="${LATEST_CHECKPOINT}"

Results will be written to the console and can also be viewed from Tensorboard.

Inference

To run inference in a local Jupyter notebook:

jupyter notebook

Load inference.ipynb and follow the instructions in the notebook.

Pretrained Copy Task

The notebook also supports loading a pretrained checkpoint for the 131k-position copy task used in our paper. This model is fairly large, so inferring more than a few positions will likely require a large accelerator. The notebook has been tested to run on a GCP TPU VM using a TPU v3-8.

Unit Tests

To run all unit tests:

pytest

Disclaimer

This is not an officially supported Google product.

About

No description, website, or topics provided.

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 88.4%
  • Jupyter Notebook 11.6%