Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Fix misc blog things for flex-attention #1704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions _posts/2024-08-07-flexattention.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
---
layout: blog_detail
title: "FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention"
author: "Team PyTorch: Horace He, Driss Guessous, Yanbo Liang, Joy Dong"
---

![a cartoon chart flexing his muscles](/assets/images/flexattention/fg1.jpg){:style="width:100%"}
Expand Down Expand Up @@ -120,6 +121,7 @@ Note that unlike typical implementations, this does *not* need to materialize a
### ALiBi Bias

![alibi bias](/assets/images/flexattention/fg6.png){:style="max-width:600px; display:block; margin-left: auto; margin-right: auto; width:100%"}
<p style="text-align: center;"><em>Source: <a href="https://arxiv.org/abs/2108.12409">Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation</a></em></p>

ALiBi was introduced in [Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation](https://arxiv.org/abs/2108.12409), and claims to have beneficial properties for length extrapolation at inference. Notably, MosaicML has pointed to [“lack of kernel support”](https://twitter.com/jefrankle/status/1804567458092605736) as the main reason why they eventually switched from ALiBi to rotary embeddings.

Expand All @@ -137,17 +139,18 @@ This demonstrates one interesting piece of flexibility `torch.compile` provides

### Soft-capping

Soft-capping is a technique used in [Gemma2](https://huggingface.co/blog/gemma2\#soft-capping-and-attention-implementations) and Grok-1 that prevents logits from growing excessively large. In FlexAttention, it looks like:
Soft-capping is a technique used in [Gemma2](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) and Grok-1 that prevents logits from growing excessively large. In FlexAttention, it looks like:

```py
softcap = 20
def soft_cap(score, b, h, q_idx, kv_idx): score = score / softcap
def soft_cap(score, b, h, q_idx, kv_idx):
score = score / softcap
score = torch.tanh(score)
score = score * softcap
return score
```

Note that we also automatically generate the backwards pass from the forwards pass here. Also, although this implementation is semantically correct, we likely want to use a tanh approximation in this case for performance reasons. See [attention-gym](https://github.com/pytorch-labs/attention-gym/blob/738268eae279c48dc8c4d1c6f40b3cfaec648831/attn\_gym/mods/softcapping.py\#L1) for more details.
Note that we also automatically generate the backwards pass from the forwards pass here. Also, although this implementation is semantically correct, we likely want to use a tanh approximation in this case for performance reasons. See [attention-gym](https://github.com/pytorch-labs/attention-gym/blob/main/attn_gym/mods/softcapping.py) for more details.

### Causal Mask

Expand All @@ -164,7 +167,7 @@ However, masking is special compared to other modifications \- if something is m

## Mask Mods

To take advantage of sparsity from masking, we need to do some more work. Specifically, by passing a `mask_mod` to [`create_block_mask`](https://github.com/pytorch/pytorch/blob/e49c0acc396e89baf8c6450e1fa0571d4ce2d4ed/torch/nn/attention/flex_attention.py\#L594), we can create a `BlockMask`. FlexAttention can then use `BlockMask` to take advantage of the sparsity\!
To take advantage of sparsity from masking, we need to do some more work. Specifically, by passing a `mask_mod` to [`create_block_mask`](https://github.com/pytorch/pytorch/blob/e49c0acc396e89baf8c6450e1fa0571d4ce2d4ed/torch/nn/attention/flex_attention.py#L594), we can create a `BlockMask`. FlexAttention can then use `BlockMask` to take advantage of the sparsity\!

The signature of `mask_mod` is very similar to `score_mod` \- just without the `score`. In particular

Expand Down Expand Up @@ -201,6 +204,7 @@ While the TFlops are roughly the same, the execution time is 2x faster for the m
### Sliding Window \+ Causal

![Sliding Window Causal diagrams](/assets/images/flexattention/fg8.png){:style="width:100%"}
<p style="text-align: center;"><em>Source: <a href="https://arxiv.org/abs/2310.06825">Mistral 7B</a></em></p>


Popularized by [Mistral](https://arxiv.org/abs/2310.06825), sliding window attention (also known as local attention) takes advantage of the intuition that the most recent tokens are the most useful. In particular, it allows the query token to only attend to, say, the 1024 most recent tokens. This is often used together with causal attention.
Expand Down Expand Up @@ -229,6 +233,7 @@ We benchmark it against `F.scaled_dot_product_attention` with a sliding window m
### PrefixLM

![PrefixLM diagram](/assets/images/flexattention/fg10.png){:style="max-width:600px; display:block; margin-left: auto; margin-right: auto; width:100%"}
<p style="text-align: center;"><em>Source: <a href="https://arxiv.org/abs/2407.07726">PaliGemma: A versatile 3B VLM for transfer</a></em></p>

The T5 architecture, proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683), describes an attention variant that performs full bidirectional attention on a “prefix”, and causal attention on the rest. We again compose two mask functions to accomplish this, one for causal masking and one that is based off of the prefix length.

Expand Down Expand Up @@ -262,7 +267,7 @@ Through `BlockMask`, we can support this efficiently in FlexAttention as well\!
document_id: [SEQ_LEN]

def document_masking(b, h, q_idx, kv_idx):
return document_id[q_idx] == document_id[kv_idx]
return document_id[q_idx] == document_id[kv_idx]
```

And that’s it\! In this case, we see that we end up with a blockdiagonal mask.
Expand Down Expand Up @@ -424,7 +429,7 @@ Although the results are not bitwise identical, we are confident that FlexAttent

### Performance

Generally speaking, FlexAttention is nearly as performant as a handwritten Triton kernel, which is unsurprising, as we heavily leverage a handwritten Triton kernel. However, due to its generality, we do incur a small performance penalty. For example, we must incur some additional latency to determine which block to compute next. In some cases, we provide some kernel options that can affect the performance of the kernel while changing its behavior. They can be found here: [performance knobs](https://github.com/pytorch/pytorch/blob/ee09d066d35d7e17cf7e9479c0b8bfc70cffc264/torch/_inductor/kernel/flex_attention.py\#L146-L155)
Generally speaking, FlexAttention is nearly as performant as a handwritten Triton kernel, which is unsurprising, as we heavily leverage a handwritten Triton kernel. However, due to its generality, we do incur a small performance penalty. For example, we must incur some additional latency to determine which block to compute next. In some cases, we provide some kernel options that can affect the performance of the kernel while changing its behavior. They can be found here: [performance knobs](https://github.com/pytorch/pytorch/blob/ee09d066d35d7e17cf7e9479c0b8bfc70cffc264/torch/_inductor/kernel/flex_attention.py#L146-L155)

As a case study, let's explore how the knobs affect the performance of causal attention. We will compare performance of the triton kernel versus FlashAttentionv2 on A100. The script can be found [here](https://github.com/pytorch/pytorch/blob/main/benchmarks/transformer/score_mod.py).

Expand Down Expand Up @@ -458,7 +463,20 @@ We look forward to leveraging the approach we used here to more applications in

### Limitations and Future Work

- FlexAttention is currently available in PyTorch nightly releases, we plan to release it as a prototype feature in 2.5.0
- We did not cover how to use FlexAttention for inference here (or how to implement PagedAttention) \- we will cover those in a later post.
- We are working to improve the performance of FlexAttention to match FlashAttention3 on H100 GPUs.
- FlexAttention requires that all sequence lengths be a multiple of 128 \- this will be addressed soon.
- We plan on adding GQA support soon \- for now, you can just replicate the kv heads.


### Acknowledgements

We want to highlight some prior work (and people) that have inspired FlexAttention.

- Tri Dao's work on FlashAttention
- Francisco Massa and the Xformers team for BlockSparseAttention in Triton
- The Jax team's work on SplashAttention
- Philippe Tillet and Keren Zhou for helping us with Triton
- Ali Hassani for discussions on neighborhood attention
- Everybody who's complained about attention kernels not supporting their favorite attention variant :)
Binary file modified assets/images/flexattention/fg9.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.