Pytorch 2.7 crashes when using flex attention with torch.amp #153042
Labels
module: flex attention
module: higher order operators
torch.cond and similar
module: pt2-dispatcher
PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,
needs reproduction
Someone else needs to try reproducing the issue given the instructions. No action needed from user
oncall: pt2
π Describe the bug
I believe this bug should exist for a very long time but is still not fixed yet, so I post this new issue here.
Basically, the current flex attention is incompatible with torch.amp.autocast. The bug can be reproduced with the following (extremely simple) code:
The error message is
Runtime Error: A compilation subprocess exited unexpectedly.
However, if we change the
enabled
parameter inwith torch.amp.autocast(device_type='cuda', enabled=True)
toFalse
, then the problem can run normally without crash.This bug exists from Pytorch 2.5 until the latest Pytorch 2.7. I found similar issues may already exist before but there have been no update. See this page for a relevant issue: #135723 .
Versions
Since the company's desktop does not connect to the Internet, I manually typed the result of collect_env.py.
cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng
The text was updated successfully, but these errors were encountered: