-
Notifications
You must be signed in to change notification settings - Fork 13.1k
CUDA: refactor FA support/selection code #15454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: refactor FA support/selection code #15454
Conversation
Not crashing is definitely correct, but I agree that it is usually not good to enable FA when it requires fallback to CPU. I think we can add a new FA mode called "auto" or similar that behaves such as:
So FA would only be enabled when it is supported, and trying to use V quantization when not supported would result in an error, rather than fallback to CPU with terrible performance. |
Sounds good. So for the CLI interface, something like |
Essentially yes, but the logic can be more granular than that. Flash attention can be enabled at a per-layer level, so if you have device A that supports FA, and device B that doesn't, you can still use FA in the layers offloaded to device A. |
Can I just double check that using any K-cache is quantization can still be used independently of FA on/off, only V-cache quantization requires it, correct? |
Yes. |
I should also clarify: when I was talking about a "functional change" I meant it in the ggml backend context. llama.cpp still checks whether FA is enabled when the V cache is quantized and raises an error if it's not. |
I noticed that rocwmma selection code seems to have been removed in this PR, and I don't see it referenced anywhere else.
is that intentional? |
rocWMMA support is handled via |
Worth a tweet or documentation or something of what llama-server should set for -fa on completion of this... |
This PR refactors and deduplicates the CUDA code for determining which kernel to run. One of the possible return values of
ggml_cuda_get_best_fattn_kernel
is that there is no suitable kernel, this is re-used for determining whether the CUDA backend can run the ggml op. This PR fixes issues with e.g. Stories 260k which crashed due to unexpected head sizes, with the new code all head sizes that are not explicitly listed are treated as unsupported.One functional change vs. master is that trying to use a quantized KV cache without
GGML_CUDA_FA_ALL_QUANTS
no longer results in a crash but instead falls back to the CPU code. This made the code simpler but I'm not 100% sure whether it's the right thing to do in terms of usability.