This codebase provides the implementation of the Ring Attention with Blockwise Transformers. The model is described in the paper Ring Attention with Blockwise Transformers for Near-Infinite Context and Blockwise Parallel Transformer for Large Context Models.
Ring Attention with Blockwise Parallel Transformers enables training sequences up to a length of 'number of devices' times longer than those possible with BPT. This is achieved by distributing the attention and feedforward computation across multiple devices and overlapping the communication with computation. Thanks to the blockwise computing of the attention and feedforward network, it is possible to train with tens of millions of tokens in context size without adding any communication or computation overhead.
First, install the package:
pip install ringattentionThen, ringattention and blockwise_feedforward can be imported and used as follows:
from ringattention import ringattention, blockwise_feedforwardYou can use the ringattention function by wrapping it with shard_map to shard the computation across multiple devices. Here is an example of how to use the ringattention function with sharding:
ring_attention_sharded = shard_map(
partial(
ringattention,
axis_name="sp",
float32_logits=True,
cache_idx=None,
blockwise_kwargs=dict(
causal_block_size=1,
deterministic=True,
dropout_rng=None,
attn_pdrop=0.0,
query_chunk_size=512,
key_chunk_size=512,
policy=jax.checkpoint_policies.nothing_saveable,
dtype=jax.numpy.float32,
precision=None,
prevent_cse=True,
)
),
mesh=LLaMAConfig.get_jax_mesh(self.config.mesh_dim),
in_specs=(
PS(("dp", "fsdp"), "sp", "tp", None),
PS(("dp", "fsdp"), "sp", "tp", None),
PS(("dp", "fsdp"), "sp", "tp", None),
PS(("dp", "fsdp"), None, None, None),
PS(("dp", "fsdp"), None),
),
out_specs=PS(("dp", "fsdp"), "sp", "tp", None),
check_rep=False
)
attn_output = ring_attention_sharded(xq, xk, xv, attention_bias, segment_ids)Explanation of the arguments:
-
query_chunk_sizeandkey_chunk_sizeare the chunk sizes for the query and key. Choose as large as you can untill out of memory to speed up the computation. -
policyis the checkpoint policy for the attention weights, usejax.checkpoint_policies.nothing_saveableto enable the checkpointing. -
causal_block_sizeis the block size for block-causal attention.causal_block_size=1is equivalent to causal attention. -
cache_idxis the cache index for inference. Ifcache_idxis notNone, the attention weights will be cached and reused in the next inference.
RingAttention is utilized in Large World Model (LWM) for million-length vision-language training, where a complete example of using RingAttention and BlockwiseTransformers can be found: LWM codebase
If you find our work relevant to your research, please cite:
@article{liu2023blockwise,
title={Blockwise Parallel Transformer for Large Context Models},
author={Liu, Hao and Abbeel, Pieter},
journal={Advances in neural information processing systems},
year={2023}
}@article{liu2023ring,
title={Ring Attention with Blockwise Transformers for Near-Infinite Context},
author={Liu, Hao and Zaharia, Matei and Abbeel, Pieter},
journal={arXiv preprint arXiv:2310.01889},
year={2023}
}