Thanks to visit codestin.com
Credit goes to github.com

Skip to content

fix(triton/decode): honor real paged KV block stride (support non-contiguous cache)#3498

Open
lorri-rao wants to merge 2 commits into
ROCm:mainfrom
Z-Y00:fix/paged-decode-noncontiguous-block-stride
Open

fix(triton/decode): honor real paged KV block stride (support non-contiguous cache)#3498
lorri-rao wants to merge 2 commits into
ROCm:mainfrom
Z-Y00:fix/paged-decode-noncontiguous-block-stride

Conversation

@lorri-rao
Copy link
Copy Markdown

@lorri-rao lorri-rao commented Jun 2, 2026

attention_forward_decode_triton_impl's split-K kernel computed the paged block base address as physical_block * BLOCK_SIZE_K * stride_kn, i.e. it reconstructed the per-block stride as block_size * slot_stride and thus assumed the paged KV cache is contiguous.

This is wrong for callers whose key_cache / value_cache are non-contiguous views of the paged buffer. In particular, vLLM hybrid attention+mamba models store the KV cache block-major with K/V interleaved per block ([num_blocks, 2, block_size, num_kv_heads, head_size]); the attention halves obtained via kv_cache.unbind(0) are non-contiguous with block stride 2 * block_size * num_kv_heads * head_size, i.e. 2x what the kernel assumed. The kernel then reads the wrong block memory (straddling K and V) and produces garbage attention output.

Thread the real block stride (k_cache.stride(0) / v_cache.stride(0)) into the kernel and use it for the K/V block base instead of BLOCK_SIZE_K * stride_kn. This is a no-op for contiguous caches (stride(0) == BLOCK_SIZE_K * stride_kn) and makes the kernel correct for any regular paged layout, contiguous or interleaved.

Validated on an interleaved (stride(0)=8192) paged cache: max abs error vs a brute-force reference dropped from ~6.3 to bf16 noise (~0.09); contiguous caches are unchanged.

Internal ticket:
ROCM-25540

@lorri-rao lorri-rao requested a review from a team June 2, 2026 22:55
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 2, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3498 --add-label <label>

micmelesse
micmelesse previously approved these changes Jun 3, 2026
Copy link
Copy Markdown
Contributor

@micmelesse micmelesse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. @lorri-rao Can you add a test so that we can pin this behavior.

Z-Y00 added 2 commits June 3, 2026 21:36
…tiguous cache)

attention_forward_decode_triton_impl's split-K kernel computed the paged block
base address as `physical_block * BLOCK_SIZE_K * stride_kn`, i.e. it reconstructed
the per-block stride as block_size * slot_stride and thus assumed the paged KV
cache is contiguous.

This is wrong for callers whose key_cache / value_cache are non-contiguous views
of the paged buffer. In particular, vLLM hybrid attention+mamba models store the
KV cache block-major with K/V interleaved per block ([num_blocks, 2, block_size,
num_kv_heads, head_size]); the attention halves obtained via kv_cache.unbind(0)
are non-contiguous with block stride 2 * block_size * num_kv_heads * head_size,
i.e. 2x what the kernel assumed. The kernel then reads the wrong block memory
(straddling K and V) and produces garbage attention output.

Thread the real block stride (k_cache.stride(0) / v_cache.stride(0)) into the
kernel and use it for the K/V block base instead of BLOCK_SIZE_K * stride_kn.
This is a no-op for contiguous caches (stride(0) == BLOCK_SIZE_K * stride_kn) and
makes the kernel correct for any regular paged layout, contiguous or interleaved.

Validated on an interleaved (stride(0)=8192) paged cache: max abs error vs a
brute-force reference dropped from ~6.3 to bf16 noise (~0.09); contiguous caches
are unchanged.
Add test_flash_attn_kvcache_noncontiguous_paged, which exercises the
split-K decode path against a paged KV cache stored block-major with K/V
interleaved per block ([num_blocks, 2, block_size, nheads_k, d]) -- the
layout vLLM hybrid attention+mamba models use. The per-component K/V
caches are non-contiguous views whose block stride is
2 * block_size * nheads_k * d.

The test compares against a dense reference gathered from the same paged
buffer (3x pytorch-baseline tolerance, matching test_flash_attn_kvcache)
and additionally asserts the non-contiguous result is bit-equivalent to
the same data in a contiguous paged cache. This pins the kernel to honor
the real k/v_cache.stride(0) instead of assuming block_size * slot_stride.

Verified it fails on the pre-fix kernel (block-stride assumed contiguous)
and passes after the fix, across mha/gqa, causal on/off, block_size
{16,256}, d {64,128}.
@lorri-rao lorri-rao force-pushed the fix/paged-decode-noncontiguous-block-stride branch from d293d61 to 757b162 Compare June 3, 2026 21:37
@lorri-rao
Copy link
Copy Markdown
Author

LGTM. @lorri-rao Can you add a test so that we can pin this behavior.

Thanks for looking into this, I just added a test @micmelesse

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants