diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 543e9fd589..892d46c510 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -44,6 +44,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu + ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu ${CMAKE_CURRENT_SOURCE_DIR}/scan.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index 4f7a971ca7..da7cd2dd2a 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -281,7 +281,8 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { Dtype dtype = out.dtype(); // Search cache. - ConvCacheKey cache_key{ + BytesKey cache_key; + cache_key.pod = { encoder.device().cuda_device(), dtype_to_cudnn_type(dtype), vector_key(in.shape()), diff --git a/mlx/backend/cuda/cudnn_utils.h b/mlx/backend/cuda/cudnn_utils.h index b282496782..537e17c44c 100644 --- a/mlx/backend/cuda/cudnn_utils.h +++ b/mlx/backend/cuda/cudnn_utils.h @@ -44,13 +44,13 @@ inline SmallVector convert_vector(const Vec& vec) { // There are 2 differences from the const_param util from kernel_utils.cuh: // 1. The rest of array is filled with 0. // 2. This util can be used in .cpp files. -template class Vec> -inline std::array vector_key(const Vec& vec) { - if (vec.size() > MAX_NDIM) { +template class Vec> +inline std::array vector_key(const Vec& vec) { + if (vec.size() > NDIM) { throw std::runtime_error( - fmt::format("ndim can not be larger than {}.", MAX_NDIM)); + fmt::format("ndim can not be larger than {}.", NDIM)); } - std::array result = {}; + std::array result = {}; std::copy_n(vec.begin(), vec.size(), result.begin()); return result; } diff --git a/mlx/backend/cuda/lru_cache.h b/mlx/backend/cuda/lru_cache.h index dc8325fcd4..94a96a9d68 100644 --- a/mlx/backend/cuda/lru_cache.h +++ b/mlx/backend/cuda/lru_cache.h @@ -135,12 +135,19 @@ class LRUCache { }; // Turn a POD struct into a container key by doing bytes compare. +// +// Usage: +// BytesKey key; +// key.pod = { ... }; template struct BytesKey { T pod; static_assert(std::is_standard_layout_v, "T is not POD"); - BytesKey(T pod) : pod(std::move(pod)) {} + BytesKey() { + // Make sure the paddings between members are filled with 0. + memset(&pod, 0, sizeof(T)); + } BytesKey(const BytesKey& other) { memcpy(&pod, &other.pod, sizeof(T)); diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cpp b/mlx/backend/cuda/scaled_dot_product_attention.cpp new file mode 100644 index 0000000000..8cc0e988a9 --- /dev/null +++ b/mlx/backend/cuda/scaled_dot_product_attention.cpp @@ -0,0 +1,321 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/cudnn_utils.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/lru_cache.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/fast_primitives.h" +#include "mlx/transforms_impl.h" + +#include + +namespace mlx::core { + +namespace fe = cudnn_frontend; + +namespace { + +#define CHECK_CUDNN_FE_ERROR(cmd) \ + do { \ + auto error = cmd; \ + if (!error.is_good()) { \ + throw std::runtime_error( \ + fmt::format("{} failed: {}.", #cmd, error.get_message())); \ + } \ + } while (0) + +std::vector normalized_strides(const array& x) { + std::vector strides(x.strides().begin(), x.strides().end()); + if (!x.flags().row_contiguous || x.ndim() < 2) { + return strides; + } + for (int i = x.ndim() - 2; i >= 0; --i) { + if (x.shape(i) == 1) { + strides[i] = x.shape(i + 1) * strides[i + 1]; + } + } + return strides; +} + +void set_tensor_attrs( + std::shared_ptr& tensor, + int64_t uid, + const array& x) { + tensor->set_uid(uid) + .set_dim({x.shape().begin(), x.shape().end()}) + .set_stride(normalized_strides(x)); +} + +array prepare_sdpa_input(const array& x, Stream s) { + // SDPA kernel's requirements on inputs: + // 1. last dim's stride be 1; + // 2. pointer be aligned. + if (x.strides(-1) != 1 || get_alignment(x) < 16) { + array x_copy = contiguous_copy_gpu(x, s); + auto& encoder = cu::get_command_encoder(s); + encoder.add_temporary(x_copy); + return x_copy; + } + return x; +} + +constexpr int QKV_NDIM = 4; + +struct SDPACacheKey { + int device_id; + cudnnDataType_t cudnn_dtype; + std::array q_shape; + std::array k_shape; + std::array v_shape; + std::array q_strides; + std::array k_strides; + std::array v_strides; + bool do_causal; +}; + +auto& sdpa_cache() { + static LRUBytesKeyCache cache( + "MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 128); + return cache; +} + +enum UIDS { + Q, + K, + V, + SCALE, + O, +}; + +fe::graph::Graph build_sdpa_graph( + cudnnHandle_t handle, + const array& q, + const array& k, + const array& v, + bool do_causal, + const array& o) { + auto dtype = fe::DataType_t::HALF; + if (q.dtype() == bfloat16) { + dtype = fe::DataType_t::BFLOAT16; + } + + fe::graph::Graph graph; + graph.set_io_data_type(dtype) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto q_ = graph.tensor(fe::graph::Tensor_attributes().set_name("Q")); + auto k_ = graph.tensor(fe::graph::Tensor_attributes().set_name("K")); + auto v_ = graph.tensor(fe::graph::Tensor_attributes().set_name("V")); + set_tensor_attrs(q_, Q, q); + set_tensor_attrs(k_, K, k); + set_tensor_attrs(v_, V, v); + + auto scale = graph.tensor(fe::graph::Tensor_attributes() + .set_name("Scale") + .set_uid(SCALE) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); + + auto sdpa_options = fe::graph::SDPA_attributes() + .set_name("sdpa_cudnn") + .set_attn_scale(scale) + .set_causal_mask(do_causal) + .set_generate_stats(false); + + auto [o_, _] = graph.sdpa(q_, k_, v_, sdpa_options); + o_->set_output(true); + set_tensor_attrs(o_, O, o); + + CHECK_CUDNN_FE_ERROR(graph.validate()); + CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle)); + CHECK_CUDNN_FE_ERROR(graph.create_execution_plans({fe::HeurMode_t::A})); + graph.select_behavior_notes( + {fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API}); + CHECK_CUDNN_FE_ERROR(graph.check_support(handle)); + CHECK_CUDNN_FE_ERROR(graph.build_plans(handle)); + + return graph; +} + +} // namespace + +bool supports_sdpa_cudnn( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool do_causal, + Stream s) { + static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1); + if (!enabled) { + return false; + } + + // cuDNN SDPA requires Ampere and later. + if (cu::device(s.device).compute_capability_major() < 8) { + return false; + } + + if (has_mask) { + // TODO: Support array masks. + if (!do_causal) { + return false; + } + // FIXME: Causal mask generates wrong results when L_Q != L_K. + if (q.shape(2) != k.shape(2)) { + return false; + } + } + + // Only use cuDNN for prefilling. + if (q.shape(2) != k.shape(2)) { + return false; + } + + // D_qk and D_v must be a multiple of 8 with maximum value 128. + if ((q.shape(-1) % 8 != 0) || (q.shape(-1) > 128) || (v.shape(-1) % 8 != 0) || + (v.shape(-1) > 128)) { + return false; + } + + Dtype dtype = q.dtype(); + return dtype == float16 || dtype == bfloat16; +} + +void sdpa_cudnn( + const array& q, + const array& k, + const array& v, + float scale, + array& o, + bool do_causal, + Stream s) { + auto& encoder = cu::get_command_encoder(s); + // TODO: Handle donation. + // TODO: Make O use same memory layout with Q. + o.set_data(cu::malloc_async(o.nbytes(), encoder.stream())); + + encoder.set_input_array(q); + encoder.set_input_array(k); + encoder.set_input_array(v); + encoder.set_output_array(o); + + auto handle = encoder.device().cudnn_handle(); + cudnnSetStream(handle, encoder.stream()); + + // Search cache. + BytesKey cache_key; + cache_key.pod = { + encoder.device().cuda_device(), + dtype_to_cudnn_type(q.dtype()), + vector_key(q.shape()), + vector_key(k.shape()), + vector_key(v.shape()), + vector_key(q.strides()), + vector_key(k.strides()), + vector_key(v.strides()), + do_causal, + }; + auto it = sdpa_cache().find(cache_key); + if (it == sdpa_cache().end()) { + it = + sdpa_cache() + .emplace(cache_key, build_sdpa_graph(handle, q, k, v, do_causal, o)) + .first; + } + auto& graph = it->second; + + std::unordered_map variant_pack{ + {Q, const_cast(gpu_ptr(q))}, + {K, const_cast(gpu_ptr(k))}, + {V, const_cast(gpu_ptr(v))}, + {SCALE, &scale}, + {O, gpu_ptr(o)}}; + + int64_t workspace_size = 0; + CHECK_CUDNN_FE_ERROR(graph.get_workspace_size(workspace_size)); + void* workspace_ptr = nullptr; + if (workspace_size > 0) { + array workspace( + cu::malloc_async(workspace_size, encoder.stream()), + {static_cast(workspace_size)}, + uint8); + encoder.add_temporary(workspace); + workspace_ptr = gpu_ptr(workspace); + } + + CudaGraph cuda_graph(encoder.device()); + CHECK_CUDNN_FE_ERROR(graph.populate_cuda_graph( + handle, variant_pack, workspace_ptr, cuda_graph)); + encoder.add_graph_node(cuda_graph); +} + +// Defined in scaled_dot_product_attention.cu file. +bool supports_sdpa_vector( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal); +void sdpa_vector( + const array& q, + const array& k, + const array& v, + float scale, + array& o, + bool do_causal, + const std::optional& sinks, + Stream s); + +namespace fast { + +bool ScaledDotProductAttention::use_fallback( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + Stream s) { + if (detail::in_grad_tracing()) { + return true; + } + if (s.device == Device::cpu) { + return true; + } + + return !supports_sdpa_vector(q, k, v, has_mask, has_arr_mask, do_causal) && + !supports_sdpa_cudnn(q, k, v, has_mask, do_causal, s); +} + +void ScaledDotProductAttention::eval_gpu( + const std::vector& inputs, + array& out) { + nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu"); + + auto& s = stream(); + + array q = prepare_sdpa_input(inputs[0], s); + array k = prepare_sdpa_input(inputs[1], s); + array v = prepare_sdpa_input(inputs[2], s); + bool has_mask = inputs.size() - has_sinks_ > 3; + bool has_arr_mask = has_mask && !do_causal_; + + if (supports_sdpa_vector(q, k, v, has_mask, has_arr_mask, do_causal_)) { + if (has_sinks_) { + sdpa_vector(q, k, v, scale_, out, do_causal_, inputs.back(), s); + } else { + sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s); + } + } else { + sdpa_cudnn(q, k, v, scale_, out, do_causal_, s); + } +} + +} // namespace fast + +} // namespace mlx::core diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu index 151fca0414..438409e85c 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cu +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -6,10 +6,6 @@ #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" -#include "mlx/fast_primitives.h" -#include "mlx/transforms_impl.h" - -#include #include #include @@ -663,23 +659,13 @@ void sdpa_vector_fallback( } // namespace -namespace fast { - -bool ScaledDotProductAttention::use_fallback( +bool supports_sdpa_vector( const array& q, const array& k, const array& v, bool has_mask, bool has_arr_mask, - bool do_causal, - Stream s) { - if (detail::in_grad_tracing()) { - return true; - } - if (s.device == Device::cpu) { - return true; - } - + bool do_causal) { const int value_head_dim = v.shape(-1); const int query_head_dim = q.shape(-1); const int query_sequence_length = q.shape(2); @@ -691,29 +677,24 @@ bool ScaledDotProductAttention::use_fallback( const bool supported_vector_config = sdpa_supported_head_dim && query_sequence_length < 4; - const bool supported_config = supported_vector_config; - - return has_arr_mask || !supported_config; + return supported_vector_config && !has_arr_mask; } -void ScaledDotProductAttention::eval_gpu( - const std::vector& inputs, - array& out) { - nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu"); - - auto& s = stream(); +void sdpa_vector( + const array& q_pre, + const array& k_pre, + const array& v_pre, + float scale, + array& o, + bool do_causal, + const std::optional& sinks_pre, + Stream s) { auto& encoder = cu::get_command_encoder(s); - - auto& q_pre = inputs[0]; - auto& k_pre = inputs[1]; - auto& v_pre = inputs[2]; - auto& o = out; - std::vector copies; // Define some copy functions to ensure the layout of the inputs is as // expected. - copies.reserve(inputs.size()); + copies.reserve(4); auto copy_unless = [&copies, &s]( auto predicate, const array& arr) -> const array& { if (!predicate(arr)) { @@ -731,8 +712,8 @@ void ScaledDotProductAttention::eval_gpu( }; std::optional sinks = std::nullopt; - if (has_sinks_) { - sinks = copy_unless(is_matrix_contiguous, inputs.back()); + if (sinks_pre) { + sinks = copy_unless(is_matrix_contiguous, sinks_pre.value()); } // We are in vector mode ie single query @@ -798,8 +779,7 @@ void ScaledDotProductAttention::eval_gpu( encoder.add_temporary(cp); } - return sdpa_vector_fallback( - s, encoder, q, k, v, scale_, o, do_causal_, sinks); + sdpa_vector_fallback(s, encoder, q, k, v, scale, o, do_causal, sinks); } // Full attention mode should never reach here @@ -808,6 +788,4 @@ void ScaledDotProductAttention::eval_gpu( } } -} // namespace fast - } // namespace mlx::core diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 19af012c6a..e5c3412414 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -168,7 +168,7 @@ def test_fast_sdpa(self): Dk = 64 - if self.is_apple_silicon or mx.cuda.is_available(): + if mx.is_available(mx.gpu): dtypes.append(np.half) for SEQUENCE_LENGTH in [63, 129, 400]: @@ -240,7 +240,7 @@ def test_fast_sdpa(self): B = 1 H = 32 dtypes = [np.float32] - if self.is_apple_silicon or mx.cuda.is_available(): + if mx.is_available(mx.gpu): dtypes.append(np.half) for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]: @@ -549,12 +549,8 @@ def test_sdpa_vector_batched(self): class TestSDPA(mlx_tests.MLXTestCase): - @property - def dtypes(self): - return ["float32", "float16"] if mx.metal.is_available() else ["float32"] - def test_sdpa(self): - if not mx.metal.is_available(): + if not mx.is_available(mx.gpu): return # fmt: off @@ -578,10 +574,11 @@ def test_sdpa(self): # fmt: on shapes = shapes_64 + shapes_128 + dtypes = ["float32", "float16"] masks = [None, "additive", "bool", "causal"] transposes = (False, True) - for dtype in self.dtypes: + for dtype in dtypes: for t in transposes: for mask_str in masks: for B, qL, kL, D, qH, kH in shapes: