Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d293d61

Browse files
committed
test(triton/decode): pin non-contiguous paged KV block stride
Add test_flash_attn_kvcache_noncontiguous_paged, which exercises the split-K decode path against a paged KV cache stored block-major with K/V interleaved per block ([num_blocks, 2, block_size, nheads_k, d]) -- the layout vLLM hybrid attention+mamba models use. The per-component K/V caches are non-contiguous views whose block stride is 2 * block_size * nheads_k * d. The test compares against a dense reference gathered from the same paged buffer (3x pytorch-baseline tolerance, matching test_flash_attn_kvcache) and additionally asserts the non-contiguous result is bit-equivalent to the same data in a contiguous paged cache. This pins the kernel to honor the real k/v_cache.stride(0) instead of assuming block_size * slot_stride. Verified it fails on the pre-fix kernel (block-stride assumed contiguous) and passes after the fix, across mha/gqa, causal on/off, block_size {16,256}, d {64,128}.
1 parent 166c192 commit d293d61

1 file changed

Lines changed: 182 additions & 0 deletions

File tree

op_tests/triton_tests/attention/test_flash_attn_kvcache.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,49 @@ def _generate_block_kvcache(
157157
return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks
158158

159159

160+
def _generate_interleaved_block_kvcache(
161+
seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype
162+
):
163+
"""Paged KV cache stored block-major with K and V interleaved per block
164+
([num_blocks, 2, block_size, nheads_k, d]) -- the layout vLLM hybrid
165+
attention+mamba models use.
166+
167+
The per-component K/V caches are *non-contiguous* views of the backing
168+
buffer: their block stride is ``2 * block_size * nheads_k * d`` (twice what a
169+
contiguous ``[num_blocks, block_size, nheads_k, d]`` cache would have). This
170+
is the case the split-K decode kernel must honor by reading the real
171+
``stride(0)`` instead of assuming ``block_size * slot_stride``.
172+
"""
173+
num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3
174+
kv_cache_paged = torch.randn(
175+
num_blocks,
176+
2,
177+
paged_kv_block_size,
178+
nheads_k,
179+
d,
180+
device=device,
181+
dtype=dtype,
182+
)
183+
k_cache_paged = kv_cache_paged[:, 0] # non-contiguous view, stride(0) = 2*...
184+
v_cache_paged = kv_cache_paged[:, 1]
185+
block_table = rearrange(
186+
torch.randperm(num_blocks, dtype=torch.int32, device=device),
187+
"(b nblocks) -> b nblocks",
188+
b=batch_size,
189+
)
190+
k_cache = rearrange(
191+
k_cache_paged[block_table.to(dtype=torch.long).flatten()],
192+
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
193+
b=batch_size,
194+
)[:, :seqlen_k]
195+
v_cache = rearrange(
196+
v_cache_paged[block_table.to(dtype=torch.long).flatten()],
197+
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
198+
b=batch_size,
199+
)[:, :seqlen_k]
200+
return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks
201+
202+
160203
@pytest.mark.parametrize("mha_type", ["mha", "gqa"])
161204
@pytest.mark.parametrize("new_kv", [False, True])
162205
@pytest.mark.parametrize("causal", [False, True])
@@ -330,6 +373,145 @@ def test_flash_attn_kvcache(
330373
)
331374

332375

376+
@pytest.mark.parametrize("mha_type", ["mha", "gqa"])
377+
@pytest.mark.parametrize("causal", [False, True])
378+
@pytest.mark.parametrize("paged_kv_block_size", [16, 256])
379+
@pytest.mark.parametrize(
380+
"seqlen_q,seqlen_k",
381+
[
382+
(1, 339),
383+
(3, 1024),
384+
(17, 156),
385+
],
386+
)
387+
@pytest.mark.parametrize("d", [64, 128])
388+
def test_flash_attn_kvcache_noncontiguous_paged(
389+
seqlen_q,
390+
seqlen_k,
391+
d,
392+
paged_kv_block_size,
393+
causal,
394+
mha_type,
395+
):
396+
"""Paged decode against a non-contiguous (K/V-interleaved) paged cache.
397+
398+
Regression guard for the split-K decode kernel previously hard-coding the
399+
paged block stride as ``block_size * slot_stride`` (contiguous-only). With an
400+
interleaved ``[num_blocks, 2, block_size, nheads_k, d]`` cache the real block
401+
stride is ``2 * block_size * nheads_k * d``; before the fix the kernel read
402+
K/V-straddling block memory and produced garbage attention. The dense
403+
reference is gathered from the same paged buffer, so this pins correct
404+
numerics for any regular paged layout, contiguous or interleaved.
405+
"""
406+
dtype = torch.bfloat16
407+
device = "cuda"
408+
torch.random.manual_seed(SEED)
409+
torch.cuda.manual_seed(SEED)
410+
batch_size = 2
411+
nheads = 6
412+
nheads_k = nheads if mha_type == "mha" else 3
413+
assert nheads % nheads_k == 0
414+
415+
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
416+
417+
(
418+
k_cache,
419+
v_cache,
420+
block_table,
421+
k_cache_paged,
422+
v_cache_paged,
423+
num_blocks,
424+
) = _generate_interleaved_block_kvcache(
425+
seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype
426+
)
427+
428+
# Sanity: the views must really be non-contiguous (block stride == 2x), else
429+
# this test would silently degrade to the contiguous case and stop guarding
430+
# the fix.
431+
assert not k_cache_paged.is_contiguous()
432+
assert not v_cache_paged.is_contiguous()
433+
assert k_cache_paged.stride(0) == 2 * paged_kv_block_size * nheads_k * d
434+
435+
cache_seqlens = torch.randint(
436+
1, seqlen_k + 1, (batch_size,), dtype=torch.int32, device=device
437+
)
438+
439+
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
440+
key_padding_mask = arange < rearrange(cache_seqlens, "b -> b 1")
441+
442+
k_cache_rep = repeat(k_cache, "b s h d -> b s (h g) d", g=nheads // nheads_k)
443+
v_cache_rep = repeat(v_cache, "b s h d -> b s (h g) d", g=nheads // nheads_k)
444+
445+
out = flash_attn_with_kvcache(
446+
q,
447+
k_cache_paged,
448+
v_cache_paged,
449+
cache_seqlens=cache_seqlens,
450+
page_table=block_table,
451+
causal=causal,
452+
)
453+
torch.cuda.synchronize()
454+
if isinstance(out, tuple):
455+
out = out[0]
456+
out = out.to(dtype)
457+
458+
out_ref, _ = attention_ref(
459+
q,
460+
k_cache_rep,
461+
v_cache_rep,
462+
None,
463+
key_padding_mask,
464+
None,
465+
0.0,
466+
None,
467+
causal=causal,
468+
window_size=(-1, -1),
469+
)
470+
out_pt, _ = attention_ref(
471+
q,
472+
k_cache_rep,
473+
v_cache_rep,
474+
None,
475+
key_padding_mask,
476+
None,
477+
0.0,
478+
None,
479+
causal=causal,
480+
window_size=(-1, -1),
481+
upcast=False,
482+
reorder_ops=True,
483+
)
484+
485+
pt_max_diff = (out_pt - out_ref).abs().max().item()
486+
our_max_diff = (out - out_ref).abs().max().item()
487+
mult = 3
488+
assert our_max_diff <= mult * pt_max_diff + 1e-5, (
489+
f"Non-contiguous paged output max diff {our_max_diff:.6e} exceeds "
490+
f"{mult}x Pytorch baseline diff {pt_max_diff:.6e} + 1e-5"
491+
)
492+
493+
# The exact same data laid out in a *contiguous* paged cache must produce the
494+
# same result -- this directly pins the kernel to the real block stride
495+
# rather than the contiguous-only assumption.
496+
out_contig = flash_attn_with_kvcache(
497+
q,
498+
k_cache_paged.contiguous(),
499+
v_cache_paged.contiguous(),
500+
cache_seqlens=cache_seqlens,
501+
page_table=block_table,
502+
causal=causal,
503+
)
504+
torch.cuda.synchronize()
505+
if isinstance(out_contig, tuple):
506+
out_contig = out_contig[0]
507+
out_contig = out_contig.to(dtype)
508+
contig_diff = (out - out_contig).abs().max().item()
509+
assert contig_diff < 1e-5, (
510+
f"Non-contiguous vs contiguous paged cache differ by {contig_diff:.6e} "
511+
"(> 1e-5): the kernel is not honoring the real paged block stride"
512+
)
513+
514+
333515
# torch.compile tests
334516
@pytest.mark.parametrize("new_kv", [False, True])
335517
@pytest.mark.parametrize("causal", [True, False])

0 commit comments

Comments
 (0)