Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Pytorch 2.7 crashes when using flex attention with torch.amp #153042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
zbh2047 opened this issue May 7, 2025 · 5 comments
Open

Pytorch 2.7 crashes when using flex attention with torch.amp #153042

zbh2047 opened this issue May 7, 2025 · 5 comments
Labels
module: flex attention module: higher order operators torch.cond and similar module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user oncall: pt2

Comments

@zbh2047
Copy link

zbh2047 commented May 7, 2025

πŸ› Describe the bug

I believe this bug should exist for a very long time but is still not fixed yet, so I post this new issue here.

Basically, the current flex attention is incompatible with torch.amp.autocast. The bug can be reproduced with the following (extremely simple) code:

import torch
import torch.nn as nn
from torch.nn.attention import flex_attention

class MultiheadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        assert embed_dim % num_heads == 0
        self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.num_heads = num_heads

    def forward(self, qkv):
        qkv = self.in_proj(qkv)
        qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        output = flex_attention.flex_attention(q, k, v)
        output = output.permute(0, 2, 1, 3)
        output = output.reshape(output.size(0), output.size(1), -1)
        output = self.out_proj(output)
        return output

def test():
    model = MultiheadSelfAttention(64, 4)
    model = model.cuda()
    model = torch.compile(model)
    x = torch.randn((1, 100, 64), dtype=torch.float, device='cuda:0')
    with torch.amp.autocast(device_type='cuda', enabled=True):
        y = model(x).sum()
        y.backward()

The error message is
Runtime Error: A compilation subprocess exited unexpectedly.
However, if we change the enabled parameter in with torch.amp.autocast(device_type='cuda', enabled=True) to False, then the problem can run normally without crash.

This bug exists from Pytorch 2.5 until the latest Pytorch 2.7. I found similar issues may already exist before but there have been no update. See this page for a relevant issue: #135723 .

Versions

Here is the concrete environmental settings:

Pytorch version: 2.7.0
    Is debug build: False
    CUDA used to build Pytorch: 12.3
    ROCM used to build Pytorch: N/A

    OS: Red Hat Enterprise Linux release 9.2 (Plow) (x86_64)
    GCC version: (realm gcc 12.1.0-19) 12.1.0
    Clang version: Could not collect
    CMake version: Could not collect
    Libc version: glibc-2.34

    Python version: 3.11.10 (main, Dec 10 2024, 18:31:47) [GCC 12.1.0] (64-bit runtime)
    Python platform: Linux-4.18.0-348.23.1.el8.criu_rseq.x86_64-with-glibc2.34
    Is CUDA available: True
    CUDA running version: 12.3.107
    CUDA_MODULE_LOADING set to: LAZY
    GPU models and configuration:
GPU 0: NVIDIA H100 80 GB HBM 3
GPU 1: NVIDIA H100 80 GB HBM 3
GPU 2: NVIDIA H100 80 GB HBM 3
GPU 3: NVIDIA H100 80 GB HBM 3

    Nvidia driver version: 550.54.15
    cuDNN version: Could not collect
    HIP runtime version: N/A
    MIOpen runtime version: N/A
    Is XNNPACK available: True

    CPU:
    Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 104
On-line CPU(s) list: 0-103
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8481C CPU @ 2.70GHz
CPU family: 6
Model: 143
Thread(s) per core: 2
Core(s) per socket: 52
Socket(s): 1
Stepping: 8
BogoMIPS: 5399.9
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 2.4 MiB (52 instances)
L1i cache: 1.6 MiB (52 instances)
L2 cache: 104 MiB (52 instances)
L3 cache: 105 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-103

Since the company's desktop does not connect to the Internet, I manually typed the result of collect_env.py.

cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng

@malfet malfet added module: crash Problem manifests as a hard crash, as opposed to a RuntimeError module: flex attention labels May 7, 2025
@pytorch-bot pytorch-bot bot added module: higher order operators torch.cond and similar oncall: pt2 module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, labels May 7, 2025
@ngimel
Copy link
Collaborator

ngimel commented May 8, 2025

I can't repro on H100, but it's quite possible that L4 is different

@zbh2047
Copy link
Author

zbh2047 commented May 8, 2025

I tried H100 and it seems the code still crashes. I've typed the environmental settings above. Hope it can be helpful. Thank you!

@ngimel
Copy link
Collaborator

ngimel commented May 8, 2025

What's your triton version?

@ngimel
Copy link
Collaborator

ngimel commented May 8, 2025

This

import torch
import torch.nn as nn
from torch.nn.attention import flex_attention

class MultiheadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.num_heads = num_heads

    def forward(self, qkv):
        qkv = self.in_proj(qkv)
        qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        output = flex_attention.flex_attention(q, k, v)
        output = output.permute(0, 2, 1, 3)
        output = output.reshape(output.size(0), output.size(1), -1)
        output = self.out_proj(output)
        return output

def test():
    model = MultiheadSelfAttention(64, 4)
    model = model.cuda()
    model = torch.compile(model)
    x = torch.randn((1, 100, 64), dtype=torch.float, device='cuda:0')
    with torch.amp.autocast(device_type='cuda', enabled=True):
        y = model(x).sum()
        y.backward()

test()

runs on H100 + triton3.3 + cuda12.8 (that's likely not important) + recent build of pytorch

@drisspg
Copy link
Contributor

drisspg commented May 8, 2025

For completeness I tried the repro on Nightly and 2.7.0 and was unable to reproduce

@drisspg drisspg added needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user and removed module: crash Problem manifests as a hard crash, as opposed to a RuntimeError labels May 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: flex attention module: higher order operators torch.cond and similar module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user oncall: pt2
Projects
None yet
Development

No branches or pull requests

4 participants