-
Notifications
You must be signed in to change notification settings - Fork 24.1k
MPS varying seq len SDPA memory leak #152550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Probably due to the attention on MPS not using flash attention implementation. I plan to implement it as a shader this week if I get some time, will report back here once I have something |
By the way, I don't think it's specific to SDPA, every MPS op is susceptible to this "bug" as we cache graph with input tensor shapes... I wonder if "solution" at least short term (and on MacOS-15) could be by making cache entry shape agnostic... Or put a limit on max number of entires |
Paritally fixes #139668 and #152550 Still work in progress. Following needs to be addressed: - [x] Some tests are failing and need to check why and bugfix - [x] Benchmark the new kernels and add to this PR for varying sequence lengths head dimensions(the ones that get dispatched to kernels) - [x] Add tests to cover the specialized paths(if applicable) - [x] Code cleanup **Tested on Macbook M1 Pro** ### Vector Fast Path (q_len=1, k_len=256) - Old: 0.378 ms - New: 0.260 ms - **31.2% speed improvement** ### Vector 2-pass (q_len=1, k_len=4096) - Old: 0.627 ms - New: 0.370 ms - **41.0% speed improvement** ### Vector Fast Path (q_len=8, k_len=256) - Old: 0.545 ms - New: 0.322 ms - **40.9% speed improvement** ### Vector 2-pass (q_len=8, k_len=4096) - Old: 1.318 ms - New: 1.057 ms - **19.8% speed improvement** Script to get perf: ``` import torch import time def benchmark_sdpa(config, iterations=100): device = config.get("device", "cpu") batch = config["batch"] heads = config["heads"] q_len = config["q_len"] k_len = config["k_len"] head_dim = config["head_dim"] q = torch.randn(batch, heads, q_len, head_dim, device=device, dtype=torch.float32) k = torch.randn(batch, heads, k_len, head_dim, device=device, dtype=torch.float32) v = torch.randn(batch, heads, k_len, head_dim, device=device, dtype=torch.float32) for _ in range(5): _ = torch.nn.functional.scaled_dot_product_attention(q, k, v) if device == "mps": torch.mps.synchronize() total_time = 0.0 for i in range(iterations): start = time.perf_counter() _ = torch.nn.functional.scaled_dot_product_attention(q, k, v) if device == "mps": torch.mps.synchronize() end = time.perf_counter() total_time += end - start avg_time = total_time / iterations print(f"[{config['name']}] Avg time per run: {avg_time * 1000:.3f} ms over {iterations} iterations") return avg_time def main(): device = "mps" if torch.backends.mps.is_available() else "cpu" print(f"Running benchmarks on device: {device}") benchmarks = [ { "name": "Vector Fast - Small q_len & moderate k_len", "batch": 1, "heads": 8, "q_len": 1, # small query sequence length triggers vector fast path "k_len": 256, # moderate key length "head_dim": 64, "device": device, }, { "name": "Vector 2-pass - Small q_len & long k_len", "batch": 1, "heads": 8, "q_len": 1, # small query sequence length "k_len": 4096, # long key length triggers the 2-pass variant "head_dim": 64, "device": device, }, # { # "name": "Full Attention - Moderate q_len/k_len", # "batch": 1, # "heads": 8, # "q_len": 128, # longer query sequence length # "k_len": 8192, # matching key length for full attention paths # "head_dim": 64, # "device": device, # }, # { # "name": "Full Attention - Longer q_len/k_len", # "batch": 1, # "heads": 8, # "q_len": 128, # very long sequence length # "k_len": 8192, # "head_dim": 64, # "device": device, # }, ] iterations = 100 for config in benchmarks: benchmark_sdpa(config, iterations=iterations) if __name__ == "__main__": main() ``` Pull Request resolved: #152781 Approved by: https://github.com/malfet
🐛 Describe the bug
After trying the fix from #152371 (thanks so much for landing this so quickly) However, I was still seeing memory leaks. I found another issue where memory usage on MPS explodes when the sequence length sufficiently varies for SDPA - this does not occur with CUDA.
Reproduction script:
CUDA results:
MPS Results:
Versions
On MPS:
Collecting environment information... PyTorch version: 2.8.0.dev20250430 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A OS: macOS 15.1.1 (arm64) GCC version: Could not collect Clang version: 15.0.0 (clang-1500.3.9.4) CMake version: Could not collect Libc version: N/A Python version: 3.12.8 (main, Jan 5 2025, 06:55:30) [Clang 19.1.6 ] (64-bit runtime) Python platform: macOS-15.1.1-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True CPU: Apple M1 Max Versions of relevant libraries: [pip3] numpy==2.2.3 [pip3] pytorch_sphinx_theme==0.0.24 [pip3] torch==2.8.0.dev20250430 [pip3] torchao==0.10.0+cpu [pip3] torchaudio==2.6.0.dev20250430 [pip3] torchdata==0.11.0 [pip3] torchtune==0.0.0 [pip3] torchvision==0.22.0.dev20250430 [conda] No relevant packages
On CUDA:
cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen
The text was updated successfully, but these errors were encountered: