Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[Kernels][GPU] Route block_scaled_matmul (NVFP4) to naive kernel on consumer Blackwell (sm_121)#6596

Open
lightofbaldr wants to merge 6 commits into
modular:mainfrom
lightofbaldr:fix/nvfp4-dispatch-sm121
Open

[Kernels][GPU] Route block_scaled_matmul (NVFP4) to naive kernel on consumer Blackwell (sm_121)#6596
lightofbaldr wants to merge 6 commits into
modular:mainfrom
lightofbaldr:fix/nvfp4-dispatch-sm121

Conversation

@lightofbaldr

Copy link
Copy Markdown
Contributor

Part of #6570 (sm_120 / sm_121 consumer Blackwell support epic).
Stacked on #6593 (sm_121a target) and #6594 (NVFP4 enablement). Until those merge, the diff here also includes their changes.

What

block_scaled_matmul hard-asserted SM100 (B200) only, so NVFP4 was unreachable from the inference dispatch on consumer Blackwell. This routes it:

  • Guard the SM100 warp-specialized body in comptime if ctx.default_device_info.compute == B200.compute.
  • Add elif ctx.default_device_info.compute >= 12.0 (sm_120 / sm_121, compute 12.x) → arch-agnostic CUDA-core naive_block_scaled_matmul.
  • else → unsupported-arch assert.

The device check uses ctx.default_device_info (the actual GPU), not a global arch predicate — block_scaled_matmul's target param defaults to "cpu", so global predicates would read CPU.

To accept the dispatch's immutable inputs, the naive kernel's read-only operands (a, b, a_scales, b_scales) and the get_scale_factor reader move from MutAnyOriginImmutAnyOrigin; only the output c stays mutable. Mutable callers are unaffected (mutable borrows as immutable).

Diff note

Most of the fp4_quantization.mojo diff is whitespace — the existing SM100 body is unchanged but re-indented one level under the new comptime if.

Test

test_block_scaled_fp4_dispatch_sm121.mojo drives block_scaled_matmul (not the naive kernel directly), so on sm_121 it exercises the routing. Known-answer construction (all-1.0 FP4 inputs, unit scales ⇒ C[i, j] == K). Verified on a DGX Spark (GB10): routes to naive, all 65,536 output elements exact.

lightofbaldr and others added 2 commits June 4, 2026 09:02
…sm_121)

BEGIN_PUBLIC
[Stdlib][GPU] Enable NVFP4 (block-scaled FP4) on consumer Blackwell (sm_121)

Builds on the sm_121a target fix (modular#6593). With the arch-specific target in
place, two stdlib gaps still blocked the arch-agnostic
`naive_block_scaled_matmul` (NVFP4) path on the DGX Spark (GB10, sm_121):

1. sm_121 was absent from the architecture-capability predicates. Add an
   `_SM_121X_ARCHS` set with `_is_sm_121x` / `_has_sm_121x`, and include it
   in `_is_sm_120x_or_newer` / `_has_sm_120x_or_newer` so the GB10 is
   recognized as a Blackwell-class arch (this is what gates the FP4
   `cast_f4e2m1x2_to_fp16x2` decode helper, among others).

2. The float8 -> float hardware conversion fast-paths emit a `pop.cast`
   whose lowering is not implemented for the sm_121 backend, so on sm_121
   route float8 conversions to the existing pure bit-manipulation fallback
   (always correct; the only cost is the hardware-accelerated path).

Together with modular#6593 these let a block-scaled FP4 (NVFP4) GEMM compile,
launch, and produce numerically correct results on the GB10. Adds a
known-answer correctness test for the naive NVFP4 path (no reference
matmul needed: all-1.0 FP4 inputs with unit scales => C[i, j] == K),
which also exercises the path on existing supported GPUs.

Both stdlib changes are scoped to sm_121 and do not affect other targets.
END_PUBLIC

Co-Authored-By: Claude Opus 4.7 <[email protected]>
Signed-off-by: Adam Kruger <[email protected]>
…onsumer Blackwell (sm_121)

BEGIN_PUBLIC
[Kernels][GPU] Route block_scaled_matmul (NVFP4) to naive kernel on consumer Blackwell (sm_121)

`block_scaled_matmul` previously hard-asserted SM100 (B200) only, so NVFP4
was unreachable from the inference dispatch on consumer Blackwell. Guard the
SM100 warp-specialized body in `comptime if ctx.default_device_info.compute
== B200.compute` and add an `elif ctx.default_device_info.compute >= 12.0`
(sm_120 / sm_121) branch that routes to the arch-agnostic CUDA-core
`naive_block_scaled_matmul`. (The device check uses ctx.default_device_info,
not a global arch predicate, because the dispatch's `target` defaults to
"cpu".)

To accept the dispatch's immutable inputs, the naive kernel's read-only
operands (a, b, a_scales, b_scales) and the `get_scale_factor` reader are
changed from MutAnyOrigin to ImmutAnyOrigin; only the output `c` stays
mutable. Mutable callers are unaffected.

Most of the fp4_quantization.mojo diff is whitespace: the existing SM100 body
is unchanged but re-indented under the new `comptime if`.

Builds on the sm_121a target (modular#6593) and NVFP4 enablement (modular#6594). Verified
on a DGX Spark (GB10): a known-answer NVFP4 GEMM driven through
`block_scaled_matmul` routes to the naive path and is numerically correct
(all 65,536 output elements exact). Adds test_block_scaled_fp4_dispatch_sm121.mojo.
END_PUBLIC

Co-Authored-By: Claude Opus 4.7 <[email protected]>
Signed-off-by: Adam Kruger <[email protected]>
@lightofbaldr lightofbaldr force-pushed the fix/nvfp4-dispatch-sm121 branch from 8aa66a5 to 3fc20d4 Compare June 4, 2026 14:03
@CINOAdam

CINOAdam commented Jun 4, 2026

Copy link
Copy Markdown

Rebased on top of current modular/main (HEAD 3fc20d40), now stacked cleanly on #6594. The conflict against main was the in-flight _width: Int → _width: SIMDSize change in compute_lambda_wrapper; carried that over into the rebased commit so the dispatch wrapper builds against current main.

Diff summary post-rebase:

  • max/kernels/src/linalg/fp4_quantization.mojo (+189/-165): remove the hard assert ctx.default_device_info.compute == B200.compute, wrap the SM100 warp-specialized body in comptime if ctx.default_device_info.compute == B200.compute, add an elif compute >= 12.0 branch that routes to naive_block_scaled_matmul. Most of the diff is the indentation shift on the existing SM100 body.
  • max/kernels/src/linalg/fp4_utils.mojo (+1/-1): tiny supporting tweak.
  • max/kernels/test/gpu/linalg/test_block_scaled_fp4_dispatch_sm121.mojo (new, 127 lines): end-to-end dispatch test exercising the new sm_121 path.

Verified on a GB10 (sm_121a): the known-answer NVFP4 GEMM driven through block_scaled_matmul routes to the naive path and produces correct results (all 65,536 output elements exact). No changes to SM100 codegen.

Ready for review.

@github-actions

github-actions Bot commented Jun 9, 2026

Copy link
Copy Markdown

All contributors have signed the CLA ✍️ ✅
Posted by the CLA Assistant Lite bot.

Five mblack wrap fixes the lint check called out:
  - var k = ... if ... else ... ternary line-wrap
  - logger.info long string concat split
  - comptime assert (a or b) multi-line
  - Optional[elementwise_epilogue_type](...) generic arg wrap
  - drop a stray blank line inside a kwarg list
  - scaling_kind = -> scaling_kind= (kwarg spacing)
@lightofbaldr lightofbaldr force-pushed the fix/nvfp4-dispatch-sm121 branch from e359a9d to 7280cee Compare June 9, 2026 14:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mojo-stdlib Tag for issues related to standard library waiting-on-review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants