Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Try not to fail when there should be memory available#2869

Merged
awni merged 3 commits intoml-explore:mainfrom
awni:improve_cuda_allocator
Dec 7, 2025
Merged

Try not to fail when there should be memory available#2869
awni merged 3 commits intoml-explore:mainfrom
awni:improve_cuda_allocator

Conversation

@awni
Copy link
Member

@awni awni commented Dec 4, 2025

There are a couple cases where MLX can fail with OOM when there is actually memory available:

  • Sometimes cudaMallocAsync returns nullptr even when there should be enough RAM outside the cache + used memory. I believe this is due to fragmentation. Instead of failing on this case, we free from the cache then try again.
  • Sometimes kernel / graph execution fails due to OOM (very curious here what could cause that). If the OS reported free memory is below a limit then we clear from the cache if possible.

@awni awni force-pushed the improve_cuda_allocator branch from 5eb48e1 to 158accb Compare December 4, 2025 00:54
@awni
Copy link
Member Author

awni commented Dec 4, 2025

I think there is a performance issue here so moving into draft.

@awni awni marked this pull request as draft December 4, 2025 03:40
@awni
Copy link
Member Author

awni commented Dec 4, 2025

Evidently calling CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); in malloc is a bad idea :(.

@awni awni force-pushed the improve_cuda_allocator branch 3 times, most recently from c4b20cd to 7a09a4a Compare December 4, 2025 19:37
@awni awni force-pushed the improve_cuda_allocator branch 2 times, most recently from c5ff2e2 to 3a4b72d Compare December 5, 2025 23:51
@awni awni force-pushed the improve_cuda_allocator branch from 3a4b72d to b3d6566 Compare December 5, 2025 23:52
@awni
Copy link
Member Author

awni commented Dec 5, 2025

Ok I fixed this and I don't see a regression in perf.

I think the basic premise for what is happening is that even though the MLX cache + active memory is well under the limit, there is fragmentation and since we are using async free, the device is not able return memory to the OS before every time we call malloc, and so CUDA can fail to allocate free memory even when the total amount of free memory exceeds the requested allocation.

@awni awni marked this pull request as ready for review December 5, 2025 23:55
@awni awni requested review from angeloskath and zcbenz December 5, 2025 23:56
@awni
Copy link
Member Author

awni commented Dec 5, 2025

I uploaded a script that repro's the issue on B200 (and should on H100 for smaller batch size). Just leaving it here for reference.

run.py

Copy link
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

return loc;
}
#else
int cuda_mem_loc(int i) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: add inline.

size_t used = 0;
CHECK_CUDA_ERROR(cudaMemPoolGetAttribute(
p, cudaMemPoolAttrReservedMemCurrent, &used));
if (used > (total_memory_ - free_limit_)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why having a free_limit_? The code would read easier for me if it is just:

if (used > memory_limit_) {
  buffer_cache_.release_cached_buffers(total_memory_ - memory_limit_);
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. memory_limit_ can change (the user can set the memory limit to be higher or lower). I wanted a value that was fixed based on the total device memory.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about using hard_memory_limit_/soft_memory_limit_? (Just being nitpick, I'm good with free_limit_ too.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really love hard_memory_limit cause it's not a hard limit.

It's more like a soft memory limit on the underlying cuda pool. I'll think a bit more on how to phrase it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other thing that is a bit of a mess in our allocator especially is how we deal with multi-device on a discrete setup where each device has it's own memory.

I think at some point it might make sense to have separate buffer cache for each device and one for the managed allocator.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense!

@awni awni merged commit a4b3bc9 into ml-explore:main Dec 7, 2025
12 checks passed
@awni awni deleted the improve_cuda_allocator branch December 9, 2025 14:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants