[Kernels][GPU] Route block_scaled_matmul (NVFP4) to naive kernel on consumer Blackwell (sm_121)#6596
Open
lightofbaldr wants to merge 6 commits into
Open
[Kernels][GPU] Route block_scaled_matmul (NVFP4) to naive kernel on consumer Blackwell (sm_121)#6596lightofbaldr wants to merge 6 commits into
lightofbaldr wants to merge 6 commits into
Conversation
…sm_121) BEGIN_PUBLIC [Stdlib][GPU] Enable NVFP4 (block-scaled FP4) on consumer Blackwell (sm_121) Builds on the sm_121a target fix (modular#6593). With the arch-specific target in place, two stdlib gaps still blocked the arch-agnostic `naive_block_scaled_matmul` (NVFP4) path on the DGX Spark (GB10, sm_121): 1. sm_121 was absent from the architecture-capability predicates. Add an `_SM_121X_ARCHS` set with `_is_sm_121x` / `_has_sm_121x`, and include it in `_is_sm_120x_or_newer` / `_has_sm_120x_or_newer` so the GB10 is recognized as a Blackwell-class arch (this is what gates the FP4 `cast_f4e2m1x2_to_fp16x2` decode helper, among others). 2. The float8 -> float hardware conversion fast-paths emit a `pop.cast` whose lowering is not implemented for the sm_121 backend, so on sm_121 route float8 conversions to the existing pure bit-manipulation fallback (always correct; the only cost is the hardware-accelerated path). Together with modular#6593 these let a block-scaled FP4 (NVFP4) GEMM compile, launch, and produce numerically correct results on the GB10. Adds a known-answer correctness test for the naive NVFP4 path (no reference matmul needed: all-1.0 FP4 inputs with unit scales => C[i, j] == K), which also exercises the path on existing supported GPUs. Both stdlib changes are scoped to sm_121 and do not affect other targets. END_PUBLIC Co-Authored-By: Claude Opus 4.7 <[email protected]> Signed-off-by: Adam Kruger <[email protected]>
…onsumer Blackwell (sm_121) BEGIN_PUBLIC [Kernels][GPU] Route block_scaled_matmul (NVFP4) to naive kernel on consumer Blackwell (sm_121) `block_scaled_matmul` previously hard-asserted SM100 (B200) only, so NVFP4 was unreachable from the inference dispatch on consumer Blackwell. Guard the SM100 warp-specialized body in `comptime if ctx.default_device_info.compute == B200.compute` and add an `elif ctx.default_device_info.compute >= 12.0` (sm_120 / sm_121) branch that routes to the arch-agnostic CUDA-core `naive_block_scaled_matmul`. (The device check uses ctx.default_device_info, not a global arch predicate, because the dispatch's `target` defaults to "cpu".) To accept the dispatch's immutable inputs, the naive kernel's read-only operands (a, b, a_scales, b_scales) and the `get_scale_factor` reader are changed from MutAnyOrigin to ImmutAnyOrigin; only the output `c` stays mutable. Mutable callers are unaffected. Most of the fp4_quantization.mojo diff is whitespace: the existing SM100 body is unchanged but re-indented under the new `comptime if`. Builds on the sm_121a target (modular#6593) and NVFP4 enablement (modular#6594). Verified on a DGX Spark (GB10): a known-answer NVFP4 GEMM driven through `block_scaled_matmul` routes to the naive path and is numerically correct (all 65,536 output elements exact). Adds test_block_scaled_fp4_dispatch_sm121.mojo. END_PUBLIC Co-Authored-By: Claude Opus 4.7 <[email protected]> Signed-off-by: Adam Kruger <[email protected]>
8aa66a5 to
3fc20d4
Compare
|
Rebased on top of current Diff summary post-rebase:
Verified on a GB10 (sm_121a): the known-answer NVFP4 GEMM driven through Ready for review. |
|
All contributors have signed the CLA ✍️ ✅ |
Five mblack wrap fixes the lint check called out: - var k = ... if ... else ... ternary line-wrap - logger.info long string concat split - comptime assert (a or b) multi-line - Optional[elementwise_epilogue_type](...) generic arg wrap - drop a stray blank line inside a kwarg list - scaling_kind = -> scaling_kind= (kwarg spacing)
e359a9d to
7280cee
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Part of #6570 (sm_120 / sm_121 consumer Blackwell support epic).
Stacked on #6593 (sm_121a target) and #6594 (NVFP4 enablement). Until those merge, the diff here also includes their changes.
What
block_scaled_matmulhard-asserted SM100 (B200) only, so NVFP4 was unreachable from the inference dispatch on consumer Blackwell. This routes it:comptime if ctx.default_device_info.compute == B200.compute.elif ctx.default_device_info.compute >= 12.0(sm_120 / sm_121, compute 12.x) → arch-agnostic CUDA-corenaive_block_scaled_matmul.else→ unsupported-arch assert.The device check uses
ctx.default_device_info(the actual GPU), not a global arch predicate —block_scaled_matmul'stargetparam defaults to"cpu", so global predicates would read CPU.To accept the dispatch's immutable inputs, the naive kernel's read-only operands (
a,b,a_scales,b_scales) and theget_scale_factorreader move fromMutAnyOrigin→ImmutAnyOrigin; only the outputcstays mutable. Mutable callers are unaffected (mutable borrows as immutable).Diff note
Most of the
fp4_quantization.mojodiff is whitespace — the existing SM100 body is unchanged but re-indented one level under the newcomptime if.Test
test_block_scaled_fp4_dispatch_sm121.mojodrivesblock_scaled_matmul(not the naive kernel directly), so on sm_121 it exercises the routing. Known-answer construction (all-1.0 FP4 inputs, unit scales ⇒C[i, j] == K). Verified on a DGX Spark (GB10): routes to naive, all 65,536 output elements exact.