Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions mlx/backend/cuda/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ struct SDPACacheKey {
std::array<int64_t, QKV_NDIM> k_strides;
std::array<int64_t, QKV_NDIM> v_strides;
bool do_causal;
bool causal_upper_left;
std::array<int, QKV_NDIM> mask_shape;
std::array<int64_t, QKV_NDIM> mask_strides;
bool has_sinks;
Expand All @@ -145,6 +146,7 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
const array& k,
const array& v,
bool do_causal,
bool causal_upper_left,
const std::optional<array>& mask_arr,
const std::optional<array>& sinks,
bool decoding = false,
Expand All @@ -159,6 +161,7 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
cache_key.pod.k_strides = vector_key<QKV_NDIM>(k.strides());
cache_key.pod.v_strides = vector_key<QKV_NDIM>(v.strides());
cache_key.pod.do_causal = do_causal;
cache_key.pod.causal_upper_left = causal_upper_left;
cache_key.pod.has_sinks = sinks.has_value();
cache_key.pod.output_logsumexp = output_logsumexp;
if (mask_arr) {
Expand Down Expand Up @@ -211,6 +214,7 @@ DnnGraph build_sdpa_graph(
const array& k,
const array& v,
bool do_causal,
bool causal_upper_left,
const std::optional<array>& mask_arr,
const std::optional<array>& sinks,
const std::optional<array>& seq_len_q,
Expand All @@ -229,7 +233,11 @@ DnnGraph build_sdpa_graph(
.set_attn_scale(graph.scalar("Scale", SCALE, float32))
.set_generate_stats(output_logsumexp);
if (do_causal) {
options.set_causal_mask_bottom_right(do_causal);
if (causal_upper_left) {
options.set_causal_mask(do_causal);
} else {
options.set_causal_mask_bottom_right(do_causal);
}
}
if (mask_arr) {
options.set_bias(graph.tensor("BIAS", BIAS, *mask_arr));
Expand Down Expand Up @@ -262,6 +270,7 @@ DnnGraph build_sdpa_backward_graph(
const array& k,
const array& v,
bool do_causal,
bool causal_upper_left,
const std::optional<array>& mask_arr,
const std::optional<array>& sinks,
const array& o,
Expand All @@ -283,7 +292,11 @@ DnnGraph build_sdpa_backward_graph(
.set_name("sdpa_backward_cudnn")
.set_attn_scale(graph.scalar("Scale", SCALE, float32));
if (do_causal) {
options.set_causal_mask_bottom_right(do_causal);
if (causal_upper_left) {
options.set_causal_mask(do_causal);
} else {
options.set_causal_mask_bottom_right(do_causal);
}
}
if (mask_arr) {
options.set_bias(graph.tensor("BIAS", BIAS, *mask_arr));
Expand Down Expand Up @@ -352,6 +365,7 @@ void sdpa_cudnn(
array& o,
std::optional<array>& stats,
bool do_causal,
bool causal_upper_left,
const std::optional<array>& mask_arr,
const std::optional<array>& sinks,
bool output_logsumexp,
Expand Down Expand Up @@ -400,7 +414,16 @@ void sdpa_cudnn(

// Search cache.
auto cache_key = build_sdpa_cache_key(
encoder, q, k, v, do_causal, mask_arr, sinks, decoding, output_logsumexp);
encoder,
q,
k,
v,
do_causal,
causal_upper_left,
mask_arr,
sinks,
decoding,
output_logsumexp);
auto it = sdpa_cache().find(cache_key);
if (it == sdpa_cache().end()) {
auto graph = build_sdpa_graph(
Expand All @@ -409,6 +432,7 @@ void sdpa_cudnn(
k,
v,
do_causal,
causal_upper_left,
mask_arr,
sinks,
seq_len_q,
Expand Down Expand Up @@ -451,6 +475,7 @@ void sdpa_backward_cudnn(
const array& o,
const array& stats,
bool do_causal,
bool causal_upper_left,
const std::optional<array>& mask_arr,
const std::optional<array>& sinks,
const array& d_o,
Expand Down Expand Up @@ -482,8 +507,8 @@ void sdpa_backward_cudnn(
}

// Search cache.
auto cache_key =
build_sdpa_cache_key(encoder, q, k, v, do_causal, mask_arr, sinks);
auto cache_key = build_sdpa_cache_key(
encoder, q, k, v, do_causal, causal_upper_left, mask_arr, sinks);
auto it = sdpa_backward_cache().find(cache_key);
if (it == sdpa_backward_cache().end()) {
auto graph = build_sdpa_backward_graph(
Expand All @@ -492,6 +517,7 @@ void sdpa_backward_cudnn(
k,
v,
do_causal,
causal_upper_left,
mask_arr,
sinks,
o,
Expand Down Expand Up @@ -539,6 +565,7 @@ void sdpa_vector(
float scale,
array& o,
bool do_causal,
bool causal_upper_left,
const std::optional<array>& sinks,
Stream s);

Expand Down Expand Up @@ -605,12 +632,13 @@ void ScaledDotProductAttention::eval_gpu(
out,
stats,
do_causal_,
causal_upper_left_,
mask_arr,
sinks,
output_logsumexp_,
s);
} else {
sdpa_vector(q, k, v, scale_, out, do_causal_, sinks, s);
sdpa_vector(q, k, v, scale_, out, do_causal_, causal_upper_left_, sinks, s);
}
}

Expand Down
23 changes: 18 additions & 5 deletions mlx/backend/cuda/scaled_dot_product_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ struct AttnParams {
int gqa_factor;
float scale;

int causal_offset;

int64_t Q_strides[3];
int64_t K_strides[3];
int64_t V_strides[3];
Expand Down Expand Up @@ -118,7 +120,7 @@ __global__ void kernel_sdpav_1pass(
for (int i = kv_seq_idx; i < params.kL; i += BN) {
bool use_key = true;
if constexpr (do_causal) {
use_key = i <= (params.kL - params.qL + q_seq_idx);
use_key = i <= (params.causal_offset + q_seq_idx);
}

if (use_key) {
Expand Down Expand Up @@ -283,7 +285,7 @@ __global__ void kernel_sdpav_2pass_1(
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
bool use_key = true;
if constexpr (do_causal) {
use_key = i <= (params.kL - params.qL + q_seq_idx);
use_key = i <= (params.causal_offset + q_seq_idx);
}

if (use_key) {
Expand Down Expand Up @@ -472,6 +474,7 @@ void sdpa_vector_1pass_fallback(
const float scale,
array& o,
bool do_causal,
bool causal_upper_left,
const std::optional<array>& sinks) {
encoder.set_input_array(q);
encoder.set_input_array(k);
Expand All @@ -492,6 +495,9 @@ void sdpa_vector_1pass_fallback(
/* int gqa_factor = */ q.shape(1) / k.shape(1),
/* float scale = */ scale,

/* int causal_offset = */
(causal_upper_left ? 0 : k.shape(2) - q.shape(2)),

/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
Expand Down Expand Up @@ -531,6 +537,7 @@ void sdpa_vector_2pass_fallback(
const float scale,
array& o,
bool do_causal,
bool causal_upper_left,
const std::optional<array>& sinks) {
cu::AttnParams params{
/* int B = */ q.shape(0),
Expand All @@ -543,6 +550,9 @@ void sdpa_vector_2pass_fallback(
/* int gqa_factor = */ q.shape(1) / k.shape(1),
/* float scale = */ scale,

/* int causal_offset = */
(causal_upper_left ? 0 : k.shape(2) - q.shape(2)),

/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
Expand Down Expand Up @@ -644,15 +654,16 @@ void sdpa_vector_fallback(
const float scale,
array& o,
bool do_causal,
bool causal_upper_left,
const std::optional<array>& sinks) {
int kL = k.shape(2);

if (kL > 1024) {
return sdpa_vector_2pass_fallback(
s, encoder, q, k, v, scale, o, do_causal, sinks);
s, encoder, q, k, v, scale, o, do_causal, causal_upper_left, sinks);
} else {
return sdpa_vector_1pass_fallback(
s, encoder, q, k, v, scale, o, do_causal, sinks);
s, encoder, q, k, v, scale, o, do_causal, causal_upper_left, sinks);
}
}

Expand Down Expand Up @@ -689,6 +700,7 @@ void sdpa_vector(
float scale,
array& o,
bool do_causal,
bool causal_upper_left,
const std::optional<array>& sinks_pre,
Stream s) {
auto& encoder = cu::get_command_encoder(s);
Expand Down Expand Up @@ -781,7 +793,8 @@ void sdpa_vector(
encoder.add_temporary(cp);
}

sdpa_vector_fallback(s, encoder, q, k, v, scale, o, do_causal, sinks);
sdpa_vector_fallback(
s, encoder, q, k, v, scale, o, do_causal, causal_upper_left, sinks);
}

// Full attention mode should never reach here
Expand Down
6 changes: 4 additions & 2 deletions mlx/backend/metal/kernels/sdpa_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ template <typename T, int D, int V = D>
const device T* sinks [[buffer(16), function_constant(has_sinks)]],
const constant int& num_q_heads
[[buffer(17), function_constant(has_sinks)]],
const constant int& causal_offset [[buffer(18)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
Expand Down Expand Up @@ -99,7 +100,7 @@ template <typename T, int D, int V = D>
for (int i = simd_gid; i < N; i += BN) {
bool use_key = true;
if (do_causal) {
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
use_key = i <= (causal_offset + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
} else if (float_mask) {
Expand Down Expand Up @@ -199,6 +200,7 @@ template <typename T, int D, int V = D>
const constant int& mask_head_stride
[[buffer(17), function_constant(has_mask)]],
const device T* sinks [[buffer(18), function_constant(has_sinks)]],
const constant int& causal_offset [[buffer(19)]],
uint3 tptg [[threads_per_threadgroup]],
uint3 tidtg [[thread_position_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
Expand Down Expand Up @@ -263,7 +265,7 @@ template <typename T, int D, int V = D>
for (int i = block_idx; i < N; i += blocks) {
bool use_key = true;
if (do_causal) {
use_key = i <= (N - q_seq_len + int(q_seq_idx));
use_key = i <= (causal_offset + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
} else if (float_mask) {
Expand Down
21 changes: 16 additions & 5 deletions mlx/backend/metal/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ void sdpa_full_self_attention_nax(
const float scale,
array& o,
bool do_causal_,
bool causal_upper_left,
const std::optional<array>& mask,
const std::optional<array>& sinks) {
using namespace mlx::steel;
Expand Down Expand Up @@ -131,7 +132,7 @@ void sdpa_full_self_attention_nax(

/* int qL_rem = */ (qL - NQ_aligned * bq),
/* int kL_rem = */ (kL - NK_aligned * bk),
/* int qL_off = */ (kL - qL),
/* int qL_off = */ (causal_upper_left ? 0 : kL - qL),

/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
Expand Down Expand Up @@ -172,6 +173,7 @@ void sdpa_full_self_attention_metal(
const float scale,
array& o,
bool do_causal_,
bool causal_upper_left,
const std::optional<array>& mask,
const std::optional<array>& sinks) {
if (metal::is_nax_available() && q.shape(3) != 80 &&
Expand All @@ -185,6 +187,7 @@ void sdpa_full_self_attention_metal(
/* const float scale = */ scale,
/* array& o = */ o,
/* bool do_causal_ = */ do_causal_,
/* bool causal_upper_left = */ causal_upper_left,
/* const std::optional<array>& mask = */ mask,
/* const std::optional<array>& sinks = */ sinks);
}
Expand Down Expand Up @@ -294,7 +297,7 @@ void sdpa_full_self_attention_metal(

/* int qL_rem = */ (qL - NQ_aligned * bq),
/* int kL_rem = */ (kL - NK_aligned * bk),
/* int qL_off = */ (kL - qL),
/* int qL_off = */ (causal_upper_left ? 0 : kL - qL),

/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
Expand Down Expand Up @@ -335,6 +338,7 @@ void sdpa_vector(
array& out,
float scale,
bool do_causal,
bool causal_upper_left,
const std::optional<array>& mask,
const std::optional<array>& sinks) {
// Set the kernel name
Expand Down Expand Up @@ -410,6 +414,8 @@ void sdpa_vector(
compute_encoder.set_input_array(*sinks, 16);
compute_encoder.set_bytes(q.shape(1), 17);
}
int32_t causal_offset = causal_upper_left ? 0 : N - q.shape(2);
compute_encoder.set_bytes(causal_offset, 18);

// Launch
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
Expand All @@ -424,6 +430,7 @@ void sdpa_vector_2pass(
array& out,
float scale,
bool do_causal,
bool causal_upper_left,
const std::optional<array>& mask,
const std::optional<array>& sinks) {
// Set the kernel name
Expand Down Expand Up @@ -554,6 +561,8 @@ void sdpa_vector_2pass(
if (has_sinks) {
compute_encoder.set_input_array(*sinks, 18);
}
int32_t causal_offset = causal_upper_left ? 0 : N - q.shape(2);
compute_encoder.set_bytes(causal_offset, 19);

// Launch
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
Expand Down Expand Up @@ -744,9 +753,11 @@ void ScaledDotProductAttention::eval_gpu(
char devc = d.get_architecture().back();
if (((devc == 'd' || devc == 's') && k.shape(2) >= 1024) ||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask, sinks);
sdpa_vector_2pass(
s, d, q, k, v, o, scale_, do_causal, causal_upper_left_, mask, sinks);
} else {
sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask, sinks);
sdpa_vector(
s, d, q, k, v, o, scale_, do_causal, causal_upper_left_, mask, sinks);
}
}

Expand Down Expand Up @@ -779,7 +790,7 @@ void ScaledDotProductAttention::eval_gpu(
: std::nullopt;

sdpa_full_self_attention_metal(
s, d, q, k, v, scale_, o, do_causal_, mask, sinks);
s, d, q, k, v, scale_, o, do_causal_, causal_upper_left_, mask, sinks);
}

d.add_temporaries(std::move(copies), s.index);
Expand Down
Loading
Loading