fix(triton/decode): honor real paged KV block stride (support non-contiguous cache)#3498
Open
lorri-rao wants to merge 2 commits into
Open
fix(triton/decode): honor real paged KV block stride (support non-contiguous cache)#3498lorri-rao wants to merge 2 commits into
lorri-rao wants to merge 2 commits into
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
micmelesse
previously approved these changes
Jun 3, 2026
Contributor
micmelesse
left a comment
There was a problem hiding this comment.
LGTM. @lorri-rao Can you add a test so that we can pin this behavior.
…tiguous cache) attention_forward_decode_triton_impl's split-K kernel computed the paged block base address as `physical_block * BLOCK_SIZE_K * stride_kn`, i.e. it reconstructed the per-block stride as block_size * slot_stride and thus assumed the paged KV cache is contiguous. This is wrong for callers whose key_cache / value_cache are non-contiguous views of the paged buffer. In particular, vLLM hybrid attention+mamba models store the KV cache block-major with K/V interleaved per block ([num_blocks, 2, block_size, num_kv_heads, head_size]); the attention halves obtained via kv_cache.unbind(0) are non-contiguous with block stride 2 * block_size * num_kv_heads * head_size, i.e. 2x what the kernel assumed. The kernel then reads the wrong block memory (straddling K and V) and produces garbage attention output. Thread the real block stride (k_cache.stride(0) / v_cache.stride(0)) into the kernel and use it for the K/V block base instead of BLOCK_SIZE_K * stride_kn. This is a no-op for contiguous caches (stride(0) == BLOCK_SIZE_K * stride_kn) and makes the kernel correct for any regular paged layout, contiguous or interleaved. Validated on an interleaved (stride(0)=8192) paged cache: max abs error vs a brute-force reference dropped from ~6.3 to bf16 noise (~0.09); contiguous caches are unchanged.
Add test_flash_attn_kvcache_noncontiguous_paged, which exercises the
split-K decode path against a paged KV cache stored block-major with K/V
interleaved per block ([num_blocks, 2, block_size, nheads_k, d]) -- the
layout vLLM hybrid attention+mamba models use. The per-component K/V
caches are non-contiguous views whose block stride is
2 * block_size * nheads_k * d.
The test compares against a dense reference gathered from the same paged
buffer (3x pytorch-baseline tolerance, matching test_flash_attn_kvcache)
and additionally asserts the non-contiguous result is bit-equivalent to
the same data in a contiguous paged cache. This pins the kernel to honor
the real k/v_cache.stride(0) instead of assuming block_size * slot_stride.
Verified it fails on the pre-fix kernel (block-stride assumed contiguous)
and passes after the fix, across mha/gqa, causal on/off, block_size
{16,256}, d {64,128}.
d293d61 to
757b162
Compare
Author
Thanks for looking into this, I just added a test @micmelesse |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
attention_forward_decode_triton_impl's split-K kernel computed the paged block base address as
physical_block * BLOCK_SIZE_K * stride_kn, i.e. it reconstructed the per-block stride as block_size * slot_stride and thus assumed the paged KV cache is contiguous.This is wrong for callers whose key_cache / value_cache are non-contiguous views of the paged buffer. In particular, vLLM hybrid attention+mamba models store the KV cache block-major with K/V interleaved per block ([num_blocks, 2, block_size, num_kv_heads, head_size]); the attention halves obtained via kv_cache.unbind(0) are non-contiguous with block stride 2 * block_size * num_kv_heads * head_size, i.e. 2x what the kernel assumed. The kernel then reads the wrong block memory (straddling K and V) and produces garbage attention output.
Thread the real block stride (k_cache.stride(0) / v_cache.stride(0)) into the kernel and use it for the K/V block base instead of BLOCK_SIZE_K * stride_kn. This is a no-op for contiguous caches (stride(0) == BLOCK_SIZE_K * stride_kn) and makes the kernel correct for any regular paged layout, contiguous or interleaved.
Validated on an interleaved (stride(0)=8192) paged cache: max abs error vs a brute-force reference dropped from ~6.3 to bf16 noise (~0.09); contiguous caches are unchanged.
Internal ticket:
ROCM-25540