Flex attention strides #152683
Labels
module: flex attention
module: higher order operators
torch.cond and similar
module: pt2-dispatcher
PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
๐ The doc issue
The FlexAttention docs nicely document the expected shapes of inputs, but does not specify anything about stride. In contrast, for example cuDNN documents that strides can be freely chosen except for requiring the last dimension to be contiguous. Knowing available options for striding is important, as that informs, e.g., whether it is possible to merge the QKV matmuls into a single matmul.
Also (and independently),
return_lse
is missing from the output documentation.Suggest a potential alternative/fix
No response
cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng
The text was updated successfully, but these errors were encountered: