[CUDA] Use cuDNN SDPA for decoding when using fixed-size KV cache#3113
[CUDA] Use cuDNN SDPA for decoding when using fixed-size KV cache#3113zcbenz merged 1 commit intoml-explore:mainfrom
Conversation
| auto is_slice = [](const array& kv) { | ||
| // When called during graph building the strides is not available. | ||
| if (kv.status() != array::evaluated) { | ||
| return (kv.has_primitive() && typeid(kv.primitive()) == typeid(Slice)) || | ||
| (kv.shape(2) % kv_cache_step == 0); | ||
| } | ||
| // Get pre-sliced sequence length from strides, and check if the buffer | ||
| // belongs to a contiguous kv cache. | ||
| int64_t T_kv = kv.strides(1) / kv.strides(2); | ||
| if (kv.size() / kv.shape(2) * T_kv != kv.buffer_size() / kv.itemsize()) { | ||
| return false; | ||
| } |
There was a problem hiding this comment.
If we assume that the cudnn case is a strict subset of what the vector fallback can handle then this is not really necessary. We could just use this function during low level dispatch (and actually use the strides which is a lot more robust).
There was a problem hiding this comment.
Strictly speaking cuDNN supports more head dims, but yeah it should be more robust just making it a subset of vector sdpa. I have updated the code.
0af3abe to
6b0ff04
Compare
|
Very nice improvement. On B200 it's more than 2x speedup for long context: Pre: 72.998 Tok/s I wonder if we should defensively increase the default cache size for the forward op? If you do long generations it will start to thrash after about 16k tokens which is not that many. Maybe we should make it 256 or something? |
|
Yeah for decoding we definitely need a larger cache size. |
6b0ff04 to
59dcdc1
Compare
The kv cache in mlx-lm has fixed-size and keys/values passed to
fast.sdpaare slices, using this information we can create cuDNN graphs with fixed sequence length and use padding masks to set the actual sequence lengths.This makes decoding (
T_q == 1) 30%~100% faster for large sequence lengths. For small sequence lengths the cuDNN SDPA is still faster but the overhead of creating cuDNN graphs would eliminate the advantage so we fallback to vector SDPA.This approach however does not support custom array masks, we can add more options to reduce the uses of array masks but cuDNN does not support left padding masks, so integrating with
BatchKVCachewould require extra efforts.