This repository provides the official implementation of SageAttention and SageAttention2.
SageAttention: Accurate 8-Bit Attention for Plug-and-play Inference Acceleration
Paper: https://arxiv.org/abs/2410.02367
Jintao Zhang, Jia Wei, Haofeng Huang, Pengle Zhang, Jun Zhu, Jianfei Chen
SageAttention2: Efficient Attention with Thorough Outlier Smoothing and Per-thread INT4 Quantization
Paper: https://arxiv.org/abs/2411.10958
Jintao Zhang, Haofeng Huang, Pengle Zhang, Jia Wei, Jun Zhu, Jianfei Chen
- Optmized kernels for Ampere, Ada and Hopper GPUs.
- INT8 quantization and smoothing for
$QK^\top$ with support for varying granularities. - FP8 quantization for
$PV$ . - Two-level accumulation strategy for
$PV$ to improve accuracy in FP8 MMA and WGMMA. - Support
torch.compilewith non-cudagraphs mode and distributed inference.
🚀 SageAttention achieves surprising speedup on most GPUs without compromising accuracy across all models in a plug-and-play way.
- [2025-02-25]: 🔥 We release SpargeAttn, a sparse attention based on SageAttention2, which could acclerate any model without training.
- [2025-02-15]: 🔥 The compilation code is updated to support RTX5090! On RTX5090, SageAttention reaches 560T, 2.7x faster than FlashAttention2!
- [2025-01-28]: 🔥⚡SageAttention is now available on Hopper GPUs (H100, H800, H20)! It matches the speed of FlashAttention3-FP8 but offers much better accuracy!
| FlashAttention2 | FlashAttention3 | FlashAttention3-FP8 | SageAttention |
|---|---|---|---|
| 25'34'' | 17'32'' | 12'14'' | 12'07'' |
Results for CogVideoX1.5-5B on NVIDIA H20 GPU
- [2025-01-24]: 🎉SageAttention is accepted by ICLR 2025!
- [2024-12-20]: 🔥Update the SageAttention2 Paper.
- [2024-12-20]: 🔥Release SageAttention 2.0.1 Beta! In this version, we introduce a new feature: per-thread quantization, which offers finer granularity while maintaining hardware efficiency.
- [2024-11-21]: 🔥SageAttention 2.0.0 beta is released! Now SageAttention has measured speedup on L20, L40, A100, A800, and A6000, RTX3090 and RTX4090.
- [2024-11-12]: Support for
sageattn_varlenis available now. - [2024-11-11]: Support for different sequence lengths between
qandk,v,(batch_size, head_num, seq_len, head_dim)or(batch_size, seq_len, head_num, head_dim)input shapes, andgroup-query attentionis available now.
python>=3.9,torch>=2.3.0,triton>=3.0.0
CUDA:>=12.8for Blackwell>=12.4for fp8 support on Ada>=12.3for fp8 support on Hopper>=12.0for Ampere
flash-attnfor benchmarking
For the stable Triton-only version, refer to SageAttention-1 and install using pip:
pip install sageattention==1.0.6
To use SageAttention 2.1.1, please compile from source:
git clone https://github.com/thu-ml/SageAttention.git
cd sageattention
python setup.py install # or pip install -e .
To benchmark the speed against FlashAttention3, please compile FlashAttention3 from source:
git clone https://github.com/Dao-AILab/flash-attention.git --recursive
git checkout b7d29fb3b79f0b78b1c369a52aaa6628dabfb0d7 # 2.7.2 release
cd hopper
python setup.py install
from sageattention import sageattn
attn_output = sageattn(q, k, v, tensor_layout="HND", is_causal=False)q, k, vare FP16/BF16 dtype with the shape(batch_size, head_num, seq_len, head_dim)using defaulttensor_layout="HND". For shape(batch_size, seq_len, head_num, head_dim), settensor_layout="NHD".is_causaldetermines the use of a causal mask.
-
sageattn: Automatically selects the optimal kernel based on the GPU to achieve a good performance-accuracy trade-off. -
sageattn_qk_int8_pv_fp16_triton: INT8 quantization for$QK^\top$ and FP16 for$PV$ using Triton backend. -
sageattn_qk_int8_pv_fp16_cuda: INT8 quantization for$QK^\top$ and FP16 for$PV$ using CUDA backend. -
sageattn_qk_int8_pv_fp8_cuda: INT8 quantization for$QK^\top$ and FP8 for$PV$ using CUDA backend. -
sageattn_qk_int8_pv_fp8_cuda_sm90: INT8 quantization for$QK^\top$ and FP8 for$PV$ using CUDA backend, specifically optimized for Hopper GPUs. -
sageattn_varlen: INT8 quantization for$QK^\top$ and FP16 for$PV$ using Triton backend. Support for varying sequence lengths within the same batch.
For optimal speed and accuracy performance on custom devices and models, we strongly recommend referring to the this file for detailed guidance.
Note: Support for different sequence lengths between
qandk,vandgroup-query attentionis available.
We can replace scaled_dot_product_attention easily.
We will take CogvideoX as an example:
Add the following codes and run
import torch.nn.functional as F
+ from sageattention import sageattn
+ F.scaled_dot_product_attention = sageattn
Specifically,
cd example
python cogvideox-2b.py --compile --attention_type sageYou can get a lossless video in ./example faster than by using python cogvideox-2b.py --compile. More examples and guidance can be found under the example/ directory.
Note: Not all models works with
F.scaled_dot_product_attention = sageattn. Technically, you should replace the original Attention by modifying theAttention Classof the target model. For image and video models, we suggest only replacing the attention in DiT (seeexample/mochi.pyfor detail).
We provide a benchmarking script to compare the speed of different kernels including SageAttention, FlashAttention2 and FlashAttention3. Please refer to the benchmark/ directory for more details.
8+8 means the kernel with INT8 quantization for 8+16 uses FP16 with FP16 accumulator for
Note: The TOPS results refer only to the Attention Kernel, excluding the quantization and smoothing.
If you use this code or find our work valuable, please cite:
@inproceedings{zhang2025sageattention,
title={SageAttention: Accurate 8-Bit Attention for Plug-and-play Inference Acceleration},
author={Zhang, Jintao and Wei, Jia and Zhang, Pengle and Zhu, Jun and Chen, Jianfei},
booktitle={International Conference on Learning Representations (ICLR)},
year={2025}
}
@misc{zhang2024sageattention2,
title={SageAttention2: Efficient Attention with Thorough Outlier Smoothing and Per-thread INT4 Quantization},
author={Jintao Zhang and Haofeng Huang and Pengle Zhang and Jia Wei and Jun Zhu and Jianfei Chen},
year={2024},
eprint={2411.10958},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2411.10958},
}