Thanks to visit codestin.com
Credit goes to github.com

Skip to content

softmax: vectorized fast path is dead-coded off and does not compile (fastmath=True + LLVM cast assertion) #627

@jhinpan

Description

@jhinpan

Summary

In kernels/softmax_kernel.py the vectorized fast path (BufferCopy128b, VEC_WIDTH=8) is permanently disabled:

# line ~104
if const_expr(False and N >= tile_cols and N % tile_cols == 0):   # <- `False and` => unreachable

So every softmax shape — even N a multiple of tile_cols (=BLOCK_THREADS*VEC_WIDTH=2048) in fp16/bf16 — runs the generic scalar path (1 element/thread, BufferCopy16b/32b), leaving bandwidth on the table.

Re-enabling it (dropping False and) surfaces two real failures, which is presumably why it was switched off:

  1. Invalid MLIR fastmath attr (line ~148): exp_val = fmath.exp2(scaled, fastmath=True) passes a Python bool, producing #arith.fastmath<True>

    MLIRError: Unable to parse attribute: error: "#arith.fastmath<True>": expected ::mlir
    

    The generic path (line ~218) correctly uses fastmath=fm_fast (arith.FastMathFlags.fast). Fixing line 148 to fastmath=fm_fast clears this.

  2. LLVM cast assertion (after fix Add pass to python #1): building any fast-path shape then aborts in codegen:

    llvm/lib/IR/Instructions.cpp:3045: CastInst::Create: Assertion `castIsValid(op, S, Ty) && "Invalid cast!"' failed.
    

    i.e. the vectorized load/store/exp2 path has a type mismatch the generic path doesn't.

Why it matters

Even on the generic scalar path, FlyDSL softmax already beats the baselines on a multi-shape MI350X/gfx950 sweep (kernel-only geomean ~1.13× vs best of torch/aiter-triton/standalone-triton). The vectorized path is pure upside for the bandwidth-bound large-N shapes — once the fastmath attr and the cast bug are fixed.

Repro (MI350X / gfx950, ROCm 7.2)

# after dropping `False and` at line 104 and setting fastmath=fm_fast at line 148:
from kernels.softmax_kernel import build_softmax_module
build_softmax_module(4096, 2048, "bf16")   # N % 2048 == 0 -> fast path -> LLVM Invalid cast! abort

Fix no1 (fastmath) is a one-liner; fix no2 needs a look at the vectorized path's load/store/exp2 element types. Happy to help verify a fix against the multi-shape softmax benchmark.

Found via a multi-shape kernel benchmark sweep (FlyDSL softmax vs torch/Triton/AITER).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions