Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Flex attention: batch-index-dependent block mask causes error with changing batch size #152297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
zhihanyang2022 opened this issue Apr 28, 2025 · 1 comment
Labels
module: flex attention module: higher order operators torch.cond and similar module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@zhihanyang2022
Copy link

zhihanyang2022 commented Apr 28, 2025

πŸ› Describe the bug

I'm trying to do attention with a custom attention mask that depends on the batch index.

My square attention mask has the following structure:

  • First n rows is causal
  • Afterwards everything is bidirectional

n is different for each batch index, and is specified through the tensor of integers named cutoffs.

During training, the last batch might be smaller. This causes an error related to flex attention.

The error goes away if I remove mode="max-autotune-no-cudagraphs". I'm hoping to include it (or other alternatives) because it's the best practice for speedup.

Below is a minimal example of the error:

from torch.nn.attention.flex_attention import flex_attention, create_block_mask
import torch
import torch.nn.functional as F

# # Flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)

@torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
def fused_flex_attention(q, k, v, mask=None):
    return flex_attention(q, k, v, block_mask=mask)

def create_mixed_diffusion_mask(cutoffs):
  def mixed_diffusion_mask(b, h, q_idx, kv_idx):
    causal = q_idx >= kv_idx
    block_identity = q_idx >= cutoffs[b]
    return causal | block_identity
  return mixed_diffusion_mask

large_batch_size = 256
large_qkv = torch.randn(large_batch_size, 8, 3, 128, 32).cuda()
large_cutoffs = torch.randint(0 ,128, (large_batch_size,)).cuda()

small_batch_size = 64
small_qkv = torch.randn(small_batch_size, 8, 3, 128, 32).cuda()
small_cutoffs = torch.randint(0 ,128, (small_batch_size,)).cuda()

block_mask = create_block_mask(create_mixed_diffusion_mask(large_cutoffs), B=large_batch_size, H=None, Q_LEN=128, KV_LEN=128)
fused_flex_attention(large_qkv[:, :, 0], large_qkv[:, :, 1], large_qkv[:, :, 2], mask=block_mask)

block_mask = create_block_mask(create_mixed_diffusion_mask(small_cutoffs), B=small_batch_size, H=None, Q_LEN=128, KV_LEN=128)
fused_flex_attention(small_qkv[:, :, 0], small_qkv[:, :, 1], small_qkv[:, :, 2], mask=block_mask)
Traceback (most recent call last):
  File "/share/thickstun/zhihan/ELMO/test_flex_attention_2.py", line 38, in <module>
    fused_flex_attention(small_qkv[:, :, 0], small_qkv[:, :, 1], small_qkv[:, :, 2], mask=block_mask)
  File "/share/thickstun/zhihan/.conda/bd3lm/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 663, in _fn
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/share/thickstun/zhihan/.conda/bd3lm/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 760, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/share/thickstun/zhihan/.conda/bd3lm/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 745, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
  File "/share/thickstun/zhihan/.conda/bd3lm/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 1293, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
  File "/share/thickstun/zhihan/.conda/bd3lm/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 1119, in codegen_and_compile
    graph.run(*example_inputs)
  File "/share/thickstun/zhihan/.conda/bd3lm/lib/python3.9/site-packages/torch/_inductor/graph.py", line 877, in run
    return super().run(*args)
  File "/share/thickstun/zhihan/.conda/bd3lm/lib/python3.9/site-packages/torch/fx/interpreter.py", line 171, in run
    self.env[node] = self.run_node(node)
  File "/share/thickstun/zhihan/.conda/bd3lm/lib/python3.9/site-packages/torch/_inductor/graph.py", line 1527, in run_node
    result = super().run_node(n)
  File "/share/thickstun/zhihan/.conda/bd3lm/lib/python3.9/site-packages/torch/fx/interpreter.py", line 240, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/share/thickstun/zhihan/.conda/bd3lm/lib/python3.9/site-packages/torch/_inductor/graph.py", line 1198, in call_function
    raise LoweringException(e, target, args, kwargs).with_traceback(
  File "/share/thickstun/zhihan/.conda/bd3lm/lib/python3.9/site-packages/torch/_inductor/graph.py", line 1188, in call_function
    out = lowerings[target](*args, **kwargs)  # type: ignore[index]
  File "/share/thickstun/zhihan/.conda/bd3lm/lib/python3.9/site-packages/torch/_inductor/lowering.py", line 465, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/share/thickstun/zhihan/.conda/bd3lm/lib/python3.9/site-packages/torch/_inductor/kernel/flex_attention.py", line 1533, in flex_attention
    autotune_select_algorithm(
  File "/share/thickstun/zhihan/.conda/bd3lm/lib/python3.9/site-packages/torch/_inductor/select_algorithm.py", line 2344, in autotune_select_algorithm
    return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
  File "/share/thickstun/zhihan/.conda/bd3lm/lib/python3.9/site-packages/torch/_inductor/select_algorithm.py", line 1734, in __call__
    inputs_key = create_inputs_key(input_nodes)
  File "/share/thickstun/zhihan/.conda/bd3lm/lib/python3.9/site-packages/torch/_inductor/select_algorithm.py", line 1624, in create_inputs_key
    return repr([AlgorithmSelectorCache.key_of(x) for x in input_nodes])
  File "/share/thickstun/zhihan/.conda/bd3lm/lib/python3.9/site-packages/torch/_inductor/select_algorithm.py", line 1624, in <listcomp>
    return repr([AlgorithmSelectorCache.key_of(x) for x in input_nodes])
  File "/share/thickstun/zhihan/.conda/bd3lm/lib/python3.9/site-packages/torch/_inductor/select_algorithm.py", line 2306, in key_of
    node.get_device().type,
torch._inductor.exc.InductorError: LoweringException: AttributeError: 'Symbol' object has no attribute 'get_device'
  target: flex_attention
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg2_1', layout=FixedLayout('cuda:0', torch.float32, size=[s1, 8, 128, 32], stride=[8*s2, s2, 32, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg4_1', layout=FixedLayout('cuda:0', torch.float32, size=[s1, 8, 128, 32], stride=[8*s2, s2, 32, 1]))
  ))
  args[2]: TensorBox(StorageBox(
    InputBuffer(name='arg7_1', layout=FixedLayout('cuda:0', torch.float32, size=[s5, 8, 128, 32], stride=[8*s2, s2, 32, 1]))
  ))
  args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
  args[4]: (128, 128, TensorBox(StorageBox(
    InputBuffer(name='arg11_1', layout=FixedLayout('cuda:0', torch.int32, size=[s8, 1, 1], stride=[1, 1, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg9_1', layout=FixedLayout('cuda:0', torch.int32, size=[s7, 1, 1, 1], stride=[1, 1, 1, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg15_1', layout=FixedLayout('cuda:0', torch.int32, size=[s10, 1, 1], stride=[1, 1, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg17_1', layout=FixedLayout('cuda:0', torch.int32, size=[s11, 1, 1, 1], stride=[1, 1, 1, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg19_1', layout=FixedLayout('cuda:0', torch.int32, size=[s12, 1, 1], stride=[1, 1, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg21_1', layout=FixedLayout('cuda:0', torch.int32, size=[s13, 1, 1, 1], stride=[1, 1, 1, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg23_1', layout=FixedLayout('cuda:0', torch.int32, size=[s14, 1, 1], stride=[1, 1, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg25_1', layout=FixedLayout('cuda:0', torch.int32, size=[s15, 1, 1, 1], stride=[1, 1, 1, 1]))
  )), 128, 128, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
  args[5]: 0.17677669529663687
  args[6]: {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True}
  args[7]: ()
  args[8]: (s9, TensorBox(StorageBox(
    InputBuffer(name='arg13_1', layout=FixedLayout('cuda:0', torch.int64, size=[s9], stride=[1]))
  )))

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Versions

PyTorch version: 2.7.0.dev20250308+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: 10.0.0-4ubuntu1 
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.9.21 (main, Dec 11 2024, 16:24:11)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-205-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.4.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA RTX 6000 Ada Generation
Nvidia driver version: 530.30.02
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      52 bits physical, 57 bits virtual
CPU(s):                             384
On-line CPU(s) list:                0-383
Thread(s) per core:                 2
Core(s) per socket:                 96
Socket(s):                          2
NUMA node(s):                       2
Vendor ID:                          AuthenticAMD
CPU family:                         25
Model:                              17
Model name:                         AMD EPYC 9654 96-Core Processor
Stepping:                           1
Frequency boost:                    enabled
CPU MHz:                            1479.987
CPU max MHz:                        2400.0000
CPU min MHz:                        1500.0000
BogoMIPS:                           4800.06
Virtualization:                     AMD-V
L1d cache:                          6 MiB
L1i cache:                          6 MiB
L2 cache:                           192 MiB
L3 cache:                           768 MiB
NUMA node0 CPU(s):                  0-95,192-287
NUMA node1 CPU(s):                  96-191,288-383
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor smca flush_l1d

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.25.1
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] pytorch-lightning==2.5.1.post0
[pip3] pytorch-triton==3.2.0+git4b3bb1f8
[pip3] torch==2.7.0.dev20250308+cu126
[pip3] torchaudio==2.6.0.dev20250308+cu126
[pip3] torchmetrics==1.6.2
[pip3] torchvision==0.22.0.dev20250308+cu126
[pip3] triton==3.2.0
[conda] cuda-cudart               12.4.127             h99ab3db_0  
[conda] cuda-cudart-dev           12.4.127             h99ab3db_0  
[conda] cuda-cudart-dev_linux-64  12.4.127             hd681fbe_0  
[conda] cuda-cudart-static        12.4.127             h99ab3db_0  
[conda] cuda-cudart-static_linux-64 12.4.127             hd681fbe_0  
[conda] cuda-cudart_linux-64      12.4.127             hd681fbe_0  
[conda] cuda-cupti                12.4.127             h6a678d5_1  
[conda] cuda-cupti-dev            12.4.127             h6a678d5_1  
[conda] cuda-libraries            12.4.1               h06a4308_1  
[conda] cuda-libraries-dev        12.4.1               h06a4308_1  
[conda] cuda-libraries-static     12.4.1               h06a4308_1  
[conda] cuda-nvrtc                12.4.127             h99ab3db_1  
[conda] cuda-nvrtc-dev            12.4.127             h99ab3db_1  
[conda] cuda-nvrtc-static         12.4.127             h99ab3db_1  
[conda] cuda-nvtx                 12.4.127             h6a678d5_1  
[conda] cuda-opencl               12.4.127             h6a678d5_0  
[conda] cuda-opencl-dev           12.4.127             h6a678d5_0  
[conda] libcublas                 12.4.5.8             h99ab3db_1  
[conda] libcublas-dev             12.4.5.8             h99ab3db_1  
[conda] libcublas-static          12.4.5.8             h99ab3db_1  
[conda] libcufft                  11.2.1.3             h99ab3db_1  
[conda] libcufft-dev              11.2.1.3             h99ab3db_1  
[conda] libcufft-static           11.2.1.3             h99ab3db_1  
[conda] libcurand                 10.3.5.147           h99ab3db_1  
[conda] libcurand-dev             10.3.5.147           h99ab3db_1  
[conda] libcurand-static          10.3.5.147           h99ab3db_1  
[conda] libcusolver               11.6.1.9             h99ab3db_1  
[conda] libcusolver-dev           11.6.1.9             h99ab3db_1  
[conda] libcusolver-static        11.6.1.9             h99ab3db_1  
[conda] libcusparse               12.3.1.170           h99ab3db_1  
[conda] libcusparse-dev           12.3.1.170           h99ab3db_1  
[conda] libcusparse-static        12.3.1.170           h99ab3db_1  
[conda] libnvjitlink              12.4.127             h99ab3db_1  
[conda] libnvjitlink-dev          12.4.127             h99ab3db_1  
[conda] libnvjitlink-static       12.4.127             h99ab3db_1  
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-cublas-cu12        12.6.4.1                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.6.80                  pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.6.77                  pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.6.77                  pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.5.1.17                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.3.0.4                 pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.7.77                pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.7.1.2                 pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.5.4.2                 pypi_0    pypi
[conda] nvidia-cusparselt-cu12    0.6.3                    pypi_0    pypi
[conda] nvidia-nccl-cu12          2.25.1                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.6.85                  pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.6.77                  pypi_0    pypi
[conda] pytorch-lightning         2.5.1.post0              pypi_0    pypi
[conda] pytorch-triton            3.2.0+git4b3bb1f8          pypi_0    pypi
[conda] torch                     2.7.0.dev20250308+cu126          pypi_0    pypi
[conda] torchaudio                2.6.0.dev20250308+cu126          pypi_0    pypi
[conda] torchmetrics              1.6.2                    pypi_0    pypi
[conda] torchvision               0.22.0.dev20250308+cu126          pypi_0    pypi
[conda] triton                    3.2.0                    pypi_0    pypi

cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng

@pytorch-bot pytorch-bot bot added module: higher order operators torch.cond and similar oncall: pt2 module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, labels Apr 28, 2025
@eellison eellison added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 28, 2025
@drisspg
Copy link
Contributor

drisspg commented May 5, 2025

I just tried to repro w/ torch==2.8.0.dev20250505+cu128, and I am unable too. I think this has recently been fixed. Feel free to re-open if otherwise

@drisspg drisspg closed this as completed May 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: flex attention module: higher order operators torch.cond and similar module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants