@@ -157,6 +157,49 @@ def _generate_block_kvcache(
157157 return k_cache , v_cache , block_table , k_cache_paged , v_cache_paged , num_blocks
158158
159159
160+ def _generate_interleaved_block_kvcache (
161+ seqlen_k , paged_kv_block_size , batch_size , nheads_k , d , device , dtype
162+ ):
163+ """Paged KV cache stored block-major with K and V interleaved per block
164+ ([num_blocks, 2, block_size, nheads_k, d]) -- the layout vLLM hybrid
165+ attention+mamba models use.
166+
167+ The per-component K/V caches are *non-contiguous* views of the backing
168+ buffer: their block stride is ``2 * block_size * nheads_k * d`` (twice what a
169+ contiguous ``[num_blocks, block_size, nheads_k, d]`` cache would have). This
170+ is the case the split-K decode kernel must honor by reading the real
171+ ``stride(0)`` instead of assuming ``block_size * slot_stride``.
172+ """
173+ num_blocks = math .ceil (seqlen_k / paged_kv_block_size ) * batch_size * 3
174+ kv_cache_paged = torch .randn (
175+ num_blocks ,
176+ 2 ,
177+ paged_kv_block_size ,
178+ nheads_k ,
179+ d ,
180+ device = device ,
181+ dtype = dtype ,
182+ )
183+ k_cache_paged = kv_cache_paged [:, 0 ] # non-contiguous view, stride(0) = 2*...
184+ v_cache_paged = kv_cache_paged [:, 1 ]
185+ block_table = rearrange (
186+ torch .randperm (num_blocks , dtype = torch .int32 , device = device ),
187+ "(b nblocks) -> b nblocks" ,
188+ b = batch_size ,
189+ )
190+ k_cache = rearrange (
191+ k_cache_paged [block_table .to (dtype = torch .long ).flatten ()],
192+ "(b nblocks) block_size ... -> b (nblocks block_size) ..." ,
193+ b = batch_size ,
194+ )[:, :seqlen_k ]
195+ v_cache = rearrange (
196+ v_cache_paged [block_table .to (dtype = torch .long ).flatten ()],
197+ "(b nblocks) block_size ... -> b (nblocks block_size) ..." ,
198+ b = batch_size ,
199+ )[:, :seqlen_k ]
200+ return k_cache , v_cache , block_table , k_cache_paged , v_cache_paged , num_blocks
201+
202+
160203@pytest .mark .parametrize ("mha_type" , ["mha" , "gqa" ])
161204@pytest .mark .parametrize ("new_kv" , [False , True ])
162205@pytest .mark .parametrize ("causal" , [False , True ])
@@ -330,6 +373,145 @@ def test_flash_attn_kvcache(
330373 )
331374
332375
376+ @pytest .mark .parametrize ("mha_type" , ["mha" , "gqa" ])
377+ @pytest .mark .parametrize ("causal" , [False , True ])
378+ @pytest .mark .parametrize ("paged_kv_block_size" , [16 , 256 ])
379+ @pytest .mark .parametrize (
380+ "seqlen_q,seqlen_k" ,
381+ [
382+ (1 , 339 ),
383+ (3 , 1024 ),
384+ (17 , 156 ),
385+ ],
386+ )
387+ @pytest .mark .parametrize ("d" , [64 , 128 ])
388+ def test_flash_attn_kvcache_noncontiguous_paged (
389+ seqlen_q ,
390+ seqlen_k ,
391+ d ,
392+ paged_kv_block_size ,
393+ causal ,
394+ mha_type ,
395+ ):
396+ """Paged decode against a non-contiguous (K/V-interleaved) paged cache.
397+
398+ Regression guard for the split-K decode kernel previously hard-coding the
399+ paged block stride as ``block_size * slot_stride`` (contiguous-only). With an
400+ interleaved ``[num_blocks, 2, block_size, nheads_k, d]`` cache the real block
401+ stride is ``2 * block_size * nheads_k * d``; before the fix the kernel read
402+ K/V-straddling block memory and produced garbage attention. The dense
403+ reference is gathered from the same paged buffer, so this pins correct
404+ numerics for any regular paged layout, contiguous or interleaved.
405+ """
406+ dtype = torch .bfloat16
407+ device = "cuda"
408+ torch .random .manual_seed (SEED )
409+ torch .cuda .manual_seed (SEED )
410+ batch_size = 2
411+ nheads = 6
412+ nheads_k = nheads if mha_type == "mha" else 3
413+ assert nheads % nheads_k == 0
414+
415+ q = torch .randn (batch_size , seqlen_q , nheads , d , device = device , dtype = dtype )
416+
417+ (
418+ k_cache ,
419+ v_cache ,
420+ block_table ,
421+ k_cache_paged ,
422+ v_cache_paged ,
423+ num_blocks ,
424+ ) = _generate_interleaved_block_kvcache (
425+ seqlen_k , paged_kv_block_size , batch_size , nheads_k , d , device , dtype
426+ )
427+
428+ # Sanity: the views must really be non-contiguous (block stride == 2x), else
429+ # this test would silently degrade to the contiguous case and stop guarding
430+ # the fix.
431+ assert not k_cache_paged .is_contiguous ()
432+ assert not v_cache_paged .is_contiguous ()
433+ assert k_cache_paged .stride (0 ) == 2 * paged_kv_block_size * nheads_k * d
434+
435+ cache_seqlens = torch .randint (
436+ 1 , seqlen_k + 1 , (batch_size ,), dtype = torch .int32 , device = device
437+ )
438+
439+ arange = rearrange (torch .arange (seqlen_k , device = device ), "s -> 1 s" )
440+ key_padding_mask = arange < rearrange (cache_seqlens , "b -> b 1" )
441+
442+ k_cache_rep = repeat (k_cache , "b s h d -> b s (h g) d" , g = nheads // nheads_k )
443+ v_cache_rep = repeat (v_cache , "b s h d -> b s (h g) d" , g = nheads // nheads_k )
444+
445+ out = flash_attn_with_kvcache (
446+ q ,
447+ k_cache_paged ,
448+ v_cache_paged ,
449+ cache_seqlens = cache_seqlens ,
450+ page_table = block_table ,
451+ causal = causal ,
452+ )
453+ torch .cuda .synchronize ()
454+ if isinstance (out , tuple ):
455+ out = out [0 ]
456+ out = out .to (dtype )
457+
458+ out_ref , _ = attention_ref (
459+ q ,
460+ k_cache_rep ,
461+ v_cache_rep ,
462+ None ,
463+ key_padding_mask ,
464+ None ,
465+ 0.0 ,
466+ None ,
467+ causal = causal ,
468+ window_size = (- 1 , - 1 ),
469+ )
470+ out_pt , _ = attention_ref (
471+ q ,
472+ k_cache_rep ,
473+ v_cache_rep ,
474+ None ,
475+ key_padding_mask ,
476+ None ,
477+ 0.0 ,
478+ None ,
479+ causal = causal ,
480+ window_size = (- 1 , - 1 ),
481+ upcast = False ,
482+ reorder_ops = True ,
483+ )
484+
485+ pt_max_diff = (out_pt - out_ref ).abs ().max ().item ()
486+ our_max_diff = (out - out_ref ).abs ().max ().item ()
487+ mult = 3
488+ assert our_max_diff <= mult * pt_max_diff + 1e-5 , (
489+ f"Non-contiguous paged output max diff { our_max_diff :.6e} exceeds "
490+ f"{ mult } x Pytorch baseline diff { pt_max_diff :.6e} + 1e-5"
491+ )
492+
493+ # The exact same data laid out in a *contiguous* paged cache must produce the
494+ # same result -- this directly pins the kernel to the real block stride
495+ # rather than the contiguous-only assumption.
496+ out_contig = flash_attn_with_kvcache (
497+ q ,
498+ k_cache_paged .contiguous (),
499+ v_cache_paged .contiguous (),
500+ cache_seqlens = cache_seqlens ,
501+ page_table = block_table ,
502+ causal = causal ,
503+ )
504+ torch .cuda .synchronize ()
505+ if isinstance (out_contig , tuple ):
506+ out_contig = out_contig [0 ]
507+ out_contig = out_contig .to (dtype )
508+ contig_diff = (out - out_contig ).abs ().max ().item ()
509+ assert contig_diff < 1e-5 , (
510+ f"Non-contiguous vs contiguous paged cache differ by { contig_diff :.6e} "
511+ "(> 1e-5): the kernel is not honoring the real paged block stride"
512+ )
513+
514+
333515# torch.compile tests
334516@pytest .mark .parametrize ("new_kv" , [False , True ])
335517@pytest .mark .parametrize ("causal" , [True , False ])
0 commit comments