Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[Pipelines] Keep Gemma 4 NVFP4 attention activations in the compute dtype#4

Draft
msaelices wants to merge 1 commit into
nvfp4-pre-blackwell-fallbackfrom
gemma4-nvfp4-attn-activation-dtype
Draft

[Pipelines] Keep Gemma 4 NVFP4 attention activations in the compute dtype#4
msaelices wants to merge 1 commit into
nvfp4-pre-blackwell-fallbackfrom
gemma4-nvfp4-attn-activation-dtype

Conversation

@msaelices

Copy link
Copy Markdown
Owner

Summary

Stacked on modular#6668. For NVFP4 Gemma 4 checkpoints the attention layer is built with dtype set to the packed uint8 weight dtype, and that dtype was reused for the QK-norm RMSNorm scales and the flash-attention output. Two failures result:

  • Load: the QK norms are bf16 scale vectors in the checkpoint → weight 'layers.0.self_attn.q_norm.weight' had different dtype (expected=uint8, actual=bfloat16).
  • Compile: the flash-attention kernel requires Q == K == V == output dtype; a uint8 output trips mha.mojo constraint failed on pre-Blackwell GPUs.

Thread an unquantized_dtype through Gemma4Attention and use it for the QK/V norms and the flash-attention output_dtype; the q/k/v/o projections keep the quantized weight dtype. The param defaults to None and falls back to the existing dtype, so non-quantized attention is unchanged (no regression risk for the bf16 path).

This is a pre-existing Gemma 4 NVFP4 defect, independent of the compressed-tensors work (modular#6699) — it also affects the modelopt checkpoints and the native (Blackwell) NVFP4 path, since the q_norm load failure is hardware-independent.

Validation

End to end on an L40S (pre-Blackwell, SM 8.9): RedHatAI/gemma-4-31B-it-NVFP4 loads, compiles, and generates coherent text — "The capital of France is Paris." at ~28 tok/s, 9.1 GiB KV cache.

Notes for reviewers

🤖 Generated with Claude Code

…type

BEGIN_PUBLIC
[Pipelines] Keep Gemma 4 NVFP4 attention activations in the compute dtype

For NVFP4 Gemma 4 checkpoints the attention layer is built with dtype set to the
packed uint8 weight dtype, and that dtype was being reused for the QK-norm
RMSNorm scales and the flash-attention output. The QK norms are bf16 scale
vectors in the checkpoint (so loading failed with an expected=uint8 /
actual=bfloat16 dtype error), and the flash-attention kernel requires
Q/K/V/output to share a dtype (the uint8 output tripped its constraint on
pre-Blackwell GPUs).

Thread an unquantized_dtype through Gemma4Attention and use it for the QK/V
norms and the flash-attention output dtype, while the q/k/v/o projections keep
the quantized weight dtype. The parameter defaults to None and falls back to the
existing dtype, so non-quantized attention is unchanged.

Validated end to end on an L40S (pre-Blackwell): the NVFP4 Gemma 4 31B model
loads, compiles, and generates coherent text.
END_PUBLIC

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Claude-Session: https://claude.ai/code/session_01LAekHUXqk6v5nLkA7R8Vxx
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant