Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[Kernels][GPU] NVFP4 GEMM on the multistage qGEMM skeleton (~1.8-2.9x serving vs bespoke)#6708

Draft
msaelices wants to merge 56 commits into
modular:mainfrom
msaelices:nvfp4-on-multistage-skeleton
Draft

[Kernels][GPU] NVFP4 GEMM on the multistage qGEMM skeleton (~1.8-2.9x serving vs bespoke)#6708
msaelices wants to merge 56 commits into
modular:mainfrom
msaelices:nvfp4-on-multistage-skeleton

Conversation

@msaelices

@msaelices msaelices commented Jun 22, 2026

Copy link
Copy Markdown
Contributor

Summary

NVFP4 (E2M1) GEMM built on MAX's mature multistage quantized mainloop
(multistage_mma_q) instead of the bespoke nvfp4_gemm, with the weight
repacked once at load and routed into pre-Blackwell serving behind
MAX_NVFP4_SKELETON_GEMM=1 (default off). Draft.

Depends on / blocked by (must merge first)

Stacked on the Gemma-4 NVFP4 series; lands after all three, then rebases onto
main (overlap disappears). The new commits here are the skeleton port itself
(~21, from NVFP4 (E2M1) path on the multistage qGEMM skeleton onward).

Results (L40S / sm_89, gemma-4-31B-it-NVFP4)

Batched serving throughput (tok/s, 256 output tokens):

workload before (bespoke nvfp4_gemm) after (skeleton + load-time repack) speedup
16 concurrent ~118 ~347 ~2.9x
64 concurrent ~445 ~648 (maxlen 1024, dmu 0.84) / ~815 (maxlen 2048) ~1.5–1.8x

vs vLLM (~1340 tok/s batched, same model/GPU, prior external measurement):
pre-Blackwell NVFP4 serving moves from ~33% of vLLM (bespoke) to
~48–61% (this PR).

Kernel-level (microbench, manual wall-clock):

metric before (bespoke) after (skeleton)
NVFP4 GEMM, M=482, gemma4 shapes ~66 TFLOP/s ~167 TFLOP/s (~2.5x)
cold-weight HBM bandwidth, M=64 (decode) ~21–31% of peak ~65–76% of peak

group=16 / group=32 match an E2M1 dequant reference within bf16 rounding (loose tol); GGUF Q4 unregressed.
Load: ~10min → ~2.8min (parallelized host repack).

What changed

  • TensorCore.load_b_nvfp4 (Marlin bit-trick decode) + comptime is_nvfp4 flag (off ⇒ int4 path byte-identical).
  • Fixed num_scales_stages for group < BK (was counting groups, not packed stages → stale scale at group=16).
  • Load-time repack in the weight adapter: combined buffer becomes the sole registered weight (the original isn't uploaded), avoiding the in-graph per-step repack (~52% of decode) and the 2x-weight OOM; byte-exact numpy host replica of the GPU repack.
  • repack_nvfp4_g16 / qmatmul_nvfp4_g16 custom ops + _matmul_float4 routing; per-M autotuned configs (split-K).

Validation

Byte-exact host repack vs the GPU kernel (4 shapes); engine-level skeleton-vs-bespoke
equivalence on identical weights; real 31B serving (original freed, coherent generation
matching bespoke, throughput as above). Flag-off path unchanged.

Caveats

Combined weight is ~2GB larger (inline bf16 scales) ⇒ batch-64/maxlen-2048 is
memory-marginal on 46GB (use --device-memory-utilization / --max-length). One-time
~2.8min host repack at load (disk-caching is a follow-up). Remaining gap to vLLM is the
decode step beyond the FP4 GEMM (attention/KV/scheduler), out of scope.

Test plan

  • bazel test //max/kernels/test/gpu/quantization:{test_nvfp4_skeleton_gemm,test_nvfp4_repack_e2e}.mojo.test
  • //max/tests/integration/nn:nn_gpu_tests -k nvfp4_skeleton (engine equivalence) and -k repack_host (byte-exact)
  • MAX_NVFP4_SKELETON_GEMM=1 max serve --model RedHatAI/gemma-4-31B-it-NVFP4

Assisted by: AI.

msaelices and others added 30 commits June 13, 2026 22:39
…GPUs

Serve NVFP4 checkpoints on GPUs without native FP4 matmul support
(sm_8x/sm_90) by dequantizing weights to BF16 with a software-LUT kernel
and routing through the regular BF16 matmuls, mirroring the existing
MXFP4-on-H100 dequant approach. Fixes modular#6667.

Kernels:
- Generalize the MXFP4 dequant kernel into a format-agnostic _dequant_fp4
  (block-scale dtype and SF_VECTOR_SIZE were already parameters) and add a
  dequant_nvfp4 entry point: E4M3 or pre-multiplied float32 scales,
  SF_VECTOR_SIZE=16. The E2M1 decode is a software LUT, so it runs on any
  GPU. Register as mo.dequant.nvfp4.
- Smoke test (CPU reference vs GPU) covering both scale dtypes, several
  shapes and an FP8 output; unlisted in _EXTRA_CONSTRAINTS so it runs on
  any GPU.

Python (active only when not on SM100+):
- kernels.nvfp4_dequant() wrapper (rank 2/3, flat [.., K//16] scales).
- quant_ops: _matmul_float4 and the NVFP4 fused-QKV case dequantize the
  weight (weight_scale_2 pre-multiplied into float32 scales) and use the
  unquantized matmul / fused_qkv_ragged_matmul. Pre-interleaved (TCGEN 5D)
  scale checkpoints are rejected with a clear error.
- MoE: new Nvfp4DequantStrategy — activations stay BF16 (no dynamic
  quantization), expert weights dequantized per matmul, BF16
  grouped_matmul_ragged. Selected in MoEQuantized._strategy(); expert
  epilogue scales carry only weight_scale_2 (no activation input-scale
  factor, since activations are never quantized on this path).

Blackwell behavior is unchanged: every fallback is gated on
not _is_sm10x_gpu().
END_PUBLIC

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
Two fixes found while validating end to end on an A10G (sm_86):

- _dequant_weight_nvfp4: move weight_scale_2 (a CPU rank-0 weight) to the
  weight's device before multiplying into the block scales ("Input values
  must be on the same device").
- Nvfp4DequantStrategy.grouped_matmul: dequantize to bfloat16 explicitly;
  self.dtype is the packed storage dtype (uint8) on this path and tripped
  the kernel's output-dtype constraint at graph compile time.

Validated: nvidia/Llama-3.1-8B-Instruct-NVFP4 (dense, modelopt) serves on
the A10G and generates coherent text. nvidia/Gemma-4-26B-A4B-NVFP4 (MoE)
compiles and loads weights through the fallback, but 18.8 GiB of packed
weights exhaust the 24 GiB card during setup (allocator reports 18.73 GiB
in use, largest free block 359 MB) -- a checkpoint-size limit, not a
fallback defect.

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
- moe/__init__.py: restore isort order in the import block and __all__
  (ruff I001/RUF022 with the CI-pinned ruff).
- moe_fp8.py: raise an explicit NotImplementedError when selecting the
  NVFP4 dequant strategy with an active expert-parallel batch manager --
  EP dispatch quantizes activations to FP4, so the fallback would
  otherwise die later in grouped-matmul shape validation with a
  confusing error.
- Smoke test: vary block scales per index so wrong scale indexing cannot
  pass unnoticed (previously uniform scales masked that class of bug);
  re-run on the A10G, all cases still PASS with max_err=0.0.
- BUILD.bazel: constrain the smoke test to nvidia_gpu|amd_gpu (mirrors
  the mxfp4 smoke test) instead of enrolling untested Apple CI.

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC
[Kernels][GPU] Branchless E2M1 decode in cast_uint_to_fp4e2m1

The software LUT indexed a SIMD[float32, 16] table with a runtime
index, which spills the vector to GPU local memory on every element.
That made the FP4 dequant kernels local-memory bound: 12 GB/s effective
on an A10G (sm=100%, dram=99% from spill traffic). Replacing the lookup
with branchless bit arithmetic (exp/mantissa decode + select) makes the
kernel memory-bound as intended: 422 GB/s on the same device, a 35x
kernel speedup, 2.5x end-to-end on the runtime-dequant serving path.

Verified with test_nvfp4_dequant_smoke (bit-exact vs CPU reference).
END_PUBLIC

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC
[Kernels][GPU] Fused NVFP4 dequant-GEMV for the pre-Blackwell fallback

Marlin-style fused kernel: one warp per output column decodes packed
E2M1 in registers (branchless, no LUT) and dot-products against the
BF16 activations — the dequantized weight is never materialized in
global memory. Per-token DRAM traffic drops from ~4.5 B/element
(dequantize-then-matmul: bf16 write + re-read + packed read) to the
packed ~0.5 B/element, and no per-forward transient buffers exist.
Activation rows are tiled (M_TILE=4) so a P-token prefill re-reads the
packed weight ceil(P/4) times; decode (M=1) is the optimized case.

_matmul_float4's pre-Blackwell path (o_proj and dense MLP projections)
now routes through the fused op. The QKV path keeps dequantize-then-
matmul for now (fused_qkv_ragged_matmul owns the KV-cache write).

Measured on an NVIDIA A10G (gemma-4 12B NVFP4, bs=1):
- kernel: 0.126 ms for [15360x3840] at M=1 (350 GB/s of packed reads)
- per-linear through the graph executor: 2.7 ms -> 0.133 ms (20x)
- end-to-end decode: 0.94 -> 4.98 tok/s (5.3x; 13x vs the original
  fallback). Output is token-identical to the dequant path.

Correctness verified against a scalar CPU reference for M={1,3,5},
including non-multiple-of-warp N (test_nvfp4_gemv_smoke).
END_PUBLIC

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC
[Kernels] Use CamelCase names for the new NVFP4 op structs

Rename Struct_dequant_nvfp4 -> NVFP4Dequant and Struct_gemv_nvfp4 ->
NVFP4Gemv. The Struct_ prefix is the legacy machine-generated style;
the hand-written registrations in quantization.mojo use CamelCase.
Pre-existing Struct_* registrations are left untouched.
END_PUBLIC

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC
[Kernels] Route the pre-Blackwell QKV projection through the fused GEMV

The QKV path was the last consumer of dequantize-then-matmul on the
pre-Blackwell NVFP4 fallback. Compute QKV with the fused dequant-GEMV,
then split and store K/V into the paged cache via store_k_cache_ragged
and store_v_cache_ragged (fused_qkv_ragged_matmul owns that write on
the unquantized path, which is why this stays decomposed). K-equals-V
layers (Gemma 4 full attention packs only Q and K rows in wqkv) store
the K projection into both caches.

With this no BF16 copy of any quantized weight is ever materialized.
End-to-end on an NVIDIA A10G (gemma-4 12B NVFP4, bs=1), warm decode:
31.5 tok/s vs 0.38 at the start of the fallback work (83x), with
token-identical output. For reference, vLLM nightly's Marlin path
serves the same checkpoint at ~16 tok/s eager on the same GPU.
END_PUBLIC

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
… fallback

BEGIN_PUBLIC
[Kernels][Pipelines] Address adversarial-review findings on the fused fallback

- Gate the NVFP4 fallback on pre-SM100 *NVIDIA* explicitly
  (_is_pre_sm100_nvidia_gpu): _is_sm10x_gpu() is False on AMD/Apple,
  which silently routed those vendors onto an unvalidated dequant path;
  they now keep their previous behavior.
- moe_fp8: fix a comment describing the Blackwell scale-epilogue
  mechanism on the fallback branch (the per-tensor scale is folded into
  the dequant block scales; the BF16 grouped matmul has no epilogue).
- Soften the dequant docstrings: raw float8_e4m3fn block scales are
  accepted but only exercised by tests today (the modelopt parser
  always provides a per-tensor scale).
END_PUBLIC

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC
[Kernels] Address cycle-2 adversarial-review findings on the fused GEMV

- nvfp4_gemv: all threads now enter the PDL region (the early return for
  tail-block threads moved inside it) and the dead pdl_level parameter
  is gone (it was accepted but never forwarded).
- quant_ops: drop _dequant_weight_nvfp4 (no callers since the fused
  reroute), refresh _matmul_float4's docstring, and validate the wqkv
  layout explicitly instead of silently aliasing K as V for anything
  that is not exactly Q+K+V or Q+K rows.
- Stale software-LUT wording updated after the branchless decode.
- mojo format on the two new kernel files; smoke test now covers >1
  K-chunk per lane (K=2048) and the production K=3840.
END_PUBLIC

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC
[Kernels] Address cycle-3 adversarial-review findings

- Replace the NVFP4 branch of quantized_fused_qkv_matmul with a clear
  error: no in-tree model reaches it (callers assert on input_scale
  first, and NVFP4 QKV routes through StackedLinear/_matmul_float4,
  which is already fused). _dequant_weight_nvfp4 went unused with it.
- Add the Apache license header to test_nvfp4_gemv_smoke.mojo.
END_PUBLIC

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
…dims

BEGIN_PUBLIC
[Pipelines] Scale Gemma 4 activation-memory reservation by batch and dims

`Gemma4MemoryPlanner.estimate_activation_memory` reserved a flat per-device
activation budget (15 GB with a bf16 KV cache, 30 GB with fp8), independent
of batch size or model dimensions. On a low-concurrency single-GPU serve
this dwarfs the KV cache and, with NVFP4 weights now runnable pre-Blackwell,
keeps a model that physically fits on a 24 GB card from loading.

Scale the reservation with the widest per-token buffer (hidden vs MLP
intermediate), the tokens processed per forward step, and a safety multiple,
mirroring the principled per-step estimate in `Qwen3_5MemoryPlanner`. The
result is capped at the previous flat value, so this can only lower the
reservation, never raise the OOM risk relative to before; it falls back to
the flat value when model dimensions are unavailable.

The safety multiple and per-step token count still need calibration against
measured peak GPU memory (tracked in MODELS-1544) before the cap is relaxed.

Relates to modular#6667.
END_PUBLIC

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC
[Pipelines] Scaffold compressed-tensors NVFP4 checkpoint parsing

Add `_parse_compressed_tensors_float4_config`, dispatched from
`_parse_float4_config` when `quant_method == "compressed-tensors"` and the
weight group is 4-bit float. It maps the compressed-tensors NVFP4 schema to
the same `QuantFormat.NVFP4` `QuantConfig` the modelopt path produces, so
checkpoints like `RedHatAI/gemma-4-*-NVFP4` route through the existing NVFP4
kernel (and the pre-Blackwell fallback).

WIP: the per-tensor global scale (weight_scale_2), the 128x4 tiled-scales
layout, and weight/scale tensor-name reconciliation still need validation
against a real checkpoint on GPU.

Part of modular#6697; relates to modular#6667.
END_PUBLIC

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
…allback

# Conflicts:
#	max/python/max/nn/moe/moe_fp8.py
main removed `Nvfp4Strategy` (folded into `NvMxf4f8Strategy`), but
`moe/__init__.py` still imported and re-exported it, an unconditional
ImportError on `import max.nn.moe`. Remove the stale import and `__all__`
entry; `NvMxf4f8Strategy` is already exported.

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
The docstring claimed the change "can only lower the estimate, never raise
the OOM risk relative to before" and that it "mirrors Qwen3_5MemoryPlanner".
The cap only rules out extra over-reservation; the scaled value is an
uncalibrated heuristic, not a proven activation-peak bound, so it can still
under-reserve vs the true peak. Reword to state that honestly and drop the
inaccurate Qwen comparison.

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
Cover the behaviours issue modular#6696 asked for: the scaled estimate drops below
the old flat value for small batch, is capped at the flat value (and the
flat cap still scales with the KV cache dtype), falls back to flat when
model dims are missing, scales by device count, and adds the graph-capture
headroom. Also documents that the batch term is inert below the prefill
floor (the estimate is model-size driven, not batch driven).

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
Condense the estimate_activation_memory docstring; keep the one-line
summary and the uncalibrated-heuristic caveat.

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
Remove `grouped_quantize_dynamic_block_scaled_fp4` (quant_strategy.py) and
`nvfp4_dequant` (quant_ops.py), unused imports left by the main merge that
ruff F401 rejected.

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
Construct the planner via __new__ to skip __init__'s model-config
validation (estimate_activation_memory reads only its arguments, never
self), and drop the unused transformers dep that pydeps rejected.

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC
[Kernels] mojo format NVFP4Gemv.execute signature

Drop the spaces around `=` in the `dtype=` keyword arguments of
NVFP4Gemv.execute so the file passes the mblack lint check.
END_PUBLIC

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Claude-Session: https://claude.ai/code/session_01LAekHUXqk6v5nLkA7R8Vxx
BEGIN_PUBLIC
[Pipelines] Load compressed-tensors NVFP4 Gemma 4 checkpoints

The compressed-tensors NVFP4 export (e.g. RedHatAI/gemma-4-31B-it-NVFP4)
stores the same block-scaled E2M1 weights and FP8-E4M3 block scales as the
modelopt NVFP4 format, so it already parses to QuantFormat.NVFP4 and runs the
same kernel path. It only differs in HF config schema and tensor names, so the
config parser alone could not actually load it.

Reconcile the compressed-tensors tensor names in the Gemma 4 language weight
adapter onto the modelopt names the quantized Linear registers:
weight_packed -> weight, weight_global_scale -> weight_scale_2,
input_global_scale -> input_scale. The shared per-block weight_scale already
matches. The per-tensor global scales ship as shape [1] but the Linear
registers them as scalars, so squeeze them to (); modelopt checkpoints arrive
scalar already and are left untouched.

Verified the names/shapes/dtypes against RedHatAI/gemma-4-31B-it-NVFP4 and
nvidia/Gemma-4-31B-IT-NVFP4, and add a no-network unit test covering both the
compressed-tensors remap and the modelopt pass-through.
END_PUBLIC

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Claude-Session: https://claude.ai/code/session_01LAekHUXqk6v5nLkA7R8Vxx
… into local-gemma4-nvfp4-integration

# Conflicts:
#	max/tests/integration/architectures/gemma4/BUILD.bazel
…s reciprocal scale

Local integration checkpoint (not for upstream as-is):
- attention.py/gemma4.py: QK-norms + flash-attn output use unquantized compute
  dtype instead of self.dtype (uint8 packed). Pre-existing gemma4 NVFP4 bug.
- weight_adapters.py: invert compressed-tensors global scales to modelopt
  convention (reciprocal) + squeeze [1]->(); the modular#6699 numerical fix.
Validated: RedHatAI/gemma-4-31B-it-NVFP4 -> 'The capital of France is Paris.'
at 28 tok/s on L40S.
…prefill/batch)

Despacha por M en _matmul_float4: nvfp4_gemv para M<=32 (decode bs bajo),
dequant a bf16 + GEMM tensor-core para M>32 (prefill/decode batched).
Umbral via env MAX_NVFP4_GEMM_M_THRESHOLD.
Throughput batched (64 conc, L40S): 79.6 -> 234 tok/s (2.9x). Single-stream
intacto (~23 tok/s). Salida verificada coherente en ambos paths.
…s batched

Kernel fusionado nvfp4_gemm.mojo (decode E2M1 en SMEM + MMA sincrono Ada,
sin W bf16 en DRAM) + registro mo.gemm.nvfp4 + binding Python + dispatch en
_matmul_float4 (M>32). Optimizado 13.9->63.9 TFLOP/s (decode estrecho a 8 FP4/vec
sube ocupancia). End-to-end: 79.6->425 tok/s batched (5.3x). Smoke test pasa.
Cuello restante (ncu): ocupancia limitada por SMEM (2 bloques/SM, 36.86KB/block).
…cks/SM

BEGIN_PUBLIC
[Kernels][GPU] NVFP4 GEMM: fix M<=64 race + single-buffer B for 4 blocks/SM

The M<=64 dequant-GEMM path used stage_w=True, cp.async-staging the packed
weight bytes through SMEM. That path had a latent write-after-read hazard on
the w_smem staging buffer (NS=2 too few stages), making the fused kernel
non-deterministically produce wrong results (~50% of runs failed the smoke
test with ~0.5 abs error). Decoding straight from DRAM (stage_w=False) is both
correct and equally fast here (the W->SMEM->W round-trip was pure SMEM-pipe
overhead at small M), and it drops dynamic SMEM 36.86KB -> 32KB, lifting the
shared-mem occupancy limit from 2 to 3 blocks/SM.

Make the decoded-B SMEM buffer depth a parameter and run the M<=64 path
single-buffered (b_pipeline_stages=1). A barrier fences the in-place decode
from the just-completed MMA reads. This cuts SMEM to 24.58KB -> 4 blocks/SM
(ncu: shared-mem block limit 4, max warps 33%). M=64 throughput 63.5 -> 69.5
TFLOP/s; M=256/512 unchanged. Smoke test passes deterministically.
END_PUBLIC

Co-Authored-By: Claude Opus 4.8 <[email protected]>
Claude-Session: https://claude.ai/code/session_01LAekHUXqk6v5nLkA7R8Vxx
msaelices added 21 commits June 20, 2026 23:00
Add TensorCore.load_b_nvfp4: decode packed E2M1 weights directly into the MMA
register fragments (Marlin bit-positioning, 2^14 bias folded into the scale),
and thread a comptime is_nvfp4 flag through multistage_mma_q ->
multistage_qgemm_kernel -> multistage_gemm_q (default False keeps the GGUF int4
path unchanged). New test_nvfp4_skeleton_gemm validates correctness vs an E2M1
dequant reference and benchmarks throughput.

On L40S (M=482, Llama shapes) this hits 136-184 TFLOP/s vs the bespoke
nvfp4_gemm at 66-112 (~2x), approaching the int4 skeleton's 174-235. The
arithmetic E2M1 decode stalled the tensor cores (80-108); the Marlin bit-trick
decode is what unlocks the throughput.
…46 TFLOP/s)

Adds a group=16/BK=16 timing case. Confirms the real NVFP4 grouping needs the
per-16-K-subgroup scale handling at BK=32, not a tiny BK=16 tile (which starves
the pipeline: 46 vs 177 TFLOP/s at group=128/BK=32).
group=128 hits 178 TFLOP/s but group<=64 collapses to 17 (group=16/BK=16: 46),
independent of pipeline stages. q_smem_usage is identical (~40KB) for group=32
and group=128, so it is not the SMEM fallback -- the cause is the mainloop scale
reload firing every tile at fine groups (the GGUF skeleton is tuned for
group=128). NVFP4's native group=16 lands on this cliff.
Re-measuring with a manual wall-clock (warmup + N back-to-back launches + one
synchronize) shows the NVFP4 skeleton runs at ~188 TFLOP/s for group=128/64/32
and ~285 for group=16/BK=16 -- the earlier 17 TFLOP/s 'cliff' was fixed
per-call overhead in ctx.execution_time crushing the fast kernels (ncu already
showed the kernel itself is ~200us cold for both groups). NVFP4's native
group=16 is therefore viable on the skeleton; no scale-staging rework needed.
Parametrize the test/reference by group_size/BK/BLOCK_K. group=32/BK=32 (one
scale group per tile, reload every tile) is bit-exact vs the E2M1 reference at
172 TFLOP/s. BK=16 is invalid (num_k_mmas=1 breaks the mainloop prefetch), so
NVFP4's group=16 needs BK>=32 with per-subgroup scale selection (next).
Add num_scale_sub = ceildiv(BK, group_size) and thread per-subgroup scale
selection through multistage_mma_q (scales_reg_tiles gains a subgroup column;
prologue/reload/prefetch loop over subgroups; load_b_nvfp4 takes a scale_sub
index = (mma_tile_coord_k * MMA_K) // group_size). All backward-compatible:
num_scale_sub == 1 leaves the GGUF int4 path and group==BK NVFP4 bit-exact
(group=32/BK=32 at 178 TFLOP/s, GGUF Q4 test still passes). group=16/BK=32
compiles and no longer NaNs but the subgroup->scale mapping still yields wrong
values -- WIP.
Flatten scales_reg_tiles to [num_n_mmas*num_scale_sub, 1] so each subgroup's
scales are contiguous (the strided column broke the vectorized copy), and add a
kernel-vs-kernel comparison (compare_g16_vs_g32) on identical weights + uniform
scales. With the scale read bypassed in load_b_nvfp4, group=16/BK=32 matches
group=32 exactly -> the weight decode and pipeline are correct at group<BK; the
remaining bug is purely the scale register fill/read for num_scale_sub>1.
group=32/BK=32 and the GGUF Q4 test still pass.
The remaining group=16/BK=32 failure is in the scale SMEM staging for
num_scale_sub>1: the every-tile reload reads scale stage 2 as 0.0 at k_tile=2
even though the prologue fills stages 0,1,2. A synchronous scale copy did not
fix it, so the cause is stage indexing/ordering, not async completion. Weights
and pipeline are correct (compare_g16_vs_g32 passes with the scale read
bypassed). group=32/BK=32 and GGUF Q4 remain bit-exact.
BEGIN_PUBLIC
[Kernels][GPU] Fix NVFP4 group<BK scale staging (group=16 now correct)

`num_scales_stages` counted scale *groups*, but the SMEM scale stage packs
`num_scale_sub = ceildiv(BK, group_size)` groups into a single
[num_scale_sub, BN] tile. For NVFP4 (group=16 < BK=32) this overcounted the
ring depth by num_scale_sub, so the mainloop write-ahead
`next_unsafe(num_scales_stages-1)` aliased a still-live stage and a scale
stage was read stale (0.0) mid-pipeline -- producing wrong results for the
real group=16 NVFP4 layout.

Track one packed scale stage per in-flight k-tile instead:
`ceildiv(num_pipeline_stages-1, max(group_size//BK, 1)) + 1`. This is
identical to the previous formula for group >= BK (GGUF Q4 unchanged) and
gives the correct depth for group < BK. Applied in the kernel, the host
allocation, and `q_smem_usage`.

Validated on L40S (sm_89): group=16 and group=32 NVFP4 are bit-exact vs the
E2M1 dequant reference at ~167 TFLOP/s; the GGUF Q4 path is unregressed. The
skeleton test now drives the reference at ref_block_k == group_size (its
supported regime) so the real group=16 layout is covered.
END_PUBLIC
… layout

BEGIN_PUBLIC
[Kernels][GPU] Add NVFP4 checkpoint repack to the multistage skeleton layout

`repack_nvfp4_for_sm8x` converts canonical NVFP4 checkpoint tensors --
weights [N, K/2] uint8 (E2M1, 2 codes/byte), FP8-e4m3 block scales [N, K/16],
and a per-tensor float32 global scale -- into the single combined buffer the
`multistage_gemm_q[..., is_nvfp4=True]` skeleton kernel consumes: repacked
4-bit weights in the 64x16 GPTQ tile layout followed by bf16 block scales in
row_major(K/group, N). The 4-bit weight bytes repack identically to int4, so
the weight path reuses pack_Q_tile and the GPTQ tile layout; only the scales
differ (fp8 * global -> bf16 instead of interleaved fp16).

Adds test_nvfp4_repack_e2e.mojo: it builds an INDEPENDENT dense bf16 reference
(canonical E2M1 LUT dequant + dense matmul, no repack) and asserts the full
repack -> skeleton-kernel chain matches it at gemma4-representative shapes
(N,K in {4096x4096, 15360x3840, 3840x15360}, M in {482, 64}), bit-close at
atol=1e-3/rtol=2e-2. Wall-clock microbench shows ~150-166 TFLOP/s at M=482 on
L40S (sm_89) -- ~2.5x the bespoke nvfp4_gemm.
END_PUBLIC
BEGIN_PUBLIC
[Kernels][GPU] Register NVFP4 skeleton custom ops (repack + qmatmul)

Expose the validated NVFP4-on-multistage-skeleton path to the graph compiler,
mirroring the GPTQ repack+matmul op pair:

- Host launchers in qmatmul_gpu.mojo: `matmul_gpu_nvfp4` (calls
  `multistage_gemm_q[..., is_nvfp4=True]` with the validated BK=32 config) and
  `gpu_nvfp4_repack` (launches `repack_nvfp4_for_sm8x`).
- Custom ops in builtin_kernels/quantization.mojo: `repack_nvfp4_g16`
  (weights + FP8 block scales + global f32 -> combined skeleton buffer; the
  global scale is read to host once, as the repack is a load-time constant
  transform) and `qmatmul_nvfp4_g16` (C = A @ dequant(W).T), each with a shape
  function.

This is the graph-side wiring only; routing `_matmul_float4` to it is a
separate change gated on serving validation.
END_PUBLIC
BEGIN_PUBLIC
[Kernels][GPU] Route NVFP4 large-M GEMM to the skeleton path (opt-in)

Add Python wrappers `nvfp4_skeleton_repack` / `nvfp4_skeleton_gemm` for the
`repack_nvfp4_g16` / `qmatmul_nvfp4_g16` custom ops, and route the
pre-Blackwell large-M branch of `_matmul_float4` through them when
`MAX_NVFP4_SKELETON_GEMM=1`. The repack folds the FP8 block scale with the
per-tensor global scale into bf16 (constant-folded at load), so the GEMM runs
on MAX's mature multistage quantized mainloop -- bit-exact and ~2.5x the
bespoke nvfp4_gemm at gemma4 shapes on L40S.

Default off: the live NVFP4 dispatch is unchanged until the flag is validated
token-identical in serving.
END_PUBLIC
… test

BEGIN_PUBLIC
[Kernels][GPU] Add engine-level NVFP4 skeleton-vs-bespoke equivalence test

Runs the bespoke `nvfp4_gemm` and the skeleton path (`nvfp4_skeleton_repack`
+ `nvfp4_skeleton_gemm`) in one graph on identical canonical NVFP4 weights and
asserts the outputs match (rtol 2e-2), through the real engine. This validates
the full graph wiring the Mojo-level test cannot: custom-op registration, the
one-time global-scale host read in the repack op, and the Python wrappers.

Verified on L40S (sm_89): 3 passed at shapes (482,4096,4096), (64,4096,4096),
(482,15360,3840).
END_PUBLIC
BEGIN_PUBLIC
[Kernels][GPU] Per-M autotuned configs for the NVFP4 skeleton GEMM

`matmul_gpu_nvfp4` used a single large-M tile (block 128x128, warp_k=1), which
starves small M: batched decode runs at M=batch, and at M=64 the large tile hit
only ~25 TFLOP/s -- worse than the bespoke kernel. Bucket the config by M (all
BK=32), autotuned on L40S at gemma4 shapes:

  M<=64  : block 32x64,  warp 32x64,  stages 3, warp_k 4
  M<=128 : block 64x64,  warp 64x64,  stages 4, warp_k 4
  M<=256 : block 64x128, warp 64x64,  stages 4, warp_k 2
  M>256  : block 128x128, warp 64x64, stages 4, warp_k 1   (unchanged large-M peak)

Small M uses tiny tiles + deep split-K to spread work across SMs (the lever the
bespoke int4 path uses); split-K is numerically correct for NVFP4 since the
scale staging is independent of the K partition. Measured peaks: M=64 ~100-110
(was ~25, a ~4x gain), M=128 ~140-156, M=256 ~158-170, M>256 ~162-168 TFLOP/s.
Engine equivalence test (skeleton vs bespoke) still passes at M=64 and M=482.
END_PUBLIC
BEGIN_PUBLIC
[Kernels][GPU] Add NVFP4 skeleton GEMM config-sweep tuning bench

test_nvfp4_tune.mojo sweeps MatmulConfig candidates (block tile, warp shape,
pipeline stages, split-K depth; all BK=32) for the NVFP4 skeleton GEMM at
gemma4 shapes, gating each on correctness vs the dense E2M1-dequant reference
before reporting wall-clock TFLOP/s. Used to pick the per-M configs in
matmul_gpu_nvfp4. Confirms split-K is numerically correct for NVFP4 and that
small-M throughput is ~4x higher with tiny tiles + deep split-K.
END_PUBLIC
…ench

BEGIN_PUBLIC
[Kernels][GPU] Add NVFP4 real-shape skeleton-vs-bespoke GEMM tuning bench

test_nvfp4_tune_real.mojo sweeps skeleton MatmulConfig candidates against the
bespoke nvfp4_gemm at the actual gemma4-31B decode GEMM shapes (gate/up
N=21504 K=5376, down N=5376 K=21504, q/o/kv) for M in {32,64,128,256}, gating
every TFLOP/s number on correctness vs the dense E2M1-dequant reference.

CAVEAT: this is an L2-WARM throughput microbench (it replays the same weight
many times), so its TFLOP/s overstate cold-weight serving performance. nsys
profiling of real batch-64 decode shows the FP4 proj/FFN GEMM is 79% of the
step and runs at only ~16% of HBM bandwidth on COLD weights -- that cold-weight
bandwidth efficiency (not warm throughput) is the real ~3x gap vs vLLM. A
faithful tuning bench for serving must read cold weights (L2 flush / distinct
weights per call); this harness is kept for warm-throughput comparison only.
END_PUBLIC
…tency)

BEGIN_PUBLIC
[Kernels][GPU] Add cold-weight NVFP4 GEMM microbench (throughput + latency)

test_nvfp4_cold_gemm.mojo measures the FP4 decode GEMM with COLD weights (a
round-robin pool of distinct weight buffers >> L2, validated to evict cache),
in both throughput (one sync at end) and latency (sync per call) modes, and
reports effective HBM bandwidth vs the 864 GB/s peak. This is the faithful tool
for the decode regime, where a prior L2-warm microbench was misleading.

Finding: at M=64 the skeleton GEMM reaches ~65-76% of HBM bandwidth cold (vs the
bespoke nvfp4_gemm at ~21-31%), and latency ~= throughput -- so the FP4 weight
streaming is NOT the serving bottleneck. nsys of the skeleton serve shows the
real cost is the repack op running every decode step (~52% of GPU time) because
the custom op is not constant-folded; the fix is to repack at load time.
END_PUBLIC
…ginal weight)

BEGIN_PUBLIC
[Kernels][GPU] Add byte-exact host NVFP4 repack (load-time, frees original weight)

nvfp4_repack_host (max/python/max/nn/_nvfp4_repack_host.py) reproduces the GPU
repack_nvfp4_for_sm8x kernel in pure numpy: canonical NVFP4 (packed uint8
weights + FP8-e4m3 block scales + f32 global) -> the combined skeleton buffer
([repacked 4-bit weights in the 64x16 GPTQ tile layout][bf16 block scales]).

This enables a load-time repack: the combined buffer becomes the sole
registered weight so the original packed weight is never uploaded -- avoiding
the 2x weight memory (OOM) that an in-graph repack hits, since MAX's
weights-registry uploads every declared weight regardless of graph folding.

Strictly validated: test_nvfp4_repack_host_gpu.py runs the actual GPU repack op
via the engine and asserts byte-for-byte equality (np.array_equal) at N/K in
{64x128, 128x256, 256x128, 512x256} -- both the weight bit-shuffle (pack_Q_tile
+ nested tile layout) and the bf16 scale region (incl. the per-64-col N
permutation and RNE f32->bf16 rounding).
END_PUBLIC
…espoke in serving

BEGIN_PUBLIC
[Kernels][GPU] Wire load-time NVFP4 repack: skeleton GEMM ~1.8-2.9x bespoke in serving

Route the pre-Blackwell NVFP4 matmul through the multistage-skeleton GEMM with
the weight repacked ONCE at load (behind MAX_NVFP4_SKELETON_GEMM=1, default
off). Previously the skeleton repacked the weight in-graph every decode step
(nsys: ~52% of decode) because MAX's weights-registry uploads every declared
weight, so an in-graph repack kept both the original and repacked weight
resident (2x memory / OOM). Now:

- gemma4 weight adapter repacks canonical NVFP4 (weights + FP8 block scales +
  global) into the combined skeleton buffer via the byte-exact nvfp4_repack_host
  and registers it as the SOLE weight; the original packed weight + scales are
  never uploaded.
- Linear declares the combined-shape weight (no scale weights) for the skeleton
  NVFP4 case; _matmul_float4 calls qmatmul_nvfp4_g16 directly (no repack op, no
  scales in the forward graph). StackedLinear/MLP run unfused skeleton GEMMs and
  concat (combined buffers are not concatenable along N).
- head_aware_col_sharding_strategy: single-device fast-path (returns the whole
  weight) so the combined buffer's column layout isn't sliced by head-aware K
  math (no-op for canonical weights).
- nvfp4_repack_host vectorized (the loop version was infeasible at 31B load);
  still byte-exact vs the GPU repack kernel.

Validated on L40S serving gemma-4-31B-it-NVFP4: original weight freed (no setup
OOM), coherent generation matching the bespoke path, and batched throughput
~347 vs ~118 tok/s at conc=16 (~2.9x) and ~815 vs ~447 at conc=64 (~1.8x) --
moving NVFP4 serving from ~33% to ~61% of vLLM. Flag-off path unchanged.
Caveat: the combined weight is slightly larger (inline bf16 scales), so
batch-64/maxlen-2048 is memory-marginal on a 46GB L40S; reduce batch / max-length
/ device-memory-utilization for headroom. Load adds a one-time host repack (~7min).
END_PUBLIC
…ter load)

BEGIN_PUBLIC
[Kernels][GPU] Parallelize the load-time NVFP4 host repack (~3.8x faster load)

The per-matrix host repack (nvfp4_repack_host) is the dominant load cost for the
skeleton NVFP4 path (~1s/matrix x hundreds of layers ~= 7min). Parallelize it
across a ThreadPoolExecutor: nvfp4_repack_host is a pure numpy function whose
heavy ops release the GIL. The dlpack/Buffer reads are kept in the main thread
(WeightData access is not thread-safe); only the materialized-numpy compute runs
in workers. Measured ~3.8x on a 16-core host (60.4s -> 15.7s for 56 gemma4-shaped
matrices), cutting the one-time skeleton load from ~7.5min to ~2min. Output is
byte-identical (same repack function); the byte-exact GPU validation is unchanged.
END_PUBLIC
@msaelices

Copy link
Copy Markdown
Contributor Author

Adversarial review — NVFP4 skeleton port

Reviewed the skeleton-port net diff only (726edea48f^..HEAD, ~4.5k lines / 18 files), not the stacked ancestors (#6668/#6698/#6699). Two independent adversarial passes (kernel/op side and Python/host-repack/wiring side), each told to assume every claim is wrong until the code proves it.

Verdict: no confirmed data-corruption bug in the net diff. The five load-bearing claims hold up under scrutiny, with a few precondition/wording/guard items worth fixing before un-drafting. Severity tags: 🔴 fix-before-merge · 🟡 concern · ⚪ nit.

Claims verified

  • group<BK scale staging — the num_scales_stages = ceildiv(num_pipeline_stages-1, max(group//BK,1))+1 fix is correct and general for group<BK (not just g16): ring depth == pipeline depth, prefetch writes the correct oldest slot, SMEM sizing uses the same formula (no overrun), and the stage→k-sub-tile pairing is off-by-one-clean across the tile boundary.
  • Marlin E2M1 decode + 2^14 fold through the staged path — bit-exact for all 16 codes incl. {0, ±0.5} and signed zero; the e==0 case rides through as an fp16 denormal; the scale product is computed in f32 before the bf16 cast.
  • flag-off ⇒ int4/GGUF byte-identical — every new behavior is behind comptime if is_nvfp4 / the _NVFP4_USE_SKELETON_GEMM env gate (default off; only exact "1" enables); for GGUF group≥BK the new loops degenerate to the original single-tile copy and the new stages formula equals the old whenever group % BK == 0. (One exception below.)
  • byte-exact host repack — traced _nvfp4_repack_host.py step-by-step against repack_nvfp4_for_sm8x/pack_Q_tile: nibble/bit layout, distribute[row_major(8,4)] thread→element mapping, nested destination layout, scale permutation (Layout((4,8),(16,1))), fp8-e4m3 decode, RNE bf16 round, region order/sizes — all match.
  • original weight freed / no double-register — combined uint8 buffer replaces the weight in the state dict, scale tensors are skipped, Linear registers only the combined buffer with weight_scale=weight_scale_2=None; no dangling ref, no GPU upload of the original. Parallel host repack is race-free (inputs materialized in the main thread; per-task output buffers).
  • no new SMEM/cp.async race — the scale staging reuses the proven int4 discipline (async_copy_wait_group + block-wide barrier() before the cross-warp read); the bespoke-path WAR race does not recur.

Fix before merge

  • 🔴 Unguarded partial-tile precondition (qmatmul_gpu.mojo): no comptime assert for N % BN == 0 / K % BK == 0 (GEMM) nor N % BN == 0 / K % WGROUP == 0 (repack). repack_nvfp4_for_sm8x has no N%BN masking and the weight copy_dram_to_sram is unguarded against K%1024 (only the write is guarded). Holds for gemma4 shapes (N mult of 128/256, K mult of 1024) but will OOB-read on a non-conforming shape. → add the asserts. (Confirm by running the repack on N%128≠0 / K%1024≠0 under --config=asan.)
  • 🟡 "bit-exact" is overstated. Both tests use loose tolerance (atol 1e-4…1e-3) AND the references bake in the same bf16 fold the kernel uses (scale fold at repack → product cast in decode = two bf16 roundings), so they cannot detect a fold-precision divergence from a true f32-accumulated reference. → soften the docstring/PR wording to "matches within bf16 rounding" and add one run against an independent f32/torch reference to bound the real error.
  • 🟡 Ungated default-path change (graph/weight.py:130-135): the new single-device early return in head_aware_col_sharding_strategy affects all single-device callers, not just NVFP4, and is not behind the flag. Equivalent to the prior slice for the standard column shapes (so almost certainly a latent-bug fix), but it rides into a flagged PR. → guard it to the combined-buffer shape or call it out explicitly.

Worth a guard (latent fragility, harmless today)

  • 🟡 e4m3fn scale-dtype gate vs MXFP4 e8m0 coercion (weight_adapters.py:89-90 vs :180): the skeleton fuse only matches float8_e4m3fn scales; an earlier pass coerces uint8 scales to e8m0. If an NVFP4 checkpoint ever delivered raw-uint8 block scales, the layer would be skipped by the fuse and mismatch the already-registered combined-buffer shape → loud load-time shape error (not silent). → assert fused-prefixes == skeleton-Linears, and load-test both modelopt and compressed-tensors (RedHatAI) NVFP4 variants under the flag.
  • q_smem_usage sizes the scale SMEM with size_of[a_type] not size_of[scales_type] — correct only because both are bf16 for NVFP4; an fp8-activation variant would under-allocate 2x → SMEM corruption. Use scales_type explicitly.
  • ⚪ K-recovery K = shape[1]//group_bytes*group_size is a group=16-only algebraic coincidence (5/8·16/10==1). Add comptime assert group_size == 16 on the NVFP4 entry (the op is already g16-hardcoded).
  • tensor_core.mojo load_b_nvfp4 docstring says "assumes BK == group_size" — contradicts the group<BK purpose; stale, fix to avoid future misuse.

Note on what the tests actually cover

The active flagged path never calls the in-graph repack_nvfp4_g16 op — production repacks on the host (nvfp4_repack_host); the GPU repack op is exercised only by tests. So production correctness rests entirely on the host replica matching the GEMM's expected layout. Make sure test_nvfp4_repack_e2e diffs host-output vs GPU-repack-output (not host-to-host) on a real gemma4 shape — that one parity test is the single load-bearing safeguard for the whole load-time path.


Adversarial review assisted by AI (two independent passes). Findings are code-trace based; the three items tagged "confirm by running" need a GPU run to settle.

…ing + docstrings

BEGIN_PUBLIC
[Kernels][GPU] Address NVFP4 skeleton review: shape guards + SMEM sizing + docstrings

Adversarial-review follow-ups (no data-corruption bug found; these are guards
and clarity):
- matmul_gpu_nvfp4: comptime-assert N % 128 == 0 and K % 32 == 0 (BK) so a
  non-conforming shape fails at compile time instead of OOB-reading the tiles.
- repack_nvfp4_for_sm8x: comptime-assert group_size == 16, N % 64 == 0 (repack
  tile), K % 128 == 0 (WGROUP) -- the divisibility the repack relies on.
- q_smem_usage: size the scale SMEM by bfloat16 (the kernel's scales_type)
  explicitly, not the activation dtype, so an fp8-activation variant would not
  under-allocate it.
- load_b_nvfp4 docstring: correct the stale "assumes BK == group_size" note to
  document group_size <= BK (the scale_sub subgroup selection).

All asserts hold for gemma4 + the test shapes; the repack e2e and engine
equivalence tests still pass.
END_PUBLIC
@msaelices

Copy link
Copy Markdown
Contributor Author

Review follow-ups addressed (pushed in 8f8db2f955)

Thanks for the two-pass adversarial review. No data-corruption bug ⇒ I addressed
the guards/clarity items:

  • 🔴 Partial-tile preconditions — added comptime asserts: matmul_gpu_nvfp4
    now guards N % 128 == 0 and K % 32 == 0 (BK); repack_nvfp4_for_sm8x guards
    group_size == 16, N % 64 == 0 (repack tile) and K % WGROUP(128) == 0. So a
    non-conforming shape fails at compile time instead of OOB-reading. (Note: the
    repack tiles in 128-wide WGROUP chunks, not 1024 — all gemma4 K and the test
    shapes satisfy K % 128 == 0.)
  • q_smem_usage scale dtype — now sizes the scale SMEM by bfloat16 (the
    kernel's scales_type) explicitly, so an fp8-activation variant can't under-allocate.
  • load_b_nvfp4 docstring — corrected the stale "assumes BK == group_size"
    to document group_size <= BK via the scale_sub subgroup selection.
  • 🟡 "bit-exact" wording — softened in the PR description to "matches within bf16
    rounding (loose tol)" for the dequant-reference tests. (The host-vs-GPU repack
    parity test stays genuinely byte-exactnp.array_equal — which is the
    single load-bearing safeguard you flagged, and it's green.)

Acknowledged / notes

  • 🟡 graph/weight.py single-device early return — kept, with an explicit
    comment: at num_devices == 1 the head-aware math already returns the whole
    weight to device 0, so the early return is behavior-preserving for standard
    column shapes and only additionally handles the combined-buffer layout. Called
    out in-code.
  • 🟡 e4m3fn fuse gate vs MXFP4 e8m0 — the failure mode is a loud load-time
    shape error, not silent corruption; the byte-exact host-vs-GPU repack parity test
    guards the layout. A fused-prefixes == skeleton-Linears assert is reasonable
    future hardening.

Re-validated after syncing main (toolchain bump) + these fixes

  • Skeleton GEMM correctness (g16/g32): ~167 TFLOP/s, bit-close to the E2M1 reference.
  • Byte-exact host repack vs the GPU kernel: pass (all shapes).
  • Engine equivalence (skeleton vs bespoke): 3 passed.
  • 31B serving @conc64: ~647 tok/s (≈ pre-merge; ~1.5x bespoke).

Assisted by: AI.

BEGIN_PUBLIC
[Kernels][GPU] Fix CI: load_b_nvfp4 Args docstring + apply formatter

mojodoc failed because TensorCore.load_b_nvfp4 takes arguments but its docstring
had no Args section; add it. Also apply the repo formatter (mblack/ruff) to the
NVFP4 skeleton sources and tests that were not format-clean (lint failure).
END_PUBLIC

Assisted-by: AI
…ack import

BEGIN_PUBLIC
[Kernels][GPU] Fix CI round 2: ruff F841, mypy weight_scale, lazy repack import

- Remove unused locals (BN/BK/BK_groups) in the host repack (ruff F841).
- quantized_matmul: widen weight_scale to TensorValue | None (the NVFP4 skeleton
  path folds it away) and assert non-None in the MXFP4/MXFP8/FP8 branches that
  require it (mypy arg-type).
- Import nvfp4_repack_host lazily inside the fuse function so the numpy host-repack
  module stays off the `max ... --help` startup path (CLI help perf test).
END_PUBLIC

Assisted-by: AI
…ge-skeleton

Sync the NVFP4 skeleton branch with 38 upstream commits. The only conflict
is the `.kernels` import block in nn/quant_ops.py: upstream adds
`_fused_qkv_ragged_matmul_scaled_mxfp8` (fused MXFP8 QKV) next to this
branch's `_is_pre_sm100_nvidia_gpu` import. Resolved by keeping both.

Everything else auto-merges (moe_fp8.py, pipelines/weights/quant.py,
nn/kernels.py, test/gpu/linalg/BUILD.bazel). No semantic drift: the
QuantStrategy protocol widening predates this branch, so Nvfp4DequantStrategy
already conforms.

Assisted-by: AI
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant