Thanks to visit codestin.com
Credit goes to github.com

Skip to content

MPS varying seq len SDPA memory leak #152550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
SalmanMohammadi opened this issue Apr 30, 2025 · 2 comments
Open

MPS varying seq len SDPA memory leak #152550

SalmanMohammadi opened this issue Apr 30, 2025 · 2 comments
Labels
module: memory usage PyTorch is using more memory than it should, or it is leaking memory module: mps Related to Apple Metal Performance Shaders framework module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@SalmanMohammadi
Copy link
Contributor

SalmanMohammadi commented Apr 30, 2025

🐛 Describe the bug

After trying the fix from #152371 (thanks so much for landing this so quickly) However, I was still seeing memory leaks. I found another issue where memory usage on MPS explodes when the sequence length sufficiently varies for SDPA - this does not occur with CUDA.

Image

Reproduction script:

import torch
import torch.nn.functional as F
import sys


def get_memory_stats(device: torch.device):
    if device.type == "mps":
        peak_active = torch.mps.current_allocated_memory()
        peak_alloc = torch.mps.driver_allocated_memory()
        return peak_active, peak_alloc
    elif device.type == "cuda":
        peak_active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0)
        peak_alloc = torch.cuda.max_memory_allocated()
        return peak_active, peak_alloc


def format_bytes(size_bytes):
    """Converts bytes to a readable string (KB, MB, GB)."""
    if size_bytes < 1024:
        return f"{size_bytes} B"
    elif size_bytes < 1024**2:
        return f"{size_bytes / 1024:.2f} KB"
    elif size_bytes < 1024**3:
        return f"{size_bytes / 1024**2:.2f} MB"
    else:
        return f"{size_bytes / 1024**3:.2f} GB"


def run_sdpa_test_single_bs(batch_size, num_iterations, num_heads, head_dim, min_seq_len, max_seq_len, device, dtype):
    actual_max_seq_len = max(max_seq_len, min_seq_len + 1)
    peak_active, peak_alloc = get_memory_stats(device)
    print(f"  Initial Memory: Active={format_bytes(peak_active)}, Alloc={format_bytes(peak_alloc)}")

    for i in range(num_iterations):
        seq_len = torch.randint(min_seq_len, actual_max_seq_len, (1,)).item()

        query = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
        key = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
        value = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
        with torch.no_grad():
            F.scaled_dot_product_attention(query, key, value)  

        peak_active, peak_alloc = get_memory_stats(device)

        if (i + 1) % (num_iterations // 10 or 1) == 0:
            print(f"  Step {i + 1}/{num_iterations}: Active={format_bytes(peak_active)}, Alloc={format_bytes(peak_alloc)}")

    final_peak_active, final_peak_alloc = get_memory_stats(device)
    print(f"  Final Memory: Active={format_bytes(final_peak_active)}, Alloc={format_bytes(final_peak_alloc)}")
    print(f"--- Finished SDPA Test for BS={batch_size}, SeqLen Range=({min_seq_len}-{actual_max_seq_len - 1}) ---")


if __name__ == "__main__":
    batch_size = 4
    num_iterations = 400
    num_heads = 8
    head_dim = 128
    min_seq_len = 128
    max_seq_len = min_seq_len + int(sys.argv[1])
    device = torch.device(sys.argv[2])
    dtype = torch.bfloat16
    run_sdpa_test_single_bs(batch_size, num_iterations, num_heads, head_dim, min_seq_len, max_seq_len, device, dtype)

CUDA results:

root@cb0c541d80f5:/workspace/axolotl# python ../mem_test.py 128 cuda
  Initial Memory: Active=0 B, Alloc=0 B
  Step 40/400: Active=8.71 MB, Alloc=8.71 MB
  Step 80/400: Active=8.71 MB, Alloc=8.71 MB
  Step 120/400: Active=9.66 MB, Alloc=9.66 MB
  Step 160/400: Active=9.66 MB, Alloc=9.66 MB
  Step 200/400: Active=9.66 MB, Alloc=9.66 MB
  Step 240/400: Active=9.66 MB, Alloc=9.66 MB
  Step 280/400: Active=9.66 MB, Alloc=9.66 MB
  Step 320/400: Active=9.66 MB, Alloc=9.66 MB
  Step 360/400: Active=9.66 MB, Alloc=9.66 MB
  Step 400/400: Active=9.66 MB, Alloc=9.66 MB
  Final Memory: Active=9.66 MB, Alloc=9.66 MB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-255) ---
root@cb0c541d80f5:/workspace/axolotl# python ../mem_test.py 256 cuda
  Initial Memory: Active=0 B, Alloc=0 B
  Step 40/400: Active=12.00 MB, Alloc=12.00 MB
  Step 80/400: Active=12.00 MB, Alloc=12.00 MB
  Step 120/400: Active=13.17 MB, Alloc=13.17 MB
  Step 160/400: Active=13.17 MB, Alloc=13.17 MB
  Step 200/400: Active=13.17 MB, Alloc=13.17 MB
  Step 240/400: Active=13.17 MB, Alloc=13.17 MB
  Step 280/400: Active=13.17 MB, Alloc=13.17 MB
  Step 320/400: Active=13.17 MB, Alloc=13.17 MB
  Step 360/400: Active=13.17 MB, Alloc=13.17 MB
  Step 400/400: Active=13.17 MB, Alloc=13.17 MB
  Final Memory: Active=13.17 MB, Alloc=13.17 MB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-383) ---
root@cb0c541d80f5:/workspace/axolotl# python ../mem_test.py 512 cuda
  Initial Memory: Active=0 B, Alloc=0 B
  Step 40/400: Active=20.78 MB, Alloc=20.78 MB
  Step 80/400: Active=20.78 MB, Alloc=20.78 MB
  Step 120/400: Active=20.78 MB, Alloc=20.78 MB
  Step 160/400: Active=20.78 MB, Alloc=20.78 MB
  Step 200/400: Active=20.78 MB, Alloc=20.78 MB
  Step 240/400: Active=20.78 MB, Alloc=20.78 MB
  Step 280/400: Active=20.78 MB, Alloc=20.78 MB
  Step 320/400: Active=20.78 MB, Alloc=20.78 MB
  Step 360/400: Active=20.78 MB, Alloc=20.78 MB
  Step 400/400: Active=20.78 MB, Alloc=20.78 MB
  Final Memory: Active=20.78 MB, Alloc=20.78 MB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-639) ---
root@cb0c541d80f5:/workspace/axolotl# python ../mem_test.py 2048 cuda
  Initial Memory: Active=0 B, Alloc=0 B
  Step 40/400: Active=67.58 MB, Alloc=67.58 MB
  Step 80/400: Active=67.58 MB, Alloc=67.58 MB
  Step 120/400: Active=67.58 MB, Alloc=67.58 MB
  Step 160/400: Active=67.58 MB, Alloc=67.58 MB
  Step 200/400: Active=67.58 MB, Alloc=67.58 MB
  Step 240/400: Active=67.58 MB, Alloc=67.58 MB
  Step 280/400: Active=67.58 MB, Alloc=67.58 MB
  Step 320/400: Active=67.89 MB, Alloc=67.89 MB
  Step 360/400: Active=67.89 MB, Alloc=67.89 MB
  Step 400/400: Active=68.14 MB, Alloc=68.14 MB
  Final Memory: Active=68.14 MB, Alloc=68.14 MB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-2175) ---

MPS Results:

> python minimal_test.py 128 mps
  Initial Memory: Active=0 B, Alloc=384.00 KB
  Step 40/400: Active=5.86 MB, Alloc=77.17 MB
  Step 80/400: Active=5.86 MB, Alloc=85.52 MB
  Step 120/400: Active=5.83 MB, Alloc=117.83 MB
  Step 160/400: Active=5.86 MB, Alloc=118.02 MB
  Step 200/400: Active=4.17 MB, Alloc=118.28 MB
  Step 240/400: Active=5.83 MB, Alloc=118.41 MB
  Step 280/400: Active=5.84 MB, Alloc=118.47 MB
  Step 320/400: Active=5.84 MB, Alloc=118.48 MB
  Step 360/400: Active=5.83 MB, Alloc=118.56 MB
  Step 400/400: Active=5.83 MB, Alloc=118.61 MB
  Final Memory: Active=5.83 MB, Alloc=118.61 MB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-255) ---
> python minimal_test.py 256 mps
  Initial Memory: Active=0 B, Alloc=384.00 KB
  Step 40/400: Active=7.81 MB, Alloc=143.22 MB
  Step 80/400: Active=7.81 MB, Alloc=151.73 MB
  Step 120/400: Active=7.81 MB, Alloc=184.08 MB
  Step 160/400: Active=7.81 MB, Alloc=184.47 MB
  Step 200/400: Active=7.81 MB, Alloc=184.77 MB
  Step 240/400: Active=7.81 MB, Alloc=185.03 MB
  Step 280/400: Active=8.11 MB, Alloc=185.28 MB
  Step 320/400: Active=7.81 MB, Alloc=185.50 MB
  Step 360/400: Active=7.81 MB, Alloc=185.78 MB
  Step 400/400: Active=17.01 MB, Alloc=185.88 MB
  Final Memory: Active=17.01 MB, Alloc=185.88 MB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-383) ---
> python minimal_test.py 512 mps
  Initial Memory: Active=0 B, Alloc=384.00 KB
  Step 40/400: Active=5.06 MB, Alloc=1.13 GB
  Step 80/400: Active=17.57 MB, Alloc=1.13 GB
  Step 120/400: Active=15.55 MB, Alloc=1.13 GB
  Step 160/400: Active=10.97 MB, Alloc=1.13 GB
  Step 200/400: Active=7.15 MB, Alloc=1.13 GB
  Step 240/400: Active=15.55 MB, Alloc=1.13 GB
  Step 280/400: Active=10.97 MB, Alloc=1.13 GB
  Step 320/400: Active=17.57 MB, Alloc=1.13 GB
  Step 360/400: Active=10.97 MB, Alloc=1.13 GB
  Step 400/400: Active=17.57 MB, Alloc=1.13 GB
  Final Memory: Active=17.57 MB, Alloc=1.13 GB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-639) ---

Versions

On MPS:

Collecting environment information...
PyTorch version: 2.8.0.dev20250430
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.1.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: Could not collect
Libc version: N/A

Python version: 3.12.8 (main, Jan  5 2025, 06:55:30) [Clang 19.1.6 ] (64-bit runtime)
Python platform: macOS-15.1.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Max

Versions of relevant libraries:
[pip3] numpy==2.2.3
[pip3] pytorch_sphinx_theme==0.0.24
[pip3] torch==2.8.0.dev20250430
[pip3] torchao==0.10.0+cpu
[pip3] torchaudio==2.6.0.dev20250430
[pip3] torchdata==0.11.0
[pip3] torchtune==0.0.0
[pip3] torchvision==0.22.0.dev20250430
[conda] No relevant packages

On CUDA:

Collecting environment information...
PyTorch version: 2.6.0+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 4.0.0
Libc version: glibc-2.35

Python version: 3.11.11 (main, Dec 11 2024, 16:28:39) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-196-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA RTX A6000
GPU 1: NVIDIA RTX A6000
GPU 2: NVIDIA RTX A6000
GPU 3: NVIDIA RTX A6000

Nvidia driver version: 550.127.08
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.1.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      48 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             128
On-line CPU(s) list:                0-127
Vendor ID:                          AuthenticAMD
Model name:                         AMD EPYC 7543 32-Core Processor
CPU family:                         25
Model:                              1
Thread(s) per core:                 2
Core(s) per socket:                 32
Socket(s):                          2
Stepping:                           1
Frequency boost:                    enabled
CPU max MHz:                        2800.0000
CPU min MHz:                        1500.0000
BogoMIPS:                           5599.84
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca
Virtualization:                     AMD-V
L1d cache:                          2 MiB (64 instances)
L1i cache:                          2 MiB (64 instances)
L2 cache:                           32 MiB (64 instances)
L3 cache:                           512 MiB (16 instances)
NUMA node(s):                       2
NUMA node0 CPU(s):                  0-31,64-95
NUMA node1 CPU(s):                  32-63,96-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] apollo-torch==1.0.3
[pip3] galore-torch==1.0
[pip3] numpy==2.0.1
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] torch==2.6.0+cu124
[pip3] torch-optimi==0.2.1
[pip3] torchao==0.9.0
[pip3] torchvision==0.21.0+cu124
[pip3] triton==3.2.0
[conda] No relevant packages

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

@Isalia20 Isalia20 added module: mps Related to Apple Metal Performance Shaders framework module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion labels Apr 30, 2025
@Isalia20
Copy link
Collaborator

Probably due to the attention on MPS not using flash attention implementation. I plan to implement it as a shader this week if I get some time, will report back here once I have something

@malfet malfet added module: memory usage PyTorch is using more memory than it should, or it is leaking memory triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 30, 2025
@malfet
Copy link
Contributor

malfet commented Apr 30, 2025

By the way, I don't think it's specific to SDPA, every MPS op is susceptible to this "bug" as we cache graph with input tensor shapes... I wonder if "solution" at least short term (and on MacOS-15) could be by making cache entry shape agnostic... Or put a limit on max number of entires

pytorchmergebot pushed a commit that referenced this issue May 7, 2025
Paritally fixes #139668 and #152550

Still work in progress. Following needs to be addressed:
- [x] Some tests are failing and need to check why and bugfix
- [x] Benchmark the new kernels and  add to this PR for varying sequence lengths head dimensions(the ones that get dispatched to kernels)
- [x] Add tests to cover the specialized paths(if applicable)
- [x] Code cleanup

**Tested on Macbook M1 Pro**
### Vector Fast Path (q_len=1, k_len=256)
- Old: 0.378 ms
- New: 0.260 ms
- **31.2% speed improvement**

### Vector 2-pass (q_len=1, k_len=4096)
- Old: 0.627 ms
- New: 0.370 ms
- **41.0% speed improvement**

### Vector Fast Path (q_len=8, k_len=256)
- Old: 0.545 ms
- New: 0.322 ms
- **40.9% speed improvement**

### Vector 2-pass (q_len=8, k_len=4096)
- Old: 1.318 ms
- New: 1.057 ms
- **19.8% speed improvement**

Script to get perf:
```
import torch
import time

def benchmark_sdpa(config, iterations=100):
    device = config.get("device", "cpu")
    batch = config["batch"]
    heads = config["heads"]
    q_len = config["q_len"]
    k_len = config["k_len"]
    head_dim = config["head_dim"]

    q = torch.randn(batch, heads, q_len, head_dim, device=device, dtype=torch.float32)
    k = torch.randn(batch, heads, k_len, head_dim, device=device, dtype=torch.float32)
    v = torch.randn(batch, heads, k_len, head_dim, device=device, dtype=torch.float32)

    for _ in range(5):
        _ = torch.nn.functional.scaled_dot_product_attention(q, k, v)
        if device == "mps":
            torch.mps.synchronize()

    total_time = 0.0
    for i in range(iterations):
        start = time.perf_counter()
        _ = torch.nn.functional.scaled_dot_product_attention(q, k, v)
        if device == "mps":
            torch.mps.synchronize()
        end = time.perf_counter()
        total_time += end - start

    avg_time = total_time / iterations
    print(f"[{config['name']}] Avg time per run: {avg_time * 1000:.3f} ms over {iterations} iterations")
    return avg_time

def main():
    device = "mps" if torch.backends.mps.is_available() else "cpu"
    print(f"Running benchmarks on device: {device}")

    benchmarks = [
        {
            "name": "Vector Fast - Small q_len & moderate k_len",
            "batch": 1,
            "heads": 8,
            "q_len": 1,      # small query sequence length triggers vector fast path
            "k_len": 256,    # moderate key length
            "head_dim": 64,
            "device": device,
        },
        {
            "name": "Vector 2-pass - Small q_len & long k_len",
            "batch": 1,
            "heads": 8,
            "q_len": 1,      # small query sequence length
            "k_len": 4096,   # long key length triggers the 2-pass variant
            "head_dim": 64,
            "device": device,
        },
        # {
        #     "name": "Full Attention - Moderate q_len/k_len",
        #     "batch": 1,
        #     "heads": 8,
        #     "q_len": 128,    # longer query sequence length
        #     "k_len": 8192,    # matching key length for full attention paths
        #     "head_dim": 64,
        #     "device": device,
        # },
        # {
        #     "name": "Full Attention - Longer q_len/k_len",
        #     "batch": 1,
        #     "heads": 8,
        #     "q_len": 128,    # very long sequence length
        #     "k_len": 8192,
        #     "head_dim": 64,
        #     "device": device,
        # },
    ]

    iterations = 100
    for config in benchmarks:
        benchmark_sdpa(config, iterations=iterations)

if __name__ == "__main__":
    main()

```
Pull Request resolved: #152781
Approved by: https://github.com/malfet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: memory usage PyTorch is using more memory than it should, or it is leaking memory module: mps Related to Apple Metal Performance Shaders framework module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants