[Kernels] causal conv1d optimization#6632
Draft
ulmentflam wants to merge 23 commits into
Draft
Conversation
BEGIN_PUBLIC
[Kernels][GPU] Collapse causal_conv1d, add channel-last support + tests
Collapse the 20 near-duplicate causal_conv1d kernels into 7 parameterized
functions. Bias and packed-sequence (seq_idx) presence are now compile-time
Bool parameters on the CPU and runtime Int8 arguments on the GPU (mirroring the
varlen_causal_conv1d idiom), replacing the hand-copied {bias|no_bias} x
{seq_idx|none} variants. A single stride-driven CPU core serves both
channel-first and channel-last layouts; the two GPU kernels share a
width-generic scalar accumulation over the conv taps. Read-only inputs (x,
weight, bias, seq_idx) are immutable borrows.
Wire channel-last as a first-class registered op (causal_conv1d_channel_last)
and add CPU + GPU channel-last tests, including a seq_idx packed-sequence
masking case. Migrate the existing channel-first forward and update tests
(CPU + GPU) to the unified signatures. This shrinks causal_conv1d.mojo from
3769 to 918 lines with no change to the causal_conv1d / causal_conv1d_update
graph op interfaces.
END_PUBLIC
Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Signed-off-by: Evan <[email protected]>
…on GB10 BEGIN_PUBLIC [Kernels][GPU] Vectorize causal_conv1d channel-first; beat Dao-AILab on GB10 The channel-first forward GPU kernel issued kNElts*kWidth scalar global loads per thread. The working set (~3 MB at mamba-130m prefill) is L2- resident, so the kernel was issue-bound on load instructions rather than bandwidth-bound. Add a vectorized fast path for the dense case (L-contiguous, no seq_idx, full tile, kNElts-aligned): one width-kNElts vector load for the thread's tile plus kWidth-1 L2-cached halo scalars, with the convolution fully unrolled at compile time so every tap/halo index resolves at comptime, and a single width-kNElts vector store. seq_idx / non-contiguous / ragged-tail inputs fall back to the existing scalar accumulation, so correctness is unchanged. The channel-first op now launches kNThreads=64, kNElts=4 (float4) so 64*4 == 256 keeps every thread busy at the common L=256, instead of the half-idle block that kNThreads=128 (tile 512) leaves. The kernel also gains an optional block_dim.y channel-fold hook (default 1, backward compatible). Measured apples-to-apples on GB10 (nsys CUDA-trace, eager, median kernel GPU duration over 110-220 launches, replicated, clock-drift controlled): the kernel goes from 5.54 us to 3.17 us (43% faster) and beats the Dao-AILab/causal-conv1d reference at 3.30 us. Add GPU tests covering the fast path (widths 1-4, SiLU on/off, multi-block, and mamba prefill dims) plus a config-sweep benchmark. END_PUBLIC Signed-off-by: Evan <[email protected]>
…-last BEGIN_PUBLIC [Kernels][GPU] causal_conv1d: bf16/fp16 fast path + coalesced channel-last Extend the vectorized causal_conv1d work to 16-bit dtypes and rewrite the channel-last GPU kernel so both layouts beat Dao-AILab/causal-conv1d on GB10. - Channel-first fast path now accumulates in float32 regardless of storage dtype (the CPU reference too), so bf16/fp16 stay within parity tolerance. bf16 at mamba prefill (B=1, dim=1536, L=256, W=4): 2.82 us vs Dao 4.06 us (31% faster) — Dao idles 3/4 of its 128-thread block with kNElts=8 at L=256, while MAX stays fully utilized at 64x4. - Channel-last kernel rewritten to map threads to channels instead of positions. Because C is the contiguous axis, neighbouring threads read neighbouring memory at each L position, so the per-position loads coalesce across the warp. Each thread does a sliding scan over kNElts positions, loading each window element once. Result: 3.15 us vs Dao 3.55 us (11% faster), down from 14.5 us for the naive vectorization. - silu is now branchless/width-generic so it applies to SIMD vectors. - Ops launch the validated configs (channel-first 64x4, channel-last 64x8); grids scale with L and dim. Adds bf16/fp16 and coalesced channel-last GPU tests (widths 1-4, SiLU, mamba dims, ragged-channel fallback). END_PUBLIC Signed-off-by: Evan <[email protected]>
…tests BEGIN_PUBLIC [Kernels][GPU] causal_conv1d: update-kernel bench + datacenter-scale tests Add a decode/update GPU benchmark and validate the forward kernels at larger-model and batched shapes. - bench_causal_conv1d_update.mojo: times causal_conv1d_update_gpu at decode shapes (seqlen=1), batch configurable. Used to confirm parity with Dao-AILab/causal_conv1d_update on GB10: at B=128 (datacenter batched decode, bandwidth-bound) MAX is 19.6 us vs Dao 20.5 us; at B=1 the kernels are within ~0.5 us (both launch-bound and negligible against a full decode step). - test_causal_conv1d datacenter-scale case: wide channels (dim up to 5120), long prefill (L up to 2048), and batched (B=8), fp32 and bf16. Confirms correctness and that the launch grids stay within limits (dim, batch <= 65535) at scale. END_PUBLIC Signed-off-by: Evan <[email protected]>
BEGIN_PUBLIC [Kernels][GPU] causal_conv1d: add packed-sequence (seq_idx) GPU tests The vectorized channel-first fast path and the coalesced channel-last path both defer to their scalar fallbacks whenever a seq_idx packed-sequence mask is active, and that fallback was not exercised on GPU. Thread an optional has_seq_idx through both GPU test harnesses (a two-segment mask split at seqlen//2) and add seq_idx cases covering widths 1-4 at a small shape and at mamba dims, validated against the CPU reference computing the same masked convolution. END_PUBLIC Signed-off-by: Evan <[email protected]>
BEGIN_PUBLIC [Kernels][GPU] causal_conv1d: document the GPU optimizations inline Expand the channel-first and channel-last forward-kernel docstrings with the optimization rationale a maintainer needs: the bound each kernel hits (channel-first issue-bound on loads; channel-last coalescing), what the vectorized fast path / thread->channel mapping does and why, the measured GB10 result vs Dao-AILab, the config the op selects, float32 accumulation for bf16/fp16 parity, and a pointer to the full notes (including what regressed). END_PUBLIC Signed-off-by: Evan <[email protected]>
11 tasks
…-conv1d-cleanup # Conflicts: # max/kernels/src/state_space/causal_conv1d.mojo
BEGIN_PUBLIC [Kernels] Apply mojo format to causal_conv1d.mojo Reformat the output TileTensor parameter declarations to wrap across multiple lines, matching the Mojo formatter output. Fixes the CI lint (mblack) check that failed after the upstream merge. END_PUBLIC Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
…-optimization # Conflicts: # max/kernels/benchmarks/BUILD.bazel # max/kernels/src/state_space/causal_conv1d.mojo
…ization Integrate upstream cleanup (UntrackedOrigin NFC) while keeping the optimized GPU kernel implementations from this branch. Co-authored-by: Cursor <[email protected]>
…tflam/modular into nightly/causal-conv1d-optimization
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Linked issue
Part of #5772
Helps PR #6625
Type of change
Motivation
Tri Dao's Causal Conv1D is fast, but we can do better. @gabrieldemarmiesse made an initial pas to bring to parity with causal conv1d. This branch makes an attempt to beat causal conv1d.
What changed
Summary
Optimizes the
causal_conv1dGPU kernels (channel-first + channel-last forward,and the decode/update kernel) so MAX meets or beats Dao-AILab/causal-conv1d at
mamba prefill/decode shapes on GB10. Channel-first is now vectorized
(issue-bound → vector loads + comptime-unrolled compute); channel-last was
rewritten to a coalesced thread→channel sliding scan (was 4× slower);
accumulation moved to float32 so bf16/fp16 stay within parity.
Benchmarks (GB10, nsys eager kernel-trace median, same-session vs Dao)
Testing
I asserted no regressions of the Mamba model running on GB10. I also kept test cases in the green.
Checklist
agreed-upon, or this is a trivial fix that does not need prior
approval
smaller PRs where possible (see
pull request sizes)
./bazelw run formatto format my changesAssisted-by:trailer in my commit message or this PR description (seeAI Tool Use Policy)
nsys_logs.zip
traces.zip