Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ constant bool align_K [[function_constant(201)]];
constant bool has_mask [[function_constant(300)]];
constant bool do_causal [[function_constant(301)]];
constant bool has_sinks [[function_constant(302)]];
constant bool output_logsumexp [[function_constant(304)]];

struct MaxOp {
template <typename T>
Expand Down Expand Up @@ -76,6 +77,7 @@ template <
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
const device T* sinks [[buffer(7), function_constant(has_sinks)]],
device float* lse_out [[buffer(8), function_constant(output_logsumexp)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
Expand Down Expand Up @@ -473,4 +475,18 @@ template <
} else {
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
}

// Write per-row logsumexp if requested
if (output_logsumexp) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
int row = int(tid.x) * BQ + tm + sm + (i * kFragSize);
if (row < params->qL) {
int64_t idx = int64_t(tid.z) * params->H * params->qL
+ int64_t(tid.y) * params->qL + row;
lse_out[idx] = float(max_score[i]) * M_LN2_F
+ metal::precise::log(float(sum_score[i]));
}
}
}
}
35 changes: 26 additions & 9 deletions mlx/backend/metal/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,11 @@ void sdpa_full_self_attention_metal(
array& o,
bool do_causal_,
const std::optional<array>& mask,
const std::optional<array>& sinks) {
if (metal::is_nax_available() && q.shape(3) != 80 &&
const std::optional<array>& sinks,
bool output_logsumexp_flag,
array* lse) {
// NAX path does not support logsumexp output
if (!output_logsumexp_flag && metal::is_nax_available() && q.shape(3) != 80 &&
(env::enable_tf32() || q.dtype() != float32)) {
return sdpa_full_self_attention_nax(
/* const Stream& s = */ s,
Expand Down Expand Up @@ -217,7 +220,8 @@ void sdpa_full_self_attention_metal(
{&align_K, MTL::DataType::DataTypeBool, 201},
{&has_mask, MTL::DataType::DataTypeBool, 300},
{&do_causal, MTL::DataType::DataTypeBool, 301},
{&has_sinks, MTL::DataType::DataTypeBool, 302}};
{&has_sinks, MTL::DataType::DataTypeBool, 302},
{&output_logsumexp_flag, MTL::DataType::DataTypeBool, 304}};

std::string base_name;
concatenate(
Expand Down Expand Up @@ -250,7 +254,9 @@ void sdpa_full_self_attention_metal(
"_do_causal_",
(do_causal ? 't' : 'n'),
"_has_sinks_",
(has_sinks ? 't' : 'n'));
(has_sinks ? 't' : 'n'),
"_lse_",
(output_logsumexp_flag ? 't' : 'n'));

auto& compute_encoder = d.get_command_encoder(s.index);

Expand Down Expand Up @@ -319,6 +325,9 @@ void sdpa_full_self_attention_metal(
if (has_sinks) {
compute_encoder.set_input_array(*sinks, 7);
}
if (output_logsumexp_flag && lse != nullptr) {
compute_encoder.set_output_array(*lse, 8);
}

MTL::Size grid_dims = MTL::Size(NQ, H, B);
MTL::Size group_dims = MTL::Size(32, wm, wn);
Expand Down Expand Up @@ -600,9 +609,6 @@ bool ScaledDotProductAttention::use_fallback(
// forward and backward.
return true;
}
if (output_logsumexp) {
return true;
}
if (s.device == Device::cpu) {
return true;
}
Expand All @@ -628,7 +634,9 @@ bool ScaledDotProductAttention::use_fallback(
const bool supports_sdpa_full = query_sequence_length > 8 &&
sdpa_full_supported_mask && sdpa_full_supported_head_dim;

const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
// Vector kernels do not support logsumexp output
const bool supports_sdpa_vector = !output_logsumexp &&
(query_sequence_length <= 8) &&
(query_sequence_length <= key_sequence_length) &&
sdpa_vector_supported_head_dim &&
(query_sequence_length * gqa_factor) <= 32;
Expand Down Expand Up @@ -778,8 +786,17 @@ void ScaledDotProductAttention::eval_gpu(
? std::optional<array>{copy_unless(is_matrix_contiguous, inputs[3])}
: std::nullopt;

// Set up logsumexp output if requested
bool lse_flag = output_logsumexp_ && outputs.size() > 1;
array* lse_ptr = nullptr;
if (lse_flag) {
auto& lse = outputs[1];
lse.set_data(allocator::malloc(lse.nbytes()));
lse_ptr = &lse;
}

sdpa_full_self_attention_metal(
s, d, q, k, v, scale_, o, do_causal_, mask, sinks);
s, d, q, k, v, scale_, o, do_causal_, mask, sinks, lse_flag, lse_ptr);
}

d.add_temporaries(std::move(copies), s.index);
Expand Down
Loading
Loading