Flex attention: batch-index-dependent block mask causes error with changing batch size #152297
Labels
module: flex attention
module: higher order operators
torch.cond and similar
module: pt2-dispatcher
PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
π Describe the bug
I'm trying to do attention with a custom attention mask that depends on the batch index.
My square attention mask has the following structure:
n
rows is causaln
is different for each batch index, and is specified through the tensor of integers namedcutoffs
.During training, the last batch might be smaller. This causes an error related to flex attention.
The error goes away if I remove
mode="max-autotune-no-cudagraphs"
. I'm hoping to include it (or other alternatives) because it's the best practice for speedup.Below is a minimal example of the error:
Versions
cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng
The text was updated successfully, but these errors were encountered: