[Kernels][GPU] NVFP4 GEMM on the multistage qGEMM skeleton (~1.8-2.9x serving vs bespoke)#6708
[Kernels][GPU] NVFP4 GEMM on the multistage qGEMM skeleton (~1.8-2.9x serving vs bespoke)#6708msaelices wants to merge 56 commits into
Conversation
…GPUs Serve NVFP4 checkpoints on GPUs without native FP4 matmul support (sm_8x/sm_90) by dequantizing weights to BF16 with a software-LUT kernel and routing through the regular BF16 matmuls, mirroring the existing MXFP4-on-H100 dequant approach. Fixes modular#6667. Kernels: - Generalize the MXFP4 dequant kernel into a format-agnostic _dequant_fp4 (block-scale dtype and SF_VECTOR_SIZE were already parameters) and add a dequant_nvfp4 entry point: E4M3 or pre-multiplied float32 scales, SF_VECTOR_SIZE=16. The E2M1 decode is a software LUT, so it runs on any GPU. Register as mo.dequant.nvfp4. - Smoke test (CPU reference vs GPU) covering both scale dtypes, several shapes and an FP8 output; unlisted in _EXTRA_CONSTRAINTS so it runs on any GPU. Python (active only when not on SM100+): - kernels.nvfp4_dequant() wrapper (rank 2/3, flat [.., K//16] scales). - quant_ops: _matmul_float4 and the NVFP4 fused-QKV case dequantize the weight (weight_scale_2 pre-multiplied into float32 scales) and use the unquantized matmul / fused_qkv_ragged_matmul. Pre-interleaved (TCGEN 5D) scale checkpoints are rejected with a clear error. - MoE: new Nvfp4DequantStrategy — activations stay BF16 (no dynamic quantization), expert weights dequantized per matmul, BF16 grouped_matmul_ragged. Selected in MoEQuantized._strategy(); expert epilogue scales carry only weight_scale_2 (no activation input-scale factor, since activations are never quantized on this path). Blackwell behavior is unchanged: every fallback is gated on not _is_sm10x_gpu(). END_PUBLIC Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
Two fixes found while validating end to end on an A10G (sm_86):
- _dequant_weight_nvfp4: move weight_scale_2 (a CPU rank-0 weight) to the
weight's device before multiplying into the block scales ("Input values
must be on the same device").
- Nvfp4DequantStrategy.grouped_matmul: dequantize to bfloat16 explicitly;
self.dtype is the packed storage dtype (uint8) on this path and tripped
the kernel's output-dtype constraint at graph compile time.
Validated: nvidia/Llama-3.1-8B-Instruct-NVFP4 (dense, modelopt) serves on
the A10G and generates coherent text. nvidia/Gemma-4-26B-A4B-NVFP4 (MoE)
compiles and loads weights through the fallback, but 18.8 GiB of packed
weights exhaust the 24 GiB card during setup (allocator reports 18.73 GiB
in use, largest free block 359 MB) -- a checkpoint-size limit, not a
fallback defect.
Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
- moe/__init__.py: restore isort order in the import block and __all__ (ruff I001/RUF022 with the CI-pinned ruff). - moe_fp8.py: raise an explicit NotImplementedError when selecting the NVFP4 dequant strategy with an active expert-parallel batch manager -- EP dispatch quantizes activations to FP4, so the fallback would otherwise die later in grouped-matmul shape validation with a confusing error. - Smoke test: vary block scales per index so wrong scale indexing cannot pass unnoticed (previously uniform scales masked that class of bug); re-run on the A10G, all cases still PASS with max_err=0.0. - BUILD.bazel: constrain the smoke test to nvidia_gpu|amd_gpu (mirrors the mxfp4 smoke test) instead of enrolling untested Apple CI. Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC [Kernels][GPU] Branchless E2M1 decode in cast_uint_to_fp4e2m1 The software LUT indexed a SIMD[float32, 16] table with a runtime index, which spills the vector to GPU local memory on every element. That made the FP4 dequant kernels local-memory bound: 12 GB/s effective on an A10G (sm=100%, dram=99% from spill traffic). Replacing the lookup with branchless bit arithmetic (exp/mantissa decode + select) makes the kernel memory-bound as intended: 422 GB/s on the same device, a 35x kernel speedup, 2.5x end-to-end on the runtime-dequant serving path. Verified with test_nvfp4_dequant_smoke (bit-exact vs CPU reference). END_PUBLIC Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC
[Kernels][GPU] Fused NVFP4 dequant-GEMV for the pre-Blackwell fallback
Marlin-style fused kernel: one warp per output column decodes packed
E2M1 in registers (branchless, no LUT) and dot-products against the
BF16 activations — the dequantized weight is never materialized in
global memory. Per-token DRAM traffic drops from ~4.5 B/element
(dequantize-then-matmul: bf16 write + re-read + packed read) to the
packed ~0.5 B/element, and no per-forward transient buffers exist.
Activation rows are tiled (M_TILE=4) so a P-token prefill re-reads the
packed weight ceil(P/4) times; decode (M=1) is the optimized case.
_matmul_float4's pre-Blackwell path (o_proj and dense MLP projections)
now routes through the fused op. The QKV path keeps dequantize-then-
matmul for now (fused_qkv_ragged_matmul owns the KV-cache write).
Measured on an NVIDIA A10G (gemma-4 12B NVFP4, bs=1):
- kernel: 0.126 ms for [15360x3840] at M=1 (350 GB/s of packed reads)
- per-linear through the graph executor: 2.7 ms -> 0.133 ms (20x)
- end-to-end decode: 0.94 -> 4.98 tok/s (5.3x; 13x vs the original
fallback). Output is token-identical to the dequant path.
Correctness verified against a scalar CPU reference for M={1,3,5},
including non-multiple-of-warp N (test_nvfp4_gemv_smoke).
END_PUBLIC
Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC [Kernels] Use CamelCase names for the new NVFP4 op structs Rename Struct_dequant_nvfp4 -> NVFP4Dequant and Struct_gemv_nvfp4 -> NVFP4Gemv. The Struct_ prefix is the legacy machine-generated style; the hand-written registrations in quantization.mojo use CamelCase. Pre-existing Struct_* registrations are left untouched. END_PUBLIC Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC [Kernels] Route the pre-Blackwell QKV projection through the fused GEMV The QKV path was the last consumer of dequantize-then-matmul on the pre-Blackwell NVFP4 fallback. Compute QKV with the fused dequant-GEMV, then split and store K/V into the paged cache via store_k_cache_ragged and store_v_cache_ragged (fused_qkv_ragged_matmul owns that write on the unquantized path, which is why this stays decomposed). K-equals-V layers (Gemma 4 full attention packs only Q and K rows in wqkv) store the K projection into both caches. With this no BF16 copy of any quantized weight is ever materialized. End-to-end on an NVIDIA A10G (gemma-4 12B NVFP4, bs=1), warm decode: 31.5 tok/s vs 0.38 at the start of the fallback work (83x), with token-identical output. For reference, vLLM nightly's Marlin path serves the same checkpoint at ~16 tok/s eager on the same GPU. END_PUBLIC Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
… fallback BEGIN_PUBLIC [Kernels][Pipelines] Address adversarial-review findings on the fused fallback - Gate the NVFP4 fallback on pre-SM100 *NVIDIA* explicitly (_is_pre_sm100_nvidia_gpu): _is_sm10x_gpu() is False on AMD/Apple, which silently routed those vendors onto an unvalidated dequant path; they now keep their previous behavior. - moe_fp8: fix a comment describing the Blackwell scale-epilogue mechanism on the fallback branch (the per-tensor scale is folded into the dequant block scales; the BF16 grouped matmul has no epilogue). - Soften the dequant docstrings: raw float8_e4m3fn block scales are accepted but only exercised by tests today (the modelopt parser always provides a per-tensor scale). END_PUBLIC Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC [Kernels] Address cycle-2 adversarial-review findings on the fused GEMV - nvfp4_gemv: all threads now enter the PDL region (the early return for tail-block threads moved inside it) and the dead pdl_level parameter is gone (it was accepted but never forwarded). - quant_ops: drop _dequant_weight_nvfp4 (no callers since the fused reroute), refresh _matmul_float4's docstring, and validate the wqkv layout explicitly instead of silently aliasing K as V for anything that is not exactly Q+K+V or Q+K rows. - Stale software-LUT wording updated after the branchless decode. - mojo format on the two new kernel files; smoke test now covers >1 K-chunk per lane (K=2048) and the production K=3840. END_PUBLIC Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC [Kernels] Address cycle-3 adversarial-review findings - Replace the NVFP4 branch of quantized_fused_qkv_matmul with a clear error: no in-tree model reaches it (callers assert on input_scale first, and NVFP4 QKV routes through StackedLinear/_matmul_float4, which is already fused). _dequant_weight_nvfp4 went unused with it. - Add the Apache license header to test_nvfp4_gemv_smoke.mojo. END_PUBLIC Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
…dims BEGIN_PUBLIC [Pipelines] Scale Gemma 4 activation-memory reservation by batch and dims `Gemma4MemoryPlanner.estimate_activation_memory` reserved a flat per-device activation budget (15 GB with a bf16 KV cache, 30 GB with fp8), independent of batch size or model dimensions. On a low-concurrency single-GPU serve this dwarfs the KV cache and, with NVFP4 weights now runnable pre-Blackwell, keeps a model that physically fits on a 24 GB card from loading. Scale the reservation with the widest per-token buffer (hidden vs MLP intermediate), the tokens processed per forward step, and a safety multiple, mirroring the principled per-step estimate in `Qwen3_5MemoryPlanner`. The result is capped at the previous flat value, so this can only lower the reservation, never raise the OOM risk relative to before; it falls back to the flat value when model dimensions are unavailable. The safety multiple and per-step token count still need calibration against measured peak GPU memory (tracked in MODELS-1544) before the cap is relaxed. Relates to modular#6667. END_PUBLIC Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC [Pipelines] Scaffold compressed-tensors NVFP4 checkpoint parsing Add `_parse_compressed_tensors_float4_config`, dispatched from `_parse_float4_config` when `quant_method == "compressed-tensors"` and the weight group is 4-bit float. It maps the compressed-tensors NVFP4 schema to the same `QuantFormat.NVFP4` `QuantConfig` the modelopt path produces, so checkpoints like `RedHatAI/gemma-4-*-NVFP4` route through the existing NVFP4 kernel (and the pre-Blackwell fallback). WIP: the per-tensor global scale (weight_scale_2), the 128x4 tiled-scales layout, and weight/scale tensor-name reconciliation still need validation against a real checkpoint on GPU. Part of modular#6697; relates to modular#6667. END_PUBLIC Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
…allback # Conflicts: # max/python/max/nn/moe/moe_fp8.py
main removed `Nvfp4Strategy` (folded into `NvMxf4f8Strategy`), but `moe/__init__.py` still imported and re-exported it, an unconditional ImportError on `import max.nn.moe`. Remove the stale import and `__all__` entry; `NvMxf4f8Strategy` is already exported. Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
The docstring claimed the change "can only lower the estimate, never raise the OOM risk relative to before" and that it "mirrors Qwen3_5MemoryPlanner". The cap only rules out extra over-reservation; the scaled value is an uncalibrated heuristic, not a proven activation-peak bound, so it can still under-reserve vs the true peak. Reword to state that honestly and drop the inaccurate Qwen comparison. Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
Cover the behaviours issue modular#6696 asked for: the scaled estimate drops below the old flat value for small batch, is capped at the flat value (and the flat cap still scales with the KV cache dtype), falls back to flat when model dims are missing, scales by device count, and adds the graph-capture headroom. Also documents that the batch term is inert below the prefill floor (the estimate is model-size driven, not batch driven). Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
Condense the estimate_activation_memory docstring; keep the one-line summary and the uncalibrated-heuristic caveat. Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
Remove `grouped_quantize_dynamic_block_scaled_fp4` (quant_strategy.py) and `nvfp4_dequant` (quant_ops.py), unused imports left by the main merge that ruff F401 rejected. Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
Construct the planner via __new__ to skip __init__'s model-config validation (estimate_activation_memory reads only its arguments, never self), and drop the unused transformers dep that pydeps rejected. Assisted-by: AI Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC [Kernels] mojo format NVFP4Gemv.execute signature Drop the spaces around `=` in the `dtype=` keyword arguments of NVFP4Gemv.execute so the file passes the mblack lint check. END_PUBLIC Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]> Claude-Session: https://claude.ai/code/session_01LAekHUXqk6v5nLkA7R8Vxx
BEGIN_PUBLIC [Pipelines] Load compressed-tensors NVFP4 Gemma 4 checkpoints The compressed-tensors NVFP4 export (e.g. RedHatAI/gemma-4-31B-it-NVFP4) stores the same block-scaled E2M1 weights and FP8-E4M3 block scales as the modelopt NVFP4 format, so it already parses to QuantFormat.NVFP4 and runs the same kernel path. It only differs in HF config schema and tensor names, so the config parser alone could not actually load it. Reconcile the compressed-tensors tensor names in the Gemma 4 language weight adapter onto the modelopt names the quantized Linear registers: weight_packed -> weight, weight_global_scale -> weight_scale_2, input_global_scale -> input_scale. The shared per-block weight_scale already matches. The per-tensor global scales ship as shape [1] but the Linear registers them as scalars, so squeeze them to (); modelopt checkpoints arrive scalar already and are left untouched. Verified the names/shapes/dtypes against RedHatAI/gemma-4-31B-it-NVFP4 and nvidia/Gemma-4-31B-IT-NVFP4, and add a no-network unit test covering both the compressed-tensors remap and the modelopt pass-through. END_PUBLIC Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]> Claude-Session: https://claude.ai/code/session_01LAekHUXqk6v5nLkA7R8Vxx
…m' into local-gemma4-nvfp4-integration
… into local-gemma4-nvfp4-integration # Conflicts: # max/tests/integration/architectures/gemma4/BUILD.bazel
…s reciprocal scale Local integration checkpoint (not for upstream as-is): - attention.py/gemma4.py: QK-norms + flash-attn output use unquantized compute dtype instead of self.dtype (uint8 packed). Pre-existing gemma4 NVFP4 bug. - weight_adapters.py: invert compressed-tensors global scales to modelopt convention (reciprocal) + squeeze [1]->(); the modular#6699 numerical fix. Validated: RedHatAI/gemma-4-31B-it-NVFP4 -> 'The capital of France is Paris.' at 28 tok/s on L40S.
…prefill/batch) Despacha por M en _matmul_float4: nvfp4_gemv para M<=32 (decode bs bajo), dequant a bf16 + GEMM tensor-core para M>32 (prefill/decode batched). Umbral via env MAX_NVFP4_GEMM_M_THRESHOLD. Throughput batched (64 conc, L40S): 79.6 -> 234 tok/s (2.9x). Single-stream intacto (~23 tok/s). Salida verificada coherente en ambos paths.
…s batched Kernel fusionado nvfp4_gemm.mojo (decode E2M1 en SMEM + MMA sincrono Ada, sin W bf16 en DRAM) + registro mo.gemm.nvfp4 + binding Python + dispatch en _matmul_float4 (M>32). Optimizado 13.9->63.9 TFLOP/s (decode estrecho a 8 FP4/vec sube ocupancia). End-to-end: 79.6->425 tok/s batched (5.3x). Smoke test pasa. Cuello restante (ncu): ocupancia limitada por SMEM (2 bloques/SM, 36.86KB/block).
…cks/SM BEGIN_PUBLIC [Kernels][GPU] NVFP4 GEMM: fix M<=64 race + single-buffer B for 4 blocks/SM The M<=64 dequant-GEMM path used stage_w=True, cp.async-staging the packed weight bytes through SMEM. That path had a latent write-after-read hazard on the w_smem staging buffer (NS=2 too few stages), making the fused kernel non-deterministically produce wrong results (~50% of runs failed the smoke test with ~0.5 abs error). Decoding straight from DRAM (stage_w=False) is both correct and equally fast here (the W->SMEM->W round-trip was pure SMEM-pipe overhead at small M), and it drops dynamic SMEM 36.86KB -> 32KB, lifting the shared-mem occupancy limit from 2 to 3 blocks/SM. Make the decoded-B SMEM buffer depth a parameter and run the M<=64 path single-buffered (b_pipeline_stages=1). A barrier fences the in-place decode from the just-completed MMA reads. This cuts SMEM to 24.58KB -> 4 blocks/SM (ncu: shared-mem block limit 4, max warps 33%). M=64 throughput 63.5 -> 69.5 TFLOP/s; M=256/512 unchanged. Smoke test passes deterministically. END_PUBLIC Co-Authored-By: Claude Opus 4.8 <[email protected]> Claude-Session: https://claude.ai/code/session_01LAekHUXqk6v5nLkA7R8Vxx
Add TensorCore.load_b_nvfp4: decode packed E2M1 weights directly into the MMA register fragments (Marlin bit-positioning, 2^14 bias folded into the scale), and thread a comptime is_nvfp4 flag through multistage_mma_q -> multistage_qgemm_kernel -> multistage_gemm_q (default False keeps the GGUF int4 path unchanged). New test_nvfp4_skeleton_gemm validates correctness vs an E2M1 dequant reference and benchmarks throughput. On L40S (M=482, Llama shapes) this hits 136-184 TFLOP/s vs the bespoke nvfp4_gemm at 66-112 (~2x), approaching the int4 skeleton's 174-235. The arithmetic E2M1 decode stalled the tensor cores (80-108); the Marlin bit-trick decode is what unlocks the throughput.
…46 TFLOP/s) Adds a group=16/BK=16 timing case. Confirms the real NVFP4 grouping needs the per-16-K-subgroup scale handling at BK=32, not a tiny BK=16 tile (which starves the pipeline: 46 vs 177 TFLOP/s at group=128/BK=32).
group=128 hits 178 TFLOP/s but group<=64 collapses to 17 (group=16/BK=16: 46), independent of pipeline stages. q_smem_usage is identical (~40KB) for group=32 and group=128, so it is not the SMEM fallback -- the cause is the mainloop scale reload firing every tile at fine groups (the GGUF skeleton is tuned for group=128). NVFP4's native group=16 lands on this cliff.
Re-measuring with a manual wall-clock (warmup + N back-to-back launches + one synchronize) shows the NVFP4 skeleton runs at ~188 TFLOP/s for group=128/64/32 and ~285 for group=16/BK=16 -- the earlier 17 TFLOP/s 'cliff' was fixed per-call overhead in ctx.execution_time crushing the fast kernels (ncu already showed the kernel itself is ~200us cold for both groups). NVFP4's native group=16 is therefore viable on the skeleton; no scale-staging rework needed.
Parametrize the test/reference by group_size/BK/BLOCK_K. group=32/BK=32 (one scale group per tile, reload every tile) is bit-exact vs the E2M1 reference at 172 TFLOP/s. BK=16 is invalid (num_k_mmas=1 breaks the mainloop prefetch), so NVFP4's group=16 needs BK>=32 with per-subgroup scale selection (next).
Add num_scale_sub = ceildiv(BK, group_size) and thread per-subgroup scale selection through multistage_mma_q (scales_reg_tiles gains a subgroup column; prologue/reload/prefetch loop over subgroups; load_b_nvfp4 takes a scale_sub index = (mma_tile_coord_k * MMA_K) // group_size). All backward-compatible: num_scale_sub == 1 leaves the GGUF int4 path and group==BK NVFP4 bit-exact (group=32/BK=32 at 178 TFLOP/s, GGUF Q4 test still passes). group=16/BK=32 compiles and no longer NaNs but the subgroup->scale mapping still yields wrong values -- WIP.
Flatten scales_reg_tiles to [num_n_mmas*num_scale_sub, 1] so each subgroup's scales are contiguous (the strided column broke the vectorized copy), and add a kernel-vs-kernel comparison (compare_g16_vs_g32) on identical weights + uniform scales. With the scale read bypassed in load_b_nvfp4, group=16/BK=32 matches group=32 exactly -> the weight decode and pipeline are correct at group<BK; the remaining bug is purely the scale register fill/read for num_scale_sub>1. group=32/BK=32 and the GGUF Q4 test still pass.
The remaining group=16/BK=32 failure is in the scale SMEM staging for num_scale_sub>1: the every-tile reload reads scale stage 2 as 0.0 at k_tile=2 even though the prologue fills stages 0,1,2. A synchronous scale copy did not fix it, so the cause is stage indexing/ordering, not async completion. Weights and pipeline are correct (compare_g16_vs_g32 passes with the scale read bypassed). group=32/BK=32 and GGUF Q4 remain bit-exact.
BEGIN_PUBLIC [Kernels][GPU] Fix NVFP4 group<BK scale staging (group=16 now correct) `num_scales_stages` counted scale *groups*, but the SMEM scale stage packs `num_scale_sub = ceildiv(BK, group_size)` groups into a single [num_scale_sub, BN] tile. For NVFP4 (group=16 < BK=32) this overcounted the ring depth by num_scale_sub, so the mainloop write-ahead `next_unsafe(num_scales_stages-1)` aliased a still-live stage and a scale stage was read stale (0.0) mid-pipeline -- producing wrong results for the real group=16 NVFP4 layout. Track one packed scale stage per in-flight k-tile instead: `ceildiv(num_pipeline_stages-1, max(group_size//BK, 1)) + 1`. This is identical to the previous formula for group >= BK (GGUF Q4 unchanged) and gives the correct depth for group < BK. Applied in the kernel, the host allocation, and `q_smem_usage`. Validated on L40S (sm_89): group=16 and group=32 NVFP4 are bit-exact vs the E2M1 dequant reference at ~167 TFLOP/s; the GGUF Q4 path is unregressed. The skeleton test now drives the reference at ref_block_k == group_size (its supported regime) so the real group=16 layout is covered. END_PUBLIC
… layout
BEGIN_PUBLIC
[Kernels][GPU] Add NVFP4 checkpoint repack to the multistage skeleton layout
`repack_nvfp4_for_sm8x` converts canonical NVFP4 checkpoint tensors --
weights [N, K/2] uint8 (E2M1, 2 codes/byte), FP8-e4m3 block scales [N, K/16],
and a per-tensor float32 global scale -- into the single combined buffer the
`multistage_gemm_q[..., is_nvfp4=True]` skeleton kernel consumes: repacked
4-bit weights in the 64x16 GPTQ tile layout followed by bf16 block scales in
row_major(K/group, N). The 4-bit weight bytes repack identically to int4, so
the weight path reuses pack_Q_tile and the GPTQ tile layout; only the scales
differ (fp8 * global -> bf16 instead of interleaved fp16).
Adds test_nvfp4_repack_e2e.mojo: it builds an INDEPENDENT dense bf16 reference
(canonical E2M1 LUT dequant + dense matmul, no repack) and asserts the full
repack -> skeleton-kernel chain matches it at gemma4-representative shapes
(N,K in {4096x4096, 15360x3840, 3840x15360}, M in {482, 64}), bit-close at
atol=1e-3/rtol=2e-2. Wall-clock microbench shows ~150-166 TFLOP/s at M=482 on
L40S (sm_89) -- ~2.5x the bespoke nvfp4_gemm.
END_PUBLIC
BEGIN_PUBLIC [Kernels][GPU] Register NVFP4 skeleton custom ops (repack + qmatmul) Expose the validated NVFP4-on-multistage-skeleton path to the graph compiler, mirroring the GPTQ repack+matmul op pair: - Host launchers in qmatmul_gpu.mojo: `matmul_gpu_nvfp4` (calls `multistage_gemm_q[..., is_nvfp4=True]` with the validated BK=32 config) and `gpu_nvfp4_repack` (launches `repack_nvfp4_for_sm8x`). - Custom ops in builtin_kernels/quantization.mojo: `repack_nvfp4_g16` (weights + FP8 block scales + global f32 -> combined skeleton buffer; the global scale is read to host once, as the repack is a load-time constant transform) and `qmatmul_nvfp4_g16` (C = A @ dequant(W).T), each with a shape function. This is the graph-side wiring only; routing `_matmul_float4` to it is a separate change gated on serving validation. END_PUBLIC
BEGIN_PUBLIC [Kernels][GPU] Route NVFP4 large-M GEMM to the skeleton path (opt-in) Add Python wrappers `nvfp4_skeleton_repack` / `nvfp4_skeleton_gemm` for the `repack_nvfp4_g16` / `qmatmul_nvfp4_g16` custom ops, and route the pre-Blackwell large-M branch of `_matmul_float4` through them when `MAX_NVFP4_SKELETON_GEMM=1`. The repack folds the FP8 block scale with the per-tensor global scale into bf16 (constant-folded at load), so the GEMM runs on MAX's mature multistage quantized mainloop -- bit-exact and ~2.5x the bespoke nvfp4_gemm at gemma4 shapes on L40S. Default off: the live NVFP4 dispatch is unchanged until the flag is validated token-identical in serving. END_PUBLIC
… test BEGIN_PUBLIC [Kernels][GPU] Add engine-level NVFP4 skeleton-vs-bespoke equivalence test Runs the bespoke `nvfp4_gemm` and the skeleton path (`nvfp4_skeleton_repack` + `nvfp4_skeleton_gemm`) in one graph on identical canonical NVFP4 weights and asserts the outputs match (rtol 2e-2), through the real engine. This validates the full graph wiring the Mojo-level test cannot: custom-op registration, the one-time global-scale host read in the repack op, and the Python wrappers. Verified on L40S (sm_89): 3 passed at shapes (482,4096,4096), (64,4096,4096), (482,15360,3840). END_PUBLIC
BEGIN_PUBLIC [Kernels][GPU] Per-M autotuned configs for the NVFP4 skeleton GEMM `matmul_gpu_nvfp4` used a single large-M tile (block 128x128, warp_k=1), which starves small M: batched decode runs at M=batch, and at M=64 the large tile hit only ~25 TFLOP/s -- worse than the bespoke kernel. Bucket the config by M (all BK=32), autotuned on L40S at gemma4 shapes: M<=64 : block 32x64, warp 32x64, stages 3, warp_k 4 M<=128 : block 64x64, warp 64x64, stages 4, warp_k 4 M<=256 : block 64x128, warp 64x64, stages 4, warp_k 2 M>256 : block 128x128, warp 64x64, stages 4, warp_k 1 (unchanged large-M peak) Small M uses tiny tiles + deep split-K to spread work across SMs (the lever the bespoke int4 path uses); split-K is numerically correct for NVFP4 since the scale staging is independent of the K partition. Measured peaks: M=64 ~100-110 (was ~25, a ~4x gain), M=128 ~140-156, M=256 ~158-170, M>256 ~162-168 TFLOP/s. Engine equivalence test (skeleton vs bespoke) still passes at M=64 and M=482. END_PUBLIC
BEGIN_PUBLIC [Kernels][GPU] Add NVFP4 skeleton GEMM config-sweep tuning bench test_nvfp4_tune.mojo sweeps MatmulConfig candidates (block tile, warp shape, pipeline stages, split-K depth; all BK=32) for the NVFP4 skeleton GEMM at gemma4 shapes, gating each on correctness vs the dense E2M1-dequant reference before reporting wall-clock TFLOP/s. Used to pick the per-M configs in matmul_gpu_nvfp4. Confirms split-K is numerically correct for NVFP4 and that small-M throughput is ~4x higher with tiny tiles + deep split-K. END_PUBLIC
…ench
BEGIN_PUBLIC
[Kernels][GPU] Add NVFP4 real-shape skeleton-vs-bespoke GEMM tuning bench
test_nvfp4_tune_real.mojo sweeps skeleton MatmulConfig candidates against the
bespoke nvfp4_gemm at the actual gemma4-31B decode GEMM shapes (gate/up
N=21504 K=5376, down N=5376 K=21504, q/o/kv) for M in {32,64,128,256}, gating
every TFLOP/s number on correctness vs the dense E2M1-dequant reference.
CAVEAT: this is an L2-WARM throughput microbench (it replays the same weight
many times), so its TFLOP/s overstate cold-weight serving performance. nsys
profiling of real batch-64 decode shows the FP4 proj/FFN GEMM is 79% of the
step and runs at only ~16% of HBM bandwidth on COLD weights -- that cold-weight
bandwidth efficiency (not warm throughput) is the real ~3x gap vs vLLM. A
faithful tuning bench for serving must read cold weights (L2 flush / distinct
weights per call); this harness is kept for warm-throughput comparison only.
END_PUBLIC
…tency) BEGIN_PUBLIC [Kernels][GPU] Add cold-weight NVFP4 GEMM microbench (throughput + latency) test_nvfp4_cold_gemm.mojo measures the FP4 decode GEMM with COLD weights (a round-robin pool of distinct weight buffers >> L2, validated to evict cache), in both throughput (one sync at end) and latency (sync per call) modes, and reports effective HBM bandwidth vs the 864 GB/s peak. This is the faithful tool for the decode regime, where a prior L2-warm microbench was misleading. Finding: at M=64 the skeleton GEMM reaches ~65-76% of HBM bandwidth cold (vs the bespoke nvfp4_gemm at ~21-31%), and latency ~= throughput -- so the FP4 weight streaming is NOT the serving bottleneck. nsys of the skeleton serve shows the real cost is the repack op running every decode step (~52% of GPU time) because the custom op is not constant-folded; the fix is to repack at load time. END_PUBLIC
…ginal weight)
BEGIN_PUBLIC
[Kernels][GPU] Add byte-exact host NVFP4 repack (load-time, frees original weight)
nvfp4_repack_host (max/python/max/nn/_nvfp4_repack_host.py) reproduces the GPU
repack_nvfp4_for_sm8x kernel in pure numpy: canonical NVFP4 (packed uint8
weights + FP8-e4m3 block scales + f32 global) -> the combined skeleton buffer
([repacked 4-bit weights in the 64x16 GPTQ tile layout][bf16 block scales]).
This enables a load-time repack: the combined buffer becomes the sole
registered weight so the original packed weight is never uploaded -- avoiding
the 2x weight memory (OOM) that an in-graph repack hits, since MAX's
weights-registry uploads every declared weight regardless of graph folding.
Strictly validated: test_nvfp4_repack_host_gpu.py runs the actual GPU repack op
via the engine and asserts byte-for-byte equality (np.array_equal) at N/K in
{64x128, 128x256, 256x128, 512x256} -- both the weight bit-shuffle (pack_Q_tile
+ nested tile layout) and the bf16 scale region (incl. the per-64-col N
permutation and RNE f32->bf16 rounding).
END_PUBLIC
…espoke in serving BEGIN_PUBLIC [Kernels][GPU] Wire load-time NVFP4 repack: skeleton GEMM ~1.8-2.9x bespoke in serving Route the pre-Blackwell NVFP4 matmul through the multistage-skeleton GEMM with the weight repacked ONCE at load (behind MAX_NVFP4_SKELETON_GEMM=1, default off). Previously the skeleton repacked the weight in-graph every decode step (nsys: ~52% of decode) because MAX's weights-registry uploads every declared weight, so an in-graph repack kept both the original and repacked weight resident (2x memory / OOM). Now: - gemma4 weight adapter repacks canonical NVFP4 (weights + FP8 block scales + global) into the combined skeleton buffer via the byte-exact nvfp4_repack_host and registers it as the SOLE weight; the original packed weight + scales are never uploaded. - Linear declares the combined-shape weight (no scale weights) for the skeleton NVFP4 case; _matmul_float4 calls qmatmul_nvfp4_g16 directly (no repack op, no scales in the forward graph). StackedLinear/MLP run unfused skeleton GEMMs and concat (combined buffers are not concatenable along N). - head_aware_col_sharding_strategy: single-device fast-path (returns the whole weight) so the combined buffer's column layout isn't sliced by head-aware K math (no-op for canonical weights). - nvfp4_repack_host vectorized (the loop version was infeasible at 31B load); still byte-exact vs the GPU repack kernel. Validated on L40S serving gemma-4-31B-it-NVFP4: original weight freed (no setup OOM), coherent generation matching the bespoke path, and batched throughput ~347 vs ~118 tok/s at conc=16 (~2.9x) and ~815 vs ~447 at conc=64 (~1.8x) -- moving NVFP4 serving from ~33% to ~61% of vLLM. Flag-off path unchanged. Caveat: the combined weight is slightly larger (inline bf16 scales), so batch-64/maxlen-2048 is memory-marginal on a 46GB L40S; reduce batch / max-length / device-memory-utilization for headroom. Load adds a one-time host repack (~7min). END_PUBLIC
…ter load) BEGIN_PUBLIC [Kernels][GPU] Parallelize the load-time NVFP4 host repack (~3.8x faster load) The per-matrix host repack (nvfp4_repack_host) is the dominant load cost for the skeleton NVFP4 path (~1s/matrix x hundreds of layers ~= 7min). Parallelize it across a ThreadPoolExecutor: nvfp4_repack_host is a pure numpy function whose heavy ops release the GIL. The dlpack/Buffer reads are kept in the main thread (WeightData access is not thread-safe); only the materialized-numpy compute runs in workers. Measured ~3.8x on a 16-core host (60.4s -> 15.7s for 56 gemma4-shaped matrices), cutting the one-time skeleton load from ~7.5min to ~2min. Output is byte-identical (same repack function); the byte-exact GPU validation is unchanged. END_PUBLIC
Adversarial review — NVFP4 skeleton portReviewed the skeleton-port net diff only ( Verdict: no confirmed data-corruption bug in the net diff. The five load-bearing claims hold up under scrutiny, with a few precondition/wording/guard items worth fixing before un-drafting. Severity tags: 🔴 fix-before-merge · 🟡 concern · ⚪ nit. Claims verified
Fix before merge
Worth a guard (latent fragility, harmless today)
Note on what the tests actually coverThe active flagged path never calls the in-graph Adversarial review assisted by AI (two independent passes). Findings are code-trace based; the three items tagged "confirm by running" need a GPU run to settle. |
…ing + docstrings BEGIN_PUBLIC [Kernels][GPU] Address NVFP4 skeleton review: shape guards + SMEM sizing + docstrings Adversarial-review follow-ups (no data-corruption bug found; these are guards and clarity): - matmul_gpu_nvfp4: comptime-assert N % 128 == 0 and K % 32 == 0 (BK) so a non-conforming shape fails at compile time instead of OOB-reading the tiles. - repack_nvfp4_for_sm8x: comptime-assert group_size == 16, N % 64 == 0 (repack tile), K % 128 == 0 (WGROUP) -- the divisibility the repack relies on. - q_smem_usage: size the scale SMEM by bfloat16 (the kernel's scales_type) explicitly, not the activation dtype, so an fp8-activation variant would not under-allocate it. - load_b_nvfp4 docstring: correct the stale "assumes BK == group_size" note to document group_size <= BK (the scale_sub subgroup selection). All asserts hold for gemma4 + the test shapes; the repack e2e and engine equivalence tests still pass. END_PUBLIC
Review follow-ups addressed (pushed in
|
BEGIN_PUBLIC [Kernels][GPU] Fix CI: load_b_nvfp4 Args docstring + apply formatter mojodoc failed because TensorCore.load_b_nvfp4 takes arguments but its docstring had no Args section; add it. Also apply the repo formatter (mblack/ruff) to the NVFP4 skeleton sources and tests that were not format-clean (lint failure). END_PUBLIC Assisted-by: AI
…ack import BEGIN_PUBLIC [Kernels][GPU] Fix CI round 2: ruff F841, mypy weight_scale, lazy repack import - Remove unused locals (BN/BK/BK_groups) in the host repack (ruff F841). - quantized_matmul: widen weight_scale to TensorValue | None (the NVFP4 skeleton path folds it away) and assert non-None in the MXFP4/MXFP8/FP8 branches that require it (mypy arg-type). - Import nvfp4_repack_host lazily inside the fuse function so the numpy host-repack module stays off the `max ... --help` startup path (CLI help perf test). END_PUBLIC Assisted-by: AI
…ge-skeleton Sync the NVFP4 skeleton branch with 38 upstream commits. The only conflict is the `.kernels` import block in nn/quant_ops.py: upstream adds `_fused_qkv_ragged_matmul_scaled_mxfp8` (fused MXFP8 QKV) next to this branch's `_is_pre_sm100_nvidia_gpu` import. Resolved by keeping both. Everything else auto-merges (moe_fp8.py, pipelines/weights/quant.py, nn/kernels.py, test/gpu/linalg/BUILD.bazel). No semantic drift: the QuantStrategy protocol widening predates this branch, so Nvfp4DequantStrategy already conforms. Assisted-by: AI
Summary
NVFP4 (E2M1) GEMM built on MAX's mature multistage quantized mainloop
(
multistage_mma_q) instead of the bespokenvfp4_gemm, with the weightrepacked once at load and routed into pre-Blackwell serving behind
MAX_NVFP4_SKELETON_GEMM=1(default off). Draft.Depends on / blocked by (must merge first)
Stacked on the Gemma-4 NVFP4 series; lands after all three, then rebases onto
main(overlap disappears). The new commits here are the skeleton port itself(~21, from
NVFP4 (E2M1) path on the multistage qGEMM skeletononward).Results (L40S / sm_89, gemma-4-31B-it-NVFP4)
Batched serving throughput (tok/s, 256 output tokens):
nvfp4_gemm)vs vLLM (~1340 tok/s batched, same model/GPU, prior external measurement):
pre-Blackwell NVFP4 serving moves from ~33% of vLLM (bespoke) to
~48–61% (this PR).
Kernel-level (microbench, manual wall-clock):
group=16 / group=32 match an E2M1 dequant reference within bf16 rounding (loose tol); GGUF Q4 unregressed.
Load: ~10min → ~2.8min (parallelized host repack).
What changed
TensorCore.load_b_nvfp4(Marlin bit-trick decode) + comptimeis_nvfp4flag (off ⇒ int4 path byte-identical).num_scales_stagesfor group < BK (was counting groups, not packed stages → stale scale at group=16).repack_nvfp4_g16/qmatmul_nvfp4_g16custom ops +_matmul_float4routing; per-M autotuned configs (split-K).Validation
Byte-exact host repack vs the GPU kernel (4 shapes); engine-level skeleton-vs-bespoke
equivalence on identical weights; real 31B serving (original freed, coherent generation
matching bespoke, throughput as above). Flag-off path unchanged.
Caveats
Combined weight is ~2GB larger (inline bf16 scales) ⇒ batch-64/maxlen-2048 is
memory-marginal on 46GB (use
--device-memory-utilization/--max-length). One-time~2.8min host repack at load (disk-caching is a follow-up). Remaining gap to vLLM is the
decode step beyond the FP4 GEMM (attention/KV/scheduler), out of scope.
Test plan
bazel test //max/kernels/test/gpu/quantization:{test_nvfp4_skeleton_gemm,test_nvfp4_repack_e2e}.mojo.test//max/tests/integration/nn:nn_gpu_tests -k nvfp4_skeleton(engine equivalence) and-k repack_host(byte-exact)MAX_NVFP4_SKELETON_GEMM=1 max serve --model RedHatAI/gemma-4-31B-it-NVFP4Assisted by: AI.