Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[Kernels][Pipelines] NVFP4 fallback for pre-Blackwell NVIDIA GPUs via fused dequant-GEMV#6668

Open
msaelices wants to merge 11 commits into
modular:mainfrom
msaelices:nvfp4-pre-blackwell-fallback
Open

[Kernels][Pipelines] NVFP4 fallback for pre-Blackwell NVIDIA GPUs via fused dequant-GEMV#6668
msaelices wants to merge 11 commits into
modular:mainfrom
msaelices:nvfp4-pre-blackwell-fallback

Conversation

@msaelices

@msaelices msaelices commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Summary

Serve NVFP4 checkpoints on GPUs without native FP4 matmul (pre-Blackwell NVIDIA). Fixes #6667.

The fallback is a Marlin-style fused dequant-GEMV: packed E2M1 is decoded in registers inside the matmul kernel, so no BF16 weight copy is ever materialized and per-token DRAM traffic is the packed bytes (~0.5 B/element vs ~4.5 for dequantize-then-matmul).

  • nvfp4_gemv / mo.gemv.nvfp4: fused kernel behind _matmul_float4, which covers every NVFP4 linear in-tree (dense/MLP and the StackedLinear QKV projections). The unreachable NVFP4 branch of quantized_fused_qkv_matmul now raises a clear error instead of carrying dead fallback code.
  • dequant_nvfp4 / mo.dequant.nvfp4: NVFP4 generalization of the MXFP4 dequant kernel, kept for the MoE Nvfp4DequantStrategy.
  • cast_uint_to_fp4e2m1 rewritten branchless: the old LUT's runtime SIMD indexing spilled to local memory on every element (35x slowdown).
  • Active only off SM100+ (Blackwell unchanged); clear errors for pre-interleaved (TCGEN 5D) scales and expert-parallel NVFP4.

Performance (NVIDIA A10G, sm_86, 24 GB; gemma-4 12B NVFP4, batch 1)

Stage Warm decode
Initial dequantize-then-matmul fallback 0.38 tok/s
+ branchless E2M1 decode (LUT spill fix) 0.94 tok/s
+ fused GEMV (all NVFP4 linears, incl. QKV via StackedLinear) 31.5 tok/s (83x)
vLLM nightly Marlin, same checkpoint and GPU, eager ~16 tok/s

Measured with device graph capture on (serving config, not part of this PR); first request after startup pays ~17 s of capture warmup. A 58-token prefill plus 150-token generation sustains 29.6 tok/s. 31.7 ms/token is approximately the bandwidth bound for reading the packed weights, i.e. the fused kernel leaves little on the table at batch 1. Model output is token-identical to the dequantize-then-matmul path at temperature 0.

Verification

