[Kernels] Causal conv1d cleanup#6628
Open
ulmentflam wants to merge 11 commits into
Open
Conversation
BEGIN_PUBLIC
[Kernels][GPU] Collapse causal_conv1d, add channel-last support + tests
Collapse the 20 near-duplicate causal_conv1d kernels into 7 parameterized
functions. Bias and packed-sequence (seq_idx) presence are now compile-time
Bool parameters on the CPU and runtime Int8 arguments on the GPU (mirroring the
varlen_causal_conv1d idiom), replacing the hand-copied {bias|no_bias} x
{seq_idx|none} variants. A single stride-driven CPU core serves both
channel-first and channel-last layouts; the two GPU kernels share a
width-generic scalar accumulation over the conv taps. Read-only inputs (x,
weight, bias, seq_idx) are immutable borrows.
Wire channel-last as a first-class registered op (causal_conv1d_channel_last)
and add CPU + GPU channel-last tests, including a seq_idx packed-sequence
masking case. Migrate the existing channel-first forward and update tests
(CPU + GPU) to the unified signatures. This shrinks causal_conv1d.mojo from
3769 to 918 lines with no change to the causal_conv1d / causal_conv1d_update
graph op interfaces.
END_PUBLIC
Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Signed-off-by: Evan <[email protected]>
Contributor
Author
|
@gabrieldemarmiesse This should be good to go and what you should continue your optimizations on. |
…-conv1d-cleanup # Conflicts: # max/kernels/src/state_space/causal_conv1d.mojo
BEGIN_PUBLIC [Kernels] Apply mojo format to causal_conv1d.mojo Reformat the output TileTensor parameter declarations to wrap across multiple lines, matching the Mojo formatter output. Fixes the CI lint (mblack) check that failed after the upstream merge. END_PUBLIC Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Contributor
Author
|
@BradLarson Let me know when I can get a review on this. This is a refactor worth getting in sooner. It cuts about 3k lines of code from the previous permuted version and updates us to new and better syntax. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Linked issue
Part of #5772
Helps PR #6625
Type of change
Motivation
Causal Conv had about 3k LOC that needed to be cleaned and deduplicated.
What changed
Causal Conv and the ops that call it.
Testing
Verified mamba1 has no regressions, and added proper tests for channel last.
Checklist
agreed-upon, or this is a trivial fix that does not need prior
approval
smaller PRs where possible (see
pull request sizes)
./bazelw run formatto format my changesAssisted-by:trailer in my commit message or this PR description (seeAI Tool Use Policy)