Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
430 changes: 430 additions & 0 deletions benchmarks/python/sdpa_vector_vjp_bench.py

Large diffs are not rendered by default.

16 changes: 11 additions & 5 deletions mlx/backend/cuda/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,6 @@ bool ScaledDotProductAttention::use_fallback(
bool has_mask,
bool has_arr_mask,
bool do_causal,
bool is_training,
bool output_logsumexp,
Stream s) {
if (s.device == Device::cpu) {
Expand Down Expand Up @@ -614,7 +613,14 @@ void ScaledDotProductAttention::eval_gpu(
}
}

bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) {
bool ScaledDotProductAttentionVJP::use_fallback(
const array& q,
const array& /* k */,
Stream s,
bool,
bool,
bool,
int) {
// The frontend adds a padding mask when sequence length is not a multiple of
// tile size.
if (q.shape(2) % 128 != 0) {
Expand All @@ -630,8 +636,8 @@ void ScaledDotProductAttentionVJP::eval_gpu(

auto& s = stream();

assert(inputs.size() >= 6);
int primals_size = inputs.size() - 3;
assert(inputs.size() >= 7); // primals(>=3) + O + LSE + dO + delta
int primals_size = inputs.size() - 4;
bool has_arr_mask = primals_size > 3 + has_sinks_;

array q = prepare_sdpa_input(inputs[0], s);
Expand All @@ -647,7 +653,7 @@ void ScaledDotProductAttentionVJP::eval_gpu(
}
std::optional<array> sinks;
if (has_sinks_) {
sinks = prepare_sdpa_sinks(inputs.back(), s);
sinks = prepare_sdpa_sinks(inputs[primals_size - 1], s);
}

assert(outputs.size() == 3);
Expand Down
2 changes: 2 additions & 0 deletions mlx/backend/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ if(MLX_METAL_JIT)
make_jit_source(gemv_masked)

make_jit_source(steel/attn/kernels/steel_attention)
make_jit_source(steel/attn/kernels/steel_attention_vjp_dq)
make_jit_source(steel/attn/kernels/steel_attention_vjp_dkv)

make_jit_source(
steel/gemm/gemm_nax kernels/steel/utils.h kernels/steel/gemm/nax.h
Expand Down
2 changes: 2 additions & 0 deletions mlx/backend/metal/jit/includes.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ const char* steel_conv_3d();
const char* steel_conv_general();
const char* gemv_masked();
const char* steel_attention();
const char* steel_attention_vjp_dq();
const char* steel_attention_vjp_dkv();

const char* gemm_nax();
const char* steel_gemm_fused_nax();
Expand Down
178 changes: 174 additions & 4 deletions mlx/backend/metal/jit_kernels.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
// Copyright © 2024 Apple Inc.
#include <cmath>
#include <cstdio>
#include <cstring>

#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/kernels.h"
Expand Down Expand Up @@ -1076,6 +1080,85 @@ MTL::ComputePipelineState* get_gather_qmm_nax_kernel(
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}

namespace {

// Produce a valid MSL float literal with enough precision for exact round-trip.
std::string float_to_msl(float v) {
char buf[32];
std::snprintf(buf, sizeof(buf), "%.9gf", v);
return std::string(buf);
}

// Encode a float's bits as hex for use in cache keys (exact matching).
std::string scale_to_hex(float v) {
uint32_t bits;
std::memcpy(&bits, &v, sizeof(bits));
char buf[16];
std::snprintf(buf, sizeof(buf), "%08x", bits);
return std::string(buf);
}

// Shared implementation for VJP dQ/dKV JIT kernel dispatch.
// Both kernels use identical lib_name construction, #define baking, and caching.
MTL::ComputePipelineState* get_steel_attention_vjp_kernel_impl(
metal::Device& d,
const std::string& kernel_name,
const array& q,
int bq,
int bk,
int bd,
int wm,
int wn,
int gqa_factor,
float scale,
float scale_log2,
bool align_Q,
bool align_K,
bool do_causal,
bool has_block_mask,
const char* shader_source(),
const char* template_name) {
std::string lib_name = kernel_name;
lib_name += "_gqa" + std::to_string(gqa_factor);
lib_name += "_s" + scale_to_hex(scale);
lib_name += align_Q ? "_aQ" : "_nQ";
lib_name += align_K ? "_aK" : "_nK";
lib_name += do_causal ? "_c" : "_nc";
lib_name += has_block_mask ? "_bm" : "_nbm";

auto lib = d.get_library(lib_name, [&]() {
std::string defines;
defines += "#define VJP_GQA_FACTOR " + std::to_string(gqa_factor) + "\n";
defines += "#define VJP_SCALE " + float_to_msl(scale) + "\n";
defines += "#define VJP_SCALE_LOG2 " + float_to_msl(scale_log2) + "\n";
defines += "#define VJP_BAKED_FC 1\n";
defines +=
"#define VJP_ALIGN_Q " + std::string(align_Q ? "true" : "false") + "\n";
defines +=
"#define VJP_ALIGN_K " + std::string(align_K ? "true" : "false") + "\n";
defines += "#define VJP_DO_CAUSAL " +
std::string(do_causal ? "true" : "false") + "\n";
defines += "#define VJP_HAS_BLOCK_MASK " +
std::string(has_block_mask ? "true" : "false") + "\n";

std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
defines,
shader_source(),
get_template_definition(
kernel_name,
template_name,
get_type_string(q.dtype()),
bq, bk, bd, wm, wn));
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}

} // namespace

MTL::ComputePipelineState* get_steel_attention_kernel(
metal::Device& d,
const std::string& kernel_name,
Expand All @@ -1087,16 +1170,53 @@ MTL::ComputePipelineState* get_steel_attention_kernel(
int bd,
int wm,
int wn,
const array& m) {
const auto& lib_name = kernel_name;
const array& m,
int gqa_factor,
float scale,
bool align_Q,
bool align_K,
bool has_mask,
bool do_causal,
bool has_sinks,
bool output_logsumexp) {
std::string lib_name = kernel_name;
lib_name += "_gqa" + std::to_string(gqa_factor);
lib_name += "_s" + scale_to_hex(scale);
lib_name += align_Q ? "_aQ" : "_nQ";
lib_name += align_K ? "_aK" : "_nK";
lib_name += has_mask ? "_m" : "_nm";
lib_name += do_causal ? "_c" : "_nc";
lib_name += has_sinks ? "_sk" : "_nsk";
lib_name += output_logsumexp ? "_lse" : "_nlse";

float scale_log2 = static_cast<float>(scale * M_LOG2E);

auto lib = d.get_library(lib_name, [&]() {
std::string defines;
defines += "#define FWD_GQA_FACTOR " + std::to_string(gqa_factor) + "\n";
defines += "#define FWD_SCALE_LOG2 " + float_to_msl(scale_log2) + "\n";
defines += "#define FWD_BAKED_FC 1\n";
defines +=
"#define FWD_ALIGN_Q " + std::string(align_Q ? "true" : "false") + "\n";
defines +=
"#define FWD_ALIGN_K " + std::string(align_K ? "true" : "false") + "\n";
defines +=
"#define FWD_HAS_MASK " + std::string(has_mask ? "true" : "false") +
"\n";
defines += "#define FWD_DO_CAUSAL " +
std::string(do_causal ? "true" : "false") + "\n";
defines += "#define FWD_HAS_SINKS " +
std::string(has_sinks ? "true" : "false") + "\n";
defines += "#define FWD_OUTPUT_LOGSUMEXP " +
std::string(output_logsumexp ? "true" : "false") + "\n";
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
defines,
metal::steel_attention(),
get_template_definition(
lib_name,
kernel_name,
"attention",
get_type_string(q.dtype()),
bq,
Expand All @@ -1107,7 +1227,57 @@ MTL::ComputePipelineState* get_steel_attention_kernel(
get_type_string(m.dtype())));
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
return d.get_kernel(kernel_name, lib);
}

MTL::ComputePipelineState* get_steel_attention_vjp_dq_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& /*hash_name*/,
const metal::MTLFCList& /*func_consts*/,
const array& q,
int bq,
int bk,
int bd,
int wm,
int wn,
int gqa_factor,
float scale,
float scale_log2,
bool align_Q,
bool align_K,
bool do_causal,
bool has_block_mask) {
return get_steel_attention_vjp_kernel_impl(
d, kernel_name, q, bq, bk, bd, wm, wn,
gqa_factor, scale, scale_log2, align_Q, align_K, do_causal,
has_block_mask,
metal::steel_attention_vjp_dq, "attention_vjp_dq");
}

MTL::ComputePipelineState* get_steel_attention_vjp_dkv_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& /*hash_name*/,
const metal::MTLFCList& /*func_consts*/,
const array& q,
int bq,
int bk,
int bd,
int wm,
int wn,
int gqa_factor,
float scale,
float scale_log2,
bool align_Q,
bool align_K,
bool do_causal,
bool has_block_mask) {
return get_steel_attention_vjp_kernel_impl(
d, kernel_name, q, bq, bk, bd, wm, wn,
gqa_factor, scale, scale_log2, align_Q, align_K, do_causal,
has_block_mask,
metal::steel_attention_vjp_dkv, "attention_vjp_dkv");
}

MTL::ComputePipelineState* get_steel_attention_nax_kernel(
Expand Down
48 changes: 47 additions & 1 deletion mlx/backend/metal/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,53 @@ MTL::ComputePipelineState* get_steel_attention_kernel(
int bd,
int wm,
int wn,
const array& m);
const array& m,
int gqa_factor,
float scale,
bool align_Q,
bool align_K,
bool has_mask,
bool do_causal,
bool has_sinks,
bool output_logsumexp);

MTL::ComputePipelineState* get_steel_attention_vjp_dq_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& q,
int bq,
int bk,
int bd,
int wm,
int wn,
int gqa_factor,
float scale,
float scale_log2,
bool align_Q,
bool align_K,
bool do_causal,
bool has_block_mask = false);

MTL::ComputePipelineState* get_steel_attention_vjp_dkv_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& q,
int bq,
int bk,
int bd,
int wm,
int wn,
int gqa_factor,
float scale,
float scale_log2,
bool align_Q,
bool align_K,
bool do_causal,
bool has_block_mask = false);

MTL::ComputePipelineState* get_steel_attention_nax_kernel(
metal::Device& d,
Expand Down
13 changes: 13 additions & 0 deletions mlx/backend/metal/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,17 @@ set(STEEL_ATTN_HEADERS
steel/attn/transforms.h
steel/attn/kernels/steel_attention.h)

set(STEEL_ATTN_VJP_HEADERS
steel/defines.h
steel/utils.h
steel/attn/attn.h
steel/attn/loader.h
steel/attn/mma.h
steel/attn/transforms.h
steel/attn/params.h
steel/attn/kernels/steel_attention_vjp_dq.h
steel/attn/kernels/steel_attention_vjp_dkv.h)

set(STEEL_NAX_HEADERS
steel/defines.h
steel/utils.h
Expand Down Expand Up @@ -153,6 +164,8 @@ if(NOT MLX_METAL_JIT)
build_kernel(steel/gemm/kernels/steel_gemm_segmented ${STEEL_HEADERS})
build_kernel(gemv_masked steel/utils.h)
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})
build_kernel(steel/attn/kernels/steel_attention_vjp_dq ${STEEL_ATTN_VJP_HEADERS})
build_kernel(steel/attn/kernels/steel_attention_vjp_dkv ${STEEL_ATTN_VJP_HEADERS})

if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
26.2))
Expand Down
Loading