Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[CUDA] Use cuDNN SDPA for decoding when using fixed-size KV cache#3113

Merged
zcbenz merged 1 commit intoml-explore:mainfrom
zcbenz:cuda-sdpa-sliced
Feb 10, 2026
Merged

[CUDA] Use cuDNN SDPA for decoding when using fixed-size KV cache#3113
zcbenz merged 1 commit intoml-explore:mainfrom
zcbenz:cuda-sdpa-sliced

Conversation

@zcbenz
Copy link
Collaborator

@zcbenz zcbenz commented Feb 8, 2026

The kv cache in mlx-lm has fixed-size and keys/values passed to fast.sdpa are slices, using this information we can create cuDNN graphs with fixed sequence length and use padding masks to set the actual sequence lengths.

This makes decoding (T_q == 1) 30%~100% faster for large sequence lengths. For small sequence lengths the cuDNN SDPA is still faster but the overhead of creating cuDNN graphs would eliminate the advantage so we fallback to vector SDPA.

This approach however does not support custom array masks, we can add more options to reduce the uses of array masks but cuDNN does not support left padding masks, so integrating with BatchKVCache would require extra efforts.

Comment on lines +78 to +89
auto is_slice = [](const array& kv) {
// When called during graph building the strides is not available.
if (kv.status() != array::evaluated) {
return (kv.has_primitive() && typeid(kv.primitive()) == typeid(Slice)) ||
(kv.shape(2) % kv_cache_step == 0);
}
// Get pre-sliced sequence length from strides, and check if the buffer
// belongs to a contiguous kv cache.
int64_t T_kv = kv.strides(1) / kv.strides(2);
if (kv.size() / kv.shape(2) * T_kv != kv.buffer_size() / kv.itemsize()) {
return false;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we assume that the cudnn case is a strict subset of what the vector fallback can handle then this is not really necessary. We could just use this function during low level dispatch (and actually use the strides which is a lot more robust).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictly speaking cuDNN supports more head dims, but yeah it should be more robust just making it a subset of vector sdpa. I have updated the code.

@awni
Copy link
Member

awni commented Feb 9, 2026

Very nice improvement. On B200 it's more than 2x speedup for long context:

mlx_lm.benchmark --model Qwen/Qwen3-4B-Instruct-2507 --prompt 64000 --g 512 -n 3

Pre: 72.998 Tok/s
Post: 208.822 Tok/s

I wonder if we should defensively increase the default cache size for the forward op? If you do long generations it will start to thrash after about 16k tokens which is not that many. Maybe we should make it 256 or something?

@zcbenz
Copy link
Collaborator Author

zcbenz commented Feb 9, 2026

Yeah for decoding we definitely need a larger cache size.

@zcbenz zcbenz merged commit 54bb3ee into ml-explore:main Feb 10, 2026
16 checks passed
@zcbenz zcbenz deleted the cuda-sdpa-sliced branch February 10, 2026 00:15
@awni awni mentioned this pull request Feb 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants