Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[Kernels] causal conv1d optimization#6632

Draft
ulmentflam wants to merge 23 commits into
modular:mainfrom
ulmentflam:nightly/causal-conv1d-optimization
Draft

[Kernels] causal conv1d optimization#6632
ulmentflam wants to merge 23 commits into
modular:mainfrom
ulmentflam:nightly/causal-conv1d-optimization

Conversation

@ulmentflam

Copy link
Copy Markdown
Contributor

Linked issue

Part of #5772
Helps PR #6625

Type of change

  • Bug fix (non-breaking change that fixes an issue)
  • Performance improvement (includes benchmark results below)
  • Documentation update
  • New feature or public API (requires prior proposal or issue approval)
  • Refactor / internal cleanup (no user-visible change)
  • Build, CI, or tooling change

Motivation

Tri Dao's Causal Conv1D is fast, but we can do better. @gabrieldemarmiesse made an initial pas to bring to parity with causal conv1d. This branch makes an attempt to beat causal conv1d.

What changed

Summary

Optimizes the causal_conv1d GPU kernels (channel-first + channel-last forward,
and the decode/update kernel) so MAX meets or beats Dao-AILab/causal-conv1d at
mamba prefill/decode shapes on GB10. Channel-first is now vectorized
(issue-bound → vector loads + comptime-unrolled compute); channel-last was
rewritten to a coalesced thread→channel sliding scan (was 4× slower);
accumulation moved to float32 so bf16/fp16 stay within parity.

Benchmarks (GB10, nsys eager kernel-trace median, same-session vs Dao)

Path dtype shape MAX Dao Result
channel-first fwd fp32 L=256 3.17 µs 3.30 µs MAX 4% faster
channel-first fwd bf16 L=256 2.82 µs 4.06 µs MAX 31% faster
channel-last fwd fp32 L=256 3.15 µs 3.55 µs MAX 11% faster
update (decode) fp32 B=128 19.6 µs 20.5 µs MAX 4% faster
update (decode) fp32 B=1 1.70 µs 1.22 µs Dao faster (tiny, launch-bound)

Testing

I asserted no regressions of the Mamba model running on GB10. I also kept test cases in the green.

Checklist

  • The linked issue above has been reviewed by a maintainer and is
    agreed-upon, or this is a trivial fix that does not need prior
    approval
  • PR is small and focused — I've split larger changes into a sequence of
    smaller PRs where possible (see
    pull request sizes)
  • I ran ./bazelw run format to format my changes
  • I added or updated tests to cover my changes
  • If AI tools assisted with this contribution, I have included an
    Assisted-by: trailer in my commit message or this PR description (see
    AI Tool Use Policy)

nsys_logs.zip
traces.zip

ulmentflam and others added 7 commits May 29, 2026 14:50
BEGIN_PUBLIC
[Kernels][GPU] Collapse causal_conv1d, add channel-last support + tests

Collapse the 20 near-duplicate causal_conv1d kernels into 7 parameterized
functions. Bias and packed-sequence (seq_idx) presence are now compile-time
Bool parameters on the CPU and runtime Int8 arguments on the GPU (mirroring the
varlen_causal_conv1d idiom), replacing the hand-copied {bias|no_bias} x
{seq_idx|none} variants. A single stride-driven CPU core serves both
channel-first and channel-last layouts; the two GPU kernels share a
width-generic scalar accumulation over the conv taps. Read-only inputs (x,
weight, bias, seq_idx) are immutable borrows.

Wire channel-last as a first-class registered op (causal_conv1d_channel_last)
and add CPU + GPU channel-last tests, including a seq_idx packed-sequence
masking case. Migrate the existing channel-first forward and update tests
(CPU + GPU) to the unified signatures. This shrinks causal_conv1d.mojo from
3769 to 918 lines with no change to the causal_conv1d / causal_conv1d_update
graph op interfaces.
END_PUBLIC

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Signed-off-by: Evan <[email protected]>
…on GB10

BEGIN_PUBLIC
[Kernels][GPU] Vectorize causal_conv1d channel-first; beat Dao-AILab on GB10

The channel-first forward GPU kernel issued kNElts*kWidth scalar global
loads per thread. The working set (~3 MB at mamba-130m prefill) is L2-
resident, so the kernel was issue-bound on load instructions rather than
bandwidth-bound.

Add a vectorized fast path for the dense case (L-contiguous, no seq_idx,
full tile, kNElts-aligned): one width-kNElts vector load for the thread's
tile plus kWidth-1 L2-cached halo scalars, with the convolution fully
unrolled at compile time so every tap/halo index resolves at comptime, and
a single width-kNElts vector store. seq_idx / non-contiguous / ragged-tail
inputs fall back to the existing scalar accumulation, so correctness is
unchanged. The channel-first op now launches kNThreads=64, kNElts=4 (float4)
so 64*4 == 256 keeps every thread busy at the common L=256, instead of the
half-idle block that kNThreads=128 (tile 512) leaves. The kernel also gains
an optional block_dim.y channel-fold hook (default 1, backward compatible).

Measured apples-to-apples on GB10 (nsys CUDA-trace, eager, median kernel
GPU duration over 110-220 launches, replicated, clock-drift controlled):
the kernel goes from 5.54 us to 3.17 us (43% faster) and beats the
Dao-AILab/causal-conv1d reference at 3.30 us.

Add GPU tests covering the fast path (widths 1-4, SiLU on/off, multi-block,
and mamba prefill dims) plus a config-sweep benchmark.
END_PUBLIC

Signed-off-by: Evan <[email protected]>
…-last

BEGIN_PUBLIC
[Kernels][GPU] causal_conv1d: bf16/fp16 fast path + coalesced channel-last

Extend the vectorized causal_conv1d work to 16-bit dtypes and rewrite the
channel-last GPU kernel so both layouts beat Dao-AILab/causal-conv1d on GB10.

- Channel-first fast path now accumulates in float32 regardless of storage
  dtype (the CPU reference too), so bf16/fp16 stay within parity tolerance.
  bf16 at mamba prefill (B=1, dim=1536, L=256, W=4): 2.82 us vs Dao 4.06 us
  (31% faster) — Dao idles 3/4 of its 128-thread block with kNElts=8 at
  L=256, while MAX stays fully utilized at 64x4.

- Channel-last kernel rewritten to map threads to channels instead of
  positions. Because C is the contiguous axis, neighbouring threads read
  neighbouring memory at each L position, so the per-position loads coalesce
  across the warp. Each thread does a sliding scan over kNElts positions,
  loading each window element once. Result: 3.15 us vs Dao 3.55 us (11%
  faster), down from 14.5 us for the naive vectorization.

- silu is now branchless/width-generic so it applies to SIMD vectors.

- Ops launch the validated configs (channel-first 64x4, channel-last 64x8);
  grids scale with L and dim. Adds bf16/fp16 and coalesced channel-last GPU
  tests (widths 1-4, SiLU, mamba dims, ragged-channel fallback).
END_PUBLIC

Signed-off-by: Evan <[email protected]>
…tests

BEGIN_PUBLIC
[Kernels][GPU] causal_conv1d: update-kernel bench + datacenter-scale tests

Add a decode/update GPU benchmark and validate the forward kernels at
larger-model and batched shapes.

- bench_causal_conv1d_update.mojo: times causal_conv1d_update_gpu at decode
  shapes (seqlen=1), batch configurable. Used to confirm parity with
  Dao-AILab/causal_conv1d_update on GB10: at B=128 (datacenter batched decode,
  bandwidth-bound) MAX is 19.6 us vs Dao 20.5 us; at B=1 the kernels are within
  ~0.5 us (both launch-bound and negligible against a full decode step).

- test_causal_conv1d datacenter-scale case: wide channels (dim up to 5120),
  long prefill (L up to 2048), and batched (B=8), fp32 and bf16. Confirms
  correctness and that the launch grids stay within limits (dim, batch <=
  65535) at scale.
END_PUBLIC

Signed-off-by: Evan <[email protected]>
BEGIN_PUBLIC
[Kernels][GPU] causal_conv1d: add packed-sequence (seq_idx) GPU tests

The vectorized channel-first fast path and the coalesced channel-last path both
defer to their scalar fallbacks whenever a seq_idx packed-sequence mask is
active, and that fallback was not exercised on GPU. Thread an optional
has_seq_idx through both GPU test harnesses (a two-segment mask split at
seqlen//2) and add seq_idx cases covering widths 1-4 at a small shape and at
mamba dims, validated against the CPU reference computing the same masked
convolution.
END_PUBLIC

Signed-off-by: Evan <[email protected]>
BEGIN_PUBLIC
[Kernels][GPU] causal_conv1d: document the GPU optimizations inline

Expand the channel-first and channel-last forward-kernel docstrings with the
optimization rationale a maintainer needs: the bound each kernel hits
(channel-first issue-bound on loads; channel-last coalescing), what the
vectorized fast path / thread->channel mapping does and why, the measured GB10
result vs Dao-AILab, the config the op selects, float32 accumulation for
bf16/fp16 parity, and a pointer to the full notes (including what regressed).
END_PUBLIC

Signed-off-by: Evan <[email protected]>
@ulmentflam ulmentflam requested a review from a team as a code owner May 30, 2026 20:48
@ulmentflam ulmentflam changed the title Nightly/causal conv1d optimization [Kernels] causal conv1d optimization May 30, 2026
@ulmentflam ulmentflam marked this pull request as draft May 30, 2026 21:27
ulmentflam and others added 7 commits June 11, 2026 22:06
…-conv1d-cleanup

# Conflicts:
#	max/kernels/src/state_space/causal_conv1d.mojo
BEGIN_PUBLIC
[Kernels] Apply mojo format to causal_conv1d.mojo

Reformat the output TileTensor parameter declarations to wrap across
multiple lines, matching the Mojo formatter output. Fixes the CI lint
(mblack) check that failed after the upstream merge.
END_PUBLIC

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
…-optimization

# Conflicts:
#	max/kernels/benchmarks/BUILD.bazel
#	max/kernels/src/state_space/causal_conv1d.mojo
…ization

Integrate upstream cleanup (UntrackedOrigin NFC) while keeping the
optimized GPU kernel implementations from this branch.

Co-authored-by: Cursor <[email protected]>
…tflam/modular into nightly/causal-conv1d-optimization
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant