[CUDA] cuDNN forward attention#2743
Conversation
|
Did you time prefill at all? I'd imagine the graph overhead there would not matter as much? For decode shouldn't we still route to the custom |
|
Ah I totally forgot about prefill. I have updated this PR to use For long prompt it shows 30% improvement on A100: # mlx_lm.benchmark --model meta-llama/Meta-Llama-3.1-8B --p 2049 -g 128 -b 1 -n 4
Running warmup..
Timing with prompt_tokens=2049, generation_tokens=128, batch_size=1.
Trial 1: prompt_tps=13290.771, generation_tps=92.388, peak_memory=20.444
Trial 2: prompt_tps=13329.313, generation_tps=92.252, peak_memory=20.444
Trial 3: prompt_tps=13370.043, generation_tps=92.463, peak_memory=20.444
Trial 4: prompt_tps=13303.649, generation_tps=92.389, peak_memory=20.444
Averages: prompt_tps=13323.444, generation_tps=92.373, peak_memory=20.444
# MLX_CUDA_USE_CUDNN_SPDA=0 mlx_lm.benchmark --model meta-llama/Meta-Llama-3.1-8B --p 2049 -g 128 -b 1 -n 4
Running warmup..
Timing with prompt_tokens=2049, generation_tokens=128, batch_size=1.
Trial 1: prompt_tps=10531.077, generation_tps=92.438, peak_memory=21.182
Trial 2: prompt_tps=10507.410, generation_tps=92.017, peak_memory=21.706
Trial 3: prompt_tps=10509.630, generation_tps=92.432, peak_memory=21.707
Trial 4: prompt_tps=10517.068, generation_tps=92.459, peak_memory=21.707
Averages: prompt_tps=10516.296, generation_tps=92.336, peak_memory=21.576Even for very short prompt there is still performance gain: # mlx_lm.benchmark --model meta-llama/Meta-Llama-3.1-8B --p 8 -g 128 -b 1 -n 4
Running warmup..
Timing with prompt_tokens=8, generation_tokens=128, batch_size=1.
Trial 1: prompt_tps=236.468, generation_tps=99.045, peak_memory=16.122
Trial 2: prompt_tps=237.316, generation_tps=99.044, peak_memory=16.122
Trial 3: prompt_tps=236.706, generation_tps=99.064, peak_memory=16.122
Trial 4: prompt_tps=237.126, generation_tps=99.064, peak_memory=16.123
Averages: prompt_tps=236.904, generation_tps=99.054, peak_memory=16.122
# MLX_CUDA_USE_CUDNN_SPDA=0 mlx_lm.benchmark --model meta-llama/Meta-Llama-3.1-8B --p 8 -g 128 -b 1 -n 4
Running warmup..
Timing with prompt_tokens=8, generation_tokens=128, batch_size=1.
Trial 1: prompt_tps=226.593, generation_tps=99.118, peak_memory=16.113
Trial 2: prompt_tps=227.747, generation_tps=99.123, peak_memory=16.114
Trial 3: prompt_tps=227.008, generation_tps=99.126, peak_memory=16.114
Trial 4: prompt_tps=227.179, generation_tps=98.821, peak_memory=16.114
Averages: prompt_tps=227.132, generation_tps=99.047, peak_memory=16.114However this benchmark warms up things before timing, so the cost of creating cuDNN graph is hidden. If I remove the warmup the result becomes subtle: # mlx_lm.benchmark --model meta-llama/Meta-Llama-3.1-8B --p 2049 -g 128 -b 1 -n 4
Timing with prompt_tokens=2049, generation_tokens=128, batch_size=1.
Trial 1: prompt_tps=2890.866, generation_tps=92.512, peak_memory=17.587
Trial 2: prompt_tps=13196.311, generation_tps=92.495, peak_memory=20.444
Trial 3: prompt_tps=13182.864, generation_tps=92.404, peak_memory=20.444
Trial 4: prompt_tps=13169.970, generation_tps=92.475, peak_memory=20.444
Averages: prompt_tps=10610.003, generation_tps=92.472, peak_memory=19.730
# MLX_CUDA_USE_CUDNN_SPDA=0 mlx_lm.benchmark --model meta-llama/Meta-Llama-3.1-8B --p 2049 -g 128 -b 1 -n 4
Timing with prompt_tokens=2049, generation_tokens=128, batch_size=1.
Trial 1: prompt_tps=3321.275, generation_tps=92.462, peak_memory=18.166
Trial 2: prompt_tps=10478.404, generation_tps=92.432, peak_memory=21.706
Trial 3: prompt_tps=10484.616, generation_tps=92.473, peak_memory=21.707
Trial 4: prompt_tps=10481.885, generation_tps=92.422, peak_memory=21.707
Averages: prompt_tps=8691.545, generation_tps=92.447, peak_memory=20.822Using cuDNN makes time to first output longer, in exchange for faster upcoming prefills. But for common use cases I think they only need to prefill once? |
|
For irregular prefill sizes this gives a massive improvement on B200: mlx_lm.benchmark --model meta-llama/Meta-Llama-3.1-8B --p 2048 -g 128 -b 1 -n 4Pre: Averages: prompt_tps=27617.508, generation_tps=230.929, peak_memory=20.798 |
| // Only use cuDNN for prefilling. | ||
| if (q.shape(2) != k.shape(2)) { | ||
| return false; | ||
| } |
There was a problem hiding this comment.
That's not a great condition because when we prefill in steps we can have cases where q.shape != k.shape. Currently in mlx-lm we prefill in steps of 2048. So for a prompt of 4096, on the second step the k length would be 4096 and the q length would be 2048.
There was a problem hiding this comment.
I need to make cuDNN work with dynamic sequence length to be able to remove this condition. But I'm still not sure if it is feasible with cuDNN, the API pytorch/pytorch#155958 uses was designed to be be used with ragged tensors so it might not deliver best performance.
None of the popular inference engines seems to be using cuDNN for attention so I think we might have to add another backend for inference, I'll investigate after making training work.
|
Thanks for the reviews, I have updated the code. |
| // TODO: Do contiguous copy for inputs. | ||
| if (q.strides(-1) != 1 || k.strides(-1) != 1 || v.strides(-1) != 1) { | ||
| return false; | ||
| } |
There was a problem hiding this comment.
I would not check the strides here as they are set after the graph is evaluated (not when it is built). This condition will basically always return true, even if the inputs might not be contiguous in the last dimension during the eval.
There was a problem hiding this comment.
As a corollary I would implement the contiguous copies in the eval_gpu if they are needed. Otherwise it will be a bug when you get unsupported strides.
There was a problem hiding this comment.
Sorry I forgot strides are set in evaluation, I have updated the code.
This is actually not the first time I was bitten by this, before evaluation the strides is not initialized and it could be anything, do you think if it makes sense adding a debug time assertion in strides()?
There was a problem hiding this comment.
do you think if it makes sense adding a debug time assertion in strides()
Yea that seems reasonable to me.
This PR adds cuDNN implementation for the
fast.scaled_dot_product_attentionop, which is currently disabled behind theMLX_CUDA_USE_CUDNN_SPDAenv.It is not very useful for inference at the moment because:
I'll create follow-up PRs to:
TORCH_CUDNN_SDPA_AVOID_RECOMPILE=1pytorch/pytorch#155958);