From a435cd176e3f4877a90dd91ff4262840ac4a8639 Mon Sep 17 00:00:00 2001 From: mm65x Date: Fri, 20 Mar 2026 10:32:50 +0000 Subject: [PATCH 1/2] add causal_upper_left mask option to scaled_dot_product_attention --- .../cuda/scaled_dot_product_attention.cpp | 41 ++++++++++++++--- .../cuda/scaled_dot_product_attention.cu | 21 ++++++--- mlx/backend/metal/kernels/sdpa_vector.h | 6 ++- .../metal/scaled_dot_product_attention.cpp | 21 ++++++--- mlx/fast.cpp | 30 +++++++++---- mlx/fast_primitives.h | 16 ++++++- python/src/fast.cpp | 28 ++++++++---- python/tests/test_fast_sdpa.py | 45 ++++++++++++++++++- 8 files changed, 170 insertions(+), 38 deletions(-) diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cpp b/mlx/backend/cuda/scaled_dot_product_attention.cpp index 93f310f56b..bfd246f152 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cpp +++ b/mlx/backend/cuda/scaled_dot_product_attention.cpp @@ -133,6 +133,7 @@ struct SDPACacheKey { std::array k_strides; std::array v_strides; bool do_causal; + bool causal_upper_left; std::array mask_shape; std::array mask_strides; bool has_sinks; @@ -145,6 +146,7 @@ inline BytesKey build_sdpa_cache_key( const array& k, const array& v, bool do_causal, + bool causal_upper_left, const std::optional& mask_arr, const std::optional& sinks, bool decoding = false, @@ -159,6 +161,7 @@ inline BytesKey build_sdpa_cache_key( cache_key.pod.k_strides = vector_key(k.strides()); cache_key.pod.v_strides = vector_key(v.strides()); cache_key.pod.do_causal = do_causal; + cache_key.pod.causal_upper_left = causal_upper_left; cache_key.pod.has_sinks = sinks.has_value(); cache_key.pod.output_logsumexp = output_logsumexp; if (mask_arr) { @@ -211,6 +214,7 @@ DnnGraph build_sdpa_graph( const array& k, const array& v, bool do_causal, + bool causal_upper_left, const std::optional& mask_arr, const std::optional& sinks, const std::optional& seq_len_q, @@ -229,7 +233,11 @@ DnnGraph build_sdpa_graph( .set_attn_scale(graph.scalar("Scale", SCALE, float32)) .set_generate_stats(output_logsumexp); if (do_causal) { - options.set_causal_mask_bottom_right(do_causal); + if (causal_upper_left) { + options.set_causal_mask(do_causal); + } else { + options.set_causal_mask_bottom_right(do_causal); + } } if (mask_arr) { options.set_bias(graph.tensor("BIAS", BIAS, *mask_arr)); @@ -262,6 +270,7 @@ DnnGraph build_sdpa_backward_graph( const array& k, const array& v, bool do_causal, + bool causal_upper_left, const std::optional& mask_arr, const std::optional& sinks, const array& o, @@ -283,7 +292,11 @@ DnnGraph build_sdpa_backward_graph( .set_name("sdpa_backward_cudnn") .set_attn_scale(graph.scalar("Scale", SCALE, float32)); if (do_causal) { - options.set_causal_mask_bottom_right(do_causal); + if (causal_upper_left) { + options.set_causal_mask(do_causal); + } else { + options.set_causal_mask_bottom_right(do_causal); + } } if (mask_arr) { options.set_bias(graph.tensor("BIAS", BIAS, *mask_arr)); @@ -352,6 +365,7 @@ void sdpa_cudnn( array& o, std::optional& stats, bool do_causal, + bool causal_upper_left, const std::optional& mask_arr, const std::optional& sinks, bool output_logsumexp, @@ -400,7 +414,16 @@ void sdpa_cudnn( // Search cache. auto cache_key = build_sdpa_cache_key( - encoder, q, k, v, do_causal, mask_arr, sinks, decoding, output_logsumexp); + encoder, + q, + k, + v, + do_causal, + causal_upper_left, + mask_arr, + sinks, + decoding, + output_logsumexp); auto it = sdpa_cache().find(cache_key); if (it == sdpa_cache().end()) { auto graph = build_sdpa_graph( @@ -409,6 +432,7 @@ void sdpa_cudnn( k, v, do_causal, + causal_upper_left, mask_arr, sinks, seq_len_q, @@ -451,6 +475,7 @@ void sdpa_backward_cudnn( const array& o, const array& stats, bool do_causal, + bool causal_upper_left, const std::optional& mask_arr, const std::optional& sinks, const array& d_o, @@ -482,8 +507,8 @@ void sdpa_backward_cudnn( } // Search cache. - auto cache_key = - build_sdpa_cache_key(encoder, q, k, v, do_causal, mask_arr, sinks); + auto cache_key = build_sdpa_cache_key( + encoder, q, k, v, do_causal, causal_upper_left, mask_arr, sinks); auto it = sdpa_backward_cache().find(cache_key); if (it == sdpa_backward_cache().end()) { auto graph = build_sdpa_backward_graph( @@ -492,6 +517,7 @@ void sdpa_backward_cudnn( k, v, do_causal, + causal_upper_left, mask_arr, sinks, o, @@ -539,6 +565,7 @@ void sdpa_vector( float scale, array& o, bool do_causal, + bool causal_upper_left, const std::optional& sinks, Stream s); @@ -605,12 +632,14 @@ void ScaledDotProductAttention::eval_gpu( out, stats, do_causal_, + causal_upper_left_, mask_arr, sinks, output_logsumexp_, s); } else { - sdpa_vector(q, k, v, scale_, out, do_causal_, sinks, s); + sdpa_vector( + q, k, v, scale_, out, do_causal_, causal_upper_left_, sinks, s); } } diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu index 6fafb0ba98..636641188c 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cu +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -32,6 +32,8 @@ struct AttnParams { int gqa_factor; float scale; + int causal_offset; + int64_t Q_strides[3]; int64_t K_strides[3]; int64_t V_strides[3]; @@ -118,7 +120,7 @@ __global__ void kernel_sdpav_1pass( for (int i = kv_seq_idx; i < params.kL; i += BN) { bool use_key = true; if constexpr (do_causal) { - use_key = i <= (params.kL - params.qL + q_seq_idx); + use_key = i <= (params.causal_offset + q_seq_idx); } if (use_key) { @@ -283,7 +285,7 @@ __global__ void kernel_sdpav_2pass_1( for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) { bool use_key = true; if constexpr (do_causal) { - use_key = i <= (params.kL - params.qL + q_seq_idx); + use_key = i <= (params.causal_offset + q_seq_idx); } if (use_key) { @@ -472,6 +474,7 @@ void sdpa_vector_1pass_fallback( const float scale, array& o, bool do_causal, + bool causal_upper_left, const std::optional& sinks) { encoder.set_input_array(q); encoder.set_input_array(k); @@ -492,6 +495,8 @@ void sdpa_vector_1pass_fallback( /* int gqa_factor = */ q.shape(1) / k.shape(1), /* float scale = */ scale, + /* int causal_offset = */ (causal_upper_left ? 0 : k.shape(2) - q.shape(2)), + /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, @@ -531,6 +536,7 @@ void sdpa_vector_2pass_fallback( const float scale, array& o, bool do_causal, + bool causal_upper_left, const std::optional& sinks) { cu::AttnParams params{ /* int B = */ q.shape(0), @@ -543,6 +549,8 @@ void sdpa_vector_2pass_fallback( /* int gqa_factor = */ q.shape(1) / k.shape(1), /* float scale = */ scale, + /* int causal_offset = */ (causal_upper_left ? 0 : k.shape(2) - q.shape(2)), + /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, @@ -644,15 +652,16 @@ void sdpa_vector_fallback( const float scale, array& o, bool do_causal, + bool causal_upper_left, const std::optional& sinks) { int kL = k.shape(2); if (kL > 1024) { return sdpa_vector_2pass_fallback( - s, encoder, q, k, v, scale, o, do_causal, sinks); + s, encoder, q, k, v, scale, o, do_causal, causal_upper_left, sinks); } else { return sdpa_vector_1pass_fallback( - s, encoder, q, k, v, scale, o, do_causal, sinks); + s, encoder, q, k, v, scale, o, do_causal, causal_upper_left, sinks); } } @@ -689,6 +698,7 @@ void sdpa_vector( float scale, array& o, bool do_causal, + bool causal_upper_left, const std::optional& sinks_pre, Stream s) { auto& encoder = cu::get_command_encoder(s); @@ -781,7 +791,8 @@ void sdpa_vector( encoder.add_temporary(cp); } - sdpa_vector_fallback(s, encoder, q, k, v, scale, o, do_causal, sinks); + sdpa_vector_fallback( + s, encoder, q, k, v, scale, o, do_causal, causal_upper_left, sinks); } // Full attention mode should never reach here diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 1eec72be31..f23fcb3ed6 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -36,6 +36,7 @@ template const device T* sinks [[buffer(16), function_constant(has_sinks)]], const constant int& num_q_heads [[buffer(17), function_constant(has_sinks)]], + const constant int& causal_offset [[buffer(18)]], uint3 tid [[threadgroup_position_in_grid]], uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -99,7 +100,7 @@ template for (int i = simd_gid; i < N; i += BN) { bool use_key = true; if (do_causal) { - use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); + use_key = i <= (causal_offset + int(q_seq_idx)); } else if (bool_mask) { use_key = bmask[0]; } else if (float_mask) { @@ -199,6 +200,7 @@ template const constant int& mask_head_stride [[buffer(17), function_constant(has_mask)]], const device T* sinks [[buffer(18), function_constant(has_sinks)]], + const constant int& causal_offset [[buffer(19)]], uint3 tptg [[threads_per_threadgroup]], uint3 tidtg [[thread_position_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], @@ -263,7 +265,7 @@ template for (int i = block_idx; i < N; i += blocks) { bool use_key = true; if (do_causal) { - use_key = i <= (N - q_seq_len + int(q_seq_idx)); + use_key = i <= (causal_offset + int(q_seq_idx)); } else if (bool_mask) { use_key = bmask[0]; } else if (float_mask) { diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 37e554f183..97f64e4a69 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -24,6 +24,7 @@ void sdpa_full_self_attention_nax( const float scale, array& o, bool do_causal_, + bool causal_upper_left, const std::optional& mask, const std::optional& sinks) { using namespace mlx::steel; @@ -131,7 +132,7 @@ void sdpa_full_self_attention_nax( /* int qL_rem = */ (qL - NQ_aligned * bq), /* int kL_rem = */ (kL - NK_aligned * bk), - /* int qL_off = */ (kL - qL), + /* int qL_off = */ (causal_upper_left ? 0 : kL - qL), /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, @@ -172,6 +173,7 @@ void sdpa_full_self_attention_metal( const float scale, array& o, bool do_causal_, + bool causal_upper_left, const std::optional& mask, const std::optional& sinks) { if (metal::is_nax_available() && q.shape(3) != 80 && @@ -185,6 +187,7 @@ void sdpa_full_self_attention_metal( /* const float scale = */ scale, /* array& o = */ o, /* bool do_causal_ = */ do_causal_, + /* bool causal_upper_left = */ causal_upper_left, /* const std::optional& mask = */ mask, /* const std::optional& sinks = */ sinks); } @@ -294,7 +297,7 @@ void sdpa_full_self_attention_metal( /* int qL_rem = */ (qL - NQ_aligned * bq), /* int kL_rem = */ (kL - NK_aligned * bk), - /* int qL_off = */ (kL - qL), + /* int qL_off = */ (causal_upper_left ? 0 : kL - qL), /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, @@ -335,6 +338,7 @@ void sdpa_vector( array& out, float scale, bool do_causal, + bool causal_upper_left, const std::optional& mask, const std::optional& sinks) { // Set the kernel name @@ -410,6 +414,8 @@ void sdpa_vector( compute_encoder.set_input_array(*sinks, 16); compute_encoder.set_bytes(q.shape(1), 17); } + int32_t causal_offset = causal_upper_left ? 0 : N - q.shape(2); + compute_encoder.set_bytes(causal_offset, 18); // Launch compute_encoder.dispatch_threadgroups(grid_dims, group_dims); @@ -424,6 +430,7 @@ void sdpa_vector_2pass( array& out, float scale, bool do_causal, + bool causal_upper_left, const std::optional& mask, const std::optional& sinks) { // Set the kernel name @@ -554,6 +561,8 @@ void sdpa_vector_2pass( if (has_sinks) { compute_encoder.set_input_array(*sinks, 18); } + int32_t causal_offset = causal_upper_left ? 0 : N - q.shape(2); + compute_encoder.set_bytes(causal_offset, 19); // Launch compute_encoder.dispatch_threadgroups(grid_dims, group_dims); @@ -744,9 +753,11 @@ void ScaledDotProductAttention::eval_gpu( char devc = d.get_architecture().back(); if (((devc == 'd' || devc == 's') && k.shape(2) >= 1024) || (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) { - sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask, sinks); + sdpa_vector_2pass( + s, d, q, k, v, o, scale_, do_causal, causal_upper_left_, mask, sinks); } else { - sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask, sinks); + sdpa_vector( + s, d, q, k, v, o, scale_, do_causal, causal_upper_left_, mask, sinks); } } @@ -779,7 +790,7 @@ void ScaledDotProductAttention::eval_gpu( : std::nullopt; sdpa_full_self_attention_metal( - s, d, q, k, v, scale_, o, do_causal_, mask, sinks); + s, d, q, k, v, scale_, o, do_causal_, causal_upper_left_, mask, sinks); } d.add_temporaries(std::move(copies), s.index); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index a668fe9abd..883c07a326 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -628,26 +628,31 @@ array scaled_dot_product_attention( } } // Check valid mask - if (mask_mode != "" && mask_mode != "causal" && mask_mode != "array") { + bool is_causal_mode = mask_mode == "causal" || + mask_mode == "causal_lower_right" || mask_mode == "causal_upper_left"; + if (mask_mode != "" && !is_causal_mode && mask_mode != "array") { std::ostringstream msg; - msg << "[scaled_dot_product_attention] Invalid mask_mode " << mask_mode - << ". mask_mode must be 'causal', 'array' or ''."; + msg << "[scaled_dot_product_attention] Invalid mask_mode '" << mask_mode + << "'. Must be 'causal', 'causal_lower_right', " + << "'causal_upper_left', 'array' or ''."; throw std::invalid_argument(msg.str()); } bool do_causal = false; + bool causal_upper_left = false; bool has_mask = false; bool has_arr_mask = false; bool has_bool_mask = false; - if (mask_mode == "causal") { + if (is_causal_mode) { has_mask = true; do_causal = true; + causal_upper_left = (mask_mode == "causal_upper_left"); if (mask_arr) { std::ostringstream msg; msg << "[scaled_dot_product_attention] Invalid mask_arr for mask_mode " - << "'casusal'. No array mask should be passed."; + << "'" << mask_mode << "'. No array mask should be passed."; throw std::invalid_argument(msg.str()); } } else if (mask_arr) { @@ -718,6 +723,7 @@ array scaled_dot_product_attention( n_q_heads, n_kv_heads, do_causal, + causal_upper_left, has_sinks, has_arr_mask, s](const std::vector& inputs) { @@ -737,7 +743,7 @@ array scaled_dot_product_attention( if (do_causal) { int kL = k.shape(-2); int qL = q.shape(-2); - int offset = kL - qL; + int offset = causal_upper_left ? 0 : kL - qL; auto q_idx = arange(offset, qL + offset, s); auto k_idx = arange(0, kL, s); q_idx = expand_dims(q_idx, 1, s); @@ -846,7 +852,13 @@ array scaled_dot_product_attention( } Shape out_shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; auto primitive = std::make_shared( - stream, fallback, scale, do_causal, has_sinks, output_logsumexp); + stream, + fallback, + scale, + do_causal, + causal_upper_left, + has_sinks, + output_logsumexp); if (output_logsumexp) { return array::make_arrays( {std::move(out_shape), Shape{q.shape(0), q.shape(1), q.shape(2), 1}}, @@ -888,7 +900,7 @@ std::vector ScaledDotProductAttention::vjp( dtypes.push_back(primals[i].dtype()); } auto primitive = std::make_shared( - s, fallback, scale_, do_causal_, has_sinks_); + s, fallback, scale_, do_causal_, causal_upper_left_, has_sinks_); std::vector inputs = primals; inputs.push_back(outputs[0]); inputs.push_back(outputs[1]); @@ -911,6 +923,7 @@ bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { const ScaledDotProductAttention& a_other = static_cast(other); return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_ && + causal_upper_left_ == a_other.causal_upper_left_ && has_sinks_ == a_other.has_sinks_ && output_logsumexp_ == a_other.output_logsumexp_; } @@ -919,6 +932,7 @@ bool ScaledDotProductAttentionVJP::is_equivalent(const Primitive& other) const { const ScaledDotProductAttentionVJP& a_other = static_cast(other); return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_ && + causal_upper_left_ == a_other.causal_upper_left_ && has_sinks_ == a_other.has_sinks_; } diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4434830875..48e21e3d35 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -210,11 +210,13 @@ class ScaledDotProductAttention : public Custom { std::function(std::vector)> fallback, float scale, bool do_causal, + bool causal_upper_left, bool has_sinks, bool output_logsumexp) : Custom(stream, std::move(fallback)), scale_(scale), do_causal_(do_causal), + causal_upper_left_(causal_upper_left), has_sinks_(has_sinks), output_logsumexp_(output_logsumexp) {} @@ -250,12 +252,18 @@ class ScaledDotProductAttention : public Custom { DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { return std::make_tuple( - nullptr, scale_, do_causal_, has_sinks_, output_logsumexp_); + nullptr, + scale_, + do_causal_, + causal_upper_left_, + has_sinks_, + output_logsumexp_); } private: float scale_; bool do_causal_; + bool causal_upper_left_; bool has_sinks_; bool output_logsumexp_; }; @@ -267,10 +275,12 @@ class ScaledDotProductAttentionVJP : public Custom { std::function(std::vector)> fallback, float scale, bool do_causal, + bool causal_upper_left, bool has_sinks) : Custom(stream, std::move(fallback)), scale_(scale), do_causal_(do_causal), + causal_upper_left_(causal_upper_left), has_sinks_(has_sinks) {} static bool use_fallback(const array& q, Stream s); @@ -286,12 +296,14 @@ class ScaledDotProductAttentionVJP : public Custom { DEFINE_NAME(ScaledDotProductAttentionVJP); bool is_equivalent(const Primitive& other) const override; auto state() const { - return std::make_tuple(nullptr, scale_, do_causal_, has_sinks_); + return std::make_tuple( + nullptr, scale_, do_causal_, causal_upper_left_, has_sinks_); } private: float scale_; bool do_causal_; + bool causal_upper_left_; bool has_sinks_; }; diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 1a43d89d9b..9d6216639d 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -206,10 +206,13 @@ void init_fast(nb::module_& parent_module) { if (has_mask) { if (has_str_mask) { auto mask_str = std::get(mask); - if (mask_str != "causal") { + if (mask_str != "causal" && mask_str != "causal_lower_right" && + mask_str != "causal_upper_left") { std::ostringstream msg; msg << "[scaled_dot_product_attention] invalid mask option '" - << mask_str << "'. Must be 'causal', or an array."; + << mask_str + << "'. Must be 'causal', 'causal_lower_right', " + << "'causal_upper_left', or an array."; throw std::invalid_argument(msg.str()); } return mx::fast::scaled_dot_product_attention( @@ -267,13 +270,20 @@ void init_fast(nb::module_& parent_module) { scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``). mask (str or array, optional): The mask to apply to the query-key scores. The mask can be an array or a string indicating - the mask type. The only supported string type is ``"causal"``. If - the mask is an array it can be a boolean or additive mask. The mask - can have at most 4 dimensions and must be broadcast-compatible with - the shape ``[B, N, T_q, T_kv]``. If an additive mask is given its - type must promote to the promoted type of ``q``, ``k``, and ``v``. - The ``"causal"`` mask uses lower-right alignment where the - last query aligns with the last key. + the mask type. Supported string types are: + + * ``"causal"`` or ``"causal_lower_right"``: Lower-right + aligned causal mask. The last query attends to the last key. + This is the standard mask for autoregressive decoding. + * ``"causal_upper_left"``: Upper-left aligned causal mask. + Query ``i`` attends to keys ``0..i``. This matches PyTorch's + default ``is_causal=True`` behavior. + + If the mask is an array it can be a boolean or additive mask. + The mask can have at most 4 dimensions and must be + broadcast-compatible with the shape ``[B, N, T_q, T_kv]``. If + an additive mask is given its type must promote to the promoted + type of ``q``, ``k``, and ``v``. sinks (array, optional): An optional array of attention sinks. Default: ``None``. diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 7606373ce4..bceed39b45 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -26,7 +26,6 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None, sinks=None): scores = q @ mx.swapaxes(k, -1, -2) is_causal = mask == "causal" if mask is not None: - if is_causal: offset = kL - L q_indices = mx.arange(L) + offset @@ -642,6 +641,50 @@ def test_sdpa_sliced(self): tolerance = {"rtol": 1e-2, "atol": 1e-2} self.assertTrue(mx.allclose(ref, out, **tolerance)) + def test_causal_mask_alignment(self): + B, H, D = 1, 2, 64 + qL, kL = 4, 8 + scale = 1.0 / math.sqrt(D) + + mx.random.seed(0) + q = mx.random.normal((B, H, qL, D)) + k = mx.random.normal((B, H, kL, D)) + v = mx.random.normal((B, H, kL, D)) + + # "causal" and "causal_lower_right" should be identical + out_causal = mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale, mask="causal" + ) + out_lr = mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale, mask="causal_lower_right" + ) + self.assertTrue(mx.allclose(out_causal, out_lr, atol=1e-6, rtol=1e-5)) + + # "causal_upper_left" should match a manual upper-left mask + q_idx = mx.arange(qL) + k_idx = mx.arange(kL) + ul_mask = q_idx[:, None] >= k_idx[None] + out_ul = mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale, mask="causal_upper_left" + ) + out_manual = mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale, mask=ul_mask + ) + self.assertTrue(mx.allclose(out_ul, out_manual, atol=1e-5, rtol=1e-4)) + + # upper-left != lower-right when qL != kL + self.assertFalse(mx.allclose(out_ul, out_lr, atol=1e-2, rtol=1e-2)) + + # when qL == kL, both should be identical + q_eq = mx.random.normal((B, H, kL, D)) + out_lr_eq = mx.fast.scaled_dot_product_attention( + q_eq, k, v, scale=scale, mask="causal_lower_right" + ) + out_ul_eq = mx.fast.scaled_dot_product_attention( + q_eq, k, v, scale=scale, mask="causal_upper_left" + ) + self.assertTrue(mx.allclose(out_lr_eq, out_ul_eq, atol=1e-6, rtol=1e-5)) + if __name__ == "__main__": mlx_tests.MLXTestRunner(failfast=True) From 9e85fd0635cebb2fc60ae02f080f5d63c10ddbfe Mon Sep 17 00:00:00 2001 From: mm65x Date: Thu, 26 Mar 2026 16:39:09 +0000 Subject: [PATCH 2/2] fix formatting in causal mask update --- mlx/backend/cuda/scaled_dot_product_attention.cpp | 3 +-- mlx/backend/cuda/scaled_dot_product_attention.cu | 6 ++++-- python/src/fast.cpp | 3 +-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cpp b/mlx/backend/cuda/scaled_dot_product_attention.cpp index bfd246f152..cc8f848330 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cpp +++ b/mlx/backend/cuda/scaled_dot_product_attention.cpp @@ -638,8 +638,7 @@ void ScaledDotProductAttention::eval_gpu( output_logsumexp_, s); } else { - sdpa_vector( - q, k, v, scale_, out, do_causal_, causal_upper_left_, sinks, s); + sdpa_vector(q, k, v, scale_, out, do_causal_, causal_upper_left_, sinks, s); } } diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu index 636641188c..c147b99e46 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cu +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -495,7 +495,8 @@ void sdpa_vector_1pass_fallback( /* int gqa_factor = */ q.shape(1) / k.shape(1), /* float scale = */ scale, - /* int causal_offset = */ (causal_upper_left ? 0 : k.shape(2) - q.shape(2)), + /* int causal_offset = */ + (causal_upper_left ? 0 : k.shape(2) - q.shape(2)), /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, @@ -549,7 +550,8 @@ void sdpa_vector_2pass_fallback( /* int gqa_factor = */ q.shape(1) / k.shape(1), /* float scale = */ scale, - /* int causal_offset = */ (causal_upper_left ? 0 : k.shape(2) - q.shape(2)), + /* int causal_offset = */ + (causal_upper_left ? 0 : k.shape(2) - q.shape(2)), /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 9d6216639d..57b94700f3 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -210,8 +210,7 @@ void init_fast(nb::module_& parent_module) { mask_str != "causal_upper_left") { std::ostringstream msg; msg << "[scaled_dot_product_attention] invalid mask option '" - << mask_str - << "'. Must be 'causal', 'causal_lower_right', " + << mask_str << "'. Must be 'causal', 'causal_lower_right', " << "'causal_upper_left', or an array."; throw std::invalid_argument(msg.str()); }