You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In kernels/softmax_kernel.py the vectorized fast path (BufferCopy128b, VEC_WIDTH=8) is permanently disabled:
# line ~104ifconst_expr(FalseandN>=tile_colsandN%tile_cols==0): # <- `False and` => unreachable
So every softmax shape — even N a multiple of tile_cols (=BLOCK_THREADS*VEC_WIDTH=2048) in fp16/bf16 — runs the generic scalar path (1 element/thread, BufferCopy16b/32b), leaving bandwidth on the table.
Re-enabling it (dropping False and) surfaces two real failures, which is presumably why it was switched off:
i.e. the vectorized load/store/exp2 path has a type mismatch the generic path doesn't.
Why it matters
Even on the generic scalar path, FlyDSL softmax already beats the baselines on a multi-shape MI350X/gfx950 sweep (kernel-only geomean ~1.13× vs best of torch/aiter-triton/standalone-triton). The vectorized path is pure upside for the bandwidth-bound large-N shapes — once the fastmath attr and the cast bug are fixed.
Repro (MI350X / gfx950, ROCm 7.2)
# after dropping `False and` at line 104 and setting fastmath=fm_fast at line 148:fromkernels.softmax_kernelimportbuild_softmax_modulebuild_softmax_module(4096, 2048, "bf16") # N % 2048 == 0 -> fast path -> LLVM Invalid cast! abort
Fix no1 (fastmath) is a one-liner; fix no2 needs a look at the vectorized path's load/store/exp2 element types. Happy to help verify a fix against the multi-shape softmax benchmark.
Found via a multi-shape kernel benchmark sweep (FlyDSL softmax vs torch/Triton/AITER).
Summary
In
kernels/softmax_kernel.pythe vectorized fast path (BufferCopy128b, VEC_WIDTH=8) is permanently disabled:So every softmax shape — even N a multiple of
tile_cols(=BLOCK_THREADS*VEC_WIDTH=2048) in fp16/bf16 — runs the generic scalar path (1 element/thread, BufferCopy16b/32b), leaving bandwidth on the table.Re-enabling it (dropping
False and) surfaces two real failures, which is presumably why it was switched off:Invalid MLIR fastmath attr (line ~148):
exp_val = fmath.exp2(scaled, fastmath=True)passes a Python bool, producing#arith.fastmath<True>→The generic path (line ~218) correctly uses
fastmath=fm_fast(arith.FastMathFlags.fast). Fixing line 148 tofastmath=fm_fastclears this.LLVM cast assertion (after fix Add pass to python #1): building any fast-path shape then aborts in codegen:
i.e. the vectorized load/store/exp2 path has a type mismatch the generic path doesn't.
Why it matters
Even on the generic scalar path, FlyDSL softmax already beats the baselines on a multi-shape MI350X/gfx950 sweep (kernel-only geomean ~1.13× vs best of torch/aiter-triton/standalone-triton). The vectorized path is pure upside for the bandwidth-bound large-N shapes — once the fastmath attr and the cast bug are fixed.
Repro (MI350X / gfx950, ROCm 7.2)
Fix no1 (fastmath) is a one-liner; fix no2 needs a look at the vectorized path's load/store/exp2 element types. Happy to help verify a fix against the multi-shape softmax benchmark.
Found via a multi-shape kernel benchmark sweep (FlyDSL softmax vs torch/Triton/AITER).