[Kernels] Selective Scan Ops Refactor and d_state support for 32, 64, 128, and 256#6622
Open
ulmentflam wants to merge 12 commits into
Open
[Kernels] Selective Scan Ops Refactor and d_state support for 32, 64, 128, and 256#6622ulmentflam wants to merge 12 commits into
ulmentflam wants to merge 12 commits into
Conversation
…upport
BEGIN_PUBLIC
[Kernels] Refactor selective_scan ops and add d_state 32/64/128/256 support
Refactors the state-space selective_scan op wrappers and generalizes the
state dimension from the Mamba1-only d_state=16 to arbitrary d_state, adding
support for d_state in {32, 64, 128, 256} so Mamba2/Mamba3 and hybrid models
can share the same kernels.
* Renames the state dimension to `d_state` throughout the ops and kernels
(mathematically correct naming; behavior-preserving in the kernel body).
* Reworks selective_scan_ops and varlen_selective_scan_ops to thread the
generalized strides/layout structs and dispatch on d_state.
* Keeps the existing decode/forward kernels' numerics unchanged.
This is the refactor base; the decode-kernel performance work
(coalesce/vectorize/parallelize selective_scan_update) stacks on top in a
follow-up PR.
END_PUBLIC
Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Signed-off-by: Evan <[email protected]>
73f461a to
bc9c04b
Compare
10 tasks
…ed list BEGIN_PUBLIC [Kernels][NFC] Drive selective_scan d_state dispatch from one supported list The selective_scan op wrappers validated and dispatched d_state with hand-rolled 8-branch if/elif chains repeated across every Args struct (three in selective_scan_ops, two in varlen_selective_scan_ops) plus _validate_d_state -- each chain enumerating the supported d_state values and re-emitting the same "Unsupported d_state" error. Introduce a single `_SUPPORTED_D_STATE` list per file as the source of truth and iterate it with `comptime for`: * `_validate_d_state` loops the list instead of a `!=` chain. * every `dispatch_for_d_state` loops the list, calling `run_cpu[ds]` / `launch_gpu[ds]` for the matching value. * the duplicated error is factored into `_unsupported_d_state_error`. The varlen list keeps its extra d_state=4 entry, so supported sets are unchanged. The `comptime for` unrolls to exactly the prior per-value instantiations, so dispatch behavior is identical. Verified: CPU 10/10 + 9/9 and GPU 12/12 + 5/5 selective_scan / varlen goldens pass. END_PUBLIC Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]> Signed-off-by: Evan <[email protected]>
BEGIN_PUBLIC [Kernels][NFC] Derive selective_scan op layouts from to_tile_tensor Every selective_scan / varlen op execute() recomputed a per-tensor `comptime <name>_LT = TileLayout[shape_types=..._shape_types, stride_types=..._stride_types]` for all 11-19 operands, then passed those aliases to the Args struct. That `TileLayout[...]` is exactly the `LayoutType` that `to_tile_tensor()` already declares for the matching `<name>_tt`, so the aliases just restated it. Drop the alias blocks (53 across the two files) and pass `<name>_tt.LayoutType` directly at the Args instantiation, matching the idiom already used in causal_conv1d_ops. Removes the now-unused `TileLayout` import. Pure refactor; the layout types are identical. Verified: CPU 10/10 + 9/9 and GPU 12/12 + 5/5 selective_scan / varlen goldens pass. END_PUBLIC Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]> Signed-off-by: Evan <[email protected]>
…n-update-dstate-128 # Conflicts: # max/kernels/src/state_space/selective_scan.mojo # max/kernels/src/state_space/varlen_selective_scan.mojo # max/kernels/src/state_space/varlen_selective_scan_ops.mojo
Contributor
Author
|
@BradLarson Let me know when I can get a review on this. I'm eager to get the Mamba2 kernels integrated, and this is the primary blocking PR for the Mamba2 kernels. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Linked issue
Part of #5772
Type of change
Motivation
The
selective_scan_op.mojowas messy. It has to dispatch a kernel perd_state. Mamba2 and Mamba3 will require the selective scan kernel to support state sizes of 32, 64, and 128. I updated the kernels to support this dimension size. I also refactored the selective scan op to better leverage structs and hide some of the branchy execution logic.What changed
The
selective_scan_ops.mojo,selective_scan.mojo,varlen_selective_scan.mojo, andvarlen_selective_scan_ops.mojofiles all changed, with the addition of multiple new structs, helper functions, and@fieldwise_initdecorators. I also updated the selective_scan kernels with aMAX_D_STATEto support the new stride deployments and refactored the D_STATE.Testing
I ran the full test suite, and verified the output of the mamba-130m model on each new d_state. This was tested on an Nvidia DGX spark.
Checklist
agreed-upon, or this is a trivial fix that does not need prior
approval
smaller PRs where possible (see
pull request sizes)
./bazelw run formatto format my changesAssisted-by:trailer in my commit message or this PR description (seeAI Tool Use Policy)
Follow-up (stacked): the decode-kernel performance work — coalesce/vectorize/parallelize
selective_scan_update— is in #6633, which stacks on this PR and should merge after it.