Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlx/backend/metal/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ build_kernel(random)
build_kernel(rms_norm)
build_kernel(rope)
build_kernel(scaled_dot_product_attention sdpa_vector.h)
build_kernel(steel/attn/kernels/sdpa_chunked_reduce
steel/attn/kernels/sdpa_chunked_reduce.h)
if(MLX_METAL_VERSION GREATER_EQUAL 320)
build_kernel(fence)
endif()
Expand Down
88 changes: 88 additions & 0 deletions mlx/backend/metal/kernels/steel/attn/kernels/sdpa_chunked_reduce.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright © 2025 Apple Inc.

#pragma once

#include <metal_stdlib>

using namespace metal;

///////////////////////////////////////////////////////////////////////////////
// Reduction kernel for chunked SDPA
//
// Combines N per-chunk outputs using logsumexp-weighted averaging:
//
// max_lse = max(lse_1, ..., lse_N)
// w_c = exp(lse_c - max_lse)
// out = sum(w_c * out_c) / sum(w_c)
//
// Grid: (D, qL, B*H) — one thread per output element, dispatch_threads
///////////////////////////////////////////////////////////////////////////////

template <typename T>
[[kernel]] void sdpa_chunked_reduce(
const device T* chunk_outs [[buffer(0)]],
const device float* chunk_lses [[buffer(1)]],
device T* output [[buffer(2)]],
const constant int& n_chunks [[buffer(3)]],
const constant int& D [[buffer(4)]],
const constant int& qL [[buffer(5)]],
const constant int& H [[buffer(6)]],
const constant int64_t* O_strides [[buffer(7)]],
const constant int& BHqL [[buffer(8)]],
uint3 tid [[thread_position_in_grid]]) {
// tid.x = d (head dimension index)
// tid.y = q (query sequence index)
// tid.z = bh (batch*head linear index)

const int d = tid.x;
const int q = tid.y;
const int bh = tid.z;

if (d >= D || q >= qL)
return;

// Decompose bh into batch and head indices
const int h = bh % H;
const int b = bh / H;

// Linear index within the BHqL plane (for chunk_outs and chunk_lses)
const int64_t bhq = int64_t(bh) * int64_t(qL) + int64_t(q);

// --- Pass 1: find max logsumexp across chunks ---
float max_lse = -INFINITY;
for (int c = 0; c < n_chunks; c++) {
int64_t lse_idx = int64_t(c) * int64_t(BHqL) + bhq;
float lse_val = chunk_lses[lse_idx];
max_lse = max(max_lse, lse_val);
}

// --- Pass 2: accumulate weighted sum and total weight ---
// Guard: when all keys in a chunk are causally masked, the kernel output
// is NaN (0/0 in softmax) and lse is -inf. exp(-inf - max_lse) = 0,
// but 0 * NaN = NaN in IEEE 754. Skip zero-weight chunks to avoid this.
float acc = 0.0f;
float sum_w = 0.0f;
for (int c = 0; c < n_chunks; c++) {
int64_t lse_idx = int64_t(c) * int64_t(BHqL) + bhq;
float w = metal::exp(chunk_lses[lse_idx] - max_lse);
if (w > 0.0f) {
sum_w += w;

int64_t out_idx = int64_t(c) * int64_t(BHqL) * int64_t(D) +
bhq * int64_t(D) + int64_t(d);
acc += w * float(chunk_outs[out_idx]);
}
}

// Normalize
float result = (sum_w > 0.0f) ? (acc / sum_w) : 0.0f;

// Write to strided output: O_strides = [batch_stride, head_stride, seq_stride]
// D dimension stride is 1 (innermost, contiguous)
int64_t o_idx = int64_t(b) * O_strides[0] +
int64_t(h) * O_strides[1] +
int64_t(q) * O_strides[2] +
int64_t(d);

output[o_idx] = T(result);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright © 2025 Apple Inc.

// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/attn/kernels/sdpa_chunked_reduce.h"

#define instantiate_chunked_reduce(tname, type) \
template [[host_name("sdpa_chunked_reduce_" #tname)]] \
[[kernel]] decltype(sdpa_chunked_reduce<type>) \
sdpa_chunked_reduce<type>;

instantiate_chunked_reduce(float32, float)
instantiate_chunked_reduce(float16, half)
instantiate_chunked_reduce(bfloat16, bfloat16_t)
// clang-format on
16 changes: 16 additions & 0 deletions mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ constant bool align_K [[function_constant(201)]];
constant bool has_mask [[function_constant(300)]];
constant bool do_causal [[function_constant(301)]];
constant bool has_sinks [[function_constant(302)]];
constant bool output_logsumexp [[function_constant(304)]];

struct MaxOp {
template <typename T>
Expand Down Expand Up @@ -76,6 +77,7 @@ template <
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
const device T* sinks [[buffer(7), function_constant(has_sinks)]],
device float* lse_out [[buffer(8), function_constant(output_logsumexp)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
Expand Down Expand Up @@ -473,4 +475,18 @@ template <
} else {
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
}

// Write per-row logsumexp if requested
if (output_logsumexp) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
int row = int(tid.x) * BQ + tm + sm + (i * kFragSize);
if (row < params->qL) {
int64_t idx = int64_t(tid.z) * params->H * params->qL
+ int64_t(tid.y) * params->qL + row;
lse_out[idx] = float(max_score[i]) * M_LN2_F
+ metal::precise::log(float(sum_score[i]));
}
}
}
}
Loading