Skip to content

[CUDA] Reduce use of managed memory#2725

Merged
awni merged 10 commits intomainfrom
async_cuda_malloc
Nov 6, 2025
Merged

[CUDA] Reduce use of managed memory#2725
awni merged 10 commits intomainfrom
async_cuda_malloc

Conversation

@awni
Copy link
Copy Markdown
Member

@awni awni commented Nov 1, 2025

This is an attempt to reduce the use of managed memory in our cuda back-end. There are still a few details to figure out / ideas to explore. But for the common use cases of LLM training / inference it works quite well.

Some key differences:

  • malloc is still the same as before. It gives you managed memory and caches it in the buffer cache
  • cu::malloc_async gives you async allocated non-managed memory and caches it in the buffer cache
  • eval_gpu should always prefer cu::malloc_async(size, stream); to get the good memory
  • pointers to kernels should be gotten with gpu_ptr<T>(array)
  • Accessing arr.data<T>() can do a copy if the data is not in the right place. So don't access it unless you need the data accessible to the CPU

Some improvements to think about:

  • All allocations still go in the buffer cache. When doing malloc_async or malloc we first check the buffer cache. It might be good to prefer the appropriate type of memory (managed or not) when pulling from the cache.

  • When doing the copy to managed memory, we rely on the fact that it does a device synchronize. This can be slow. So it may be good to make it async when possible.

Benchmarks

Command:

mlx_lm.benchmark --model meta-llama/Meta-Llama-3.1-8B --p 2049 -g 128 -b 1 -n 4

DGX Spark
Pre: Averages: prompt_tps=744.025, generation_tps=10.301, peak_memory=21.707
Post: Averages: prompt_tps=2093.590, generation_tps=14.553, peak_memory=21.707

B200
Pre: Averages: prompt_tps=4780.945, generation_tps=207.900, peak_memory=21.707
Post: Averages: prompt_tps=37887.799, generation_tps=217.166, peak_memory=21.707

note the prompt length of 2049 is to get an aligned size (2048) for prompt processing, otherwise it's quite a bit slower.

@awni awni requested review from angeloskath and zcbenz November 1, 2025 19:42
@awni awni force-pushed the async_cuda_malloc branch from 5313209 to c27a064 Compare November 1, 2025 20:19
@zcbenz
Copy link
Copy Markdown
Collaborator

zcbenz commented Nov 2, 2025

A future improvement we can do is to use cudaFreeAsync to free buffers.

@awni awni force-pushed the async_cuda_malloc branch 3 times, most recently from 74c6ceb to 6adde62 Compare November 3, 2025 23:05
@awni
Copy link
Copy Markdown
Member Author

awni commented Nov 3, 2025

Ok I think this is good to go. @zcbenz, thanks for the review. I addressed your comments. The main thing was the removal of the cuda pool which simplified stuff nicely.

@awni awni force-pushed the async_cuda_malloc branch from 6adde62 to cc6df9f Compare November 3, 2025 23:07
Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome improvement!

@zcbenz
Copy link
Copy Markdown
Collaborator

zcbenz commented Nov 3, 2025

It would be really nice if we could replace the MLX buffer cache with a CUDA pool(s). In CUDA 13 it's possible to use a pool with managed memory but not in CUDA 12.

What do you think if we use MLX buffer cache for managed memory and CUDA memory pool for others (as a future work)?

note the prompt length of 2049 is to get an aligned size (2048) for prompt processing, otherwise it's quite a bit slower.

Can you elaborate on this? I'm trying to understand why an extra token is needed to get aligned size.

@awni
Copy link
Copy Markdown
Member Author

awni commented Nov 4, 2025

What do you think if we use MLX buffer cache for managed memory and CUDA memory pool for others (as a future work)?

Worth exploring. We might also want to make the cache in a similar API to make it easy to swap a cuda pool even for managed memory when you are on cuda toolkit 13 and up.

@awni
Copy link
Copy Markdown
Member Author

awni commented Nov 4, 2025

Can you elaborate on this? I'm trying to understand why an extra token is needed to get aligned size.

It's an artifact of how we process the prompt to avoid compute the LM head for the prompt for all but the last token. So basically we process the prompt in two steps: n-1 tokens with no LM head and the last 1 token with the LM head.

@awni awni force-pushed the async_cuda_malloc branch from 4f8574e to fc00f16 Compare November 4, 2025 15:57
@Shivansh9000
Copy link
Copy Markdown

No

@awni awni force-pushed the async_cuda_malloc branch from 234000e to 7eaa504 Compare November 4, 2025 17:32
Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks awesome!

I still think we need to add a check in the CudaAllocator::malloc_impl for 0 sized buffers and return early with {nullptr, 0, -1}.

Also the corresponding check in CudaAllocator::free should be if (!buf || !buf->data). Or the latter can be moved into CudaAllocator::cuda_free dealer's choice.

@awni awni force-pushed the async_cuda_malloc branch 8 times, most recently from c77269b to c4c8690 Compare November 5, 2025 22:12
@awni awni force-pushed the async_cuda_malloc branch from c4c8690 to b741d8b Compare November 5, 2025 23:07
@awni awni merged commit df58b41 into main Nov 6, 2025
7 checks passed
@awni awni deleted the async_cuda_malloc branch November 6, 2025 00:05
@yiakwy-xpu-ml-framework-team
Copy link
Copy Markdown

Hi @awni I want to help in supporting this framework in H100/H800 DGX machines with standard IB and nccl support.

How can I start with it ?

@yiakwy-xpu-ml-framework-team
Copy link
Copy Markdown

@awni Is there any body working on VMM , may be this can be the first topic I can contribute :

https://github.com/ruizhang1230/vTensor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants