Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[Kernels] Selective Scan Ops Refactor and d_state support for 32, 64, 128, and 256#6622

Open
ulmentflam wants to merge 12 commits into
modular:mainfrom
ulmentflam:nightly/selective-scan-update-dstate-128
Open

[Kernels] Selective Scan Ops Refactor and d_state support for 32, 64, 128, and 256#6622
ulmentflam wants to merge 12 commits into
modular:mainfrom
ulmentflam:nightly/selective-scan-update-dstate-128

Conversation

@ulmentflam

@ulmentflam ulmentflam commented May 29, 2026

Copy link
Copy Markdown
Contributor

Linked issue

Part of #5772

Type of change

  • Bug fix (non-breaking change that fixes an issue)
  • Performance improvement (includes benchmark results below)
  • Documentation update
  • New feature or public API (requires prior proposal or issue approval)
  • Refactor / internal cleanup (no user-visible change)
  • Build, CI, or tooling change

Motivation

The selective_scan_op.mojo was messy. It has to dispatch a kernel per d_state. Mamba2 and Mamba3 will require the selective scan kernel to support state sizes of 32, 64, and 128. I updated the kernels to support this dimension size. I also refactored the selective scan op to better leverage structs and hide some of the branchy execution logic.

What changed

The selective_scan_ops.mojo, selective_scan.mojo, varlen_selective_scan.mojo, and varlen_selective_scan_ops.mojo files all changed, with the addition of multiple new structs, helper functions, and @fieldwise_init decorators. I also updated the selective_scan kernels with a MAX_D_STATE to support the new stride deployments and refactored the D_STATE.

Testing

I ran the full test suite, and verified the output of the mamba-130m model on each new d_state. This was tested on an Nvidia DGX spark.

Checklist

  • The linked issue above has been reviewed by a maintainer and is
    agreed-upon, or this is a trivial fix that does not need prior
    approval
  • PR is small and focused — I've split larger changes into a sequence of
    smaller PRs where possible (see
    pull request sizes)
  • I ran ./bazelw run format to format my changes
  • I added or updated tests to cover my changes
  • If AI tools assisted with this contribution, I have included an
    Assisted-by: trailer in my commit message or this PR description (see
    AI Tool Use Policy)

Follow-up (stacked): the decode-kernel performance work — coalesce/vectorize/parallelize selective_scan_update — is in #6633, which stacks on this PR and should merge after it.

@ulmentflam ulmentflam requested a review from a team as a code owner May 29, 2026 03:51
…upport

BEGIN_PUBLIC
[Kernels] Refactor selective_scan ops and add d_state 32/64/128/256 support

Refactors the state-space selective_scan op wrappers and generalizes the
state dimension from the Mamba1-only d_state=16 to arbitrary d_state, adding
support for d_state in {32, 64, 128, 256} so Mamba2/Mamba3 and hybrid models
can share the same kernels.

* Renames the state dimension to `d_state` throughout the ops and kernels
  (mathematically correct naming; behavior-preserving in the kernel body).
* Reworks selective_scan_ops and varlen_selective_scan_ops to thread the
  generalized strides/layout structs and dispatch on d_state.
* Keeps the existing decode/forward kernels' numerics unchanged.

This is the refactor base; the decode-kernel performance work
(coalesce/vectorize/parallelize selective_scan_update) stacks on top in a
follow-up PR.
END_PUBLIC

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Signed-off-by: Evan <[email protected]>
ulmentflam and others added 6 commits May 31, 2026 16:26
…ed list

BEGIN_PUBLIC
[Kernels][NFC] Drive selective_scan d_state dispatch from one supported list

The selective_scan op wrappers validated and dispatched d_state with hand-rolled
8-branch if/elif chains repeated across every Args struct (three in
selective_scan_ops, two in varlen_selective_scan_ops) plus _validate_d_state --
each chain enumerating the supported d_state values and re-emitting the same
"Unsupported d_state" error.

Introduce a single `_SUPPORTED_D_STATE` list per file as the source of truth and
iterate it with `comptime for`:

* `_validate_d_state` loops the list instead of a `!=` chain.
* every `dispatch_for_d_state` loops the list, calling `run_cpu[ds]` /
  `launch_gpu[ds]` for the matching value.
* the duplicated error is factored into `_unsupported_d_state_error`.

The varlen list keeps its extra d_state=4 entry, so supported sets are
unchanged. The `comptime for` unrolls to exactly the prior per-value
instantiations, so dispatch behavior is identical. Verified: CPU 10/10 + 9/9
and GPU 12/12 + 5/5 selective_scan / varlen goldens pass.
END_PUBLIC

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Signed-off-by: Evan <[email protected]>
BEGIN_PUBLIC
[Kernels][NFC] Derive selective_scan op layouts from to_tile_tensor

Every selective_scan / varlen op execute() recomputed a per-tensor
`comptime <name>_LT = TileLayout[shape_types=..._shape_types,
stride_types=..._stride_types]` for all 11-19 operands, then passed those aliases
to the Args struct. That `TileLayout[...]` is exactly the `LayoutType` that
`to_tile_tensor()` already declares for the matching `<name>_tt`, so the aliases
just restated it.

Drop the alias blocks (53 across the two files) and pass `<name>_tt.LayoutType`
directly at the Args instantiation, matching the idiom already used in
causal_conv1d_ops. Removes the now-unused `TileLayout` import.

Pure refactor; the layout types are identical. Verified: CPU 10/10 + 9/9 and
GPU 12/12 + 5/5 selective_scan / varlen goldens pass.
END_PUBLIC

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Signed-off-by: Evan <[email protected]>
…n-update-dstate-128

# Conflicts:
#	max/kernels/src/state_space/selective_scan.mojo
#	max/kernels/src/state_space/varlen_selective_scan.mojo
#	max/kernels/src/state_space/varlen_selective_scan_ops.mojo
@ulmentflam

ulmentflam commented Jun 16, 2026

Copy link
Copy Markdown
Contributor Author

@BradLarson Let me know when I can get a review on this. I'm eager to get the Mamba2 kernels integrated, and this is the primary blocking PR for the Mamba2 kernels.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant