Skip to content

[CUDA] cuDNN forward attention#2743

Merged
zcbenz merged 8 commits intoml-explore:mainfrom
zcbenz:cudnn-spda
Nov 14, 2025
Merged

[CUDA] cuDNN forward attention#2743
zcbenz merged 8 commits intoml-explore:mainfrom
zcbenz:cudnn-spda

Conversation

@zcbenz
Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz commented Nov 7, 2025

This PR adds cuDNN implementation for the fast.scaled_dot_product_attention op, which is currently disabled behind the MLX_CUDA_USE_CUDNN_SPDA env.

It is not very useful for inference at the moment because:

  1. the overhead of cuDNN graph is very high;
  2. the cuDNN graph is currently recreated whenever sequence length increases.

I'll create follow-up PRs to:

  1. implement backward attention and use cuDNN attention for training;
  2. implement missing features (array masks, sinks, ...);
  3. investigate ways that avoid recreating graph when sequence length changes ([cuDNN][SDPA] Introduce TORCH_CUDNN_SDPA_AVOID_RECOMPILE=1 pytorch/pytorch#155958);
  4. use cuDNN for batch inference;
  5. remove duplicate code between cuDNN conv and attention code, I probably have to update the conv code to use new APIs.

@zcbenz zcbenz requested review from awni and jagrit06 November 7, 2025 22:34
@awni
Copy link
Copy Markdown
Member

awni commented Nov 8, 2025

Did you time prefill at all? I'd imagine the graph overhead there would not matter as much?

For decode shouldn't we still route to the custom sdpa_vector implementation? I imagine that will be faster?

@zcbenz
Copy link
Copy Markdown
Collaborator Author

zcbenz commented Nov 8, 2025

Ah I totally forgot about prefill. I have updated this PR to use sdpa_cudnn for prefill and sdpa_vector for decoding, it is enabled by default but can be turned off with MLX_CUDA_USE_CUDNN_SPDA=0.

For long prompt it shows 30% improvement on A100:

# mlx_lm.benchmark --model meta-llama/Meta-Llama-3.1-8B --p 2049 -g 128 -b 1 -n 4
Running warmup..
Timing with prompt_tokens=2049, generation_tokens=128, batch_size=1.
Trial 1:  prompt_tps=13290.771, generation_tps=92.388, peak_memory=20.444
Trial 2:  prompt_tps=13329.313, generation_tps=92.252, peak_memory=20.444
Trial 3:  prompt_tps=13370.043, generation_tps=92.463, peak_memory=20.444
Trial 4:  prompt_tps=13303.649, generation_tps=92.389, peak_memory=20.444
Averages: prompt_tps=13323.444, generation_tps=92.373, peak_memory=20.444
# MLX_CUDA_USE_CUDNN_SPDA=0 mlx_lm.benchmark --model meta-llama/Meta-Llama-3.1-8B --p 2049 -g 128 -b 1 -n 4
Running warmup..
Timing with prompt_tokens=2049, generation_tokens=128, batch_size=1.
Trial 1:  prompt_tps=10531.077, generation_tps=92.438, peak_memory=21.182
Trial 2:  prompt_tps=10507.410, generation_tps=92.017, peak_memory=21.706
Trial 3:  prompt_tps=10509.630, generation_tps=92.432, peak_memory=21.707
Trial 4:  prompt_tps=10517.068, generation_tps=92.459, peak_memory=21.707
Averages: prompt_tps=10516.296, generation_tps=92.336, peak_memory=21.576

Even for very short prompt there is still performance gain:

# mlx_lm.benchmark --model meta-llama/Meta-Llama-3.1-8B --p 8 -g 128 -b 1 -n 4
Running warmup..
Timing with prompt_tokens=8, generation_tokens=128, batch_size=1.
Trial 1:  prompt_tps=236.468, generation_tps=99.045, peak_memory=16.122
Trial 2:  prompt_tps=237.316, generation_tps=99.044, peak_memory=16.122
Trial 3:  prompt_tps=236.706, generation_tps=99.064, peak_memory=16.122
Trial 4:  prompt_tps=237.126, generation_tps=99.064, peak_memory=16.123
Averages: prompt_tps=236.904, generation_tps=99.054, peak_memory=16.122
# MLX_CUDA_USE_CUDNN_SPDA=0 mlx_lm.benchmark --model meta-llama/Meta-Llama-3.1-8B --p 8 -g 128 -b 1 -n 4
Running warmup..
Timing with prompt_tokens=8, generation_tokens=128, batch_size=1.
Trial 1:  prompt_tps=226.593, generation_tps=99.118, peak_memory=16.113
Trial 2:  prompt_tps=227.747, generation_tps=99.123, peak_memory=16.114
Trial 3:  prompt_tps=227.008, generation_tps=99.126, peak_memory=16.114
Trial 4:  prompt_tps=227.179, generation_tps=98.821, peak_memory=16.114
Averages: prompt_tps=227.132, generation_tps=99.047, peak_memory=16.114

However this benchmark warms up things before timing, so the cost of creating cuDNN graph is hidden. If I remove the warmup the result becomes subtle:

# mlx_lm.benchmark --model meta-llama/Meta-Llama-3.1-8B --p 2049 -g 128 -b 1 -n 4
Timing with prompt_tokens=2049, generation_tokens=128, batch_size=1.
Trial 1:  prompt_tps=2890.866, generation_tps=92.512, peak_memory=17.587
Trial 2:  prompt_tps=13196.311, generation_tps=92.495, peak_memory=20.444
Trial 3:  prompt_tps=13182.864, generation_tps=92.404, peak_memory=20.444
Trial 4:  prompt_tps=13169.970, generation_tps=92.475, peak_memory=20.444
Averages: prompt_tps=10610.003, generation_tps=92.472, peak_memory=19.730
# MLX_CUDA_USE_CUDNN_SPDA=0 mlx_lm.benchmark --model meta-llama/Meta-Llama-3.1-8B --p 2049 -g 128 -b 1 -n 4
Timing with prompt_tokens=2049, generation_tokens=128, batch_size=1.
Trial 1:  prompt_tps=3321.275, generation_tps=92.462, peak_memory=18.166
Trial 2:  prompt_tps=10478.404, generation_tps=92.432, peak_memory=21.706
Trial 3:  prompt_tps=10484.616, generation_tps=92.473, peak_memory=21.707
Trial 4:  prompt_tps=10481.885, generation_tps=92.422, peak_memory=21.707
Averages: prompt_tps=8691.545, generation_tps=92.447, peak_memory=20.822

Using cuDNN makes time to first output longer, in exchange for faster upcoming prefills. But for common use cases I think they only need to prefill once?

@awni
Copy link
Copy Markdown
Member

awni commented Nov 11, 2025

For irregular prefill sizes this gives a massive improvement on B200:

mlx_lm.benchmark --model meta-llama/Meta-Llama-3.1-8B --p 2048 -g 128 -b 1 -n 4

Pre: Averages: prompt_tps=27617.508, generation_tps=230.929, peak_memory=20.798
Post: Averages: prompt_tps=58095.697, generation_tps=231.540, peak_memory=19.748

Comment on lines +159 to +162
// Only use cuDNN for prefilling.
if (q.shape(2) != k.shape(2)) {
return false;
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not a great condition because when we prefill in steps we can have cases where q.shape != k.shape. Currently in mlx-lm we prefill in steps of 2048. So for a prompt of 4096, on the second step the k length would be 4096 and the q length would be 2048.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to make cuDNN work with dynamic sequence length to be able to remove this condition. But I'm still not sure if it is feasible with cuDNN, the API pytorch/pytorch#155958 uses was designed to be be used with ragged tensors so it might not deliver best performance.

None of the popular inference engines seems to be using cuDNN for attention so I think we might have to add another backend for inference, I'll investigate after making training work.

@zcbenz
Copy link
Copy Markdown
Collaborator Author

zcbenz commented Nov 12, 2025

Thanks for the reviews, I have updated the code.

Comment on lines +170 to +173
// TODO: Do contiguous copy for inputs.
if (q.strides(-1) != 1 || k.strides(-1) != 1 || v.strides(-1) != 1) {
return false;
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not check the strides here as they are set after the graph is evaluated (not when it is built). This condition will basically always return true, even if the inputs might not be contiguous in the last dimension during the eval.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a corollary I would implement the contiguous copies in the eval_gpu if they are needed. Otherwise it will be a bug when you get unsupported strides.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I forgot strides are set in evaluation, I have updated the code.

This is actually not the first time I was bitten by this, before evaluation the strides is not initialized and it could be anything, do you think if it makes sense adding a debug time assertion in strides()?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think if it makes sense adding a debug time assertion in strides()

Yea that seems reasonable to me.

Copy link
Copy Markdown
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!!

@zcbenz zcbenz merged commit 3b2ffce into ml-explore:main Nov 14, 2025
8 checks passed
@zcbenz zcbenz deleted the cudnn-spda branch November 14, 2025 00:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants