[Kernels] Coalesce, vectorize and parallelize selective_scan_update decode#6633
Open
ulmentflam wants to merge 14 commits into
Open
[Kernels] Coalesce, vectorize and parallelize selective_scan_update decode#6633ulmentflam wants to merge 14 commits into
ulmentflam wants to merge 14 commits into
Conversation
…upport
BEGIN_PUBLIC
[Kernels] Refactor selective_scan ops and add d_state 32/64/128/256 support
Refactors the state-space selective_scan op wrappers and generalizes the
state dimension from the Mamba1-only d_state=16 to arbitrary d_state, adding
support for d_state in {32, 64, 128, 256} so Mamba2/Mamba3 and hybrid models
can share the same kernels.
* Renames the state dimension to `d_state` throughout the ops and kernels
(mathematically correct naming; behavior-preserving in the kernel body).
* Reworks selective_scan_ops and varlen_selective_scan_ops to thread the
generalized strides/layout structs and dispatch on d_state.
* Keeps the existing decode/forward kernels' numerics unchanged.
This is the refactor base; the decode-kernel performance work
(coalesce/vectorize/parallelize selective_scan_update) stacks on top in a
follow-up PR.
END_PUBLIC
Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Signed-off-by: Evan <[email protected]>
11 tasks
…ed list BEGIN_PUBLIC [Kernels][NFC] Drive selective_scan d_state dispatch from one supported list The selective_scan op wrappers validated and dispatched d_state with hand-rolled 8-branch if/elif chains repeated across every Args struct (three in selective_scan_ops, two in varlen_selective_scan_ops) plus _validate_d_state -- each chain enumerating the supported d_state values and re-emitting the same "Unsupported d_state" error. Introduce a single `_SUPPORTED_D_STATE` list per file as the source of truth and iterate it with `comptime for`: * `_validate_d_state` loops the list instead of a `!=` chain. * every `dispatch_for_d_state` loops the list, calling `run_cpu[ds]` / `launch_gpu[ds]` for the matching value. * the duplicated error is factored into `_unsupported_d_state_error`. The varlen list keeps its extra d_state=4 entry, so supported sets are unchanged. The `comptime for` unrolls to exactly the prior per-value instantiations, so dispatch behavior is identical. Verified: CPU 10/10 + 9/9 and GPU 12/12 + 5/5 selective_scan / varlen goldens pass. END_PUBLIC Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]> Signed-off-by: Evan <[email protected]>
28965ff to
95dd03f
Compare
BEGIN_PUBLIC [Kernels][NFC] Derive selective_scan op layouts from to_tile_tensor Every selective_scan / varlen op execute() recomputed a per-tensor `comptime <name>_LT = TileLayout[shape_types=..._shape_types, stride_types=..._stride_types]` for all 11-19 operands, then passed those aliases to the Args struct. That `TileLayout[...]` is exactly the `LayoutType` that `to_tile_tensor()` already declares for the matching `<name>_tt`, so the aliases just restated it. Drop the alias blocks (53 across the two files) and pass `<name>_tt.LayoutType` directly at the Args instantiation, matching the idiom already used in causal_conv1d_ops. Removes the now-unused `TileLayout` import. Pure refactor; the layout types are identical. Verified: CPU 10/10 + 9/9 and GPU 12/12 + 5/5 selective_scan / varlen goldens pass. END_PUBLIC Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]> Signed-off-by: Evan <[email protected]>
…ate decode
BEGIN_PUBLIC
[Kernels][GPU] Coalesce, vectorize and parallelize selective_scan_update decode
Rewrites the Mamba/Mamba2 per-token decode kernel selective_scan_update_gpu
from a register-spilling, uncoalesced scalar kernel into two coalesced,
vectorized, state-parallel layouts selected at compile time by d_state, launched
over a 2D (dim-tile, batch) grid so (b, d) come from block indices with no
per-thread division:
* d_state > WARP_SIZE (Mamba2/3, d_state in {64,128,256}): warp-per-row -- one
warp owns a (batch, dim) row, each lane owns d_state/WARP_SIZE CONTIGUOUS
state elements loaded as one aligned 128-bit float4, sums them in-register,
and the output sum over the state dim finishes with a single warp.sum (no
shared memory, no barrier). A block packs 4 rows, matching the
state-spaces/mamba Triton BLOCK_SIZE_M tiling.
* d_state <= WARP_SIZE (Mamba1/2, d_state in {2..32}): warp-cooperative -- each
lane owns min(4, d_state) contiguous elements (an aligned float4),
d_state/that lanes cooperate on a row and finish with a narrow
warp.lane_group_sum (no shared memory, no barriers).
Key efficiency points, each verified on GB10 with ncu:
* Consecutive lanes on consecutive state elements -> coalesced loads.
* 16-byte-aligned width-4 loads emit one LD.E.128 instead of 4 scalar loads,
cutting L1 load sectors ~3.6x (now identical to Triton's).
* 2D grid removes the per-thread divmod by the runtime dim, dropping executed
instructions below Triton's.
* All four global loads (A, B, state, C) issue before any is consumed --
the kernel is long-scoreboard (global-load-latency) bound at ~99% occupancy
(ncu), so 4 loads in flight share one latency stall instead of serializing.
* O(1) registers/thread (28-35, zero spill) vs the old kernel's four
SIMD[MAX_D_STATE=256] vectors (255-register cap -> ~3.6 MB spill, ~8%
occupancy, ~78% memory-bound on scattered transactions, batch*dim/128 blocks).
Measured on GB10 (FP32, dim=1536, nsys eager kernel-trace median, MAX and the
Triton selective_state_update collected back-to-back / clock-matched): the
d_state=128 path reaches parity with Triton at the Mamba2 batched decode
profile (B=128: ~833 us on nsys kernel-trace both sides; ~860-896 us vs ~836 us
on amortized wall-clock -- parity within the measurement spread, from 2.9x
behind on the original kernel), and every profile is many-x faster than the
prior MAX kernel.
Triton retains a ~1.1-1.5x edge at small d_state / batch=1, where the kernel is
memory-latency-bound at ~99% occupancy; MAX matches or beats Triton there on
every measurable resource (instructions, L1/L2 sectors, occupancy, registers),
so the residual is instruction-scheduling codegen, not algorithm or traffic.
Adds real-dim Mamba1 (d_state=16) and Mamba2 (d_state=128) decode regression
tests and a microbenchmark.
Stacks on the selective_scan ops refactor / d_state generalization.
END_PUBLIC
Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Signed-off-by: Evan <[email protected]>
BEGIN_PUBLIC [Kernels][GPU][NFC] Factor shared selective_scan_update decode helpers The two compile-time decode layouts in selective_scan_update_gpu (warp-per-row for d_state > WARP_SIZE, warp-cooperative for d_state <= WARP_SIZE) carried near-identical per-row scalar prologue, vectorized recurrence, and output epilogue, duplicated across the `comptime if`/`else` arms. Factor the shared work into three `@always_inline` helpers -- `_decode_row_prologue` (dt + bias + softplus, x load), `_decode_row_recur` (the hoisted A/B/state/C loads, state recurrence, state store, and in-register partial sum), and `_decode_row_finalize` (D skip term, silu(z) gating, output store). Each layout arm now contains only its distinctive logic: index setup, vector width, and the cross-lane reduction (`warp.sum` vs `lane_group_sum`). Pure refactor. The helpers are `@always_inline`, so the generated code -- and the load-hoisting order the kernel depends on -- is unchanged. Verified on GB10 (FP32, dim=1536): 14/14 GPU goldens pass, and the bench is within run-to-run noise at every profile (d=128 B=128 ~823 us, parity with Triton preserved; d=16/128 B=1 and d=16 B=128 all unchanged). END_PUBLIC Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]> Signed-off-by: Evan <[email protected]>
… tidy prologue BEGIN_PUBLIC [Kernels][GPU][NFC] Share selective_scan_update decode launch config; tidy prologue Readability cleanup of the selective_scan_update decode path, no behavior change: * Extract the decode launch grid/block math into two helpers next to the kernel constants -- `selective_scan_update_decode_grid_dim_x(dim, d_state)` and `selective_scan_update_decode_block_dim(d_state)`. The op dispatch, the microbenchmark, and the GPU test now all derive grid/block from these instead of each re-deriving the warp-per-row vs warp-cooperative tiling inline. This also lets selective_scan_ops.mojo drop its duplicated `_WARP_COOP_BLOCK` / `_DECODE_WARPS_PER_BLOCK` / `_DECODE_ROWS_PER_THREAD` "must match" copies and the now-unused WARP_SIZE import. * `_decode_row_prologue` now returns `(dt_val, x_val)`; the trivial `dt_x = dt_val * x_val` moves into `_decode_row_recur` where it is consumed. Call sites use named tuple unpacking instead of positional `pro[0/1/2]`. * Drop a duplicate `from std.math import ceildiv` in the test and the unused `ceildiv` import in the bench. The launch helpers compute exactly the values the inline code did, and the prologue/recur change is `@always_inline` (relocated multiply, same inlined code). Verified on GB10 (FP32, dim=1536): 14/14 GPU goldens pass and the bench is within run-to-run noise at every profile (d=128 B=128 ~825 us parity with Triton preserved; d=16 B=128 104-124 us noise band unchanged). END_PUBLIC Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]> Signed-off-by: Evan <[email protected]>
95dd03f to
f6af50e
Compare
…n-decode-opt # Conflicts: # max/kernels/src/state_space/selective_scan.mojo # max/kernels/src/state_space/varlen_selective_scan.mojo # max/kernels/src/state_space/varlen_selective_scan_ops.mojo
Contributor
Author
|
@BradLarson This would be nice to have in before I PR the Mamba2 kernels. As I benchmarked above, it gives Mamba2 a 3x performance gain over the standard kernel from the initial implementation and achieves approximate parity with Triton. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Linked issue
Part of #5772
Type of change
Stacked on #6622
This stacks on top of #6622 (Selective Scan Ops Refactor and d_state support for 32/64/128/256) and should merge after it. Because an upstream base must live in
modular/modular, this PR is opened againstmain, so the diff currently also shows the refactor commit from #6622. The net-new content here is the decode-kernel optimization; once #6622 merges and this is rebased, the diff reduces to the optimization alone.Motivation
The Mamba/Mamba2 per-token decode kernel
selective_scan_update_gpuwas a register-spilling, uncoalesced scalar kernel: one thread owned a whole(batch, dim)row, held fourSIMD[MAX_D_STATE=256]vectors live (255-register cap → ~3.6 MB local spill, ~8% occupancy even atd_state=16), read state with a fulld_statestride between adjacent threads (~78% memory-bound on scattered transactions), and launched onlybatch*dim/128blocks.What changed
Rewrites
selective_scan_update_gpuinto two coalesced, vectorized, state-parallel layouts selected at compile time byd_state, launched over a 2D(dim-tile, batch)grid so(b, d)come from block indices with no per-thread division:d_state > WARP_SIZE(Mamba2/3,d_state ∈ {64,128,256}) — warp-per-row: one warp owns a(batch, dim)row, each lane ownsd_state/WARP_SIZEcontiguous state elements loaded as one aligned 128-bitfloat4, summed in-register, finished with a singlewarp.sum(no shared memory, no barrier). A block packs 4 rows, matching thestate-spaces/mambaTritonBLOCK_SIZE_Mtiling.d_state <= WARP_SIZE(Mamba1/2,d_state ∈ {2..32}) — warp-cooperative: each lane ownsmin(4, d_state)contiguous elements (an alignedfloat4),d_state/thatlanes cooperate on a row and finish with a narrowwarp.lane_group_sum(no shared memory, no barriers).Key wins (each verified on GB10 with ncu): coalesced + 16-byte-aligned
float4loads emit oneLD.E.128instead of 4 scalar loads (L1 load sectors ~3.6x → identical to Triton's); the 2D grid removes the per-thread divmod; all four global loads (A, B, state, C) issue before any is consumed (the kernel is long-scoreboard / global-load-latency bound at ~99% occupancy); registers255 → 28, zero spill.Files:
max/kernels/src/state_space/selective_scan.mojo— kernel rewrite; adds_WARP_COOP_BLOCK,_DECODE_WARPS_PER_BLOCK,_DECODE_ELEMS_PER_THREAD,_DECODE_ROWS_PER_THREAD.max/kernels/src/state_space/selective_scan_ops.mojo— launch grid/block dispatched byd_state.max/kernels/test/gpu/state_space/test_selective_scan.mojo— launch mirrors dispatch; adds real-dim..._mamba1_130m(d=16, warp-coop) and..._mamba2_130m(d=128, warp-per-row) regression guards.max/kernels/benchmarks/gpu/state_space/bench_selective_scan_update.mojo(new) +_EXTRA_DEPSwiring inmax/kernels/benchmarks/BUILD.bazel.Testing
14/14 GPU goldens pass at rtol 1e-2 (including both new real-dim regression tests covering both dispatch paths). Verified on an NVIDIA DGX Spark (GB10, sm_121).
Performance (GB10, FP32, dim=1536, n_groups=1)
Over the prior MAX kernel (amortized wall-clock, same harness):
(
aftercolumn re-measured on a DGX Spark / GB10 for this PR's kernel, amortized wall-clock frombench_selective_scan_update;d=16 B=128is run-to-run noisy at 86–123 µs.beforeis the prior shipped kernel on the same harness.)vs Triton
selective_state_update(nsys eager kernel-trace median GPU time, MAX and Triton collected back-to-back / clock-matched):The
d=128 B=128profile moves ~200 MB of state at ~240 GB/s = the measured GB10 DRAM-bandwidth ceiling on both sides, so parity there is the physical optimum. The residual gap atd=16/batch=1is pure instruction-scheduling codegen (MAX matches or beats Triton on grid, occupancy, registers, L1/L2 sectors, and executed instructions), left as follow-up.Checklist
./bazelw run //:formatto format my changesAssisted-by:/Co-Authored-By:trailer in the commit)Assisted-by: Claude Opus 4.8 (1M context)