Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[Kernels] Coalesce, vectorize and parallelize selective_scan_update decode#6633

Open
ulmentflam wants to merge 14 commits into
modular:mainfrom
ulmentflam:nightly/selective-scan-decode-opt
Open

[Kernels] Coalesce, vectorize and parallelize selective_scan_update decode#6633
ulmentflam wants to merge 14 commits into
modular:mainfrom
ulmentflam:nightly/selective-scan-decode-opt

Conversation

@ulmentflam

@ulmentflam ulmentflam commented May 31, 2026

Copy link
Copy Markdown
Contributor

Linked issue

Part of #5772

Type of change

  • Bug fix (non-breaking change that fixes an issue)
  • Performance improvement (includes benchmark results below)
  • Documentation update
  • New feature or public API (requires prior proposal or issue approval)
  • Refactor / internal cleanup (no user-visible change)
  • Build, CI, or tooling change

Stacked on #6622

This stacks on top of #6622 (Selective Scan Ops Refactor and d_state support for 32/64/128/256) and should merge after it. Because an upstream base must live in modular/modular, this PR is opened against main, so the diff currently also shows the refactor commit from #6622. The net-new content here is the decode-kernel optimization; once #6622 merges and this is rebased, the diff reduces to the optimization alone.

Motivation

The Mamba/Mamba2 per-token decode kernel selective_scan_update_gpu was a register-spilling, uncoalesced scalar kernel: one thread owned a whole (batch, dim) row, held four SIMD[MAX_D_STATE=256] vectors live (255-register cap → ~3.6 MB local spill, ~8% occupancy even at d_state=16), read state with a full d_state stride between adjacent threads (~78% memory-bound on scattered transactions), and launched only batch*dim/128 blocks.

What changed

Rewrites selective_scan_update_gpu into two coalesced, vectorized, state-parallel layouts selected at compile time by d_state, launched over a 2D (dim-tile, batch) grid so (b, d) come from block indices with no per-thread division:

  • d_state > WARP_SIZE (Mamba2/3, d_state ∈ {64,128,256}) — warp-per-row: one warp owns a (batch, dim) row, each lane owns d_state/WARP_SIZE contiguous state elements loaded as one aligned 128-bit float4, summed in-register, finished with a single warp.sum (no shared memory, no barrier). A block packs 4 rows, matching the state-spaces/mamba Triton BLOCK_SIZE_M tiling.
  • d_state <= WARP_SIZE (Mamba1/2, d_state ∈ {2..32}) — warp-cooperative: each lane owns min(4, d_state) contiguous elements (an aligned float4), d_state/that lanes cooperate on a row and finish with a narrow warp.lane_group_sum (no shared memory, no barriers).

Key wins (each verified on GB10 with ncu): coalesced + 16-byte-aligned float4 loads emit one LD.E.128 instead of 4 scalar loads (L1 load sectors ~3.6x → identical to Triton's); the 2D grid removes the per-thread divmod; all four global loads (A, B, state, C) issue before any is consumed (the kernel is long-scoreboard / global-load-latency bound at ~99% occupancy); registers 255 → 28, zero spill.

Files:

  • max/kernels/src/state_space/selective_scan.mojo — kernel rewrite; adds _WARP_COOP_BLOCK, _DECODE_WARPS_PER_BLOCK, _DECODE_ELEMS_PER_THREAD, _DECODE_ROWS_PER_THREAD.
  • max/kernels/src/state_space/selective_scan_ops.mojo — launch grid/block dispatched by d_state.
  • max/kernels/test/gpu/state_space/test_selective_scan.mojo — launch mirrors dispatch; adds real-dim ..._mamba1_130m (d=16, warp-coop) and ..._mamba2_130m (d=128, warp-per-row) regression guards.
  • max/kernels/benchmarks/gpu/state_space/bench_selective_scan_update.mojo (new) + _EXTRA_DEPS wiring in max/kernels/benchmarks/BUILD.bazel.

Testing

14/14 GPU goldens pass at rtol 1e-2 (including both new real-dim regression tests covering both dispatch paths). Verified on an NVIDIA DGX Spark (GB10, sm_121).

Performance (GB10, FP32, dim=1536, n_groups=1)

Over the prior MAX kernel (amortized wall-clock, same harness):

profile before after speedup
Mamba1 d=16 B=1 5.02 µs 1.91 µs 2.6x
Mamba2 d=128 B=1 47.6 µs 2.58 µs 18x
Mamba1 d=16 B=128 201 µs ~123 µs 1.6x
Mamba2 d=128 B=128 2452 µs ~825 µs 3.0x

(after column re-measured on a DGX Spark / GB10 for this PR's kernel, amortized wall-clock from bench_selective_scan_update; d=16 B=128 is run-to-run noisy at 86–123 µs. before is the prior shipped kernel on the same harness.)

vs Triton selective_state_update (nsys eager kernel-trace median GPU time, MAX and Triton collected back-to-back / clock-matched):

profile MAX Triton ratio
d=16 B=1 1.44 µs 1.06 µs 1.36x behind
d=128 B=1 2.50 µs 2.14 µs 1.16x behind
d=16 B=128 ~120 µs 67 µs ~1.7x behind
d=128 B=128 ~830 µs ~833 µs parity

The d=128 B=128 profile moves ~200 MB of state at ~240 GB/s = the measured GB10 DRAM-bandwidth ceiling on both sides, so parity there is the physical optimum. The residual gap at d=16 / batch=1 is pure instruction-scheduling codegen (MAX matches or beats Triton on grid, occupancy, registers, L1/L2 sectors, and executed instructions), left as follow-up.

Checklist

Assisted-by: Claude Opus 4.8 (1M context)

…upport

BEGIN_PUBLIC
[Kernels] Refactor selective_scan ops and add d_state 32/64/128/256 support

Refactors the state-space selective_scan op wrappers and generalizes the
state dimension from the Mamba1-only d_state=16 to arbitrary d_state, adding
support for d_state in {32, 64, 128, 256} so Mamba2/Mamba3 and hybrid models
can share the same kernels.

* Renames the state dimension to `d_state` throughout the ops and kernels
  (mathematically correct naming; behavior-preserving in the kernel body).
* Reworks selective_scan_ops and varlen_selective_scan_ops to thread the
  generalized strides/layout structs and dispatch on d_state.
* Keeps the existing decode/forward kernels' numerics unchanged.

This is the refactor base; the decode-kernel performance work
(coalesce/vectorize/parallelize selective_scan_update) stacks on top in a
follow-up PR.
END_PUBLIC

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Signed-off-by: Evan <[email protected]>
@ulmentflam ulmentflam requested a review from a team as a code owner May 31, 2026 19:19
@ulmentflam ulmentflam closed this May 31, 2026
@ulmentflam ulmentflam changed the title [Kernels][GPU] Coalesce, vectorize and parallelize selective_scan_update decode [Kernels] Coalesce, vectorize and parallelize selective_scan_update decode May 31, 2026
@ulmentflam ulmentflam reopened this May 31, 2026
…ed list

BEGIN_PUBLIC
[Kernels][NFC] Drive selective_scan d_state dispatch from one supported list

The selective_scan op wrappers validated and dispatched d_state with hand-rolled
8-branch if/elif chains repeated across every Args struct (three in
selective_scan_ops, two in varlen_selective_scan_ops) plus _validate_d_state --
each chain enumerating the supported d_state values and re-emitting the same
"Unsupported d_state" error.

Introduce a single `_SUPPORTED_D_STATE` list per file as the source of truth and
iterate it with `comptime for`:

* `_validate_d_state` loops the list instead of a `!=` chain.
* every `dispatch_for_d_state` loops the list, calling `run_cpu[ds]` /
  `launch_gpu[ds]` for the matching value.
* the duplicated error is factored into `_unsupported_d_state_error`.

The varlen list keeps its extra d_state=4 entry, so supported sets are
unchanged. The `comptime for` unrolls to exactly the prior per-value
instantiations, so dispatch behavior is identical. Verified: CPU 10/10 + 9/9
and GPU 12/12 + 5/5 selective_scan / varlen goldens pass.
END_PUBLIC

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Signed-off-by: Evan <[email protected]>
@ulmentflam ulmentflam force-pushed the nightly/selective-scan-decode-opt branch from 28965ff to 95dd03f Compare May 31, 2026 20:27
ulmentflam and others added 4 commits May 31, 2026 16:36
BEGIN_PUBLIC
[Kernels][NFC] Derive selective_scan op layouts from to_tile_tensor

Every selective_scan / varlen op execute() recomputed a per-tensor
`comptime <name>_LT = TileLayout[shape_types=..._shape_types,
stride_types=..._stride_types]` for all 11-19 operands, then passed those aliases
to the Args struct. That `TileLayout[...]` is exactly the `LayoutType` that
`to_tile_tensor()` already declares for the matching `<name>_tt`, so the aliases
just restated it.

Drop the alias blocks (53 across the two files) and pass `<name>_tt.LayoutType`
directly at the Args instantiation, matching the idiom already used in
causal_conv1d_ops. Removes the now-unused `TileLayout` import.

Pure refactor; the layout types are identical. Verified: CPU 10/10 + 9/9 and
GPU 12/12 + 5/5 selective_scan / varlen goldens pass.
END_PUBLIC

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Signed-off-by: Evan <[email protected]>
…ate decode

BEGIN_PUBLIC
[Kernels][GPU] Coalesce, vectorize and parallelize selective_scan_update decode

Rewrites the Mamba/Mamba2 per-token decode kernel selective_scan_update_gpu
from a register-spilling, uncoalesced scalar kernel into two coalesced,
vectorized, state-parallel layouts selected at compile time by d_state, launched
over a 2D (dim-tile, batch) grid so (b, d) come from block indices with no
per-thread division:

* d_state > WARP_SIZE (Mamba2/3, d_state in {64,128,256}): warp-per-row -- one
  warp owns a (batch, dim) row, each lane owns d_state/WARP_SIZE CONTIGUOUS
  state elements loaded as one aligned 128-bit float4, sums them in-register,
  and the output sum over the state dim finishes with a single warp.sum (no
  shared memory, no barrier). A block packs 4 rows, matching the
  state-spaces/mamba Triton BLOCK_SIZE_M tiling.
* d_state <= WARP_SIZE (Mamba1/2, d_state in {2..32}): warp-cooperative -- each
  lane owns min(4, d_state) contiguous elements (an aligned float4),
  d_state/that lanes cooperate on a row and finish with a narrow
  warp.lane_group_sum (no shared memory, no barriers).

Key efficiency points, each verified on GB10 with ncu:
* Consecutive lanes on consecutive state elements -> coalesced loads.
* 16-byte-aligned width-4 loads emit one LD.E.128 instead of 4 scalar loads,
  cutting L1 load sectors ~3.6x (now identical to Triton's).
* 2D grid removes the per-thread divmod by the runtime dim, dropping executed
  instructions below Triton's.
* All four global loads (A, B, state, C) issue before any is consumed --
  the kernel is long-scoreboard (global-load-latency) bound at ~99% occupancy
  (ncu), so 4 loads in flight share one latency stall instead of serializing.
* O(1) registers/thread (28-35, zero spill) vs the old kernel's four
  SIMD[MAX_D_STATE=256] vectors (255-register cap -> ~3.6 MB spill, ~8%
  occupancy, ~78% memory-bound on scattered transactions, batch*dim/128 blocks).

Measured on GB10 (FP32, dim=1536, nsys eager kernel-trace median, MAX and the
Triton selective_state_update collected back-to-back / clock-matched): the
d_state=128 path reaches parity with Triton at the Mamba2 batched decode
profile (B=128: ~833 us on nsys kernel-trace both sides; ~860-896 us vs ~836 us
on amortized wall-clock -- parity within the measurement spread, from 2.9x
behind on the original kernel), and every profile is many-x faster than the
prior MAX kernel.
Triton retains a ~1.1-1.5x edge at small d_state / batch=1, where the kernel is
memory-latency-bound at ~99% occupancy; MAX matches or beats Triton there on
every measurable resource (instructions, L1/L2 sectors, occupancy, registers),
so the residual is instruction-scheduling codegen, not algorithm or traffic.

Adds real-dim Mamba1 (d_state=16) and Mamba2 (d_state=128) decode regression
tests and a microbenchmark.

Stacks on the selective_scan ops refactor / d_state generalization.
END_PUBLIC

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Signed-off-by: Evan <[email protected]>
BEGIN_PUBLIC
[Kernels][GPU][NFC] Factor shared selective_scan_update decode helpers

The two compile-time decode layouts in selective_scan_update_gpu (warp-per-row
for d_state > WARP_SIZE, warp-cooperative for d_state <= WARP_SIZE) carried
near-identical per-row scalar prologue, vectorized recurrence, and output
epilogue, duplicated across the `comptime if`/`else` arms.

Factor the shared work into three `@always_inline` helpers -- `_decode_row_prologue`
(dt + bias + softplus, x load), `_decode_row_recur` (the hoisted A/B/state/C
loads, state recurrence, state store, and in-register partial sum), and
`_decode_row_finalize` (D skip term, silu(z) gating, output store). Each layout
arm now contains only its distinctive logic: index setup, vector width, and the
cross-lane reduction (`warp.sum` vs `lane_group_sum`).

Pure refactor. The helpers are `@always_inline`, so the generated code -- and
the load-hoisting order the kernel depends on -- is unchanged. Verified on GB10
(FP32, dim=1536): 14/14 GPU goldens pass, and the bench is within run-to-run
noise at every profile (d=128 B=128 ~823 us, parity with Triton preserved;
d=16/128 B=1 and d=16 B=128 all unchanged).
END_PUBLIC

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Signed-off-by: Evan <[email protected]>
… tidy prologue

BEGIN_PUBLIC
[Kernels][GPU][NFC] Share selective_scan_update decode launch config; tidy prologue

Readability cleanup of the selective_scan_update decode path, no behavior
change:

* Extract the decode launch grid/block math into two helpers next to the
  kernel constants -- `selective_scan_update_decode_grid_dim_x(dim, d_state)`
  and `selective_scan_update_decode_block_dim(d_state)`. The op dispatch, the
  microbenchmark, and the GPU test now all derive grid/block from these instead
  of each re-deriving the warp-per-row vs warp-cooperative tiling inline. This
  also lets selective_scan_ops.mojo drop its duplicated `_WARP_COOP_BLOCK` /
  `_DECODE_WARPS_PER_BLOCK` / `_DECODE_ROWS_PER_THREAD` "must match" copies and
  the now-unused WARP_SIZE import.
* `_decode_row_prologue` now returns `(dt_val, x_val)`; the trivial
  `dt_x = dt_val * x_val` moves into `_decode_row_recur` where it is consumed.
  Call sites use named tuple unpacking instead of positional `pro[0/1/2]`.
* Drop a duplicate `from std.math import ceildiv` in the test and the unused
  `ceildiv` import in the bench.

The launch helpers compute exactly the values the inline code did, and the
prologue/recur change is `@always_inline` (relocated multiply, same inlined
code). Verified on GB10 (FP32, dim=1536): 14/14 GPU goldens pass and the bench
is within run-to-run noise at every profile (d=128 B=128 ~825 us parity with
Triton preserved; d=16 B=128 104-124 us noise band unchanged).
END_PUBLIC

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Signed-off-by: Evan <[email protected]>
@ulmentflam ulmentflam force-pushed the nightly/selective-scan-decode-opt branch from 95dd03f to f6af50e Compare May 31, 2026 20:37
…n-decode-opt

# Conflicts:
#	max/kernels/src/state_space/selective_scan.mojo
#	max/kernels/src/state_space/varlen_selective_scan.mojo
#	max/kernels/src/state_space/varlen_selective_scan_ops.mojo
@ulmentflam

Copy link
Copy Markdown
Contributor Author

@BradLarson This would be nice to have in before I PR the Mamba2 kernels. As I benchmarked above, it gives Mamba2 a 3x performance gain over the standard kernel from the initial implementation and achieves approximate parity with Triton.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant