Tags: Dao-AILab/flash-attention
Tags
fix (#2481) Co-authored-by: wangziheng <[email protected]>
Allow compact block sparse index tensors (#2417) * Allow compact block sparse index tensors Relax validation in block_sparsity.py to allow idx.shape[3] <= expected_n_blocks instead of requiring exact equality. FA4 only accesses indices 0..cnt-1 per query tile, so the index tensor's last dimension does not need to be as large as ceil(seqlen_k / block_size_n). This enables memory-efficient compact index tensors that avoid O(N^2) memory at long sequence lengths (e.g., 1M+ tokens for sparse attention / NSA workloads). Changes: - _check_and_expand_block: accept compact n-block dimension and expand only the batch/head/m-block dimensions - infer_block_sparse_expected_shapes: change strict equality check to upper-bound check (error only when n-blocks exceeds expected, not when smaller) Backward compatible: existing code that passes full-sized tensors is unaffected. * Add test for compact block sparse index tensors Verify that truncating block sparse index tensors to idx.shape[3] = max(cnt) (instead of the full ceil(seqlen_k / block_size_n)) produces bit-identical output to full-sized tensors. This validates the relaxed validation from the previous commit.
PreviousNext