[Pipelines] Keep Gemma 4 NVFP4 attention activations in the compute dtype#4
Draft
msaelices wants to merge 1 commit into
Draft
[Pipelines] Keep Gemma 4 NVFP4 attention activations in the compute dtype#4msaelices wants to merge 1 commit into
msaelices wants to merge 1 commit into
Conversation
…type BEGIN_PUBLIC [Pipelines] Keep Gemma 4 NVFP4 attention activations in the compute dtype For NVFP4 Gemma 4 checkpoints the attention layer is built with dtype set to the packed uint8 weight dtype, and that dtype was being reused for the QK-norm RMSNorm scales and the flash-attention output. The QK norms are bf16 scale vectors in the checkpoint (so loading failed with an expected=uint8 / actual=bfloat16 dtype error), and the flash-attention kernel requires Q/K/V/output to share a dtype (the uint8 output tripped its constraint on pre-Blackwell GPUs). Thread an unquantized_dtype through Gemma4Attention and use it for the QK/V norms and the flash-attention output dtype, while the q/k/v/o projections keep the quantized weight dtype. The parameter defaults to None and falls back to the existing dtype, so non-quantized attention is unchanged. Validated end to end on an L40S (pre-Blackwell): the NVFP4 Gemma 4 31B model loads, compiles, and generates coherent text. END_PUBLIC Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]> Claude-Session: https://claude.ai/code/session_01LAekHUXqk6v5nLkA7R8Vxx
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Stacked on modular#6668. For NVFP4 Gemma 4 checkpoints the attention layer is built with
dtypeset to the packed uint8 weight dtype, and that dtype was reused for the QK-norm RMSNorm scales and the flash-attention output. Two failures result:weight 'layers.0.self_attn.q_norm.weight' had different dtype (expected=uint8, actual=bfloat16).Q == K == V == outputdtype; a uint8 output tripsmha.mojoconstraint failedon pre-Blackwell GPUs.Thread an
unquantized_dtypethroughGemma4Attentionand use it for the QK/V norms and the flash-attentionoutput_dtype; the q/k/v/o projections keep the quantized weight dtype. The param defaults toNoneand falls back to the existingdtype, so non-quantized attention is unchanged (no regression risk for the bf16 path).This is a pre-existing Gemma 4 NVFP4 defect, independent of the compressed-tensors work (modular#6699) — it also affects the modelopt checkpoints and the native (Blackwell) NVFP4 path, since the
q_normload failure is hardware-independent.Validation
End to end on an L40S (pre-Blackwell, SM 8.9):
RedHatAI/gemma-4-31B-it-NVFP4loads, compiles, and generates coherent text — "The capital of France is Paris." at ~28 tok/s, 9.1 GiB KV cache.Notes for reviewers
nvfp4-pre-blackwell-fallback([Kernels][Pipelines] NVFP4 fallback for pre-Blackwell NVIDIA GPUs via fused dequant-GEMV modular/modular#6668) so the stack is testable on the only hardware where it was exercised; retarget tomainonce [Kernels][Pipelines] NVFP4 fallback for pre-Blackwell NVIDIA GPUs via fused dequant-GEMV modular/modular#6668 lands.gemma4/layers/attention.py+gemma4/gemma4.py(21 insertions).🤖 Generated with Claude Code