From 08be80b48a7193b817dda47cd499c9fabdd429b9 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 6 Nov 2025 02:49:53 -0800 Subject: [PATCH 1/8] Separate sdpa kernels in another file --- mlx/backend/cuda/CMakeLists.txt | 1 + .../cuda/scaled_dot_product_attention.cpp | 52 +++++++++++++++++ .../cuda/scaled_dot_product_attention.cu | 56 ++++++------------- .../cuda/scaled_dot_product_attention.h | 29 ++++++++++ 4 files changed, 99 insertions(+), 39 deletions(-) create mode 100644 mlx/backend/cuda/scaled_dot_product_attention.cpp create mode 100644 mlx/backend/cuda/scaled_dot_product_attention.h diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 543e9fd589..892d46c510 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -44,6 +44,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu + ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu ${CMAKE_CURRENT_SOURCE_DIR}/scan.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cpp b/mlx/backend/cuda/scaled_dot_product_attention.cpp new file mode 100644 index 0000000000..8802082b23 --- /dev/null +++ b/mlx/backend/cuda/scaled_dot_product_attention.cpp @@ -0,0 +1,52 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/scaled_dot_product_attention.h" +#include "mlx/fast_primitives.h" +#include "mlx/transforms_impl.h" + +#include + +namespace mlx::core { + +namespace fast { + +bool ScaledDotProductAttention::use_fallback( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + Stream s) { + if (detail::in_grad_tracing()) { + return true; + } + if (s.device == Device::cpu) { + return true; + } + + return !supports_sdpa_vector(q, k, v, has_mask, has_arr_mask, do_causal); +} + +void ScaledDotProductAttention::eval_gpu( + const std::vector& inputs, + array& out) { + nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu"); + + auto& s = stream(); + + assert(inputs.size() == 3 || inputs.size() == 4); + const auto& q = inputs[0]; + const auto& k = inputs[1]; + const auto& v = inputs[2]; + + if (has_sinks_) { + sdpa_vector(q, k, v, scale_, out, do_causal_, inputs.back(), s); + } else { + sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s); + } +} + +} // namespace fast + +} // namespace mlx::core diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu index 151fca0414..fefc6d9ee0 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cu +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -1,15 +1,11 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/scaled_dot_product_attention.h" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" -#include "mlx/fast_primitives.h" -#include "mlx/transforms_impl.h" - -#include #include #include @@ -663,23 +659,13 @@ void sdpa_vector_fallback( } // namespace -namespace fast { - -bool ScaledDotProductAttention::use_fallback( +bool supports_sdpa_vector( const array& q, const array& k, const array& v, bool has_mask, bool has_arr_mask, - bool do_causal, - Stream s) { - if (detail::in_grad_tracing()) { - return true; - } - if (s.device == Device::cpu) { - return true; - } - + bool do_causal) { const int value_head_dim = v.shape(-1); const int query_head_dim = q.shape(-1); const int query_sequence_length = q.shape(2); @@ -691,29 +677,24 @@ bool ScaledDotProductAttention::use_fallback( const bool supported_vector_config = sdpa_supported_head_dim && query_sequence_length < 4; - const bool supported_config = supported_vector_config; - - return has_arr_mask || !supported_config; + return supported_vector_config && !has_arr_mask; } -void ScaledDotProductAttention::eval_gpu( - const std::vector& inputs, - array& out) { - nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu"); - - auto& s = stream(); +void sdpa_vector( + const array& q_pre, + const array& k_pre, + const array& v_pre, + float scale, + array& o, + bool do_causal, + const std::optional& sinks_pre, + Stream s) { auto& encoder = cu::get_command_encoder(s); - - auto& q_pre = inputs[0]; - auto& k_pre = inputs[1]; - auto& v_pre = inputs[2]; - auto& o = out; - std::vector copies; // Define some copy functions to ensure the layout of the inputs is as // expected. - copies.reserve(inputs.size()); + copies.reserve(4); auto copy_unless = [&copies, &s]( auto predicate, const array& arr) -> const array& { if (!predicate(arr)) { @@ -731,8 +712,8 @@ void ScaledDotProductAttention::eval_gpu( }; std::optional sinks = std::nullopt; - if (has_sinks_) { - sinks = copy_unless(is_matrix_contiguous, inputs.back()); + if (sinks_pre) { + sinks = copy_unless(is_matrix_contiguous, sinks_pre.value()); } // We are in vector mode ie single query @@ -798,8 +779,7 @@ void ScaledDotProductAttention::eval_gpu( encoder.add_temporary(cp); } - return sdpa_vector_fallback( - s, encoder, q, k, v, scale_, o, do_causal_, sinks); + sdpa_vector_fallback(s, encoder, q, k, v, scale, o, do_causal, sinks); } // Full attention mode should never reach here @@ -808,6 +788,4 @@ void ScaledDotProductAttention::eval_gpu( } } -} // namespace fast - } // namespace mlx::core diff --git a/mlx/backend/cuda/scaled_dot_product_attention.h b/mlx/backend/cuda/scaled_dot_product_attention.h new file mode 100644 index 0000000000..b47959a077 --- /dev/null +++ b/mlx/backend/cuda/scaled_dot_product_attention.h @@ -0,0 +1,29 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device.h" + +#include + +namespace mlx::core { + +bool supports_sdpa_vector( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal); + +void sdpa_vector( + const array& q, + const array& k, + const array& v, + float scale, + array& o, + bool do_causal, + const std::optional& sinks, + Stream s); + +} // namespace mlx::core From 709501bdd03f3b48a582f2283847328c8ed7d2a8 Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 7 Nov 2025 00:32:42 -0800 Subject: [PATCH 2/8] Initial support for cuDNN SDPA --- mlx/backend/cuda/cudnn_utils.h | 10 +- .../cuda/scaled_dot_product_attention.cpp | 208 +++++++++++++++++- python/tests/test_fast_sdpa.py | 9 +- 3 files changed, 219 insertions(+), 8 deletions(-) diff --git a/mlx/backend/cuda/cudnn_utils.h b/mlx/backend/cuda/cudnn_utils.h index b282496782..537e17c44c 100644 --- a/mlx/backend/cuda/cudnn_utils.h +++ b/mlx/backend/cuda/cudnn_utils.h @@ -44,13 +44,13 @@ inline SmallVector convert_vector(const Vec& vec) { // There are 2 differences from the const_param util from kernel_utils.cuh: // 1. The rest of array is filled with 0. // 2. This util can be used in .cpp files. -template class Vec> -inline std::array vector_key(const Vec& vec) { - if (vec.size() > MAX_NDIM) { +template class Vec> +inline std::array vector_key(const Vec& vec) { + if (vec.size() > NDIM) { throw std::runtime_error( - fmt::format("ndim can not be larger than {}.", MAX_NDIM)); + fmt::format("ndim can not be larger than {}.", NDIM)); } - std::array result = {}; + std::array result = {}; std::copy_n(vec.begin(), vec.size(), result.begin()); return result; } diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cpp b/mlx/backend/cuda/scaled_dot_product_attention.cpp index 8802082b23..5662e6a4d2 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cpp +++ b/mlx/backend/cuda/scaled_dot_product_attention.cpp @@ -1,6 +1,8 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/scaled_dot_product_attention.h" +#include "mlx/backend/cuda/cudnn_utils.h" +#include "mlx/backend/cuda/lru_cache.h" #include "mlx/fast_primitives.h" #include "mlx/transforms_impl.h" @@ -8,6 +10,205 @@ namespace mlx::core { +namespace fe = cudnn_frontend; + +namespace { + +#define CHECK_CUDNN_FE_ERROR(cmd) \ + do { \ + auto error = cmd; \ + if (!error.is_good()) { \ + throw std::runtime_error( \ + fmt::format("{} failed: {}.", #cmd, error.get_message())); \ + } \ + } while (0) + +std::vector normalized_strides(const array& x) { + std::vector strides(x.strides().begin(), x.strides().end()); + if (!x.flags().row_contiguous || x.ndim() < 2) { + return strides; + } + for (int i = x.ndim() - 2; i >= 0; --i) { + if (x.shape(i) == 1) { + strides[i] = x.shape(i + 1) * strides[i + 1]; + } + } + return strides; +} + +void set_tensor_attrs( + std::shared_ptr& tensor, + int64_t uid, + const array& x) { + tensor->set_uid(uid) + .set_dim({x.shape().begin(), x.shape().end()}) + .set_stride(normalized_strides(x)); +} + +constexpr int QKV_NDIM = 4; + +struct SDPACacheKey { + int device_id; + cudnnDataType_t cudnn_dtype; + std::array q_shape; + std::array k_shape; + std::array v_shape; + std::array q_strides; + std::array k_strides; + std::array v_strides; + bool do_causal; +}; + +auto& sdpa_cache() { + static LRUBytesKeyCache cache( + "MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 128); + return cache; +} + +enum UIDS { + Q, + K, + V, + SCALE, + O, +}; + +fe::graph::Graph build_sdpa_graph( + cudnnHandle_t handle, + const array& q, + const array& k, + const array& v, + bool do_causal, + const array& o) { + auto dtype = fe::DataType_t::HALF; + if (q.dtype() == bfloat16) { + dtype = fe::DataType_t::BFLOAT16; + } + + fe::graph::Graph graph; + graph.set_io_data_type(dtype) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto q_ = graph.tensor(fe::graph::Tensor_attributes().set_name("Q")); + auto k_ = graph.tensor(fe::graph::Tensor_attributes().set_name("K")); + auto v_ = graph.tensor(fe::graph::Tensor_attributes().set_name("V")); + set_tensor_attrs(q_, Q, q); + set_tensor_attrs(k_, K, k); + set_tensor_attrs(v_, V, v); + + auto scale = graph.tensor(fe::graph::Tensor_attributes() + .set_name("Scale") + .set_uid(SCALE) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); + + auto sdpa_options = fe::graph::SDPA_attributes() + .set_name("sdpa_cudnn") + .set_attn_scale(scale) + .set_causal_mask(do_causal) + .set_generate_stats(false); + + auto [o_, _] = graph.sdpa(q_, k_, v_, sdpa_options); + o_->set_output(true); + set_tensor_attrs(o_, O, o); + + CHECK_CUDNN_FE_ERROR(graph.validate()); + CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle)); + CHECK_CUDNN_FE_ERROR(graph.create_execution_plans({fe::HeurMode_t::A})); + graph.select_behavior_notes( + {fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API}); + CHECK_CUDNN_FE_ERROR(graph.check_support(handle)); + CHECK_CUDNN_FE_ERROR(graph.build_plans(handle)); + + return graph; +} + +} // namespace + +bool supports_sdpa_cudnn( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal) { + if (has_mask && !do_causal) { // causal mask only + return false; + } + Dtype dtype = q.dtype(); + return dtype == float16 || dtype == bfloat16; +} + +void sdpa_cudnn( + const array& q, + const array& k, + const array& v, + float scale, + array& o, + bool do_causal, + const std::optional& sinks, + Stream s) { + auto& encoder = cu::get_command_encoder(s); + // TODO: Handle donation. + o.set_data(cu::malloc_async(o.nbytes(), encoder.stream())); + + encoder.set_input_array(q); + encoder.set_input_array(k); + encoder.set_input_array(v); + encoder.set_output_array(o); + + auto handle = encoder.device().cudnn_handle(); + cudnnSetStream(handle, encoder.stream()); + + // Search cache. + SDPACacheKey cache_key{ + encoder.device().cuda_device(), + dtype_to_cudnn_type(q.dtype()), + vector_key(q.shape()), + vector_key(k.shape()), + vector_key(v.shape()), + vector_key(q.strides()), + vector_key(k.strides()), + vector_key(v.strides()), + do_causal, + }; + auto it = sdpa_cache().find(cache_key); + if (it == sdpa_cache().end()) { + it = + sdpa_cache() + .emplace(cache_key, build_sdpa_graph(handle, q, k, v, do_causal, o)) + .first; + } + auto& graph = it->second; + + std::unordered_map variant_pack{ + {Q, const_cast(gpu_ptr(q))}, + {K, const_cast(gpu_ptr(k))}, + {V, const_cast(gpu_ptr(v))}, + {SCALE, &scale}, + {O, gpu_ptr(o)}}; + + int64_t workspace_size = 0; + CHECK_CUDNN_FE_ERROR(graph.get_workspace_size(workspace_size)); + void* workspace_ptr = nullptr; + if (workspace_size > 0) { + array workspace( + cu::malloc_async(workspace_size, encoder.stream()), + {static_cast(workspace_size)}, + uint8); + encoder.add_temporary(workspace); + workspace_ptr = gpu_ptr(workspace); + } + + CudaGraph cuda_graph(encoder.device()); + CHECK_CUDNN_FE_ERROR(graph.populate_cuda_graph( + handle, variant_pack, workspace_ptr, cuda_graph)); + encoder.add_graph_node(cuda_graph); +} + namespace fast { bool ScaledDotProductAttention::use_fallback( @@ -25,7 +226,8 @@ bool ScaledDotProductAttention::use_fallback( return true; } - return !supports_sdpa_vector(q, k, v, has_mask, has_arr_mask, do_causal); + return !supports_sdpa_vector(q, k, v, has_mask, has_arr_mask, do_causal) && + !supports_sdpa_cudnn(q, k, v, has_mask, has_arr_mask, do_causal); } void ScaledDotProductAttention::eval_gpu( @@ -40,11 +242,15 @@ void ScaledDotProductAttention::eval_gpu( const auto& k = inputs[1]; const auto& v = inputs[2]; + sdpa_cudnn(q, k, v, scale_, out, do_causal_, std::nullopt, s); + +#if 0 if (has_sinks_) { sdpa_vector(q, k, v, scale_, out, do_causal_, inputs.back(), s); } else { sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s); } +#endif } } // namespace fast diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 19af012c6a..44828fded0 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -551,10 +551,15 @@ def test_sdpa_vector_batched(self): class TestSDPA(mlx_tests.MLXTestCase): @property def dtypes(self): - return ["float32", "float16"] if mx.metal.is_available() else ["float32"] + if mx.metal.is_available(): + return ["float32", "float16"] + elif mx.cuda.is_available(): + return ["float16"] + else: + return ["float32"] def test_sdpa(self): - if not mx.metal.is_available(): + if not mx.metal.is_available() and not mx.cuda.is_available(): return # fmt: off From 88bb16d8b4767ef8fcf8307fa3f418b8dfefa074 Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 7 Nov 2025 02:57:38 -0800 Subject: [PATCH 3/8] Diable a few corner cases --- .../cuda/scaled_dot_product_attention.cpp | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cpp b/mlx/backend/cuda/scaled_dot_product_attention.cpp index 5662e6a4d2..3d91a365c6 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cpp +++ b/mlx/backend/cuda/scaled_dot_product_attention.cpp @@ -133,11 +133,23 @@ bool supports_sdpa_cudnn( const array& k, const array& v, bool has_mask, - bool has_arr_mask, bool do_causal) { - if (has_mask && !do_causal) { // causal mask only + static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 0); + if (!enabled) { return false; } + + if (has_mask) { + // TODO: Support array masks. + if (!do_causal) { + return false; + } + // TODO: Fix causal mask when L_Q != L_K. + if (q.shape(2) != k.shape(2)) { + return false; + } + } + Dtype dtype = q.dtype(); return dtype == float16 || dtype == bfloat16; } @@ -149,7 +161,6 @@ void sdpa_cudnn( float scale, array& o, bool do_causal, - const std::optional& sinks, Stream s) { auto& encoder = cu::get_command_encoder(s); // TODO: Handle donation. @@ -227,7 +238,7 @@ bool ScaledDotProductAttention::use_fallback( } return !supports_sdpa_vector(q, k, v, has_mask, has_arr_mask, do_causal) && - !supports_sdpa_cudnn(q, k, v, has_mask, has_arr_mask, do_causal); + !supports_sdpa_cudnn(q, k, v, has_mask, do_causal); } void ScaledDotProductAttention::eval_gpu( @@ -237,20 +248,20 @@ void ScaledDotProductAttention::eval_gpu( auto& s = stream(); - assert(inputs.size() == 3 || inputs.size() == 4); const auto& q = inputs[0]; const auto& k = inputs[1]; const auto& v = inputs[2]; + bool has_mask = inputs.size() - has_sinks_ > 3; - sdpa_cudnn(q, k, v, scale_, out, do_causal_, std::nullopt, s); - -#if 0 - if (has_sinks_) { - sdpa_vector(q, k, v, scale_, out, do_causal_, inputs.back(), s); + if (supports_sdpa_cudnn(q, k, v, has_mask, do_causal_)) { + sdpa_cudnn(q, k, v, scale_, out, do_causal_, s); } else { - sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s); + if (has_sinks_) { + sdpa_vector(q, k, v, scale_, out, do_causal_, inputs.back(), s); + } else { + sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s); + } } -#endif } } // namespace fast From a13f295ae5f0b2a828feef3a9c72d094233a496c Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 7 Nov 2025 03:04:21 -0800 Subject: [PATCH 4/8] Remove scaled_dot_product_attention.h --- .../cuda/scaled_dot_product_attention.cpp | 20 ++++++++++++- .../cuda/scaled_dot_product_attention.cu | 2 +- .../cuda/scaled_dot_product_attention.h | 29 ------------------- 3 files changed, 20 insertions(+), 31 deletions(-) delete mode 100644 mlx/backend/cuda/scaled_dot_product_attention.h diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cpp b/mlx/backend/cuda/scaled_dot_product_attention.cpp index 3d91a365c6..572deb6764 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cpp +++ b/mlx/backend/cuda/scaled_dot_product_attention.cpp @@ -1,7 +1,7 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/cuda/scaled_dot_product_attention.h" #include "mlx/backend/cuda/cudnn_utils.h" +#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/lru_cache.h" #include "mlx/fast_primitives.h" #include "mlx/transforms_impl.h" @@ -220,6 +220,24 @@ void sdpa_cudnn( encoder.add_graph_node(cuda_graph); } +// Defined in scaled_dot_product_attention.cu file. +bool supports_sdpa_vector( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal); +void sdpa_vector( + const array& q, + const array& k, + const array& v, + float scale, + array& o, + bool do_causal, + const std::optional& sinks, + Stream s); + namespace fast { bool ScaledDotProductAttention::use_fallback( diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu index fefc6d9ee0..438409e85c 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cu +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -1,9 +1,9 @@ // Copyright © 2025 Apple Inc. +#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/scaled_dot_product_attention.h" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" diff --git a/mlx/backend/cuda/scaled_dot_product_attention.h b/mlx/backend/cuda/scaled_dot_product_attention.h deleted file mode 100644 index b47959a077..0000000000 --- a/mlx/backend/cuda/scaled_dot_product_attention.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#pragma once - -#include "mlx/backend/cuda/device.h" - -#include - -namespace mlx::core { - -bool supports_sdpa_vector( - const array& q, - const array& k, - const array& v, - bool has_mask, - bool has_arr_mask, - bool do_causal); - -void sdpa_vector( - const array& q, - const array& k, - const array& v, - float scale, - array& o, - bool do_causal, - const std::optional& sinks, - Stream s); - -} // namespace mlx::core From 96891d31157a92a7fcf26d89b37ca09b43689e66 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sat, 8 Nov 2025 03:29:24 -0800 Subject: [PATCH 5/8] Use cuDNN attention for prefilling --- mlx/backend/cuda/conv.cpp | 3 ++- mlx/backend/cuda/lru_cache.h | 9 ++++++++- mlx/backend/cuda/scaled_dot_product_attention.cpp | 12 +++++++++--- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index 4f7a971ca7..da7cd2dd2a 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -281,7 +281,8 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { Dtype dtype = out.dtype(); // Search cache. - ConvCacheKey cache_key{ + BytesKey cache_key; + cache_key.pod = { encoder.device().cuda_device(), dtype_to_cudnn_type(dtype), vector_key(in.shape()), diff --git a/mlx/backend/cuda/lru_cache.h b/mlx/backend/cuda/lru_cache.h index dc8325fcd4..94a96a9d68 100644 --- a/mlx/backend/cuda/lru_cache.h +++ b/mlx/backend/cuda/lru_cache.h @@ -135,12 +135,19 @@ class LRUCache { }; // Turn a POD struct into a container key by doing bytes compare. +// +// Usage: +// BytesKey key; +// key.pod = { ... }; template struct BytesKey { T pod; static_assert(std::is_standard_layout_v, "T is not POD"); - BytesKey(T pod) : pod(std::move(pod)) {} + BytesKey() { + // Make sure the paddings between members are filled with 0. + memset(&pod, 0, sizeof(T)); + } BytesKey(const BytesKey& other) { memcpy(&pod, &other.pod, sizeof(T)); diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cpp b/mlx/backend/cuda/scaled_dot_product_attention.cpp index 572deb6764..3ff77a214d 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cpp +++ b/mlx/backend/cuda/scaled_dot_product_attention.cpp @@ -134,7 +134,7 @@ bool supports_sdpa_cudnn( const array& v, bool has_mask, bool do_causal) { - static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 0); + static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1); if (!enabled) { return false; } @@ -144,12 +144,17 @@ bool supports_sdpa_cudnn( if (!do_causal) { return false; } - // TODO: Fix causal mask when L_Q != L_K. + // FIXME: Causal mask generates wrong results when L_Q != L_K. if (q.shape(2) != k.shape(2)) { return false; } } + // Only use cuDNN for prefilling. + if (q.shape(2) != k.shape(2)) { + return false; + } + Dtype dtype = q.dtype(); return dtype == float16 || dtype == bfloat16; } @@ -175,7 +180,8 @@ void sdpa_cudnn( cudnnSetStream(handle, encoder.stream()); // Search cache. - SDPACacheKey cache_key{ + BytesKey cache_key; + cache_key.pod = { encoder.device().cuda_device(), dtype_to_cudnn_type(q.dtype()), vector_key(q.shape()), From 1327af535a73f92a6f1a8aa97a5781958e8f4b6d Mon Sep 17 00:00:00 2001 From: Cheng Date: Mon, 10 Nov 2025 15:32:04 -0800 Subject: [PATCH 6/8] cuDNN SDPA requires Ampere and later --- mlx/backend/cuda/scaled_dot_product_attention.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cpp b/mlx/backend/cuda/scaled_dot_product_attention.cpp index 3ff77a214d..5067dc2502 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cpp +++ b/mlx/backend/cuda/scaled_dot_product_attention.cpp @@ -133,12 +133,18 @@ bool supports_sdpa_cudnn( const array& k, const array& v, bool has_mask, - bool do_causal) { + bool do_causal, + Stream s) { static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1); if (!enabled) { return false; } + // cuDNN SDPA requires Ampere and later. + if (cu::device(s.device).compute_capability_major() < 8) { + return false; + } + if (has_mask) { // TODO: Support array masks. if (!do_causal) { @@ -262,7 +268,7 @@ bool ScaledDotProductAttention::use_fallback( } return !supports_sdpa_vector(q, k, v, has_mask, has_arr_mask, do_causal) && - !supports_sdpa_cudnn(q, k, v, has_mask, do_causal); + !supports_sdpa_cudnn(q, k, v, has_mask, do_causal, s); } void ScaledDotProductAttention::eval_gpu( @@ -277,7 +283,7 @@ void ScaledDotProductAttention::eval_gpu( const auto& v = inputs[2]; bool has_mask = inputs.size() - has_sinks_ > 3; - if (supports_sdpa_cudnn(q, k, v, has_mask, do_causal_)) { + if (supports_sdpa_cudnn(q, k, v, has_mask, do_causal_, s)) { sdpa_cudnn(q, k, v, scale_, out, do_causal_, s); } else { if (has_sinks_) { From c7cc5eb569de4d5febb25d4fcb8191fcd0614f15 Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 11 Nov 2025 18:24:35 -0800 Subject: [PATCH 7/8] Address reviews --- .../cuda/scaled_dot_product_attention.cpp | 19 ++++++++++++++++--- python/tests/test_fast_sdpa.py | 18 +++++------------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cpp b/mlx/backend/cuda/scaled_dot_product_attention.cpp index 5067dc2502..5793ae2f5b 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cpp +++ b/mlx/backend/cuda/scaled_dot_product_attention.cpp @@ -161,6 +161,17 @@ bool supports_sdpa_cudnn( return false; } + // D_qk and D_v must be a multiple of 8 with maximum value 128. + if ((q.shape(-1) % 8 != 0) || (q.shape(-1) > 128) || (v.shape(-1) % 8 != 0) || + (v.shape(-1) > 128)) { + return false; + } + + // TODO: Do contiguous copy for inputs. + if (q.strides(-1) != 1 || k.strides(-1) != 1 || v.strides(-1) != 1) { + return false; + } + Dtype dtype = q.dtype(); return dtype == float16 || dtype == bfloat16; } @@ -175,6 +186,7 @@ void sdpa_cudnn( Stream s) { auto& encoder = cu::get_command_encoder(s); // TODO: Handle donation. + // TODO: Make O use same memory layout with Q. o.set_data(cu::malloc_async(o.nbytes(), encoder.stream())); encoder.set_input_array(q); @@ -282,15 +294,16 @@ void ScaledDotProductAttention::eval_gpu( const auto& k = inputs[1]; const auto& v = inputs[2]; bool has_mask = inputs.size() - has_sinks_ > 3; + bool has_arr_mask = has_mask && !do_causal_; - if (supports_sdpa_cudnn(q, k, v, has_mask, do_causal_, s)) { - sdpa_cudnn(q, k, v, scale_, out, do_causal_, s); - } else { + if (supports_sdpa_vector(q, k, v, has_mask, has_arr_mask, do_causal_)) { if (has_sinks_) { sdpa_vector(q, k, v, scale_, out, do_causal_, inputs.back(), s); } else { sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s); } + } else { + sdpa_cudnn(q, k, v, scale_, out, do_causal_, s); } } diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 44828fded0..e5c3412414 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -168,7 +168,7 @@ def test_fast_sdpa(self): Dk = 64 - if self.is_apple_silicon or mx.cuda.is_available(): + if mx.is_available(mx.gpu): dtypes.append(np.half) for SEQUENCE_LENGTH in [63, 129, 400]: @@ -240,7 +240,7 @@ def test_fast_sdpa(self): B = 1 H = 32 dtypes = [np.float32] - if self.is_apple_silicon or mx.cuda.is_available(): + if mx.is_available(mx.gpu): dtypes.append(np.half) for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]: @@ -549,17 +549,8 @@ def test_sdpa_vector_batched(self): class TestSDPA(mlx_tests.MLXTestCase): - @property - def dtypes(self): - if mx.metal.is_available(): - return ["float32", "float16"] - elif mx.cuda.is_available(): - return ["float16"] - else: - return ["float32"] - def test_sdpa(self): - if not mx.metal.is_available() and not mx.cuda.is_available(): + if not mx.is_available(mx.gpu): return # fmt: off @@ -583,10 +574,11 @@ def test_sdpa(self): # fmt: on shapes = shapes_64 + shapes_128 + dtypes = ["float32", "float16"] masks = [None, "additive", "bool", "causal"] transposes = (False, True) - for dtype in self.dtypes: + for dtype in dtypes: for t in transposes: for mask_str in masks: for B, qL, kL, D, qH, kH in shapes: From 78d10a29705a378b2d3d9a312c982d8bb2eff3dc Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 12 Nov 2025 14:52:04 -0800 Subject: [PATCH 8/8] Do contiguous copy of inputs --- .../cuda/scaled_dot_product_attention.cpp | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cpp b/mlx/backend/cuda/scaled_dot_product_attention.cpp index 5793ae2f5b..8cc0e988a9 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cpp +++ b/mlx/backend/cuda/scaled_dot_product_attention.cpp @@ -3,6 +3,7 @@ #include "mlx/backend/cuda/cudnn_utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/lru_cache.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/fast_primitives.h" #include "mlx/transforms_impl.h" @@ -45,6 +46,19 @@ void set_tensor_attrs( .set_stride(normalized_strides(x)); } +array prepare_sdpa_input(const array& x, Stream s) { + // SDPA kernel's requirements on inputs: + // 1. last dim's stride be 1; + // 2. pointer be aligned. + if (x.strides(-1) != 1 || get_alignment(x) < 16) { + array x_copy = contiguous_copy_gpu(x, s); + auto& encoder = cu::get_command_encoder(s); + encoder.add_temporary(x_copy); + return x_copy; + } + return x; +} + constexpr int QKV_NDIM = 4; struct SDPACacheKey { @@ -167,11 +181,6 @@ bool supports_sdpa_cudnn( return false; } - // TODO: Do contiguous copy for inputs. - if (q.strides(-1) != 1 || k.strides(-1) != 1 || v.strides(-1) != 1) { - return false; - } - Dtype dtype = q.dtype(); return dtype == float16 || dtype == bfloat16; } @@ -290,9 +299,9 @@ void ScaledDotProductAttention::eval_gpu( auto& s = stream(); - const auto& q = inputs[0]; - const auto& k = inputs[1]; - const auto& v = inputs[2]; + array q = prepare_sdpa_input(inputs[0], s); + array k = prepare_sdpa_input(inputs[1], s); + array v = prepare_sdpa_input(inputs[2], s); bool has_mask = inputs.size() - has_sinks_ > 3; bool has_arr_mask = has_mask && !do_causal_;