CPU-reference smoke tests for both kernels (test_nvfp4_gemv_smoke, test_nvfp4_dequant_smoke); nvidia/Llama-3.1-8B-Instruct-NVFP4 and berkerdooo/gemma-4-12B-it-NVFP4 (with #6670) serve E2E with coherent output.

Out of scope (tracked in #6667): compressed-tensors NVFP4 parsing, Gemma 4 activation-memory headroom, tensor-core GEMM tile for large-batch prefill.

Assisted-by: AI

@msaelices

Copy link
Copy Markdown
Contributor Author

Perf follow-up: profiling the fallback on an A10G found the software LUT was the bottleneck — cast_uint_to_fp4e2m1 indexed a SIMD[float32, 16] with a runtime index, which spills the vector to GPU local memory on every element (kernel ran at ~12 GB/s effective, fully local-memory bound). Replaced with a branchless exp/mantissa decode in 4135fb1: the dequant kernel now hits 422 GB/s (~35x) on the same device, verified bit-exact by the smoke test. End-to-end on the runtime-dequant serving experiment this was 2.5x (0.38 -> 0.94 tok/s on a 12B; remaining gap vs a fused Marlin-style dequant-GEMM is per-op dispatch overhead and the 3x memory traffic of dequantize-then-matmul).

@msaelices

Copy link
Copy Markdown
Contributor Author

Implemented the Marlin-style fused dequant-GEMV (4831155): one warp per output column decodes packed E2M1 in registers and dot-products against BF16 activations — the dequantized weight is never materialized, so per-token DRAM traffic drops ~9x and the per-forward transient buffers disappear. _matmul_float4's pre-Blackwell path now routes through it (QKV keeps dequantize-then-matmul for now since fused_qkv_ragged_matmul owns the KV-cache write).

A10G, gemma-4 12B NVFP4, bs=1: per-linear 2.7 ms -> 0.133 ms through the graph executor (20x); end-to-end decode 0.38 -> 4.98 tok/s cumulative with the E2M1 LUT fix (vLLM nightly's Marlin on the same checkpoint: ~16 tok/s eager — remaining gap is the QKV path, M_TILE prefill scaling, and per-op dispatch). Correctness: scalar CPU reference in test_nvfp4_gemv_smoke, plus token-identical model output vs the dequant path.

@msaelices

Copy link
Copy Markdown
Contributor Author

QKV now also routes through the fused GEMV (split + store_k/v_cache_ragged, with K-equals-V handling for Gemma 4 full-attention layers) — no BF16 copy of any quantized weight is materialized anywhere on the fallback anymore.

Final numbers on the A10G (gemma-4 12B NVFP4, bs=1, warm decode): 31.5 tok/s, vs 0.38 at the start of this work (83x) and ~16 tok/s for vLLM nightly's Marlin on the same checkpoint/GPU. 58-token prefill + 150-token generation sustains 29.6 tok/s. Output stays token-identical to the dequant path. Earlier per-request numbers in this thread (~5 tok/s) were first-request measurements that included device graph capture warmup.

@msaelices msaelices changed the title [Kernels][Pipelines] NVFP4 dequant fallback for pre-Blackwell NVIDIA GPUs [Kernels][Pipelines] NVFP4 fallback for pre-Blackwell NVIDIA GPUs via fused dequant-GEMV Jun 12, 2026
@msaelices msaelices marked this pull request as ready for review June 12, 2026 14:56
@msaelices msaelices requested review from a team as code owners June 12, 2026 14:56
Copilot AI review requested due to automatic review settings June 12, 2026 14:56

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Adds a pre-Blackwell (pre-SM100) NVFP4 execution path by introducing fused NVFP4 dequant-GEMV and dequant kernels, then wiring them into dense and MoE inference to enable serving NVFP4 checkpoints on older NVIDIA GPUs.

Changes:

  • Introduces NVFP4 fused dequant-GEMV (mo.gemv.nvfp4) and NVFP4 dequant (mo.dequant.nvfp4) kernels + graph-compiler registrations.
  • Routes NVFP4 matmul / fused QKV paths to the new fallback on pre-SM100 NVIDIA GPUs, including explicit KV-cache stores.
  • Adds an MoE NVFP4 “dequant-to-BF16 then BF16 grouped matmul” fallback strategy and GPU smoke tests for the new kernels.

Reviewed changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 10 comments.

Show a summary per file
File Description
max/python/max/nn/quant_ops.py Adds pre-SM100 NVIDIA NVFP4 fallback paths for float4 matmul and fused QKV, including KV-cache writes.
max/python/max/nn/moe/quant_strategy.py Adds Nvfp4DequantStrategy that dequantizes NVFP4 expert weights to BF16 and uses ragged grouped matmul.
max/python/max/nn/moe/moe_fp8.py Selects the NVFP4 dequant fallback strategy on pre-SM100 NVIDIA; blocks EP NVFP4 on those GPUs.
max/python/max/nn/moe/init.py Exports the new NVFP4 dequant strategy.
max/python/max/nn/kernels.py Adds Python wrappers for nvfp4_gemv, nvfp4_dequant, and a pre-SM100 NVIDIA detection helper.
max/kernels/src/linalg/nvfp4_gemv.mojo Implements the fused NVFP4 dequant-GEMV kernel.
max/kernels/src/linalg/mxfp4_dequant.mojo Generalizes FP4 dequant into a shared kernel and adds NVFP4 dequant entrypoint.
max/kernels/src/linalg/fp4_utils.mojo Reworks FP4 E2M1 decode to avoid SIMD spills and dramatically speed up dequant.
max/kernels/src/graph_compiler/builtin_kernels/quantization.mojo Registers mo.dequant.nvfp4 and mo.gemv.nvfp4 builtins for graph execution.
max/kernels/test/gpu/linalg/test_nvfp4_gemv_smoke.mojo Adds correctness + quick throughput smoke test for NVFP4 dequant-GEMV.
max/kernels/test/gpu/linalg/test_nvfp4_dequant_smoke.mojo Adds smoke test for NVFP4 dequant with multiple dtypes/shapes.
max/kernels/test/gpu/linalg/BUILD.bazel Wires the new GPU tests into Bazel with platform constraints.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +5365 to +5376
def _is_pre_sm100_nvidia_gpu() -> bool:
"""Checks for an NVIDIA GPU without native FP4 matmul (pre-Blackwell).

Deliberately False on AMD and Apple GPUs: the NVFP4 dequant fallback is
only validated on NVIDIA, so other vendors keep their previous behavior
(erroring at the Blackwell-only kernels).
"""
try:
arch = accelerator_architecture_name()
except Exception:
return False
return arch.startswith("sm_") and not arch.startswith("sm_10")
Comment on lines +5279 to +5284
def nvfp4_dequant(
packed_weights: TensorValue,
scales: TensorValue,
out_type: DType = DType.bfloat16,
) -> TensorValue:
"""Dequantizes NVFP4 packed weights to BF16 or FP8 on GPU.
Comment thread max/python/max/nn/quant_ops.py Outdated
Comment on lines +76 to +81
def _dequant_weight_nvfp4(
weight: TensorValue,
weight_scale: TensorValue,
weight_scale_2: TensorValue,
scales_pre_interleaved: bool,
) -> TensorValue:
Comment thread max/python/max/nn/quant_ops.py Outdated
Comment on lines +89 to +98
if scales_pre_interleaved:
raise ValueError(
"NVFP4 checkpoints with pre-interleaved (TCGEN 5D) scales are"
" not supported on pre-Blackwell GPUs: the dequant fallback"
" needs the flat [N, K//16] scale layout."
)
scales_f32 = weight_scale.to(weight.device).cast(
DType.float32
) * weight_scale_2.to(weight.device)
return nvfp4_dequant(weight, scales_f32, out_type=DType.bfloat16)
Returns:
The output tensor in bf16.
"""
if _is_pre_sm100_nvidia_gpu():
Comment on lines +135 to +142
if scales_pre_interleaved:
raise ValueError(
"NVFP4 checkpoints with pre-interleaved (TCGEN 5D) scales"
" are not supported on pre-Blackwell GPUs"
)
scales_f32 = weight_scale.to(weight.device).cast(
DType.float32
) * weight_scale_2.to(weight.device)
Comment on lines +433 to +449
# Pre-multiply the per-tensor scale into the block scales:
# [E, 1, 1] * [E, N, K//16] -> float32 [E, N, K//16].
scales_f32 = weight_scales.cast(DType.float32) * ops.reshape(
expert_scales.to(weight_scales.device).cast(DType.float32),
[expert_scales.shape[0], 1, 1],
)
# Activations are BF16 on this path by construction; self.dtype is
# the packed storage dtype (uint8) and must not be used here.
dequanted = nvfp4_dequant(weight, scales_f32, out_type=DType.bfloat16)

return grouped_matmul_ragged(
hidden,
dequanted,
expert_start,
expert_ids,
usage_stats,
)
Comment on lines +430 to +431
if usage_stats.device.is_gpu():
usage_stats = usage_stats.to(DeviceRef.CPU())
# Correctness + quick throughput check for the fused NVFP4 dequant-GEMV.
# ===----------------------------------------------------------------------=== #

from std.math import ceildiv
from std.time import perf_counter_ns

from std.gpu.host import DeviceContext
from layout import Coord, TileTensor, row_major
msaelices added 11 commits June 13, 2026 22:39
…GPUs

Serve NVFP4 checkpoints on GPUs without native FP4 matmul support
(sm_8x/sm_90) by dequantizing weights to BF16 with a software-LUT kernel
and routing through the regular BF16 matmuls, mirroring the existing
MXFP4-on-H100 dequant approach. Fixes modular#6667.

Kernels:
- Generalize the MXFP4 dequant kernel into a format-agnostic _dequant_fp4
  (block-scale dtype and SF_VECTOR_SIZE were already parameters) and add a
  dequant_nvfp4 entry point: E4M3 or pre-multiplied float32 scales,
  SF_VECTOR_SIZE=16. The E2M1 decode is a software LUT, so it runs on any
  GPU. Register as mo.dequant.nvfp4.
- Smoke test (CPU reference vs GPU) covering both scale dtypes, several
  shapes and an FP8 output; unlisted in _EXTRA_CONSTRAINTS so it runs on
  any GPU.

Python (active only when not on SM100+):
- kernels.nvfp4_dequant() wrapper (rank 2/3, flat [.., K//16] scales).
- quant_ops: _matmul_float4 and the NVFP4 fused-QKV case dequantize the
  weight (weight_scale_2 pre-multiplied into float32 scales) and use the
  unquantized matmul / fused_qkv_ragged_matmul. Pre-interleaved (TCGEN 5D)
  scale checkpoints are rejected with a clear error.
- MoE: new Nvfp4DequantStrategy — activations stay BF16 (no dynamic
  quantization), expert weights dequantized per matmul, BF16
  grouped_matmul_ragged. Selected in MoEQuantized._strategy(); expert
  epilogue scales carry only weight_scale_2 (no activation input-scale
  factor, since activations are never quantized on this path).

Blackwell behavior is unchanged: every fallback is gated on
not _is_sm10x_gpu().
END_PUBLIC

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
Two fixes found while validating end to end on an A10G (sm_86):

- _dequant_weight_nvfp4: move weight_scale_2 (a CPU rank-0 weight) to the
  weight's device before multiplying into the block scales ("Input values
  must be on the same device").
- Nvfp4DequantStrategy.grouped_matmul: dequantize to bfloat16 explicitly;
  self.dtype is the packed storage dtype (uint8) on this path and tripped
  the kernel's output-dtype constraint at graph compile time.

Validated: nvidia/Llama-3.1-8B-Instruct-NVFP4 (dense, modelopt) serves on
the A10G and generates coherent text. nvidia/Gemma-4-26B-A4B-NVFP4 (MoE)
compiles and loads weights through the fallback, but 18.8 GiB of packed
weights exhaust the 24 GiB card during setup (allocator reports 18.73 GiB
in use, largest free block 359 MB) -- a checkpoint-size limit, not a
fallback defect.

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
- moe/__init__.py: restore isort order in the import block and __all__
  (ruff I001/RUF022 with the CI-pinned ruff).
- moe_fp8.py: raise an explicit NotImplementedError when selecting the
  NVFP4 dequant strategy with an active expert-parallel batch manager --
  EP dispatch quantizes activations to FP4, so the fallback would
  otherwise die later in grouped-matmul shape validation with a
  confusing error.
- Smoke test: vary block scales per index so wrong scale indexing cannot
  pass unnoticed (previously uniform scales masked that class of bug);
  re-run on the A10G, all cases still PASS with max_err=0.0.
- BUILD.bazel: constrain the smoke test to nvidia_gpu|amd_gpu (mirrors
  the mxfp4 smoke test) instead of enrolling untested Apple CI.

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC
[Kernels][GPU] Branchless E2M1 decode in cast_uint_to_fp4e2m1

The software LUT indexed a SIMD[float32, 16] table with a runtime
index, which spills the vector to GPU local memory on every element.
That made the FP4 dequant kernels local-memory bound: 12 GB/s effective
on an A10G (sm=100%, dram=99% from spill traffic). Replacing the lookup
with branchless bit arithmetic (exp/mantissa decode + select) makes the
kernel memory-bound as intended: 422 GB/s on the same device, a 35x
kernel speedup, 2.5x end-to-end on the runtime-dequant serving path.

Verified with test_nvfp4_dequant_smoke (bit-exact vs CPU reference).
END_PUBLIC

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC
[Kernels][GPU] Fused NVFP4 dequant-GEMV for the pre-Blackwell fallback

Marlin-style fused kernel: one warp per output column decodes packed
E2M1 in registers (branchless, no LUT) and dot-products against the
BF16 activations — the dequantized weight is never materialized in
global memory. Per-token DRAM traffic drops from ~4.5 B/element
(dequantize-then-matmul: bf16 write + re-read + packed read) to the
packed ~0.5 B/element, and no per-forward transient buffers exist.
Activation rows are tiled (M_TILE=4) so a P-token prefill re-reads the
packed weight ceil(P/4) times; decode (M=1) is the optimized case.

_matmul_float4's pre-Blackwell path (o_proj and dense MLP projections)
now routes through the fused op. The QKV path keeps dequantize-then-
matmul for now (fused_qkv_ragged_matmul owns the KV-cache write).

Measured on an NVIDIA A10G (gemma-4 12B NVFP4, bs=1):
- kernel: 0.126 ms for [15360x3840] at M=1 (350 GB/s of packed reads)
- per-linear through the graph executor: 2.7 ms -> 0.133 ms (20x)
- end-to-end decode: 0.94 -> 4.98 tok/s (5.3x; 13x vs the original
  fallback). Output is token-identical to the dequant path.

Correctness verified against a scalar CPU reference for M={1,3,5},
including non-multiple-of-warp N (test_nvfp4_gemv_smoke).
END_PUBLIC

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC
[Kernels] Use CamelCase names for the new NVFP4 op structs

Rename Struct_dequant_nvfp4 -> NVFP4Dequant and Struct_gemv_nvfp4 ->
NVFP4Gemv. The Struct_ prefix is the legacy machine-generated style;
the hand-written registrations in quantization.mojo use CamelCase.
Pre-existing Struct_* registrations are left untouched.
END_PUBLIC

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC
[Kernels] Route the pre-Blackwell QKV projection through the fused GEMV

The QKV path was the last consumer of dequantize-then-matmul on the
pre-Blackwell NVFP4 fallback. Compute QKV with the fused dequant-GEMV,
then split and store K/V into the paged cache via store_k_cache_ragged
and store_v_cache_ragged (fused_qkv_ragged_matmul owns that write on
the unquantized path, which is why this stays decomposed). K-equals-V
layers (Gemma 4 full attention packs only Q and K rows in wqkv) store
the K projection into both caches.

With this no BF16 copy of any quantized weight is ever materialized.
End-to-end on an NVIDIA A10G (gemma-4 12B NVFP4, bs=1), warm decode:
31.5 tok/s vs 0.38 at the start of the fallback work (83x), with
token-identical output. For reference, vLLM nightly's Marlin path
serves the same checkpoint at ~16 tok/s eager on the same GPU.
END_PUBLIC

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
… fallback

BEGIN_PUBLIC
[Kernels][Pipelines] Address adversarial-review findings on the fused fallback

- Gate the NVFP4 fallback on pre-SM100 *NVIDIA* explicitly
  (_is_pre_sm100_nvidia_gpu): _is_sm10x_gpu() is False on AMD/Apple,
  which silently routed those vendors onto an unvalidated dequant path;
  they now keep their previous behavior.
- moe_fp8: fix a comment describing the Blackwell scale-epilogue
  mechanism on the fallback branch (the per-tensor scale is folded into
  the dequant block scales; the BF16 grouped matmul has no epilogue).
- Soften the dequant docstrings: raw float8_e4m3fn block scales are
  accepted but only exercised by tests today (the modelopt parser
  always provides a per-tensor scale).
END_PUBLIC

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC
[Kernels] Address cycle-2 adversarial-review findings on the fused GEMV

- nvfp4_gemv: all threads now enter the PDL region (the early return for
  tail-block threads moved inside it) and the dead pdl_level parameter
  is gone (it was accepted but never forwarded).
- quant_ops: drop _dequant_weight_nvfp4 (no callers since the fused
  reroute), refresh _matmul_float4's docstring, and validate the wqkv
  layout explicitly instead of silently aliasing K as V for anything
  that is not exactly Q+K+V or Q+K rows.
- Stale software-LUT wording updated after the branchless decode.
- mojo format on the two new kernel files; smoke test now covers >1
  K-chunk per lane (K=2048) and the production K=3840.
END_PUBLIC

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
BEGIN_PUBLIC
[Kernels] Address cycle-3 adversarial-review findings

- Replace the NVFP4 branch of quantized_fused_qkv_matmul with a clear
  error: no in-tree model reaches it (callers assert on input_scale
  first, and NVFP4 QKV routes through StackedLinear/_matmul_float4,
  which is already fused). _dequant_weight_nvfp4 went unused with it.
- Add the Apache license header to test_nvfp4_gemv_smoke.mojo.
END_PUBLIC

Assisted-by: AI
Signed-off-by: Manuel Saelices <[email protected]>
@msaelices msaelices force-pushed the nvfp4-pre-blackwell-fallback branch from 1b7eb1d to c781b87 Compare June 13, 2026 22:40
modularbot pushed a commit that referenced this pull request Jun 16, 2026
…4 12B) and serve them text-only (#88920)

[External] [Pipelines] Load public gemma4_unified checkpoints (Gemma 4
12B) and serve them text-only

## Summary

Makes the public `gemma4_unified` checkpoints (the Gemma 4 12B line:
`google/gemma-4-12b-it` and derived quants) loadable by MAX, serving
them **text-only**. Fixes the first two blockers of #6669
and implements the text-only path proposed there.

- Register an `AutoConfig` shim for `model_type: "gemma4_unified"`
(transformers < 5.5 path) and `gemma4_unified_arch`, a
`dataclasses.replace` alias of `gemma4_arch` under the public
`Gemma4UnifiedForConditionalGeneration` name — these checkpoints carry
the regular Gemma 4 language layout with **no bundled MTP draft
weights**, so they belong to `gemma4`, not `unified_mtp_gemma4`.
- Text-only serving: the unified checkpoints use a lightweight
`vision_embedder` with a different schema than `Gemma4VisionModel`;
detect it, set `vision_config=None`, skip building/loading the vision
graph, and return 0 vision-cache bytes.
- `finalize` now passes the real `model.language_model.` prefixes to
`parse_quant_config`, so per-layer quantized/ignored classification
stops looking up `model.layers.*` keys that never exist (ignore-listed
BF16 attention in modelopt 12B quants was never recognized as ignored).
- Honor `quant_config.attn_quantized_layers` per layer instead of
assuming attention is always BF16 under NVFP4 (true for nvidia
first-party checkpoints, false for the community 12B quants), and drop
modelopt `k_scale`/`v_scale` KV-cache scales in the weight adapter.

Self-review (adversarial pass) caught and fixed an inference-path crash
before any human review: `prepare_initial_token_inputs` read
`vision_config.pooling_kernel_size` unconditionally, so text-only
serving would have failed on the first prefill; image requests now get a
clear rejection and the unified alias declares TEXT-only modalities.

## Verification (NVIDIA A10G 24 GB; bazel-built serve)

Each fix was driven by serving real checkpoints; the chain `config →
arch resolution → graph build → weight load` now completes for
`google/gemma-4-12b-it`-family checkpoints where each stage failed
before (`gemma4_unified` not recognized → no matching architecture →
vision `AttributeError` → o_proj shape mismatch → unexpected
`k_scale`/`v_scale`).

**Full text-generation E2E verified on the A10G**
(`berkerdooo/gemma-4-12B-it-NVFP4`, attention quantized): correct,
coherent EN/ES completions via `/v1/chat/completions`. With
#6668's fused dequant-GEMV (no BF16 weight
materialization), the 12B sustains **31.5 tok/s warm decode at batch 1
on the 24 GB card** (measured on a branch combining this PR with
#6668) — ~2x vLLM nightly's Marlin path on the same
checkpoint and GPU — with token-identical output across the fallback's
kernel revisions. Note: on a 24 GB card the stock activation-memory
estimate (~15 GiB headroom, sized for the vision configs) still blocks
startup — the measurement lowered it locally; right-sizing that estimate
for text-only unified configs is intentionally left out of this PR
pending maintainer guidance.

Related: #6669 (blocker 3 — implementing the unified
`vision_embedder` — remains open), #6668,
#6665.

Assisted-by: AI
ORIGINAL_AUTHOR=Manuel Saelices <[email protected]>
ORIGINAL_USER=@msaelices

---------

Co-authored-by: Manuel Saelices <[email protected]>
Co-authored-by: Kathy Wu <[email protected]>
Closes #6670
MODULAR_ORIG_COMMIT_REV_ID: baef6c1c5c38e4923a55ece1caf8d6d2e42f9a98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] NVFP4 dequant fallback for pre-Blackwell NVIDIA GPUs (serve Gemma 4 NVFP4 on Ampere)

2 participants