[Kernels][Pipelines] NVFP4 fallback for pre-Blackwell NVIDIA GPUs via fused dequant-GEMV#6668
[Kernels][Pipelines] NVFP4 fallback for pre-Blackwell NVIDIA GPUs via fused dequant-GEMV#6668msaelices wants to merge 11 commits into
Conversation
|
Perf follow-up: profiling the fallback on an A10G found the software LUT was the bottleneck — |
|
Implemented the Marlin-style fused dequant-GEMV (4831155): one warp per output column decodes packed E2M1 in registers and dot-products against BF16 activations — the dequantized weight is never materialized, so per-token DRAM traffic drops ~9x and the per-forward transient buffers disappear. A10G, gemma-4 12B NVFP4, bs=1: per-linear 2.7 ms -> 0.133 ms through the graph executor (20x); end-to-end decode 0.38 -> 4.98 tok/s cumulative with the E2M1 LUT fix (vLLM nightly's Marlin on the same checkpoint: ~16 tok/s eager — remaining gap is the QKV path, M_TILE prefill scaling, and per-op dispatch). Correctness: scalar CPU reference in |
|
QKV now also routes through the fused GEMV (split + Final numbers on the A10G (gemma-4 12B NVFP4, bs=1, warm decode): 31.5 tok/s, vs 0.38 at the start of this work (83x) and ~16 tok/s for vLLM nightly's Marlin on the same checkpoint/GPU. 58-token prefill + 150-token generation sustains 29.6 tok/s. Output stays token-identical to the dequant path. Earlier per-request numbers in this thread (~5 tok/s) were first-request measurements that included device graph capture warmup. |
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds a pre-Blackwell (pre-SM100) NVFP4 execution path by introducing fused NVFP4 dequant-GEMV and dequant kernels, then wiring them into dense and MoE inference to enable serving NVFP4 checkpoints on older NVIDIA GPUs.
Changes:
- Introduces NVFP4 fused dequant-GEMV (
mo.gemv.nvfp4) and NVFP4 dequant (mo.dequant.nvfp4) kernels + graph-compiler registrations. - Routes NVFP4 matmul / fused QKV paths to the new fallback on pre-SM100 NVIDIA GPUs, including explicit KV-cache stores.
- Adds an MoE NVFP4 “dequant-to-BF16 then BF16 grouped matmul” fallback strategy and GPU smoke tests for the new kernels.
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 10 comments.
Show a summary per file
| File | Description |
|---|---|
| max/python/max/nn/quant_ops.py | Adds pre-SM100 NVIDIA NVFP4 fallback paths for float4 matmul and fused QKV, including KV-cache writes. |
| max/python/max/nn/moe/quant_strategy.py | Adds Nvfp4DequantStrategy that dequantizes NVFP4 expert weights to BF16 and uses ragged grouped matmul. |
| max/python/max/nn/moe/moe_fp8.py | Selects the NVFP4 dequant fallback strategy on pre-SM100 NVIDIA; blocks EP NVFP4 on those GPUs. |
| max/python/max/nn/moe/init.py | Exports the new NVFP4 dequant strategy. |
| max/python/max/nn/kernels.py | Adds Python wrappers for nvfp4_gemv, nvfp4_dequant, and a pre-SM100 NVIDIA detection helper. |
| max/kernels/src/linalg/nvfp4_gemv.mojo | Implements the fused NVFP4 dequant-GEMV kernel. |
| max/kernels/src/linalg/mxfp4_dequant.mojo | Generalizes FP4 dequant into a shared kernel and adds NVFP4 dequant entrypoint. |
| max/kernels/src/linalg/fp4_utils.mojo | Reworks FP4 E2M1 decode to avoid SIMD spills and dramatically speed up dequant. |
| max/kernels/src/graph_compiler/builtin_kernels/quantization.mojo | Registers mo.dequant.nvfp4 and mo.gemv.nvfp4 builtins for graph execution. |
| max/kernels/test/gpu/linalg/test_nvfp4_gemv_smoke.mojo | Adds correctness + quick throughput smoke test for NVFP4 dequant-GEMV. |
| max/kernels/test/gpu/linalg/test_nvfp4_dequant_smoke.mojo | Adds smoke test for NVFP4 dequant with multiple dtypes/shapes. |
| max/kernels/test/gpu/linalg/BUILD.bazel | Wires the new GPU tests into Bazel with platform constraints. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _is_pre_sm100_nvidia_gpu() -> bool: | ||
| """Checks for an NVIDIA GPU without native FP4 matmul (pre-Blackwell). | ||
|
|
||
| Deliberately False on AMD and Apple GPUs: the NVFP4 dequant fallback is | ||
| only validated on NVIDIA, so other vendors keep their previous behavior | ||
| (erroring at the Blackwell-only kernels). | ||
| """ | ||
| try: | ||
| arch = accelerator_architecture_name() | ||
| except Exception: | ||
| return False | ||
| return arch.startswith("sm_") and not arch.startswith("sm_10") |
| def nvfp4_dequant( | ||
| packed_weights: TensorValue, | ||
| scales: TensorValue, | ||
| out_type: DType = DType.bfloat16, | ||
| ) -> TensorValue: | ||
| """Dequantizes NVFP4 packed weights to BF16 or FP8 on GPU. |
| def _dequant_weight_nvfp4( | ||
| weight: TensorValue, | ||
| weight_scale: TensorValue, | ||
| weight_scale_2: TensorValue, | ||
| scales_pre_interleaved: bool, | ||
| ) -> TensorValue: |
| if scales_pre_interleaved: | ||
| raise ValueError( | ||
| "NVFP4 checkpoints with pre-interleaved (TCGEN 5D) scales are" | ||
| " not supported on pre-Blackwell GPUs: the dequant fallback" | ||
| " needs the flat [N, K//16] scale layout." | ||
| ) | ||
| scales_f32 = weight_scale.to(weight.device).cast( | ||
| DType.float32 | ||
| ) * weight_scale_2.to(weight.device) | ||
| return nvfp4_dequant(weight, scales_f32, out_type=DType.bfloat16) |
| Returns: | ||
| The output tensor in bf16. | ||
| """ | ||
| if _is_pre_sm100_nvidia_gpu(): |
| if scales_pre_interleaved: | ||
| raise ValueError( | ||
| "NVFP4 checkpoints with pre-interleaved (TCGEN 5D) scales" | ||
| " are not supported on pre-Blackwell GPUs" | ||
| ) | ||
| scales_f32 = weight_scale.to(weight.device).cast( | ||
| DType.float32 | ||
| ) * weight_scale_2.to(weight.device) |
| # Pre-multiply the per-tensor scale into the block scales: | ||
| # [E, 1, 1] * [E, N, K//16] -> float32 [E, N, K//16]. | ||
| scales_f32 = weight_scales.cast(DType.float32) * ops.reshape( | ||
| expert_scales.to(weight_scales.device).cast(DType.float32), | ||
| [expert_scales.shape[0], 1, 1], | ||
| ) | ||
| # Activations are BF16 on this path by construction; self.dtype is | ||
| # the packed storage dtype (uint8) and must not be used here. | ||
| dequanted = nvfp4_dequant(weight, scales_f32, out_type=DType.bfloat16) | ||
|
|
||
| return grouped_matmul_ragged( | ||
| hidden, | ||
| dequanted, | ||
| expert_start, | ||
| expert_ids, | ||
| usage_stats, | ||
| ) |
| if usage_stats.device.is_gpu(): | ||
| usage_stats = usage_stats.to(DeviceRef.CPU()) |
| # Correctness + quick throughput check for the fused NVFP4 dequant-GEMV. | ||
| # ===----------------------------------------------------------------------=== # | ||
|
|
||
| from std.math import ceildiv |
| from std.time import perf_counter_ns | ||
|
|
||
| from std.gpu.host import DeviceContext | ||
| from layout import Coord, TileTensor, row_major |
…GPUs Serve NVFP4 checkpoints on GPUs without native FP4 matmul support (sm_8x/sm_90) by dequantizing weights to BF16 with a software-LUT kernel and routing through the regular BF16 matmuls, mirroring the existing MXFP4-on-H100 dequant approach. Fixes modular#6667. Kernels: - Generalize the MXFP4 dequant kernel into a format-agnostic _dequant_fp4 (block-scale dtype and SF_VECTOR_SIZE were already parameters) and add a dequant_nvfp4 entry point: E4M3 or pre-multiplied float32 scales, SF_VECTOR_SIZE=16. The E2M1 decode is a software LUT, so it runs on any GPU. Register as mo.dequant.nvfp4. - Smoke test (CPU reference vs GPU) covering both scale dtypes, several shapes and an FP8 output; unlisted in _EXTRA_CONSTRAINTS so it runs on any GPU. Python (active only when not on SM100+): - kernels.nvfp4_dequant() wrapper (rank 2/3, flat [.., K//16] scales). - quant_ops: _matmul_float4 and the NVFP4 fused-QKV case dequantize the weight (weight_scale_2 pre-multiplied into float32 scales) and use the unquantized matmul / fused_qkv_ragged_matmul. Pre-interleaved (TCGEN 5D) scale checkpoints are rejected with a clear error. - MoE: new Nvfp4DequantStrategy — activations stay BF16 (no dynamic quantization), expert weights dequantized per matmul, BF16 grouped_matmul_ragged. Selected in MoEQuantized._strategy(); expert epilogue scales carry only weight_scale_2 (no activation input-scale factor, since activations are never quantized on this path). Blackwell behavior is unchanged: every fallback is gated on not _is_sm10x_gpu(). END_PUBLIC Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
Two fixes found while validating end to end on an A10G (sm_86):
- _dequant_weight_nvfp4: move weight_scale_2 (a CPU rank-0 weight) to the
weight's device before multiplying into the block scales ("Input values
must be on the same device").
- Nvfp4DequantStrategy.grouped_matmul: dequantize to bfloat16 explicitly;
self.dtype is the packed storage dtype (uint8) on this path and tripped
the kernel's output-dtype constraint at graph compile time.
Validated: nvidia/Llama-3.1-8B-Instruct-NVFP4 (dense, modelopt) serves on
the A10G and generates coherent text. nvidia/Gemma-4-26B-A4B-NVFP4 (MoE)
compiles and loads weights through the fallback, but 18.8 GiB of packed
weights exhaust the 24 GiB card during setup (allocator reports 18.73 GiB
in use, largest free block 359 MB) -- a checkpoint-size limit, not a
fallback defect.
Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
- moe/__init__.py: restore isort order in the import block and __all__ (ruff I001/RUF022 with the CI-pinned ruff). - moe_fp8.py: raise an explicit NotImplementedError when selecting the NVFP4 dequant strategy with an active expert-parallel batch manager -- EP dispatch quantizes activations to FP4, so the fallback would otherwise die later in grouped-matmul shape validation with a confusing error. - Smoke test: vary block scales per index so wrong scale indexing cannot pass unnoticed (previously uniform scales masked that class of bug); re-run on the A10G, all cases still PASS with max_err=0.0. - BUILD.bazel: constrain the smoke test to nvidia_gpu|amd_gpu (mirrors the mxfp4 smoke test) instead of enrolling untested Apple CI. Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC [Kernels][GPU] Branchless E2M1 decode in cast_uint_to_fp4e2m1 The software LUT indexed a SIMD[float32, 16] table with a runtime index, which spills the vector to GPU local memory on every element. That made the FP4 dequant kernels local-memory bound: 12 GB/s effective on an A10G (sm=100%, dram=99% from spill traffic). Replacing the lookup with branchless bit arithmetic (exp/mantissa decode + select) makes the kernel memory-bound as intended: 422 GB/s on the same device, a 35x kernel speedup, 2.5x end-to-end on the runtime-dequant serving path. Verified with test_nvfp4_dequant_smoke (bit-exact vs CPU reference). END_PUBLIC Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC
[Kernels][GPU] Fused NVFP4 dequant-GEMV for the pre-Blackwell fallback
Marlin-style fused kernel: one warp per output column decodes packed
E2M1 in registers (branchless, no LUT) and dot-products against the
BF16 activations — the dequantized weight is never materialized in
global memory. Per-token DRAM traffic drops from ~4.5 B/element
(dequantize-then-matmul: bf16 write + re-read + packed read) to the
packed ~0.5 B/element, and no per-forward transient buffers exist.
Activation rows are tiled (M_TILE=4) so a P-token prefill re-reads the
packed weight ceil(P/4) times; decode (M=1) is the optimized case.
_matmul_float4's pre-Blackwell path (o_proj and dense MLP projections)
now routes through the fused op. The QKV path keeps dequantize-then-
matmul for now (fused_qkv_ragged_matmul owns the KV-cache write).
Measured on an NVIDIA A10G (gemma-4 12B NVFP4, bs=1):
- kernel: 0.126 ms for [15360x3840] at M=1 (350 GB/s of packed reads)
- per-linear through the graph executor: 2.7 ms -> 0.133 ms (20x)
- end-to-end decode: 0.94 -> 4.98 tok/s (5.3x; 13x vs the original
fallback). Output is token-identical to the dequant path.
Correctness verified against a scalar CPU reference for M={1,3,5},
including non-multiple-of-warp N (test_nvfp4_gemv_smoke).
END_PUBLIC
Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC [Kernels] Use CamelCase names for the new NVFP4 op structs Rename Struct_dequant_nvfp4 -> NVFP4Dequant and Struct_gemv_nvfp4 -> NVFP4Gemv. The Struct_ prefix is the legacy machine-generated style; the hand-written registrations in quantization.mojo use CamelCase. Pre-existing Struct_* registrations are left untouched. END_PUBLIC Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC [Kernels] Route the pre-Blackwell QKV projection through the fused GEMV The QKV path was the last consumer of dequantize-then-matmul on the pre-Blackwell NVFP4 fallback. Compute QKV with the fused dequant-GEMV, then split and store K/V into the paged cache via store_k_cache_ragged and store_v_cache_ragged (fused_qkv_ragged_matmul owns that write on the unquantized path, which is why this stays decomposed). K-equals-V layers (Gemma 4 full attention packs only Q and K rows in wqkv) store the K projection into both caches. With this no BF16 copy of any quantized weight is ever materialized. End-to-end on an NVIDIA A10G (gemma-4 12B NVFP4, bs=1), warm decode: 31.5 tok/s vs 0.38 at the start of the fallback work (83x), with token-identical output. For reference, vLLM nightly's Marlin path serves the same checkpoint at ~16 tok/s eager on the same GPU. END_PUBLIC Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
… fallback BEGIN_PUBLIC [Kernels][Pipelines] Address adversarial-review findings on the fused fallback - Gate the NVFP4 fallback on pre-SM100 *NVIDIA* explicitly (_is_pre_sm100_nvidia_gpu): _is_sm10x_gpu() is False on AMD/Apple, which silently routed those vendors onto an unvalidated dequant path; they now keep their previous behavior. - moe_fp8: fix a comment describing the Blackwell scale-epilogue mechanism on the fallback branch (the per-tensor scale is folded into the dequant block scales; the BF16 grouped matmul has no epilogue). - Soften the dequant docstrings: raw float8_e4m3fn block scales are accepted but only exercised by tests today (the modelopt parser always provides a per-tensor scale). END_PUBLIC Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC [Kernels] Address cycle-2 adversarial-review findings on the fused GEMV - nvfp4_gemv: all threads now enter the PDL region (the early return for tail-block threads moved inside it) and the dead pdl_level parameter is gone (it was accepted but never forwarded). - quant_ops: drop _dequant_weight_nvfp4 (no callers since the fused reroute), refresh _matmul_float4's docstring, and validate the wqkv layout explicitly instead of silently aliasing K as V for anything that is not exactly Q+K+V or Q+K rows. - Stale software-LUT wording updated after the branchless decode. - mojo format on the two new kernel files; smoke test now covers >1 K-chunk per lane (K=2048) and the production K=3840. END_PUBLIC Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC [Kernels] Address cycle-3 adversarial-review findings - Replace the NVFP4 branch of quantized_fused_qkv_matmul with a clear error: no in-tree model reaches it (callers assert on input_scale first, and NVFP4 QKV routes through StackedLinear/_matmul_float4, which is already fused). _dequant_weight_nvfp4 went unused with it. - Add the Apache license header to test_nvfp4_gemv_smoke.mojo. END_PUBLIC Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
1b7eb1d to
c781b87
Compare
…4 12B) and serve them text-only (#88920) [External] [Pipelines] Load public gemma4_unified checkpoints (Gemma 4 12B) and serve them text-only ## Summary Makes the public `gemma4_unified` checkpoints (the Gemma 4 12B line: `google/gemma-4-12b-it` and derived quants) loadable by MAX, serving them **text-only**. Fixes the first two blockers of #6669 and implements the text-only path proposed there. - Register an `AutoConfig` shim for `model_type: "gemma4_unified"` (transformers < 5.5 path) and `gemma4_unified_arch`, a `dataclasses.replace` alias of `gemma4_arch` under the public `Gemma4UnifiedForConditionalGeneration` name — these checkpoints carry the regular Gemma 4 language layout with **no bundled MTP draft weights**, so they belong to `gemma4`, not `unified_mtp_gemma4`. - Text-only serving: the unified checkpoints use a lightweight `vision_embedder` with a different schema than `Gemma4VisionModel`; detect it, set `vision_config=None`, skip building/loading the vision graph, and return 0 vision-cache bytes. - `finalize` now passes the real `model.language_model.` prefixes to `parse_quant_config`, so per-layer quantized/ignored classification stops looking up `model.layers.*` keys that never exist (ignore-listed BF16 attention in modelopt 12B quants was never recognized as ignored). - Honor `quant_config.attn_quantized_layers` per layer instead of assuming attention is always BF16 under NVFP4 (true for nvidia first-party checkpoints, false for the community 12B quants), and drop modelopt `k_scale`/`v_scale` KV-cache scales in the weight adapter. Self-review (adversarial pass) caught and fixed an inference-path crash before any human review: `prepare_initial_token_inputs` read `vision_config.pooling_kernel_size` unconditionally, so text-only serving would have failed on the first prefill; image requests now get a clear rejection and the unified alias declares TEXT-only modalities. ## Verification (NVIDIA A10G 24 GB; bazel-built serve) Each fix was driven by serving real checkpoints; the chain `config → arch resolution → graph build → weight load` now completes for `google/gemma-4-12b-it`-family checkpoints where each stage failed before (`gemma4_unified` not recognized → no matching architecture → vision `AttributeError` → o_proj shape mismatch → unexpected `k_scale`/`v_scale`). **Full text-generation E2E verified on the A10G** (`berkerdooo/gemma-4-12B-it-NVFP4`, attention quantized): correct, coherent EN/ES completions via `/v1/chat/completions`. With #6668's fused dequant-GEMV (no BF16 weight materialization), the 12B sustains **31.5 tok/s warm decode at batch 1 on the 24 GB card** (measured on a branch combining this PR with #6668) — ~2x vLLM nightly's Marlin path on the same checkpoint and GPU — with token-identical output across the fallback's kernel revisions. Note: on a 24 GB card the stock activation-memory estimate (~15 GiB headroom, sized for the vision configs) still blocks startup — the measurement lowered it locally; right-sizing that estimate for text-only unified configs is intentionally left out of this PR pending maintainer guidance. Related: #6669 (blocker 3 — implementing the unified `vision_embedder` — remains open), #6668, #6665. Assisted-by: AI ORIGINAL_AUTHOR=Manuel Saelices <[email protected]> ORIGINAL_USER=@msaelices --------- Co-authored-by: Manuel Saelices <[email protected]> Co-authored-by: Kathy Wu <[email protected]> Closes #6670 MODULAR_ORIG_COMMIT_REV_ID: baef6c1c5c38e4923a55ece1caf8d6d2e42f9a98
Summary
Serve NVFP4 checkpoints on GPUs without native FP4 matmul (pre-Blackwell NVIDIA). Fixes #6667.
The fallback is a Marlin-style fused dequant-GEMV: packed E2M1 is decoded in registers inside the matmul kernel, so no BF16 weight copy is ever materialized and per-token DRAM traffic is the packed bytes (~0.5 B/element vs ~4.5 for dequantize-then-matmul).
nvfp4_gemv/mo.gemv.nvfp4: fused kernel behind_matmul_float4, which covers every NVFP4 linear in-tree (dense/MLP and the StackedLinear QKV projections). The unreachable NVFP4 branch ofquantized_fused_qkv_matmulnow raises a clear error instead of carrying dead fallback code.dequant_nvfp4/mo.dequant.nvfp4: NVFP4 generalization of the MXFP4 dequant kernel, kept for the MoENvfp4DequantStrategy.cast_uint_to_fp4e2m1rewritten branchless: the old LUT's runtime SIMD indexing spilled to local memory on every element (35x slowdown).Performance (NVIDIA A10G, sm_86, 24 GB; gemma-4 12B NVFP4, batch 1)
Measured with device graph capture on (serving config, not part of this PR); first request after startup pays ~17 s of capture warmup. A 58-token prefill plus 150-token generation sustains 29.6 tok/s. 31.7 ms/token is approximately the bandwidth bound for reading the packed weights, i.e. the fused kernel leaves little on the table at batch 1. Model output is token-identical to the dequantize-then-matmul path at temperature 0.
Verification
CPU-reference smoke tests for both kernels (
test_nvfp4_gemv_smoke,test_nvfp4_dequant_smoke);nvidia/Llama-3.1-8B-Instruct-NVFP4andberkerdooo/gemma-4-12B-it-NVFP4(with #6670) serve E2E with coherent output.Out of scope (tracked in #6667): compressed-tensors NVFP4 parsing, Gemma 4 activation-memory headroom, tensor-core GEMM tile for large-batch prefill.
Assisted-by: AI