diff --git a/benchmarks/python/sdpa_vector_vjp_bench.py b/benchmarks/python/sdpa_vector_vjp_bench.py new file mode 100644 index 0000000000..c9b4e43a1c --- /dev/null +++ b/benchmarks/python/sdpa_vector_vjp_bench.py @@ -0,0 +1,430 @@ +# Copyright © 2026 Apple Inc. +# +# Benchmark for SDPA backward (VJP) kernels. +# +# Tests three VJP paths: +# 1. Vector VJP — query seq len <= 8 (decode-time attention) +# 2. Steel VJP — fused two-kernel backward for D={64,96,128}, float16/bfloat16 +# 3. Unfused VJP — materialized attention matrix backward (reference) +# +# The benchmark explicitly forces fused vs unfused modes via MLX_SDPA_VJP_MODE +# to isolate each path's performance. Auto-dispatch threshold validation is +# tested separately. +# +# Usage: +# python sdpa_vector_vjp_bench.py [--section vector|steel|memory|auto|all] +# +# Environment: +# MLX_SDPA_VJP_LONG_L_THRESHOLD — auto-dispatch L threshold (default 1024) +# MLX_SDPA_VJP_ATTENTION_BYTES_THRESHOLD — auto-dispatch bytes threshold (default 128MB) + +import argparse +import math +import os +import subprocess +import time + +import mlx.core as mx +import numpy as np + +device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) +device_name = device_name.decode("utf-8").strip("\n") + +N_warmup = 5 +N_iter_bench = 40 +N_iter_func = 8 + + +def prepare_inputs(B, qL, kL, D, qH, kH, dtype): + np_dtype = getattr(np, dtype) + scale = 1.0 / math.sqrt(D) + q_np = np.random.normal(0.0, 1.0, (B, qH, qL, D)).astype(np_dtype) + k_np = np.random.normal(0.0, scale, (B, kH, kL, D)).astype(np_dtype) + v_np = np.random.normal(0.0, scale, (B, kH, kL, D)).astype(np_dtype) + return mx.array(q_np), mx.array(k_np), mx.array(v_np), scale + + +def mlx_ref_attn(q, k, v, scale, causal=False): + """Unfused attention: materialize full attention matrix.""" + n_q_heads = q.shape[1] + n_kv_heads = k.shape[1] + n_repeats = n_q_heads // n_kv_heads + B, _, L, D = q.shape + + q_s = q * mx.array(scale, q.dtype) + if n_repeats > 1: + q_s = mx.reshape(q_s, [B, n_kv_heads, n_repeats, L, -1]) + k_e = mx.expand_dims(k, 2) + v_e = mx.expand_dims(v, 2) + else: + k_e = k + v_e = v + + scores = q_s @ mx.swapaxes(k_e, -1, -2) + if causal: + mask = mx.triu(mx.full(scores.shape[-2:], float("-inf"), dtype=scores.dtype), k=1) + scores = scores + mask + scores = mx.softmax(scores, axis=-1, precise=True) + out = scores @ v_e + + if n_repeats > 1: + out = mx.reshape(out, [B, n_q_heads, L, -1]) + return out + + +def mlx_fused_attn(q, k, v, scale, causal=False): + """Fused SDPA — dispatches to vector, steel, or unfused VJP based on mode.""" + mask = "causal" if causal else None + return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) + + +def do_vjp_bench(f, q, k, v, scale, causal=False): + """Run VJP N_iter_func times, accumulating gradients to force computation.""" + def loss_fn(q, k, v): + return f(q, k, v, scale, causal=causal).sum() + + grad_fn = mx.grad(loss_fn, argnums=(0, 1, 2)) + dq, dk, dv = grad_fn(q, k, v) + for _ in range(N_iter_func - 1): + dq_i, dk_i, dv_i = grad_fn(q, k, v) + dq = dq + dq_i + dk = dk + dk_i + dv = dv + dv_i + mx.eval(dq, dk, dv) + + +def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, + causal=False, ref_fn=mlx_ref_attn, fused_fn=mlx_fused_attn): + q, k, v, scale = prepare_inputs(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype) + + # Warmup both paths + for _ in range(N_warmup): + do_vjp_bench(ref_fn, q, k, v, scale, causal) + do_vjp_bench(fused_fn, q, k, v, scale, causal) + + # Interleaved measurement for thermal fairness + times_unfused = [] + times_fused = [] + for _ in range(N_iter_bench): + s = time.perf_counter_ns() + do_vjp_bench(ref_fn, q, k, v, scale, causal) + e = time.perf_counter_ns() + times_unfused.append((e - s) * 1e-9) + + s = time.perf_counter_ns() + do_vjp_bench(fused_fn, q, k, v, scale, causal) + e = time.perf_counter_ns() + times_fused.append((e - s) * 1e-9) + + times_unfused.sort() + times_fused.sort() + + def stats(t): + return t[len(t) // 2], t[int(len(t) * 0.9)] + + fused_p50, fused_p90 = stats(times_fused) + unfused_p50, unfused_p90 = stats(times_unfused) + + # Correctness check + def loss_ref(q, k, v): + return ref_fn(q, k, v, scale, causal=causal).sum() + + def loss_fused(q, k, v): + return fused_fn(q, k, v, scale, causal=causal).sum() + + grads_ref = mx.grad(loss_ref, argnums=(0, 1, 2))(q, k, v) + grads_fused = mx.grad(loss_fused, argnums=(0, 1, 2))(q, k, v) + mx.eval(grads_ref, grads_fused) + + atol = 1e-5 if dtype == "float32" else 1e-2 + for i, name in enumerate(["dQ", "dK", "dV"]): + if not mx.allclose(grads_ref[i], grads_fused[i], atol=atol, rtol=atol): + max_diff = mx.max(mx.abs(grads_ref[i] - grads_fused[i])) + print( + f" ** {name} MISMATCH (B={B}, qsl={qsl}, ksl={ksl}, D={head_dim}, " + f"qH={n_q_heads}, kvH={n_kv_heads}, {'causal' if causal else 'dense'}) " + f"max|diff|={max_diff:3.2e}" + ) + + return (fused_p50, fused_p90), (unfused_p50, unfused_p90) + + +def print_header(): + print( + f"{'B':>3s}, {'qsl':>5s}, {'ksl':>5s}, {'D':>4s}, {'qH':>4s}, {'kvH':>4s}, " + f"{'mask':>6s}, {'dtype':>8s}, {'unf_p50':>8s}, {'fus_p50':>8s}, " + f"{'speedup':>8s}, {'unf_p90':>8s}, {'fus_p90':>8s}" + ) + + +def print_row(B, qsl, ksl, D, qH, kvH, causal, dtype, unfused_p50, fused_p50, + unfused_p90, fused_p90): + speedup = unfused_p50 / fused_p50 if fused_p50 > 0 else float("inf") + mask_str = "causal" if causal else "dense" + print( + f"{B:3d}, {qsl:5d}, {ksl:5d}, {D:4d}, {qH:4d}, {kvH:4d}, " + f"{mask_str:>6s}, {dtype:>8s}, {unfused_p50:8.3f}, {fused_p50:8.3f}, " + f"{speedup:7.2f}x, {unfused_p90:8.3f}, {fused_p90:8.3f}" + ) + + +def run_section(title, shapes, dtypes, set_mode=None): + """Run a benchmark section. set_mode overrides MLX_SDPA_VJP_MODE.""" + print() + print("=" * 100) + print(f" {title}") + if set_mode: + print(f" (MLX_SDPA_VJP_MODE={set_mode})") + print("=" * 100) + print_header() + + old_mode = os.environ.get("MLX_SDPA_VJP_MODE") + if set_mode: + os.environ["MLX_SDPA_VJP_MODE"] = set_mode + + try: + for dtype in dtypes: + for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, causal in shapes: + (fused_p50, fused_p90), (unfused_p50, unfused_p90) = bench_shape( + B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, causal=causal + ) + print_row(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, causal, dtype, + unfused_p50, fused_p50, unfused_p90, fused_p90) + finally: + if set_mode: + if old_mode is not None: + os.environ["MLX_SDPA_VJP_MODE"] = old_mode + else: + os.environ.pop("MLX_SDPA_VJP_MODE", None) + + +def run_vector_section(): + """Vector VJP: decode-time attention (qsl <= 8). All dtypes, all D values.""" + # fmt: off + shapes = ( + # ( B, qsl, ksl, hdim, n_qh, n_kvh, causal) + ( 1, 1, 512, 128, 32, 32, False), + ( 1, 1, 2048, 128, 32, 32, False), + ( 1, 1, 4096, 128, 32, 32, False), + ( 1, 1, 8192, 128, 32, 32, False), + ( 1, 1, 16384, 128, 32, 32, False), + ( 1, 1, 2048, 64, 32, 32, False), + ( 1, 1, 2048, 96, 32, 32, False), + ( 1, 4, 2048, 128, 32, 32, False), + ( 1, 8, 2048, 128, 32, 32, False), + # D=256 + ( 1, 1, 2048, 256, 32, 32, False), + ( 1, 4, 2048, 256, 32, 32, False), + # GQA + ( 1, 1, 2048, 128, 32, 8, False), + ( 1, 4, 2048, 128, 32, 8, False), + ) + # fmt: on + run_section("VECTOR VJP (query seq len <= 8)", shapes, ("float16", "float32")) + + +def run_steel_section(): + """Steel VJP: fused two-kernel backward. Force fused mode to isolate performance.""" + # fmt: off + shapes = ( + # ( B, qsl, ksl, hdim, n_qh, n_kvh, causal) + # --- D=64 dense --- + ( 1, 512, 512, 64, 32, 32, False), + ( 1, 1024, 1024, 64, 32, 32, False), + ( 1, 2048, 2048, 64, 32, 32, False), + ( 1, 4096, 4096, 64, 32, 32, False), + # --- D=64 causal --- + ( 1, 512, 512, 64, 32, 32, True), + ( 1, 1024, 1024, 64, 32, 32, True), + ( 1, 2048, 2048, 64, 32, 32, True), + ( 1, 4096, 4096, 64, 32, 32, True), + # --- D=96 dense --- + ( 1, 512, 512, 96, 32, 32, False), + ( 1, 1024, 1024, 96, 32, 32, False), + ( 1, 2048, 2048, 96, 32, 32, False), + # --- D=96 causal --- + ( 1, 512, 512, 96, 32, 32, True), + ( 1, 1024, 1024, 96, 32, 32, True), + ( 1, 2048, 2048, 96, 32, 32, True), + # --- D=128 dense --- + ( 1, 512, 512, 128, 32, 32, False), + ( 1, 1024, 1024, 128, 32, 32, False), + ( 1, 2048, 2048, 128, 32, 32, False), + # --- D=128 causal --- + ( 1, 512, 512, 128, 32, 32, True), + ( 1, 1024, 1024, 128, 32, 32, True), + ( 1, 2048, 2048, 128, 32, 32, True), + # --- GQA (D=64) --- + ( 1, 512, 512, 64, 32, 8, False), + ( 1, 1024, 1024, 64, 32, 8, False), + ( 1, 2048, 2048, 64, 32, 8, False), + # --- GQA (D=128) --- + ( 1, 512, 512, 128, 32, 8, False), + ( 1, 1024, 1024, 128, 32, 8, False), + ( 1, 2048, 2048, 128, 32, 8, False), + # --- Unaligned --- + ( 1, 100, 100, 64, 32, 32, False), + ( 1, 100, 100, 128, 32, 32, False), + # --- Batch > 1 --- + ( 2, 512, 512, 128, 32, 32, False), + ( 4, 256, 256, 128, 32, 32, False), + ) + # fmt: on + # Only test dtypes that steel VJP actually supports + run_section( + "STEEL VJP — Fused vs Unfused (D=64/96/128, float16 only)", + shapes, + ("float16",), + set_mode="fused", + ) + + +def run_auto_section(): + """Test auto-dispatch threshold behavior. Uses default auto mode.""" + # fmt: off + shapes = ( + # ( B, qsl, ksl, hdim, n_qh, n_kvh, causal) + # L boundary around default threshold (1024) + ( 1, 512, 512, 64, 32, 32, False), + ( 1, 1023, 1023, 64, 32, 32, False), + ( 1, 1024, 1024, 64, 32, 32, False), + ( 1, 1025, 1025, 64, 32, 32, False), + ( 1, 2048, 2048, 64, 32, 32, False), + # Same for D=128 + ( 1, 512, 512, 128, 32, 32, False), + ( 1, 1023, 1023, 128, 32, 32, False), + ( 1, 1024, 1024, 128, 32, 32, False), + ( 1, 2048, 2048, 128, 32, 32, False), + # Batch scaling: B*H*L*L*2 >= 128MB triggers fused even for short L + # B=4, H=32, L=512: 4*32*512*512*2 = 64MB (below threshold) + ( 4, 512, 512, 128, 32, 32, False), + # B=8, H=32, L=512: 8*32*512*512*2 = 128MB (at threshold) + ( 8, 512, 512, 128, 32, 32, False), + ) + # fmt: on + l_thresh = os.environ.get("MLX_SDPA_VJP_LONG_L_THRESHOLD", "1024") + bytes_thresh = os.environ.get("MLX_SDPA_VJP_ATTENTION_BYTES_THRESHOLD", str(1 << 30)) + run_section( + f"AUTO-DISPATCH THRESHOLD (L_thresh={l_thresh}, bytes_thresh={int(bytes_thresh)/1e6:.0f}MB)", + shapes, + ("float16",), + set_mode=None, # Use whatever MLX_SDPA_VJP_MODE is set (default: auto) + ) + + +def run_memory_section(): + """Memory usage comparison: fused avoids O(L^2) attention matrix.""" + print() + print("=" * 100) + print(" MEMORY USAGE (peak bytes during VJP, float16)") + print("=" * 100) + + # fmt: off + mem_configs = [ + # (B, qsl, ksl, D, qH, kvH, causal) + (1, 512, 512, 64, 32, 32, False), + (1, 1024, 1024, 64, 32, 32, False), + (1, 2048, 2048, 64, 32, 32, False), + (1, 4096, 4096, 64, 32, 32, False), + (1, 1024, 1024, 96, 32, 32, False), + (1, 2048, 2048, 96, 32, 32, False), + (1, 1024, 1024, 128, 32, 32, False), + (1, 2048, 2048, 128, 32, 32, False), + # Causal + (1, 2048, 2048, 64, 32, 32, True), + (1, 2048, 2048, 128, 32, 32, True), + # GQA + (1, 2048, 2048, 128, 32, 8, False), + ] + # fmt: on + + print( + f"{'Config':>40s}, {'Unfused':>10s}, {'Fused':>10s}, " + f"{'Savings':>8s}, {'Attn Matrix':>12s}" + ) + + for B, qsl, ksl, D, qH, kvH, causal in mem_configs: + _scale = 1.0 / math.sqrt(D) + attn_bytes = B * qH * qsl * ksl * 2 # float16 + mask_str = "causal" if causal else "dense" + label = f"D={D},L={qsl},{mask_str},H={qH}{'/' + str(kvH) if qH != kvH else ''}" + + # Measure unfused + os.environ["MLX_SDPA_VJP_MODE"] = "unfused" + mx.clear_cache() + q, k, v, scale = prepare_inputs(B, qsl, ksl, D, qH, kvH, "float16") + mx.eval(q, k, v) + + def loss_ref(q, k, v): + return mlx_ref_attn(q, k, v, _scale, causal=causal).sum() + + grad_ref = mx.grad(loss_ref, argnums=(0, 1, 2)) + mx.reset_peak_memory() + mx.eval(grad_ref(q, k, v)) + mem_unfused = mx.get_peak_memory() + + # Measure fused + os.environ["MLX_SDPA_VJP_MODE"] = "fused" + mx.clear_cache() + q, k, v, scale = prepare_inputs(B, qsl, ksl, D, qH, kvH, "float16") + mx.eval(q, k, v) + + def loss_fused(q, k, v): + return mlx_fused_attn(q, k, v, _scale, causal=causal).sum() + + grad_fused = mx.grad(loss_fused, argnums=(0, 1, 2)) + mx.reset_peak_memory() + mx.eval(grad_fused(q, k, v)) + mem_fused = mx.get_peak_memory() + + os.environ.pop("MLX_SDPA_VJP_MODE", None) + + savings = 1.0 - mem_fused / mem_unfused if mem_unfused > 0 else 0.0 + print( + f"{label:>40s}, {mem_unfused / 1e6:>8.1f}MB, " + f"{mem_fused / 1e6:>8.1f}MB, {100 * savings:>6.1f}%, " + f"{attn_bytes / 1e6:>10.1f}MB" + ) + + print() + print("Note: Measured peak memory is allocator-reported and may be affected by") + print("caching and memory pooling. Fused VJP avoids materializing the O(L^2)") + print("attention matrix that unfused backward requires.") + + +# ─── main ────────────────────────────────────────────────────────────────────── + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="SDPA VJP Benchmark") + parser.add_argument( + "--section", + choices=["vector", "steel", "memory", "auto", "all"], + default="all", + help="Which section to run (default: all)", + ) + args = parser.parse_args() + + print(f"Device: {device_name}") + print() + print("Benchmark measures fused (steel) vs unfused (materialized) VJP backward.") + print("Speedup > 1.0x means fused is faster. P50 (median) of interleaved runs.") + print() + + vjp_mode = os.environ.get("MLX_SDPA_VJP_MODE", "auto") + l_thresh = os.environ.get("MLX_SDPA_VJP_LONG_L_THRESHOLD", "1024") + bytes_thresh = os.environ.get("MLX_SDPA_VJP_ATTENTION_BYTES_THRESHOLD", str(1 << 30)) + print(f"Default dispatch: MLX_SDPA_VJP_MODE={vjp_mode}, " + f"L_threshold={l_thresh}, bytes_threshold={int(bytes_thresh) / 1e6:.0f}MB") + + sections = { + "vector": run_vector_section, + "steel": run_steel_section, + "auto": run_auto_section, + "memory": run_memory_section, + } + + if args.section == "all": + for fn in sections.values(): + fn() + else: + sections[args.section]() diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cpp b/mlx/backend/cuda/scaled_dot_product_attention.cpp index 93f310f56b..6258237b66 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cpp +++ b/mlx/backend/cuda/scaled_dot_product_attention.cpp @@ -551,7 +551,6 @@ bool ScaledDotProductAttention::use_fallback( bool has_mask, bool has_arr_mask, bool do_causal, - bool is_training, bool output_logsumexp, Stream s) { if (s.device == Device::cpu) { @@ -614,7 +613,14 @@ void ScaledDotProductAttention::eval_gpu( } } -bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) { +bool ScaledDotProductAttentionVJP::use_fallback( + const array& q, + const array& /* k */, + Stream s, + bool, + bool, + bool, + int) { // The frontend adds a padding mask when sequence length is not a multiple of // tile size. if (q.shape(2) % 128 != 0) { @@ -630,8 +636,8 @@ void ScaledDotProductAttentionVJP::eval_gpu( auto& s = stream(); - assert(inputs.size() >= 6); - int primals_size = inputs.size() - 3; + assert(inputs.size() >= 7); // primals(>=3) + O + LSE + dO + delta + int primals_size = inputs.size() - 4; bool has_arr_mask = primals_size > 3 + has_sinks_; array q = prepare_sdpa_input(inputs[0], s); @@ -647,7 +653,7 @@ void ScaledDotProductAttentionVJP::eval_gpu( } std::optional sinks; if (has_sinks_) { - sinks = prepare_sdpa_sinks(inputs.back(), s); + sinks = prepare_sdpa_sinks(inputs[primals_size - 1], s); } assert(outputs.size() == 3); diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 67c69579ad..a99b8a1472 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -82,6 +82,8 @@ if(MLX_METAL_JIT) make_jit_source(gemv_masked) make_jit_source(steel/attn/kernels/steel_attention) + make_jit_source(steel/attn/kernels/steel_attention_vjp_dq) + make_jit_source(steel/attn/kernels/steel_attention_vjp_dkv) make_jit_source( steel/gemm/gemm_nax kernels/steel/utils.h kernels/steel/gemm/nax.h diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index dcaf09a1e9..439f508f94 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -45,6 +45,8 @@ const char* steel_conv_3d(); const char* steel_conv_general(); const char* gemv_masked(); const char* steel_attention(); +const char* steel_attention_vjp_dq(); +const char* steel_attention_vjp_dkv(); const char* gemm_nax(); const char* steel_gemm_fused_nax(); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index a0703cd875..d46e269cb9 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -1,4 +1,8 @@ // Copyright © 2024 Apple Inc. +#include +#include +#include + #include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/kernels.h" @@ -1076,6 +1080,85 @@ MTL::ComputePipelineState* get_gather_qmm_nax_kernel( return d.get_kernel(kernel_name, lib, hash_name, func_consts); } +namespace { + +// Produce a valid MSL float literal with enough precision for exact round-trip. +std::string float_to_msl(float v) { + char buf[32]; + std::snprintf(buf, sizeof(buf), "%.9gf", v); + return std::string(buf); +} + +// Encode a float's bits as hex for use in cache keys (exact matching). +std::string scale_to_hex(float v) { + uint32_t bits; + std::memcpy(&bits, &v, sizeof(bits)); + char buf[16]; + std::snprintf(buf, sizeof(buf), "%08x", bits); + return std::string(buf); +} + +// Shared implementation for VJP dQ/dKV JIT kernel dispatch. +// Both kernels use identical lib_name construction, #define baking, and caching. +MTL::ComputePipelineState* get_steel_attention_vjp_kernel_impl( + metal::Device& d, + const std::string& kernel_name, + const array& q, + int bq, + int bk, + int bd, + int wm, + int wn, + int gqa_factor, + float scale, + float scale_log2, + bool align_Q, + bool align_K, + bool do_causal, + bool has_block_mask, + const char* shader_source(), + const char* template_name) { + std::string lib_name = kernel_name; + lib_name += "_gqa" + std::to_string(gqa_factor); + lib_name += "_s" + scale_to_hex(scale); + lib_name += align_Q ? "_aQ" : "_nQ"; + lib_name += align_K ? "_aK" : "_nK"; + lib_name += do_causal ? "_c" : "_nc"; + lib_name += has_block_mask ? "_bm" : "_nbm"; + + auto lib = d.get_library(lib_name, [&]() { + std::string defines; + defines += "#define VJP_GQA_FACTOR " + std::to_string(gqa_factor) + "\n"; + defines += "#define VJP_SCALE " + float_to_msl(scale) + "\n"; + defines += "#define VJP_SCALE_LOG2 " + float_to_msl(scale_log2) + "\n"; + defines += "#define VJP_BAKED_FC 1\n"; + defines += + "#define VJP_ALIGN_Q " + std::string(align_Q ? "true" : "false") + "\n"; + defines += + "#define VJP_ALIGN_K " + std::string(align_K ? "true" : "false") + "\n"; + defines += "#define VJP_DO_CAUSAL " + + std::string(do_causal ? "true" : "false") + "\n"; + defines += "#define VJP_HAS_BLOCK_MASK " + + std::string(has_block_mask ? "true" : "false") + "\n"; + + std::string kernel_source; + concatenate( + kernel_source, + metal::utils(), + defines, + shader_source(), + get_template_definition( + kernel_name, + template_name, + get_type_string(q.dtype()), + bq, bk, bd, wm, wn)); + return kernel_source; + }); + return d.get_kernel(kernel_name, lib); +} + +} // namespace + MTL::ComputePipelineState* get_steel_attention_kernel( metal::Device& d, const std::string& kernel_name, @@ -1087,16 +1170,53 @@ MTL::ComputePipelineState* get_steel_attention_kernel( int bd, int wm, int wn, - const array& m) { - const auto& lib_name = kernel_name; + const array& m, + int gqa_factor, + float scale, + bool align_Q, + bool align_K, + bool has_mask, + bool do_causal, + bool has_sinks, + bool output_logsumexp) { + std::string lib_name = kernel_name; + lib_name += "_gqa" + std::to_string(gqa_factor); + lib_name += "_s" + scale_to_hex(scale); + lib_name += align_Q ? "_aQ" : "_nQ"; + lib_name += align_K ? "_aK" : "_nK"; + lib_name += has_mask ? "_m" : "_nm"; + lib_name += do_causal ? "_c" : "_nc"; + lib_name += has_sinks ? "_sk" : "_nsk"; + lib_name += output_logsumexp ? "_lse" : "_nlse"; + + float scale_log2 = static_cast(scale * M_LOG2E); + auto lib = d.get_library(lib_name, [&]() { + std::string defines; + defines += "#define FWD_GQA_FACTOR " + std::to_string(gqa_factor) + "\n"; + defines += "#define FWD_SCALE_LOG2 " + float_to_msl(scale_log2) + "\n"; + defines += "#define FWD_BAKED_FC 1\n"; + defines += + "#define FWD_ALIGN_Q " + std::string(align_Q ? "true" : "false") + "\n"; + defines += + "#define FWD_ALIGN_K " + std::string(align_K ? "true" : "false") + "\n"; + defines += + "#define FWD_HAS_MASK " + std::string(has_mask ? "true" : "false") + + "\n"; + defines += "#define FWD_DO_CAUSAL " + + std::string(do_causal ? "true" : "false") + "\n"; + defines += "#define FWD_HAS_SINKS " + + std::string(has_sinks ? "true" : "false") + "\n"; + defines += "#define FWD_OUTPUT_LOGSUMEXP " + + std::string(output_logsumexp ? "true" : "false") + "\n"; std::string kernel_source; concatenate( kernel_source, metal::utils(), + defines, metal::steel_attention(), get_template_definition( - lib_name, + kernel_name, "attention", get_type_string(q.dtype()), bq, @@ -1107,7 +1227,57 @@ MTL::ComputePipelineState* get_steel_attention_kernel( get_type_string(m.dtype()))); return kernel_source; }); - return d.get_kernel(kernel_name, lib, hash_name, func_consts); + return d.get_kernel(kernel_name, lib); +} + +MTL::ComputePipelineState* get_steel_attention_vjp_dq_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& /*hash_name*/, + const metal::MTLFCList& /*func_consts*/, + const array& q, + int bq, + int bk, + int bd, + int wm, + int wn, + int gqa_factor, + float scale, + float scale_log2, + bool align_Q, + bool align_K, + bool do_causal, + bool has_block_mask) { + return get_steel_attention_vjp_kernel_impl( + d, kernel_name, q, bq, bk, bd, wm, wn, + gqa_factor, scale, scale_log2, align_Q, align_K, do_causal, + has_block_mask, + metal::steel_attention_vjp_dq, "attention_vjp_dq"); +} + +MTL::ComputePipelineState* get_steel_attention_vjp_dkv_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& /*hash_name*/, + const metal::MTLFCList& /*func_consts*/, + const array& q, + int bq, + int bk, + int bd, + int wm, + int wn, + int gqa_factor, + float scale, + float scale_log2, + bool align_Q, + bool align_K, + bool do_causal, + bool has_block_mask) { + return get_steel_attention_vjp_kernel_impl( + d, kernel_name, q, bq, bk, bd, wm, wn, + gqa_factor, scale, scale_log2, align_Q, align_K, do_causal, + has_block_mask, + metal::steel_attention_vjp_dkv, "attention_vjp_dkv"); } MTL::ComputePipelineState* get_steel_attention_nax_kernel( diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 63fccc59ff..ea23663910 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -344,7 +344,53 @@ MTL::ComputePipelineState* get_steel_attention_kernel( int bd, int wm, int wn, - const array& m); + const array& m, + int gqa_factor, + float scale, + bool align_Q, + bool align_K, + bool has_mask, + bool do_causal, + bool has_sinks, + bool output_logsumexp); + +MTL::ComputePipelineState* get_steel_attention_vjp_dq_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& q, + int bq, + int bk, + int bd, + int wm, + int wn, + int gqa_factor, + float scale, + float scale_log2, + bool align_Q, + bool align_K, + bool do_causal, + bool has_block_mask = false); + +MTL::ComputePipelineState* get_steel_attention_vjp_dkv_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& q, + int bq, + int bk, + int bd, + int wm, + int wn, + int gqa_factor, + float scale, + float scale_log2, + bool align_Q, + bool align_K, + bool do_causal, + bool has_block_mask = false); MTL::ComputePipelineState* get_steel_attention_nax_kernel( metal::Device& d, diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 8d3d8a1953..ee27bbb33e 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -98,6 +98,17 @@ set(STEEL_ATTN_HEADERS steel/attn/transforms.h steel/attn/kernels/steel_attention.h) +set(STEEL_ATTN_VJP_HEADERS + steel/defines.h + steel/utils.h + steel/attn/attn.h + steel/attn/loader.h + steel/attn/mma.h + steel/attn/transforms.h + steel/attn/params.h + steel/attn/kernels/steel_attention_vjp_dq.h + steel/attn/kernels/steel_attention_vjp_dkv.h) + set(STEEL_NAX_HEADERS steel/defines.h steel/utils.h @@ -153,6 +164,8 @@ if(NOT MLX_METAL_JIT) build_kernel(steel/gemm/kernels/steel_gemm_segmented ${STEEL_HEADERS}) build_kernel(gemv_masked steel/utils.h) build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS}) + build_kernel(steel/attn/kernels/steel_attention_vjp_dq ${STEEL_ATTN_VJP_HEADERS}) + build_kernel(steel/attn/kernels/steel_attention_vjp_dkv ${STEEL_ATTN_VJP_HEADERS}) if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL 26.2)) diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 1eec72be31..6363d06364 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -11,6 +11,7 @@ constant bool bool_mask [[function_constant(23)]]; constant bool float_mask [[function_constant(24)]]; constant bool has_sinks [[function_constant(25)]]; constant int blocks [[function_constant(26)]]; +constant bool output_lse [[function_constant(28)]]; template [[kernel]] void sdpa_vector( @@ -81,8 +82,10 @@ template out += o_offset * V + simd_gid * v_per_thread; // Read the query and 0 the output accumulator + // Scale by M_LOG2E_F to match STEEL attention domain (exp2 instead of exp) + const U log2e_scale = static_cast(scale * M_LOG2E_F); for (int i = 0; i < qk_per_thread; i++) { - q[i] = static_cast(scale) * queries[i]; + q[i] = log2e_scale * queries[i]; } for (int i = 0; i < v_per_thread; i++) { o[i] = 0; @@ -91,7 +94,8 @@ template U max_score = Limits::finite_min; U sum_exp_score = 0; if (has_sinks && simd_gid == 0) { - max_score = static_cast(sinks[q_batch_head_idx % num_q_heads]); + // Scale sink by M_LOG2E_F to match log2 domain + max_score = static_cast(M_LOG2E_F) * static_cast(sinks[q_batch_head_idx % num_q_heads]); sum_exp_score = 1; } @@ -118,13 +122,14 @@ template } score = simd_sum(score); if (float_mask) { - score += static_cast(fmask[0]); + // Scale float mask by M_LOG2E_F to match log2 domain + score += static_cast(M_LOG2E_F) * static_cast(fmask[0]); } - // Update the accumulators + // Update the accumulators (using exp2 to match STEEL attention) U new_max = max(max_score, score); - U factor = fast::exp(max_score - new_max); - U exp_score = fast::exp(score - new_max); + U factor = fast::exp2(max_score - new_max); + U exp_score = fast::exp2(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; @@ -156,7 +161,7 @@ template threadgroup_barrier(mem_flags::mem_threadgroup); max_score = max_scores[simd_lid]; U new_max = simd_max(max_score); - U factor = fast::exp(max_score - new_max); + U factor = fast::exp2(max_score - new_max); sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); // Now we need to aggregate all the outputs @@ -199,6 +204,7 @@ template const constant int& mask_head_stride [[buffer(17), function_constant(has_mask)]], const device T* sinks [[buffer(18), function_constant(has_sinks)]], + device float* lse_out [[buffer(19), function_constant(output_lse)]], uint3 tptg [[threads_per_threadgroup]], uint3 tidtg [[thread_position_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], @@ -246,16 +252,22 @@ template } sums += o_offset * blocks + block_idx; maxs += o_offset * blocks + block_idx; + if (output_lse) { + lse_out += o_offset; + } - // Read the query + // Read the query and 0 the output accumulator + // Scale by M_LOG2E_F to match STEEL attention domain (exp2 instead of exp) + const U log2e_scale = static_cast(scale * M_LOG2E_F); for (int i = 0; i < qk_per_thread; i++) { - q[i] = static_cast(scale) * queries[i]; + q[i] = log2e_scale * queries[i]; } U max_score = Limits::finite_min; U sum_exp_score = 0; if (has_sinks && block_idx == 0) { - max_score = static_cast(sinks[q_head_idx]); + // Scale sink by M_LOG2E_F to match log2 domain + max_score = static_cast(M_LOG2E_F) * static_cast(sinks[q_head_idx]); sum_exp_score = 1; } @@ -278,13 +290,14 @@ template score = simd_sum(score); if (float_mask) { - score += fmask[0]; + // Scale float mask by M_LOG2E_F to match log2 domain + score += static_cast(M_LOG2E_F) * static_cast(fmask[0]); } - // Update the accumulators + // Update the accumulators (using exp2 to match STEEL attention) U new_max = max(max_score, score); - U factor = fast::exp(max_score - new_max); - U exp_score = fast::exp(score - new_max); + U factor = fast::exp2(max_score - new_max); + U exp_score = fast::exp2(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; @@ -324,6 +337,7 @@ template const device float* maxs [[buffer(2)]], device T* out [[buffer(3)]], const constant int& blocks [[buffer(4)]], + device float* lse_out [[buffer(5), function_constant(output_lse)]], uint3 tid [[threadgroup_position_in_grid]], uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -345,6 +359,9 @@ template sums += q_offset * blocks; maxs += q_offset * blocks; out += q_offset * D + simd_gid * elem_per_thread; + if (output_lse) { + lse_out += q_offset; + } // Set defaults U sum_exp_score = 0.0; @@ -356,16 +373,16 @@ template } max_score = simd_max(max_score); - // Reduce the d + // Reduce the d (using exp2 to match log2 domain from pass 1) for (int b = 0; b < blocks / BN; ++b) { - U factor = fast::exp(maxs[simd_lid + BN * b] - max_score); + U factor = fast::exp2(maxs[simd_lid + BN * b] - max_score); sum_exp_score += factor * sums[simd_lid + BN * b]; } sum_exp_score = simd_sum(sum_exp_score); // Reduce the sum exp and partials for (int b = 0; b < blocks / BN; ++b) { - U factor = fast::exp(maxs[simd_gid] - max_score); + U factor = fast::exp2(maxs[simd_gid] - max_score); // Update the output accumulator for (int i = 0; i < elem_per_thread; i++) { @@ -376,6 +393,12 @@ template partials += BN * D; } + // Write logsumexp if requested: lse = max_score + log2(sum_exp_score) + // max_score and sum_exp_score are in log2 domain + if (output_lse && simd_gid == 0 && simd_lid == 0) { + lse_out[0] = max_score + metal::log2(sum_exp_score); + } + // Use shared memory to transpose and reduce the final block for (int i = 0; i < elem_per_thread; i++) { outputs[simd_lid * BD + simd_gid] = o[i]; diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index 0d9628e834..2e094c8296 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -2,18 +2,39 @@ #include "mlx/backend/metal/kernels/steel/attn/attn.h" +// JIT-baked constants: when compiled via JIT, these are #defined as literals +// before this header is included. For metallib builds, they fall back to +// params-> reads at runtime (values) or function constants (booleans). +#ifndef FWD_GQA_FACTOR + #define FWD_GQA_FACTOR (params->gqa_factor) + #define FWD_SCALE_LOG2 (params->scale * M_LOG2E_F) + #define FWD_UNDEF_DEFINES +#endif + using namespace mlx::steel; /////////////////////////////////////////////////////////////////////////////// // GEMM kernels /////////////////////////////////////////////////////////////////////////////// +// When JIT-compiled, boolean flags are baked as constexpr booleans, +// enabling full dead-code elimination. Metallib builds use function constants. +#ifdef FWD_BAKED_FC +constexpr constant bool align_Q = FWD_ALIGN_Q; +constexpr constant bool align_K = FWD_ALIGN_K; +constexpr constant bool has_mask = FWD_HAS_MASK; +constexpr constant bool do_causal = FWD_DO_CAUSAL; +constexpr constant bool has_sinks = FWD_HAS_SINKS; +constexpr constant bool output_logsumexp = FWD_OUTPUT_LOGSUMEXP; +#else constant bool align_Q [[function_constant(200)]]; constant bool align_K [[function_constant(201)]]; constant bool has_mask [[function_constant(300)]]; constant bool do_causal [[function_constant(301)]]; constant bool has_sinks [[function_constant(302)]]; +constant bool output_logsumexp [[function_constant(303)]]; +#endif struct MaxOp { template @@ -76,6 +97,7 @@ template < const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], const device MaskType* mask [[buffer(6), function_constant(has_mask)]], const device T* sinks [[buffer(7), function_constant(has_sinks)]], + device float* LSE [[buffer(8), function_constant(output_logsumexp)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], @@ -91,7 +113,7 @@ template < tidl.y * params->Q_strides[1] + // Head tidl.x * BQ * params->Q_strides[2]; // Sequence - ulong kv_head_idx = int(tid.y) / params->gqa_factor; + ulong kv_head_idx = int(tid.y) / FWD_GQA_FACTOR; K += tidl.z * params->K_strides[0] + // Batch kv_head_idx * params->K_strides[1]; // Head @@ -163,7 +185,7 @@ template < VBlockLoader loader_v( V, params->V_strides[2], Vs, simd_group_id, simd_lane_id); - const AccumType scale = params->scale * M_LOG2E_F; + const AccumType scale = FWD_SCALE_LOG2; // Prepare MMA tiles constexpr short kFragSize = 8; // MMAFrag size @@ -460,6 +482,32 @@ template < Otile.template row_bin_op(sum_score); threadgroup_barrier(mem_flags::mem_none); + // Output logsumexp if requested for VJP backward pass + // LSE = max_score + log2(sum_score) in log2 domain (matches STEEL convention) + // Physical storage shape: [B*H, qL], laid out as linear array indexed by (B*H + // + head)*qL + query_pos LSE_strides[0] = qL (stride between (batch, head) + // rows) LSE_strides[1] = 1 (stride between query positions within a row) + if (output_logsumexp) { + // Compute linear index for (batch, head) combination + // This matches the VJP kernel's indexing: (tidl.z * H + tidl.y) * + // LSE_strides[0] + device float* lse_out = + LSE + (tidl.z * params->H + tidl.y) * params->LSE_strides[0]; + + // Write one logsumexp per query position in this tile + // Each thread handles kRowsPT query positions + // align_Q=true means query length is aligned (all blocks full), so always + // write align_Q=false means last block is partial, so check bounds + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + int row_pos = tid.x * BQ + tm + sm + (i * decltype(Stile)::kFragRows); + if (align_Q || row_pos < params->qL) { + AccumType lse_val = max_score[i] + fast::log2(sum_score[i]); + lse_out[row_pos * params->LSE_strides[1]] = static_cast(lse_val); + } + } + } + // Store results O += (tm + sm) * params->O_strides[2] + sn; @@ -474,3 +522,9 @@ template < Otile.template store(O, params->O_strides[2]); } } + +#ifdef FWD_UNDEF_DEFINES + #undef FWD_GQA_FACTOR + #undef FWD_SCALE_LOG2 + #undef FWD_UNDEF_DEFINES +#endif diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal index 0ff9d91b00..7bddfcb054 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal @@ -13,6 +13,7 @@ #define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \ instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 32, 96, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype) diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dkv.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dkv.h new file mode 100644 index 0000000000..4165b945d5 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dkv.h @@ -0,0 +1,715 @@ +// Copyright © 2024-25 Apple Inc. + +#pragma once + +/////////////////////////////////////////////////////////////////////////////// +// STEEL VJP dKV Kernel for Scaled Dot-Product Attention +// +// Supports WM=1 (32 threads, single simdgroup) and WM=2 (64 threads, two +// simdgroups). WM=2 halves per-thread register pressure (~216 vs ~364 regs +// for D=128) at the cost of a threadgroup reduction at the end. +// dO is loaded on-the-fly from smem (not hoisted to registers) to further +// reduce live register count by TQ*TD*2 floats (~64 regs for D=128 WM=2). +// +// Grid: [NK, num_kv_heads, B] - one threadgroup per (kv_block, kv_head, batch) +// Loop: Over GQA query heads, then over Q blocks to accumulate dK/dV +// +// Algorithm (log2 domain): +// S = Q @ K^T (unscaled) +// S *= scale_log2 (post-scale in float32) +// P = exp2(S - LSE) +// dV += P^T @ dO (via scatter-to-smem transpose) +// dP = dO @ V^T +// dS = scale * P * (dP - delta) (scale baked into dS) +// dK += dS^T @ Q (via scatter-to-smem transpose) +// +// See companion kernel steel_attention_vjp_dq.h for dQ computation. +/////////////////////////////////////////////////////////////////////////////// + +// JIT-baked constants: when compiled via JIT, these are #defined as literals +// before this header is included. For metallib builds, they fall back to +// params-> reads at runtime. +#ifndef VJP_GQA_FACTOR + #define VJP_GQA_FACTOR (params->gqa_factor) + #define VJP_SCALE (params->scale) + #define VJP_SCALE_LOG2 (params->scale_log2) + #define VJP_UNDEF_DEFINES +#endif + +using namespace mlx::steel; + +// When JIT-compiled, align/causal flags are baked as constexpr booleans, +// enabling full dead-code elimination. Metallib builds use function constants. +#ifdef VJP_BAKED_FC +constexpr constant bool align_Q_vjp_dkv = VJP_ALIGN_Q; +constexpr constant bool align_K_vjp_dkv = VJP_ALIGN_K; +constexpr constant bool do_causal_vjp_dkv = VJP_DO_CAUSAL; +constexpr constant bool has_block_mask_vjp_dkv = VJP_HAS_BLOCK_MASK; +#else +constant bool align_Q_vjp_dkv [[function_constant(200)]]; +constant bool align_K_vjp_dkv [[function_constant(201)]]; +constant bool do_causal_vjp_dkv [[function_constant(301)]]; +constant bool has_block_mask_vjp_dkv [[function_constant(302)]]; +#endif + +/////////////////////////////////////////////////////////////////////////////// +// STEEL Attention VJP dKV Kernel +/////////////////////////////////////////////////////////////////////////////// + +// clang-format off +template < + typename T, + int BQ, // Query block size (32) + int BK, // KV block size (16 or 32) + int BD, // Head dimension (64, 96, 128) + int WM, // Warps in M dimension (1 or 2) + int WN, // Warps in N dimension (1) + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] +void attention_vjp_dkv( + // Forward inputs + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + const device float* delta [[buffer(3)]], + const device T* dO [[buffer(4)]], + const device float* LSE [[buffer(5)]], + // Gradient outputs (dK and dV) + device T* dK [[buffer(6)]], + device T* dV [[buffer(7)]], + // Parameters + const constant AttnVJPParams* params [[buffer(8)]], + // Sparse block mask (optional, gated by has_block_mask function constant) + const device uint8_t* block_mask [[buffer(9)]], + // Thread info + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // clang-format on + + (void)lid; + + // ========================================================================= + // Constants + // ========================================================================= + constexpr short kFragSize = 8; + using MMAFrag_acc_t = BaseMMAFrag; + + constexpr int kNWarps = WM * WN; + constexpr int TGP_SIZE = kNWarps * 32; + constexpr int TQ = BQ / (kNWarps * kFragSize); // 4 for WM=1, 2 for WM=2 + constexpr int TK = BK / kFragSize; + constexpr int TD = BD / kFragSize; + + // ========================================================================= + // Simd coordinates + // WM=1: tm=0 always. WM=2: tm=0 (sg0) or TQ*8 (sg1). + // ========================================================================= + const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + const short sm = simd_coord.y; + const short sn = simd_coord.x; + const short tm = kFragSize * TQ * simd_group_id; + + // ========================================================================= + // Thread/block IDs + // ========================================================================= + int kb = tid.x; + ulong3 tidl{tid.x, tid.y, tid.z}; + ulong kv_head_idx = int(tid.y); + + // ========================================================================= + // Shared memory layout + // ========================================================================= + constexpr short pad = 16 / sizeof(T); + constexpr short LDQ = BD + pad; // Q/dO row stride + constexpr short LDKt = BK + pad; // K^T/V^T row stride (transposed) + constexpr short LDT = BQ + pad; // P^T/dS^T row stride (scatter) + + // KV_smem: aliased for K^T, V^T, P^T scatter, dS^T scatter + constexpr int kv_s0 = BD * LDKt; // K^T/V^T (BD rows x LDKt cols) + constexpr int kv_s1 = BK * LDT; // P^T/dS^T (BK rows x LDT cols) + constexpr int kv_s = kv_s0 > kv_s1 ? kv_s0 : kv_s1; + + // Q_smem and dO_smem are used only during the iteration phase (Q-block loop). + // red_smem is used only during the post-loop reduction phase (WM>1). + // Since they are temporally disjoint, alias red_smem over the Q+dO region + // to reduce threadgroup memory (e.g. D=128: 23 KB → 15 KB, enabling 2 TGs/core). + constexpr int kQdO_elems = 2 * BQ * LDQ; + + constexpr int kRedTK = (TK <= 2) ? TK : TK / 2; + constexpr int kRedRows = kRedTK * kFragSize; + constexpr int kRedSize = (kNWarps > 1) ? kRedRows * BD : 1; + + static_assert( + kNWarps == 1 || kQdO_elems * sizeof(T) >= kRedSize * sizeof(AccumType), + "QdO smem region too small to alias with red_smem"); + + threadgroup T QdO_smem[kQdO_elems]; + threadgroup T* Q_smem = QdO_smem; + threadgroup T* dO_smem = QdO_smem + BQ * LDQ; + + threadgroup T KV_smem[kv_s]; + + // red_smem aliases over QdO_smem (safe: temporally disjoint with Q/dO usage). + // For WM=1, red_smem is never accessed (compiler eliminates the reduction block). + threadgroup AccumType* red_smem = (threadgroup AccumType*)QdO_smem; + + // Smem offsets for fragment reads (each simdgroup reads its own Q rows) + const short Qs_off = (tm + sm) * LDQ + sn; + const short Kts_off = sm * LDKt + sn; + + // ========================================================================= + // K, V, dK, dV pointers (fixed for all Q iterations) + // ========================================================================= + const device T* K_block = K + tidl.z * params->K_strides[0] + + kv_head_idx * params->K_strides[1] + kb * BK * params->K_strides[2]; + + const device T* V_block = V + tidl.z * params->V_strides[0] + + kv_head_idx * params->V_strides[1] + kb * BK * params->V_strides[2]; + + device T* dK_block = dK + tidl.z * params->dK_strides[0] + + kv_head_idx * params->dK_strides[1] + kb * BK * params->dK_strides[2]; + + device T* dV_block = dV + tidl.z * params->dV_strides[0] + + kv_head_idx * params->dV_strides[1] + kb * BK * params->dV_strides[2]; + + // ========================================================================= + // Block loader types + // ========================================================================= + using QBlockLoader = BlockLoaderT; + using KtBlockLoader = BlockLoaderT; + + // ========================================================================= + // dK, dV accumulators — full [TK, TD] tiles (no D-column distribution) + // ========================================================================= + MMATile dKtile; + MMATile dVtile; + dKtile.clear(); + dVtile.clear(); + + // ========================================================================= + // Q block loop bounds (causal: skip Q-tiles fully below this K-tile) + // ========================================================================= + int qb_start = 0; + if (do_causal_vjp_dkv) { + int k_start = kb * BK; + qb_start = max(0, (k_start - params->qL_off) / BQ); + } + + // ========================================================================= + // Main loop: iterate over GQA heads, then Q blocks + // ========================================================================= + const ulong q_head_start = kv_head_idx * VJP_GQA_FACTOR; + STEEL_PRAGMA_UNROLL + for (int gqa_idx = 0; gqa_idx < VJP_GQA_FACTOR; gqa_idx++) { + ulong q_head_idx = q_head_start + gqa_idx; + for (int qb = qb_start; qb < params->NQ; qb++) { + + // Block-sparse: skip Q-tiles where block_mask[qb][kb] == 0. + // All threads in a threadgroup share kb and qb, so this is a + // uniform branch — no warp divergence, just skips the iteration. + if (has_block_mask_vjp_dkv && !block_mask[qb * params->NK_tiles + kb]) { + continue; + } + + // Per-head pointers + const device T* Q_ptr = Q + tidl.z * params->Q_strides[0] + + q_head_idx * params->Q_strides[1] + + qb * BQ * params->Q_strides[2]; + + const device T* dO_ptr = dO + tidl.z * params->dO_strides[0] + + q_head_idx * params->dO_strides[1] + + qb * BQ * params->dO_strides[2]; + + // ======================================================================= + // Load Q, dO, K^T into shared memory + // ======================================================================= + QBlockLoader loader_q( + Q_ptr, params->Q_strides[2], Q_smem, simd_group_id, simd_lane_id); + QBlockLoader loader_do( + dO_ptr, params->dO_strides[2], dO_smem, simd_group_id, simd_lane_id); + KtBlockLoader loader_kt( + K_block, params->K_strides[2], KV_smem, simd_group_id, simd_lane_id); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (!align_Q_vjp_dkv && qb == params->NQ_aligned) { + loader_q.load_safe(short2(BD, params->qL_rem)); + loader_do.load_safe(short2(BD, params->qL_rem)); + } else { + loader_q.load_unsafe(); + loader_do.load_unsafe(); + } + + if (!align_K_vjp_dkv && kb == params->NK_aligned) { + loader_kt.load_safe(short2(BD, params->kL_rem)); + } else { + loader_kt.load_unsafe(); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // ======================================================================= + // Hoist Q into register tile (each simdgroup loads its own rows). + // dO is loaded on-the-fly from smem to reduce register pressure. + // For D=128 WM=2, this saves ~64 registers/thread (TQ*TD*2 floats). + // ======================================================================= + MMATile Qreg; + Qreg.template load(&Q_smem[Qs_off]); + + // ======================================================================= + // S = Q @ K^T (unscaled) + // ======================================================================= + MMATile Stile; + MMATile Ktile; + Stile.clear(); + + STEEL_PRAGMA_UNROLL + for (short dd = 0; dd < TD; dd++) { + Ktile.template load( + &KV_smem[Kts_off + dd * kFragSize * LDKt]); + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + MMAFrag_acc_t::mma( + Stile.frag_at(iq, ik), + Qreg.frag_at(iq, dd), + Ktile.frag_at(0, ik), + Stile.frag_at(iq, ik)); + } + } + } + + // ======================================================================= + // Post-scale S *= scale_log2 (in float32) + // ======================================================================= + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < TQ * TK * 2; ii++) { + Stile.elems()[ii] *= VJP_SCALE_LOG2; + } + + // ======================================================================= + // K boundary mask (last K block) + // ======================================================================= + if (!align_K_vjp_dkv && kb == params->NK_aligned) { + constexpr AccumType neg_inf = -INFINITY; + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TK; j++) { + short col = sn + j * kFragSize; + if (col >= params->kL_rem) + Stile.frag_at(iq, j)[0] = neg_inf; + if ((col + 1) >= params->kL_rem) + Stile.frag_at(iq, j)[1] = neg_inf; + } + } + } + + // ======================================================================= + // Q boundary mask (last Q block — ensures exp2 gives exact zeros) + // ======================================================================= + if (!align_Q_vjp_dkv && qb == params->NQ_aligned) { + constexpr AccumType neg_inf = -INFINITY; + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + if ((tm + iq * kFragSize + sm) >= params->qL_rem) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TK; j++) { + Stile.frag_at(iq, j)[0] = neg_inf; + Stile.frag_at(iq, j)[1] = neg_inf; + } + } + } + } + + // ======================================================================= + // Causal mask + // ======================================================================= + if (do_causal_vjp_dkv) { + constexpr AccumType neg_inf = -INFINITY; + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + int q_row = qb * BQ + params->qL_off + tm + iq * kFragSize + sm; + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TK; j++) { + int k_col = kb * BK + sn + j * kFragSize; + if (q_row < k_col) + Stile.frag_at(iq, j)[0] = neg_inf; + if (q_row < (k_col + 1)) + Stile.frag_at(iq, j)[1] = neg_inf; + } + } + } + + // ======================================================================= + // Read LSE and delta from device memory (no shared memory needed) + // ======================================================================= + const long lse_base = + (long)(tidl.z * params->H + q_head_idx) * params->LSE_strides[0]; + const long delta_base = + (long)(tidl.z * params->H + q_head_idx) * params->delta_strides[0]; + AccumType L_vals[TQ]; + AccumType delta_vals[TQ]; + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + long q_row_idx = (long)qb * BQ + tm + iq * kFragSize + sm; + L_vals[iq] = (q_row_idx < params->qL) + ? LSE[lse_base + q_row_idx * params->LSE_strides[1]] + : AccumType(0); + delta_vals[iq] = (q_row_idx < params->qL) + ? delta[delta_base + q_row_idx * params->delta_strides[1]] + : AccumType(0); + } + + // ======================================================================= + // P = exp2(S - LSE) + // ======================================================================= + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TK; j++) { + Stile.frag_at(iq, j)[0] = + fast::exp2(Stile.frag_at(iq, j)[0] - L_vals[iq]); + Stile.frag_at(iq, j)[1] = + fast::exp2(Stile.frag_at(iq, j)[1] - L_vals[iq]); + } + } + // Stile now holds P + + // ======================================================================= + // dV += P^T @ dO + // Step 1: Scatter P^T -> KV_smem[BK x BQ+pad] as type T + // ======================================================================= + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TK; j++) { + KV_smem[(j * kFragSize + sn) * LDT + tm + iq * kFragSize + sm] = + static_cast(Stile.frag_at(iq, j)[0]); + KV_smem[(j * kFragSize + sn + 1) * LDT + tm + iq * kFragSize + sm] = + static_cast(Stile.frag_at(iq, j)[1]); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Step 2: Load Pt_tile[TK, TQ] from KV_smem + MMATile Pt_tile; + Pt_tile.template load( + &KV_smem[sm * LDT + tm + sn]); + + // Step 3: dV[TK, TD] += Pt[TK, TQ] @ dO[TQ, TD] + // dO loaded on-the-fly from smem (saves TQ*TD*2 = 64 regs for D=128). + // Reduction over TQ (=2 for WM=2), so only 1 dO fragment live at a time. + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short id = 0; id < TD; id++) { + typename MMAFrag_acc_t::frag_type dO_frag; + MMAFrag_acc_t::load( + dO_frag, + &dO_smem[(tm + iq * kFragSize + sm) * LDQ + + id * kFragSize + sn], + Int{}, + Int<1>{}); + + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + MMAFrag_acc_t::mma( + dVtile.frag_at(ik, id), + Pt_tile.frag_at(ik, iq), + dO_frag, + dVtile.frag_at(ik, id)); + } + } + } + + // ======================================================================= + // dP = dO @ V^T + // Load V^T into KV_smem (aliased, overwrites P^T scatter) + // ======================================================================= + threadgroup_barrier(mem_flags::mem_threadgroup); + + KtBlockLoader loader_vt( + V_block, params->V_strides[2], KV_smem, simd_group_id, simd_lane_id); + if (!align_K_vjp_dkv && kb == params->NK_aligned) { + loader_vt.load_safe(short2(BD, params->kL_rem)); + } else { + loader_vt.load_unsafe(); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + MMATile dPtile; + dPtile.clear(); + + STEEL_PRAGMA_UNROLL + for (short dd = 0; dd < TD; dd++) { + // Reuse Ktile to load V^T row (same KV_smem layout as K^T) + Ktile.template load( + &KV_smem[Kts_off + dd * kFragSize * LDKt]); + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + // Load dO fragment on-the-fly from smem + typename MMAFrag_acc_t::frag_type dO_frag; + MMAFrag_acc_t::load( + dO_frag, + &dO_smem[(tm + iq * kFragSize + sm) * LDQ + + dd * kFragSize + sn], + Int{}, + Int<1>{}); + + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + MMAFrag_acc_t::mma( + dPtile.frag_at(iq, ik), + dO_frag, + Ktile.frag_at(0, ik), + dPtile.frag_at(iq, ik)); + } + } + } + + // ======================================================================= + // dS = scale * P * (dP - delta) + // Reuse Stile (which held P) — overwrite with dS + // ======================================================================= + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TK; j++) { + Stile.frag_at(iq, j)[0] = VJP_SCALE * Stile.frag_at(iq, j)[0] * + (dPtile.frag_at(iq, j)[0] - delta_vals[iq]); + Stile.frag_at(iq, j)[1] = VJP_SCALE * Stile.frag_at(iq, j)[1] * + (dPtile.frag_at(iq, j)[1] - delta_vals[iq]); + } + } + // Stile now holds dS + + // ======================================================================= + // dK += dS^T @ Q + // Step 1: Scatter dS^T -> KV_smem[BK x BQ+pad] as type T + // ======================================================================= + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TK; j++) { + KV_smem[(j * kFragSize + sn) * LDT + tm + iq * kFragSize + sm] = + static_cast(Stile.frag_at(iq, j)[0]); + KV_smem[(j * kFragSize + sn + 1) * LDT + tm + iq * kFragSize + sm] = + static_cast(Stile.frag_at(iq, j)[1]); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Step 2: Load dSt_tile[TK, TQ] from KV_smem + MMATile dSt_tile; + dSt_tile.template load( + &KV_smem[sm * LDT + tm + sn]); + + // Step 3: dK[TK, TD] += dSt[TK, TQ] @ Q[TQ, TD] + tile_matmad(dKtile, dSt_tile, Qreg, dKtile); + + } // End Q block loop + } // End GQA loop + + // ========================================================================= + // Multi-warp reduction: sum partial dK/dV across simdgroups + // For WM=1 this block is eliminated by the compiler (kNWarps == 1). + // For WM=2: sg0 stores its partial to red_smem, sg1 reads and adds. + // Two phases: dV first, then dK (reusing the same red_smem buffer). + // NOTE: red_smem is aliased over QdO_smem. This barrier ensures all + // threads have finished reading Q_smem/dO_smem from the last iteration + // before red_smem writes begin overwriting that memory. + // ========================================================================= + if constexpr (kNWarps > 1) { + threadgroup_barrier(mem_flags::mem_threadgroup); + constexpr int kChunks = TK / kRedTK; + + // Phase 1: Reduce dV (in chunks of kRedTK tile-rows) + for (int chunk = 0; chunk < kChunks; chunk++) { + int ik_base = chunk * kRedTK; + if (simd_group_id == 0) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < kRedTK; ik++) { + STEEL_PRAGMA_UNROLL + for (short id = 0; id < TD; id++) { + short row = ik * kFragSize + sm; + short col = id * kFragSize + sn; + red_smem[row * BD + col] = dVtile.frag_at(ik_base + ik, id)[0]; + red_smem[row * BD + col + 1] = dVtile.frag_at(ik_base + ik, id)[1]; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (simd_group_id == kNWarps - 1) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < kRedTK; ik++) { + STEEL_PRAGMA_UNROLL + for (short id = 0; id < TD; id++) { + short row = ik * kFragSize + sm; + short col = id * kFragSize + sn; + dVtile.frag_at(ik_base + ik, id)[0] += red_smem[row * BD + col]; + dVtile.frag_at(ik_base + ik, id)[1] += red_smem[row * BD + col + 1]; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Phase 2: Reduce dK (same chunked pattern) + for (int chunk = 0; chunk < kChunks; chunk++) { + int ik_base = chunk * kRedTK; + if (simd_group_id == 0) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < kRedTK; ik++) { + STEEL_PRAGMA_UNROLL + for (short id = 0; id < TD; id++) { + short row = ik * kFragSize + sm; + short col = id * kFragSize + sn; + red_smem[row * BD + col] = dKtile.frag_at(ik_base + ik, id)[0]; + red_smem[row * BD + col + 1] = dKtile.frag_at(ik_base + ik, id)[1]; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (simd_group_id == kNWarps - 1) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < kRedTK; ik++) { + STEEL_PRAGMA_UNROLL + for (short id = 0; id < TD; id++) { + short row = ik * kFragSize + sm; + short col = id * kFragSize + sn; + dKtile.frag_at(ik_base + ik, id)[0] += red_smem[row * BD + col]; + dKtile.frag_at(ik_base + ik, id)[1] += red_smem[row * BD + col + 1]; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } + + // ========================================================================= + // Write dK and dV to device memory + // For WM>1, only the last simdgroup writes (it holds the reduced result). + // ========================================================================= + if (kNWarps == 1 || simd_group_id == (kNWarps - 1)) { + dV_block += sm * (long)params->dV_strides[2] + sn; + dK_block += sm * (long)params->dK_strides[2] + sn; + + if (!align_K_vjp_dkv && kb == params->NK_aligned) { + auto dims = short2((short)(BD - sn), (short)(params->kL_rem - sm)); + if (dims.x > 0 && dims.y > 0) { + dVtile.template store_safe( + dV_block, (int)params->dV_strides[2], dims); + dKtile.template store_safe( + dK_block, (int)params->dK_strides[2], dims); + } + } else { + dVtile.template store(dV_block, (int)params->dV_strides[2]); + dKtile.template store(dK_block, (int)params->dK_strides[2]); + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Template instantiation macros +/////////////////////////////////////////////////////////////////////////////// + +// WM=1 kernel names (backward compatible, no _wm suffix) +#define instantiate_attention_vjp_dkv_kernel(tname, dtype, bq, bk, bd, wm, wn) \ + template [[host_name( \ + "attention_vjp_dkv_" #tname "_" #bq "_" #bk "_" #bd)]] [[kernel]] void \ + attention_vjp_dkv( \ + const device dtype*, \ + const device dtype*, \ + const device dtype*, \ + const device float*, \ + const device dtype*, \ + const device float*, \ + device dtype*, \ + device dtype*, \ + const constant AttnVJPParams*, \ + const device uint8_t*, \ + uint, \ + uint, \ + uint3, \ + uint3); + +// WM>1 kernel names (includes _wmN suffix for disambiguation) +#define instantiate_attention_vjp_dkv_kernel_wm(tname, dtype, bq, bk, bd, wm, wn) \ + template [[host_name( \ + "attention_vjp_dkv_" #tname "_" #bq "_" #bk "_" #bd "_wm" #wm)]] \ + [[kernel]] void attention_vjp_dkv( \ + const device dtype*, \ + const device dtype*, \ + const device dtype*, \ + const device float*, \ + const device dtype*, \ + const device float*, \ + device dtype*, \ + device dtype*, \ + const constant AttnVJPParams*, \ + const device uint8_t*, \ + uint, \ + uint, \ + uint3, \ + uint3); + +// WM=1, WN=1 for all configurations (single simdgroup, MFA-aligned) +// D=64: BK=32 (~14KB smem), D=96/128: BK=16 (~11/14KB smem) +// dKV dispatch always uses BK=16 for D>64 (higher per-thread register pressure +// at BK=32 outweighs benefit of fewer KV iterations for WM=1 kernel). +#define instantiate_attention_vjp_dkv_bd64(tname, dtype) \ + instantiate_attention_vjp_dkv_kernel(tname, dtype, 32, 32, 64, 1, 1) + +#define instantiate_attention_vjp_dkv_bd96(tname, dtype) \ + instantiate_attention_vjp_dkv_kernel(tname, dtype, 32, 16, 96, 1, 1) + +#define instantiate_attention_vjp_dkv_bd128(tname, dtype) \ + instantiate_attention_vjp_dkv_kernel(tname, dtype, 32, 16, 128, 1, 1) + +// WM=2 variants for D>=96 (reduced register pressure: ~280 vs ~428 regs/thread) +#define instantiate_attention_vjp_dkv_bd96_wm2(tname, dtype) \ + instantiate_attention_vjp_dkv_kernel_wm(tname, dtype, 32, 16, 96, 2, 1) + +#define instantiate_attention_vjp_dkv_bd128_wm2(tname, dtype) \ + instantiate_attention_vjp_dkv_kernel_wm(tname, dtype, 32, 16, 128, 2, 1) + +// BQ=16 WM=2 for D=128: TQ=1 per simdgroup → ~202 regs/thread (no spilling!) +// Smem: max(Q+dO(8704), red(8192)) + KV(6144) = 14848 bytes (2 TGs/core) +// Q/dO and red_smem are aliased (temporally disjoint: iterations vs reduction). +// Trade-off: 2x more Q-tile iterations vs BQ=32, but each runs spill-free. +#define instantiate_attention_vjp_dkv_bd128_bq16_wm2(tname, dtype) \ + instantiate_attention_vjp_dkv_kernel_wm(tname, dtype, 16, 16, 128, 2, 1) + +#define instantiate_attention_vjp_dkv_all(tname, dtype) \ + instantiate_attention_vjp_dkv_bd64(tname, dtype) \ + instantiate_attention_vjp_dkv_bd96(tname, dtype) \ + instantiate_attention_vjp_dkv_bd128(tname, dtype) \ + instantiate_attention_vjp_dkv_bd96_wm2(tname, dtype) \ + instantiate_attention_vjp_dkv_bd128_wm2(tname, dtype) \ + instantiate_attention_vjp_dkv_bd128_bq16_wm2(tname, dtype) + +#ifdef VJP_UNDEF_DEFINES + #undef VJP_GQA_FACTOR + #undef VJP_SCALE + #undef VJP_SCALE_LOG2 + #undef VJP_UNDEF_DEFINES +#endif diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dkv.metal b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dkv.metal new file mode 100644 index 0000000000..fa2afbce36 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dkv.metal @@ -0,0 +1,11 @@ +// Copyright © 2024-25 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/attn/attn.h" + +#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dkv.h" + +instantiate_attention_vjp_dkv_all(float16, half); +instantiate_attention_vjp_dkv_all(bfloat16, bfloat16_t); +// clang-format on diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dq.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dq.h new file mode 100644 index 0000000000..aaa9d24a68 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dq.h @@ -0,0 +1,468 @@ +// Copyright © 2024-25 Apple Inc. + +#pragma once + +/////////////////////////////////////////////////////////////////////////////// +// STEEL VJP dQ Kernel for Scaled Dot-Product Attention +// +// Part of the two-kernel backward pass optimization that eliminates atomic +// operations. This kernel computes ONLY dQ gradients. +// +// Grid: [NQ, H, B] - one threadgroup per (query_block, head, batch) +// Loop: Over KV blocks to accumulate dQ +// +// Algorithm (MFA-aligned, log2 domain): +// S = Q @ K^T (unscaled) +// S *= scale_log2 (post-scale in float32) +// P = exp2(S - LSE) +// dP = dO @ V^T +// dS = scale * P * (dP - delta) (scale baked into dS) +// dQ += dS @ K (no write-back scale needed) +// +// Architecture (MFA-aligned): +// - Aliased KV_smem: single buffer for K^T, V^T, K row-major (3 phases) +// - Q/dO hoisted to register tiles after initial smem load +// - 3-phase per KV iteration: (1) K_t→S, (2) V_t→dP, (3) K_r→dQ +// - LSE/delta read from device memory +// - Per-fragment MMA for S and dP +// +// See companion kernel steel_attention_vjp_dkv.h for dK/dV computation. +/////////////////////////////////////////////////////////////////////////////// + +// JIT-baked constants: when compiled via JIT, these are #defined as literals +// before this header is included. For metallib builds, they fall back to +// params-> reads at runtime (values) or function constants (booleans). +#ifndef VJP_GQA_FACTOR + #define VJP_GQA_FACTOR (params->gqa_factor) + #define VJP_SCALE (params->scale) + #define VJP_SCALE_LOG2 (params->scale_log2) + #define VJP_UNDEF_DEFINES +#endif + +using namespace mlx::steel; + +// When JIT-compiled, align/causal flags are baked as constexpr booleans, +// enabling full dead-code elimination. Metallib builds use function constants. +#ifdef VJP_BAKED_FC +constexpr constant bool align_Q_vjp_dq = VJP_ALIGN_Q; +constexpr constant bool align_K_vjp_dq = VJP_ALIGN_K; +constexpr constant bool do_causal_vjp_dq = VJP_DO_CAUSAL; +constexpr constant bool has_block_mask_vjp_dq = VJP_HAS_BLOCK_MASK; +#else +constant bool align_Q_vjp_dq [[function_constant(200)]]; +constant bool align_K_vjp_dq [[function_constant(201)]]; +constant bool do_causal_vjp_dq [[function_constant(301)]]; +constant bool has_block_mask_vjp_dq [[function_constant(302)]]; +#endif + +/////////////////////////////////////////////////////////////////////////////// +// STEEL Attention VJP dQ Kernel +/////////////////////////////////////////////////////////////////////////////// + +// clang-format off +template < + typename T, + int BQ, // Query block size (32) + int BK, // KV block size (16) + int BD, // Head dimension (64, 96, 128) + int WM, // Warps in M dimension (4) + int WN, // Warps in N dimension (1) + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] +void attention_vjp_dq( + // Forward inputs + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + const device float* delta [[buffer(3)]], + const device T* dO [[buffer(4)]], + const device float* LSE [[buffer(5)]], + // Gradient output (dQ only) + device T* dQ [[buffer(6)]], + // Parameters + const constant AttnVJPParams* params [[buffer(7)]], + // Sparse block mask (optional, gated by has_block_mask function constant) + const device uint8_t* block_mask [[buffer(8)]], + // Thread info + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // clang-format on + + (void)lid; + + ulong3 tidl{tid.x, tid.y, tid.z}; + + // Input pointer setup + const device T* Q_block = Q + + tidl.z * params->Q_strides[0] + + tidl.y * params->Q_strides[1] + + tidl.x * BQ * params->Q_strides[2]; + + ulong kv_head_idx = int(tid.y) / VJP_GQA_FACTOR; + const device T* K_base = K + + tidl.z * params->K_strides[0] + + kv_head_idx * params->K_strides[1]; + + const device T* V_base = V + + tidl.z * params->V_strides[0] + + kv_head_idx * params->V_strides[1]; + + const device T* dO_block = dO + + tidl.z * params->dO_strides[0] + + tidl.y * params->dO_strides[1] + + tidl.x * BQ * params->dO_strides[2]; + + device T* dQ_block = dQ + + tidl.z * params->dQ_strides[0] + + tidl.y * params->dQ_strides[1] + + tidl.x * BQ * params->dQ_strides[2]; + + // ========================================================================= + // Threadgroup memory setup (MFA-aligned) + // KV_smem aliased for K^T, V^T, K row-major (3 separate load phases) + // ========================================================================= + constexpr short pad = 16 / sizeof(T); + constexpr short LDQ = BD + pad; // Q/dO row stride in smem + constexpr short LDKt = BK + pad; // K^T/V^T row stride (transposed) + constexpr short LDKr = BD + pad; // K row-major row stride + + // KV_smem aliased for K^T, V^T, K row-major + constexpr int kv_s0 = BD * LDKt; // K^T or V^T size + constexpr int kv_s1 = BK * LDKr; // K row-major size + constexpr int kv_s = kv_s0 > kv_s1 ? kv_s0 : kv_s1; + + threadgroup T Q_smem[BQ * LDQ]; // staging only + threadgroup T dO_smem[BQ * LDQ]; // staging only + threadgroup T KV_smem[kv_s]; // aliased buffer + + // Block loaders + using QBlockLoader = BlockLoaderT; + using KtBlockLoader = BlockLoaderT; + using KrBlockLoader = BlockLoaderT; + + QBlockLoader loader_q(Q_block, params->Q_strides[2], Q_smem, simd_group_id, simd_lane_id); + QBlockLoader loader_do(dO_block, params->dO_strides[2], dO_smem, simd_group_id, simd_lane_id); + KtBlockLoader loader_kt(K_base, params->K_strides[2], KV_smem, simd_group_id, simd_lane_id); + KtBlockLoader loader_vt(V_base, params->V_strides[2], KV_smem, simd_group_id, simd_lane_id); + KrBlockLoader loader_kr(K_base, params->K_strides[2], KV_smem, simd_group_id, simd_lane_id); + + // MMA setup + constexpr short kFragSize = 8; + using MMAFrag_acc_t = BaseMMAFrag; + + constexpr int kNWarps = WM * WN; + constexpr int TQ = BQ / (kNWarps * kFragSize); + constexpr int TK = BK / kFragSize; + constexpr int TD = BD / kFragSize; + + static_assert(TQ == 1, "TQ must be 1"); + + // Register tiles + MMATile Qtile; // Q hoisted from smem + MMATile dOtile; // dO hoisted from smem + MMATile dQtile; // accumulator + MMATile Stile; // S / P / dS + MMATile dPtile; // dP + MMATile Ktile; // K fragment per dd + MMATile KRtile; // K row-major fragment + + // Coordinates + const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + const short sm = simd_coord.y; + const short sn = simd_coord.x; + const short tm = kFragSize * TQ * simd_group_id; + + // Smem offsets + const short Qs_off = (tm + sm) * LDQ + sn; // Q/dO read offset + const short Kts_off = sm * LDKt + sn; // K^T/V^T read offset + const short KRs_off = sm * LDKr + sn; // K row-major read offset + + // ========================================================================= + // Pre-loop: Load Q/dO → hoist to registers → read LSE/delta + // ========================================================================= + if (!align_Q_vjp_dq && int(tid.x) == params->NQ_aligned) { + loader_q.load_safe(short2(BD, params->qL_rem)); + loader_do.load_safe(short2(BD, params->qL_rem)); + } else { + loader_q.load_unsafe(); + loader_do.load_unsafe(); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Hoist Q and dO to register tiles + Qtile.template load(&Q_smem[Qs_off]); + dOtile.template load(&dO_smem[Qs_off]); + + dQtile.clear(); + + // Read LSE and delta from device memory (scalar per row) + const long lse_base = (long)(tidl.z * params->H + tidl.y) * params->LSE_strides[0]; + const long delta_base = (long)(tidl.z * params->H + tidl.y) * params->delta_strides[0]; + const long q_row_idx = (long)tid.x * BQ + tm + sm; + const AccumType L_val = (q_row_idx < params->qL) + ? LSE[lse_base + q_row_idx * params->LSE_strides[1]] : AccumType(0); + const AccumType delta_val = (q_row_idx < params->qL) + ? delta[delta_base + q_row_idx * params->delta_strides[1]] : AccumType(0); + + // KV loop bounds + int kb_lim = params->NK; + if (do_causal_vjp_dq) { + int q_max = (tid.x + 1) * BQ + params->qL_off; + kb_lim = (q_max + BK - 1) / BK; + kb_lim = min(params->NK, kb_lim); + } + + // ========================================================================= + // Main loop over KV blocks (3-phase per iteration) + // ========================================================================= + const int qb = tid.x; // Q-block index for block_mask lookup + + for (int kb = 0; kb < kb_lim; kb++) { + + // Block-sparse: skip K-tiles where block_mask[qb][kb] == 0. + // All threads in a threadgroup share tid.x and kb, so this is a + // uniform branch — no warp divergence, just skips the barriers and math. + if (has_block_mask_vjp_dq && !block_mask[qb * params->NK_tiles + kb]) { + loader_kt.next(); + loader_vt.next(); + loader_kr.next(); + continue; + } + + // ===================================================================== + // Phase 1: K^T → S = Q @ K^T + // ===================================================================== + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (!align_K_vjp_dq && kb == params->NK_aligned) { + loader_kt.load_safe(short2(BD, params->kL_rem)); + } else { + loader_kt.load_unsafe(); + } + + Stile.clear(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_UNROLL + for (short dd = 0; dd < TD; dd++) { + Ktile.template load(&KV_smem[Kts_off + dd * kFragSize * LDKt]); + + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + MMAFrag_acc_t::mma( + Stile.frag_at(0, ik), + Qtile.frag_at(0, dd), + Ktile.frag_at(0, ik), + Stile.frag_at(0, ik)); + } + } + + // Post-scale S to log2 domain: S *= scale_log2 (in float32) + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TK; j++) { + Stile.frag_at(0, j)[0] *= VJP_SCALE_LOG2; + Stile.frag_at(0, j)[1] *= VJP_SCALE_LOG2; + } + + // Apply sequence length mask + if (!align_K_vjp_dq && kb == params->NK_aligned) { + using stile_t = decltype(Stile); + constexpr AccumType neg_inf = -INFINITY; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + short col_pos = sn + (j * stile_t::kFragCols); + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { + if ((col_pos + jj) >= params->kL_rem) { + Stile.frag_at(i, j)[jj] = neg_inf; + } + } + } + } + } + + // Apply causal mask + if (do_causal_vjp_dq && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K_vjp_dq))) { + using stile_t = decltype(Stile); + constexpr AccumType neg_inf = -INFINITY; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + const int row_pos = tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows); + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + const int col_pos = kb * BK + sn + (j * stile_t::kFragCols); + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { + if (row_pos < (col_pos + jj)) { + Stile.frag_at(i, j)[jj] = neg_inf; + } + } + } + } + } + + // ===================================================================== + // Phase 2: V^T → dP = dO @ V^T + // ===================================================================== + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (!align_K_vjp_dq && kb == params->NK_aligned) { + loader_vt.load_safe(short2(BD, params->kL_rem)); + } else { + loader_vt.load_unsafe(); + } + + dPtile.clear(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_UNROLL + for (short dd = 0; dd < TD; dd++) { + Ktile.template load(&KV_smem[Kts_off + dd * kFragSize * LDKt]); + + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + MMAFrag_acc_t::mma( + dPtile.frag_at(0, ik), + dOtile.frag_at(0, dd), + Ktile.frag_at(0, ik), + dPtile.frag_at(0, ik)); + } + } + + // ===================================================================== + // Softmax + dS (no barrier needed — purely register operations) + // ===================================================================== + + // P = exp2(S - LSE) + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TK; j++) { + Stile.frag_at(0, j)[0] = fast::exp2(Stile.frag_at(0, j)[0] - L_val); + Stile.frag_at(0, j)[1] = fast::exp2(Stile.frag_at(0, j)[1] - L_val); + } + + // dS = scale * P * (dP - delta) + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TK; j++) { + Stile.frag_at(0, j)[0] = VJP_SCALE * Stile.frag_at(0, j)[0] * (dPtile.frag_at(0, j)[0] - delta_val); + Stile.frag_at(0, j)[1] = VJP_SCALE * Stile.frag_at(0, j)[1] * (dPtile.frag_at(0, j)[1] - delta_val); + } + // Stile now holds dS + + // ===================================================================== + // Phase 3: K row-major → dQ += dS @ K + // ===================================================================== + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (!align_K_vjp_dq && kb == params->NK_aligned) { + loader_kr.load_safe(short2(BD, params->kL_rem)); + } else { + loader_kr.load_unsafe(); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + STEEL_PRAGMA_UNROLL + for (short id = 0; id < TD; id++) { + KRtile.template load( + &KV_smem[KRs_off + ik * kFragSize * LDKr + id * kFragSize]); + MMAFrag_acc_t::mma( + dQtile.frag_at(iq, id), + Stile.frag_at(iq, ik), + KRtile.frag_at(0, 0), + dQtile.frag_at(iq, id)); + } + } + } + + loader_kt.next(); + loader_vt.next(); + loader_kr.next(); + } + + // ========================================================================= + // Write dQ output — no scale needed (scale already baked into dS) + // ========================================================================= + threadgroup_barrier(mem_flags::mem_none); + + dQ_block += (tm + sm) * params->dQ_strides[2] + sn; + + if (!align_Q_vjp_dq && int(tid.x) == params->NQ_aligned) { + auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm)); + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + dQtile.template store_safe(dQ_block, params->dQ_strides[2], dst_tile_dims); + } else { + dQtile.template store(dQ_block, params->dQ_strides[2]); + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Template instantiation macros +/////////////////////////////////////////////////////////////////////////////// + +// tname is the string name used in kernel lookup (e.g., "float32", "float16") +// dtype is the actual C++ type (e.g., float, half, bfloat16_t) +#define instantiate_attention_vjp_dq_kernel(tname, dtype, bq, bk, bd, wm, wn) \ + template [[host_name("attention_vjp_dq_" #tname "_" #bq "_" #bk "_" #bd)]] \ + [[kernel]] void attention_vjp_dq( \ + const device dtype*, \ + const device dtype*, \ + const device dtype*, \ + const device float*, \ + const device dtype*, \ + const device float*, \ + device dtype*, \ + const constant AttnVJPParams*, \ + const device uint8_t*, \ + uint, \ + uint, \ + uint3, \ + uint3); + +// Common configurations (2 bytes per half in threadgroup memory): +// D=64: BK=32, Q+K+V+dO = ~19KB +// D=96: BK=16, Q+K+V+dO = ~22KB +// D=128: BK=16, Q+KV(aliased)+dO = ~24KB +#define instantiate_attention_vjp_dq_bd64(tname, dtype) \ + instantiate_attention_vjp_dq_kernel(tname, dtype, 32, 32, 64, 4, 1) + +#define instantiate_attention_vjp_dq_bd96(tname, dtype) \ + instantiate_attention_vjp_dq_kernel(tname, dtype, 32, 16, 96, 4, 1) + +// BK=32 for D=96 on M3+ (halves KV iterations, fits in ~22KB smem) +#define instantiate_attention_vjp_dq_bd96_bk32(tname, dtype) \ + instantiate_attention_vjp_dq_kernel(tname, dtype, 32, 32, 96, 4, 1) + +#define instantiate_attention_vjp_dq_bd128(tname, dtype) \ + instantiate_attention_vjp_dq_kernel(tname, dtype, 32, 16, 128, 4, 1) + +// BK=32 for D=128 on M3+ (halves KV iterations, fits in ~27KB smem) +#define instantiate_attention_vjp_dq_bd128_bk32(tname, dtype) \ + instantiate_attention_vjp_dq_kernel(tname, dtype, 32, 32, 128, 4, 1) + +#define instantiate_attention_vjp_dq_all(tname, dtype) \ + instantiate_attention_vjp_dq_bd64(tname, dtype) \ + instantiate_attention_vjp_dq_bd96(tname, dtype) \ + instantiate_attention_vjp_dq_bd96_bk32(tname, dtype) \ + instantiate_attention_vjp_dq_bd128(tname, dtype) \ + instantiate_attention_vjp_dq_bd128_bk32(tname, dtype) + +#ifdef VJP_UNDEF_DEFINES + #undef VJP_GQA_FACTOR + #undef VJP_SCALE + #undef VJP_SCALE_LOG2 + #undef VJP_UNDEF_DEFINES +#endif diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dq.metal b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dq.metal new file mode 100644 index 0000000000..1e2114ef5a --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dq.metal @@ -0,0 +1,11 @@ +// Copyright © 2024-25 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/attn/attn.h" + +#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dq.h" + +instantiate_attention_vjp_dq_all(float16, half); +instantiate_attention_vjp_dq_all(bfloat16, bfloat16_t); +// clang-format on diff --git a/mlx/backend/metal/kernels/steel/attn/loader.h b/mlx/backend/metal/kernels/steel/attn/loader.h index 7ec798146b..a43f039168 100644 --- a/mlx/backend/metal/kernels/steel/attn/loader.h +++ b/mlx/backend/metal/kernels/steel/attn/loader.h @@ -197,11 +197,41 @@ struct BlockLoaderT { /* Load from device memory into threadgroup memory - without bound checking */ METAL_FUNC void load_unsafe() const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { + // Use vec vectorized loads when destination is contiguous (row-major) + // and we have enough elements for 4-wide loads. This matches MFA's + // MFABlockLoaderT optimization for Q, dO, and K row-major loaders. + constexpr bool can_vectorize = (kDstStrCol == 1) && (vec_size % 4 == 0); + if constexpr (can_vectorize) { + // Runtime alignment check: vec4 cast requires src_ld to be a multiple + // of 4 so that src + i * src_ld + j stays 4-element aligned. This holds + // for contiguous tensors (src_ld = D ∈ {64, 96, 128}) but may not for + // sliced tensors that only satisfy is_matrix_contiguous (strides[-1]==1). + if (src_ld % 4 == 0) { + using vec4_t = vec; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j += 4) { + *(threadgroup vec4_t*)(dst + i * kDstStrRow + j) = + *(const device vec4_t*)(src + i * src_ld + j); + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j] = src[i * src_ld + j]; + } + } + } + } else { STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j]; + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j]; + } } } } diff --git a/mlx/backend/metal/kernels/steel/attn/params.h b/mlx/backend/metal/kernels/steel/attn/params.h index f1cf09fada..1aedc3b86c 100644 --- a/mlx/backend/metal/kernels/steel/attn/params.h +++ b/mlx/backend/metal/kernels/steel/attn/params.h @@ -34,11 +34,50 @@ struct AttnParams { int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1) int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1) int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1) + int64_t LSE_strides[2]; ///< LSE strides (B*H, L) - logsumexp output for VJP }; struct AttnMaskParams { int64_t M_strides[3]; ///< Mask strides (B, H, qL, kL = 1) }; +struct AttnVJPParams { + int B; ///< Batch Size + int H; ///< Heads (query heads) + int D; ///< Head Dim + + int qL; ///< Query Sequence Length + int kL; ///< Key Sequence Length + + int gqa_factor; ///< Group Query factor + float scale; ///< Attention scale + float scale_log2; ///< scale * log2(e) for log2-domain scoring + + int NQ; ///< Number of query blocks + int NK; ///< Number of key/value blocks + + int NQ_aligned; ///< Number of full query blocks + int NK_aligned; ///< Number of full key/value blocks + + int qL_rem; ///< Remainder in last query block + int kL_rem; ///< Remainder in last key/value block + int qL_off; ///< Offset in query sequence start + + int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1) + int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1) + int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1) + int64_t dO_strides[3]; ///< dO strides (B, H, L, D = 1) + int64_t LSE_strides[2]; ///< LSE strides (B*H, L) - logsumexp + int64_t delta_strides[2]; ///< delta strides (B*H, L) - rowsum(dO*O) + + // VJP-specific output strides + int64_t dQ_strides[3]; ///< dQ strides (B, H, L, D = 1) + int64_t dK_strides[3]; ///< dK strides (B, H, L, D = 1) + int64_t dV_strides[3]; ///< dV strides (B, H, L, D = 1) + + // Sparse block mask support + int NK_tiles; ///< Number of K-tile columns (for block_mask indexing) +}; + } // namespace steel } // namespace mlx diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index a0b02084c2..5d11690124 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -400,12 +400,34 @@ MTL::ComputePipelineState* get_steel_attention_kernel( const std::string& hash_name, const metal::MTLFCList& func_consts, const array&, - int, - int, - int, - int, - int, - const array&) { + int, int, int, int, int, + const array&, + int, float, + bool, bool, bool, bool, bool, bool) { + return d.get_kernel(kernel_name, hash_name, func_consts); +} + +MTL::ComputePipelineState* get_steel_attention_vjp_dq_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array&, + int, int, int, int, int, + int, float, float, + bool, bool, bool, bool) { + return d.get_kernel(kernel_name, hash_name, func_consts); +} + +MTL::ComputePipelineState* get_steel_attention_vjp_dkv_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array&, + int, int, int, int, int, + int, float, float, + bool, bool, bool, bool) { return d.get_kernel(kernel_name, hash_name, func_consts); } diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 37e554f183..962770ccad 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -1,5 +1,9 @@ // Copyright © 2024 Apple Inc. +#include +#include #include +#include +#include #include "mlx/backend/common/compiled.h" #include "mlx/backend/gpu/copy.h" @@ -9,12 +13,67 @@ #include "mlx/backend/metal/kernels/steel/attn/params.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" +#include "mlx/ops.h" #include "mlx/utils.h" namespace mlx::core::fast { namespace { +// Shared eligibility check for STEEL fused VJP kernels. +// Must be consistent between use_fallback() and eval_gpu() to avoid UB. +inline bool steel_vjp_eligible(int head_dim, Dtype dtype) { + return (head_dim == 64 || head_dim == 96 || head_dim == 128) && + (dtype == float16 || dtype == bfloat16); +} + +// Copy predicates shared between forward and VJP eval_gpu methods. + +// Returns true if the array's last dimension has stride 1. +bool is_matrix_contiguous(const array& arr) { + return arr.strides(-1) == 1; +} + +// Returns true if Q doesn't need a contiguous copy for the vector path. +// Allows row-contiguous or transposed layouts where batch/head dims are +// interchangeable with sequence when one is a singleton. +bool q_is_vector_compatible(const array& arr) { + if (arr.flags().row_contiguous) { + return true; + } + auto& strides = arr.strides(); + auto& shape = arr.shape(); + if (shape[0] == 1 || shape[1] == 1) { + auto bidx = shape[0] == 1 ? 1 : 0; + return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) && + (strides[bidx] == shape[3]); + } + return false; +} + +// Returns true if K/V doesn't need a contiguous copy for the vector path. +// Requires last dim stride=1 and contiguous batch/head dimensions. +bool kv_is_vector_compatible(const array& arr) { + auto& strides = arr.strides(); + auto& shape = arr.shape(); + if (strides.back() != 1) { + return false; + } + if (shape[0] == 1 || shape[1] == 1) { + return true; + } + return (strides[0] == strides[1] * shape[1]); +} + +// Returns true if mask doesn't need a contiguous copy. +// Checks row-contiguity or batch/head dimension compatibility. +bool mask_is_compatible(const array& q, const array& arr) { + auto& strides = arr.strides(); + auto& shape = arr.shape(); + return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 || + (strides[0] == strides[1] * shape[1]); +} + void sdpa_full_self_attention_nax( const Stream& s, metal::Device& d, @@ -112,6 +171,7 @@ void sdpa_full_self_attention_nax( const int NQ_aligned = qL / bq; const int NK_aligned = kL / bk; + // NAX doesn't support logsumexp output - provide dummy strides AttnParams params{ /* int B = */ B, /* int H = */ H, @@ -136,7 +196,8 @@ void sdpa_full_self_attention_nax( /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, - /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}}; + /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}, + /* int64_t LSE_strides[2] = */ {0, 0}}; compute_encoder.set_input_array(q, 0); compute_encoder.set_input_array(k, 1); @@ -173,9 +234,12 @@ void sdpa_full_self_attention_metal( array& o, bool do_causal_, const std::optional& mask, - const std::optional& sinks) { + const std::optional& sinks, + bool output_logsumexp_ = false, + array* lse_out = nullptr) { + // NAX path does not support logsumexp output - skip when VJP needs it if (metal::is_nax_available() && q.shape(3) != 80 && - (env::enable_tf32() || q.dtype() != float32)) { + (env::enable_tf32() || q.dtype() != float32) && !output_logsumexp_) { return sdpa_full_self_attention_nax( /* const Stream& s = */ s, /* metal::Device& d = */ d, @@ -211,13 +275,15 @@ void sdpa_full_self_attention_metal( const bool has_mask = mask.has_value(); const bool do_causal = do_causal_; const bool has_sinks = sinks.has_value(); + const bool output_logsumexp = output_logsumexp_; metal::MTLFCList func_consts = { {&align_Q, MTL::DataType::DataTypeBool, 200}, {&align_K, MTL::DataType::DataTypeBool, 201}, {&has_mask, MTL::DataType::DataTypeBool, 300}, {&do_causal, MTL::DataType::DataTypeBool, 301}, - {&has_sinks, MTL::DataType::DataTypeBool, 302}}; + {&has_sinks, MTL::DataType::DataTypeBool, 302}, + {&output_logsumexp, MTL::DataType::DataTypeBool, 303}}; std::string base_name; concatenate( @@ -250,7 +316,9 @@ void sdpa_full_self_attention_metal( "_do_causal_", (do_causal ? 't' : 'n'), "_has_sinks_", - (has_sinks ? 't' : 'n')); + (has_sinks ? 't' : 'n'), + "_lse_", + (output_logsumexp ? 't' : 'n')); auto& compute_encoder = d.get_command_encoder(s.index); @@ -265,7 +333,15 @@ void sdpa_full_self_attention_metal( bd, wm, wn, - (has_mask ? *mask : q)); + (has_mask ? *mask : q), + gqa_factor, + scale, + align_Q, + align_K, + has_mask, + do_causal, + has_sinks, + output_logsumexp); compute_encoder.set_compute_pipeline_state(kernel); @@ -275,6 +351,14 @@ void sdpa_full_self_attention_metal( const int NQ_aligned = qL / bq; const int NK_aligned = kL / bk; + // Compute LSE strides if outputting logsumexp: shape [B, H, qL, 1] + // The VJP kernel expects strides as: + // LSE_strides[0] = qL (stride between heads within same batch) + // LSE_strides[1] = 1 (stride between query positions) + // Linear index = (batch * H + head) * qL + query_pos + int64_t lse_str_head = qL; // Stride between heads + int64_t lse_str_qpos = 1; // Stride between query positions + AttnParams params{ /* int B = */ B, /* int H = */ H, @@ -299,7 +383,8 @@ void sdpa_full_self_attention_metal( /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, - /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}}; + /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}, + /* int64_t LSE_strides[2] = */ {lse_str_head, lse_str_qpos}}; compute_encoder.set_input_array(q, 0); compute_encoder.set_input_array(k, 1); @@ -319,6 +404,9 @@ void sdpa_full_self_attention_metal( if (has_sinks) { compute_encoder.set_input_array(*sinks, 7); } + if (output_logsumexp && lse_out != nullptr) { + compute_encoder.set_output_array(*lse_out, 8); + } MTL::Size grid_dims = MTL::Size(NQ, H, B); MTL::Size group_dims = MTL::Size(32, wm, wn); @@ -410,7 +498,6 @@ void sdpa_vector( compute_encoder.set_input_array(*sinks, 16); compute_encoder.set_bytes(q.shape(1), 17); } - // Launch compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -425,7 +512,8 @@ void sdpa_vector_2pass( float scale, bool do_causal, const std::optional& mask, - const std::optional& sinks) { + const std::optional& sinks, + array* lse_out = nullptr) { // Set the kernel name std::string kname; kname.reserve(64); @@ -565,8 +653,15 @@ void sdpa_vector_2pass( kname += "_"; kname += std::to_string(v.shape(-1)); + bool do_output_lse = (lse_out != nullptr); + metal::MTLFCList pass2_func_consts = { + {&do_output_lse, MTL::DataType::DataTypeBool, 28}, + }; + std::string pass2_hash_name = kname; + pass2_hash_name += do_output_lse ? "_lse" : "_nolse"; + // Get the kernel - kernel = d.get_kernel(kname); + kernel = d.get_kernel(kname, pass2_hash_name, pass2_func_consts); compute_encoder.set_compute_pipeline_state(kernel); // Set its arguments @@ -575,6 +670,9 @@ void sdpa_vector_2pass( compute_encoder.set_input_array(maxs, 2); compute_encoder.set_output_array(out, 3); compute_encoder.set_bytes(blocks, 4); + if (do_output_lse) { + compute_encoder.set_output_array(*lse_out, 5); + } // Launch group_dims = MTL::Size(1024, 1, 1); @@ -592,17 +690,11 @@ bool ScaledDotProductAttention::use_fallback( bool has_mask, bool has_arr_mask, bool do_causal, - bool is_training, bool output_logsumexp, Stream s) { - if (is_training) { - // It's faster for training on Metal to use the unfused SDPA for both - // forward and backward. - return true; - } - if (output_logsumexp) { - return true; - } + // Note: When output_logsumexp is true, the caller (fast.cpp) has already + // verified VJP availability with proper has_mask/has_sinks parameters. + // No redundant check needed here. if (s.device == Device::cpu) { return true; } @@ -620,7 +712,8 @@ bool ScaledDotProductAttention::use_fallback( (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || query_head_dim == 256); const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && - (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); + (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 96 || + query_head_dim == 128); const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || (query_sequence_length <= key_sequence_length && do_causal); @@ -653,8 +746,6 @@ void ScaledDotProductAttention::eval_gpu( std::vector copies; - // Define some copy functions to ensure the layout of the inputs is as - // expected. copies.reserve(inputs.size()); auto copy_unless = [&copies, &s]( auto predicate, const array& arr) -> const array& { @@ -667,11 +758,6 @@ void ScaledDotProductAttention::eval_gpu( } }; - // Checks that the headdim dimension has stride 1. - auto is_matrix_contiguous = [](const array& arr) { - return arr.strides(-1) == 1; - }; - std::optional sinks = std::nullopt; if (has_sinks_) { sinks = copy_unless(is_matrix_contiguous, inputs.back()); @@ -680,41 +766,10 @@ void ScaledDotProductAttention::eval_gpu( // We are in vector mode ie single query if (q_pre.shape(2) <= 8) { - auto q_copy_unless = [](const array& arr) { - if (arr.flags().row_contiguous) { - return true; - } - auto& strides = arr.strides(); - auto& shape = arr.shape(); - if (shape[0] == 1 || shape[1] == 1) { - // If either the batch or head dimension is a singleton, the other can - // be transposed with the sequence dimension - auto bidx = shape[0] == 1 ? 1 : 0; - return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) && - (strides[bidx] == shape[3]); - } - return false; - }; - - auto kv_copy_unless = [](const array& arr) { - // keys and values should be copied if: - // - the last dimension is not contiguous - // - the batch and head dim are not contiguous - auto& strides = arr.strides(); - auto& shape = arr.shape(); - if (strides.back() != 1) { - return false; - } - if (shape[0] == 1 || shape[1] == 1) { - return true; - } - return (strides[0] == strides[1] * shape[1]); - }; - - bool q_copied = !q_copy_unless(q_pre); + bool q_copied = !q_is_vector_compatible(q_pre); array q = (q_copied) ? contiguous_copy_gpu(q_pre, s) : q_pre; - const auto& k = copy_unless(kv_copy_unless, k_pre); - const auto& v = copy_unless(kv_copy_unless, v_pre); + const auto& k = copy_unless(kv_is_vector_compatible, k_pre); + const auto& v = copy_unless(kv_is_vector_compatible, v_pre); // Donate the query if possible if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) { @@ -726,15 +781,19 @@ void ScaledDotProductAttention::eval_gpu( o.set_data(allocator::malloc(o.nbytes())); } - auto mask_copy_unless = [&q](const array& arr) { - auto& strides = arr.strides(); - auto& shape = arr.shape(); - return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 || - (strides[0] == strides[1] * shape[1]); - }; + // Handle logsumexp output for VJP backward pass + array* lse_out = nullptr; + if (output_logsumexp_ && outputs.size() > 1) { + auto& lse = outputs[1]; + lse.set_data(allocator::malloc(lse.nbytes())); + lse_out = &outputs[1]; + } + auto mask_pred = [&q](const array& arr) { + return mask_is_compatible(q, arr); + }; auto mask = has_arr_mask - ? std::optional{copy_unless(mask_copy_unless, inputs[3])} + ? std::optional{copy_unless(mask_pred, inputs[3])} : std::nullopt; // We route to the 2 pass fused attention if @@ -744,7 +803,8 @@ void ScaledDotProductAttention::eval_gpu( char devc = d.get_architecture().back(); if (((devc == 'd' || devc == 's') && k.shape(2) >= 1024) || (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) { - sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask, sinks); + sdpa_vector_2pass( + s, d, q, k, v, o, scale_, do_causal, mask, sinks, lse_out); } else { sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask, sinks); } @@ -774,25 +834,612 @@ void ScaledDotProductAttention::eval_gpu( {str_oB, str_oH, str_oL, str_oD}, flags); + // Handle logsumexp output for VJP backward pass + array* lse_out = nullptr; + if (output_logsumexp_ && outputs.size() > 1) { + auto& lse = outputs[1]; + lse.set_data(allocator::malloc(lse.nbytes())); + lse_out = &outputs[1]; + } + auto mask = has_arr_mask ? std::optional{copy_unless(is_matrix_contiguous, inputs[3])} : std::nullopt; sdpa_full_self_attention_metal( - s, d, q, k, v, scale_, o, do_causal_, mask, sinks); + s, + d, + q, + k, + v, + scale_, + o, + do_causal_, + mask, + sinks, + output_logsumexp_, + lse_out); } d.add_temporaries(std::move(copies), s.index); } -bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) { - return true; +bool ScaledDotProductAttentionVJP::use_fallback( + const array& q, + const array& k, + Stream s, + bool do_causal, + bool has_mask, + bool has_sinks, + int n_kv_heads) { + // Use fallback on CPU + if (s.device == Device::cpu) { + return true; + } + + const int query_head_dim = q.shape(-1); + const int query_seq_len = q.shape(2); + + // Vector path (qL <= 8): the 1-pass sdpa_vector kernel does not write LSE, + // which is required by the VJP. Fall back to unfused backward for short + // sequences until the vector kernel is extended to support LSE output. + if (query_seq_len <= 8) { + return true; + } + + // STEEL VJP dispatch policy. + // + // Fused VJP avoids materializing the O(L^2) attention matrix, providing + // 70-95% memory savings. However, fused uses BQ=32 STEEL tiles (~1.9 + // TFLOPS) while unfused uses NAX-optimized large-tile matmul (~10.7 + // TFLOPS), so performance varies by head dim, causal mode, and GQA. + // + // Policy: + // Causal: fused by default (tile skipping gives speed advantage) + // Dense: unfused by default (fused always slower), memory ceiling for OOM + // Memory: 1 GB ceiling forces fused to prevent OOM on large dense attention + // + // Environment overrides: + // MLX_SDPA_VJP_MODE={auto|unfused|fused} + // MLX_SDPA_VJP_LONG_L_THRESHOLD=N (overrides auto L thresholds) + // MLX_SDPA_VJP_ATTENTION_BYTES_THRESHOLD=N (memory ceiling, default 1GB) + + const bool steel_eligible = + steel_vjp_eligible(query_head_dim, q.dtype()) && !has_mask && !has_sinks; + + if (!steel_eligible) { + return true; // Not eligible for fused VJP, use unfused + } + + // Read dispatch policy from environment (cached — read once per process) + static const char* mode_env = std::getenv("MLX_SDPA_VJP_MODE"); + std::string_view mode = mode_env ? mode_env : "auto"; + + if (mode == "unfused") { + return true; + } + if (mode == "fused") { + return false; // Force fused VJP + } + + // --- Auto mode: causal thresholds (speed) + memory ceiling (OOM) --- + + const int B = q.shape(0); + const int n_q_heads = q.shape(1); + const int gqa_factor = (n_kv_heads > 0) ? (n_q_heads / n_kv_heads) : 1; + const bool is_gqa = gqa_factor > 1; + + // Causal attention: fused is competitive or faster due to ~50% tile skipping. + // Use L thresholds to decide (speed-driven, not memory-driven). + if (do_causal) { + int l_threshold; + if (query_head_dim <= 96) { + // D=64/96: fused is faster at all L (non-GQA). + // GQA erodes advantage; fused at L>=1024. + l_threshold = is_gqa ? 1024 : 0; + } else { + // D=128: fused is slower but saves a lot of memory at L>=1024. + // Acceptable trade-off for training. GQA needs higher L. + l_threshold = is_gqa ? 2048 : 1024; + } + + static const char* thresh_env = + std::getenv("MLX_SDPA_VJP_LONG_L_THRESHOLD"); + if (thresh_env) { + l_threshold = std::atoi(thresh_env); + } + + if (query_seq_len >= l_threshold) { + return false; // Use fused VJP + } + } + + // Memory ceiling: OOM protection for dense attention (and short causal). + // Dense fused is always slower — only forced to prevent OOM. + // Default 1 GB: covers L>=4096 at H=32 (attn matrix = 1+ GB). + const int key_seq_len = k.shape(2); + const size_t attn_bytes = + static_cast(B) * n_q_heads * query_seq_len * key_seq_len * 2; + + static const char* bytes_env = + std::getenv("MLX_SDPA_VJP_ATTENTION_BYTES_THRESHOLD"); + const size_t bytes_threshold = bytes_env + ? static_cast(std::atoll(bytes_env)) + : static_cast(1) << 30; // 1 GB + + if (attn_bytes >= bytes_threshold) { + return false; // Use fused to avoid OOM + } + + return true; // Default: unfused (dense is never speed-competitive) } +namespace { + +// Dispatch the STEEL VJP dQ kernel. +// Computes dQ gradients using tiled matrix multiply on the GPU. +// Grid: [NQ, H, B] - one threadgroup per (query_block, head, batch) +void sdpa_steel_vjp_dq_dispatch( + const Stream& s, + metal::Device& d, + const array& q, + const array& k, + const array& v, + const array& delta, + const array& d_out, + const array& logsumexp, + array& d_q, + float scale, + bool do_causal, + const std::optional& block_mask = std::nullopt, + int qL_off_override = -1) { + using namespace mlx::steel; + + constexpr int bq = 32; + constexpr int wm = 4; + constexpr int wn = 1; + + int B = q.shape(0); + int H = q.shape(1); + int D = q.shape(3); + int gqa_factor = q.shape(1) / k.shape(1); + + // Select BK based on head dimension and architecture (matches MFA): + // D=64: BK=32 (all architectures) + // D=96/128: BK=32 on M3+ (halves KV iterations), BK=16 on M1/M2 + int bk; + if (D <= 64) { + bk = 32; + } else if (D <= 128 && d.get_architecture_gen() >= 15) { + bk = 32; // M3+ has dynamic register allocation + } else { + bk = 16; + } + + int qL = q.shape(2); + int kL = k.shape(2); + + const int NQ = (qL + bq - 1) / bq; + const int NK = (kL + bk - 1) / bk; + + const int NQ_aligned = qL / bq; + const int NK_aligned = kL / bk; + + const bool align_Q = (qL % bq) == 0; + const bool align_K = (kL % bk) == 0; + + bool has_block_mask_flag = block_mask.has_value(); + + // Function constants (same indices as forward kernel) + metal::MTLFCList func_consts = { + {&align_Q, MTL::DataType::DataTypeBool, 200}, + {&align_K, MTL::DataType::DataTypeBool, 201}, + {&do_causal, MTL::DataType::DataTypeBool, 301}, + {&has_block_mask_flag, MTL::DataType::DataTypeBool, 302}, + }; + + // Kernel name: matches host_name from instantiation macro + // Format: attention_vjp_dq_{type}_{bq}_{bk}_{bd} + std::string kname = "attention_vjp_dq_"; + kname += type_to_name(q); + kname += "_"; + kname += std::to_string(bq); + kname += "_"; + kname += std::to_string(bk); + kname += "_"; + kname += std::to_string(D); + + std::string hash_name = kname; + hash_name += "_align_Q_"; + hash_name += (align_Q ? 't' : 'n'); + hash_name += "_align_K_"; + hash_name += (align_K ? 't' : 'n'); + hash_name += "_causal_"; + hash_name += (do_causal ? 't' : 'n'); + hash_name += "_bmask_"; + hash_name += (has_block_mask_flag ? 't' : 'n'); + + float scale_log2 = static_cast(scale * M_LOG2E); + + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_steel_attention_vjp_dq_kernel( + d, + kname, + hash_name, + func_consts, + q, + bq, + bk, + D, + wm, + wn, + gqa_factor, + scale, + scale_log2, + align_Q, + align_K, + do_causal, + has_block_mask_flag); + compute_encoder.set_compute_pipeline_state(kernel); + + // LSE strides: shape [B, H, qL] stored linearly as (batch * H + head) * qL + + // pos + int64_t lse_str_head = qL; + int64_t lse_str_qpos = 1; + + // qL_off: causal mask offset. When caller pads Q/K to different block sizes, + // the padded kL - qL may differ from the original. Use override if provided. + int qL_off = (qL_off_override >= 0) ? qL_off_override : (kL - qL); + + AttnVJPParams params{ + /* int B = */ B, + /* int H = */ H, + /* int D = */ D, + + /* int qL = */ qL, + /* int kL = */ kL, + + /* int gqa_factor = */ gqa_factor, + /* float scale = */ scale, + /* float scale_log2 = */ scale_log2, + + /* int NQ = */ NQ, + /* int NK = */ NK, + + /* int NQ_aligned = */ NQ_aligned, + /* int NK_aligned = */ NK_aligned, + + /* int qL_rem = */ (qL - NQ_aligned * bq), + /* int kL_rem = */ (kL - NK_aligned * bk), + /* int qL_off = */ qL_off, + + /* int64_t Q_strides[3] = */ + {q.strides(0), q.strides(1), q.strides(2)}, + /* int64_t K_strides[3] = */ + {k.strides(0), k.strides(1), k.strides(2)}, + /* int64_t V_strides[3] = */ + {v.strides(0), v.strides(1), v.strides(2)}, + /* int64_t dO_strides[3] = */ + {d_out.strides(0), d_out.strides(1), d_out.strides(2)}, + /* int64_t LSE_strides[2] = */ {lse_str_head, lse_str_qpos}, + /* int64_t delta_strides[2] = */ {qL, 1}, + + /* int64_t dQ_strides[3] = */ + {d_q.strides(0), d_q.strides(1), d_q.strides(2)}, + /* int64_t dK_strides[3] = */ {0, 0, 0}, + /* int64_t dV_strides[3] = */ {0, 0, 0}, + + /* int NK_tiles = */ NK, + }; + + // Set buffers (must match kernel signature in steel_attention_vjp_dq.h) + compute_encoder.set_input_array(q, 0); + compute_encoder.set_input_array(k, 1); + compute_encoder.set_input_array(v, 2); + compute_encoder.set_input_array(delta, 3); + compute_encoder.set_input_array(d_out, 4); + compute_encoder.set_input_array(logsumexp, 5); + compute_encoder.set_output_array(d_q, 6); + compute_encoder.set_bytes(params, 7); + if (has_block_mask_flag) { + compute_encoder.set_input_array(*block_mask, 8); + } + + // Grid: [NQ, H, B] - one threadgroup per (query_block, head, batch) + MTL::Size grid_dims = MTL::Size(NQ, H, B); + // Group: WM * WN * 32 threads = 4 * 1 * 32 = 128 + MTL::Size group_dims = MTL::Size(wm * wn * 32, 1, 1); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +// Dispatch the STEEL VJP dKV kernel. +// Computes dK and dV gradients using tiled matrix multiply on the GPU. +// Grid: [NK, n_kv_heads, B] - one threadgroup per (kv_block, kv_head, batch) +void sdpa_steel_vjp_dkv_dispatch( + const Stream& s, + metal::Device& d, + const array& q, + const array& k, + const array& v, + const array& delta, + const array& d_out, + const array& logsumexp, + array& d_k, + array& d_v, + float scale, + bool do_causal, + const std::optional& block_mask = std::nullopt, + int qL_off_override = -1) { + using namespace mlx::steel; + + constexpr int wn = 1; + + int B = q.shape(0); + int H = q.shape(1); + int D = q.shape(3); + int gqa_factor = q.shape(1) / k.shape(1); + int n_kv_heads = k.shape(1); + + // Select BQ, WM, BK based on head dimension: + // D=64: BQ=32, WM=1, BK=32 (32 threads, register pressure OK) + // D=96: BQ=32, WM=2, BK=16 (64 threads, ~280 regs/thread) + // D=128: BQ=16, WM=2, BK=16 (64 threads, ~202 regs/thread, no spilling!) + // BQ=16 halves Q-tile register tiles (TQ=1 vs TQ=2) at cost of 2x Q + // iterations. + int bq = (D >= 128) ? 16 : 32; + int wm = (D >= 96) ? 2 : 1; + int bk = (D <= 64) ? 32 : 16; + + int qL = q.shape(2); + int kL = k.shape(2); + + const int NQ = (qL + bq - 1) / bq; + const int NK = (kL + bk - 1) / bk; + + const int NQ_aligned = qL / bq; + const int NK_aligned = kL / bk; + + const bool align_Q = (qL % bq) == 0; + const bool align_K = (kL % bk) == 0; + + bool has_block_mask_flag = block_mask.has_value(); + + // Function constants + metal::MTLFCList func_consts = { + {&align_Q, MTL::DataType::DataTypeBool, 200}, + {&align_K, MTL::DataType::DataTypeBool, 201}, + {&do_causal, MTL::DataType::DataTypeBool, 301}, + {&has_block_mask_flag, MTL::DataType::DataTypeBool, 302}, + }; + + // Kernel name: matches host_name from instantiation macro + // Format: attention_vjp_dkv_{type}_{bq}_{bk}_{bd}[_wmN] + std::string kname = "attention_vjp_dkv_"; + kname += type_to_name(q); + kname += "_"; + kname += std::to_string(bq); + kname += "_"; + kname += std::to_string(bk); + kname += "_"; + kname += std::to_string(D); + if (wm > 1) { + kname += "_wm"; + kname += std::to_string(wm); + } + + std::string hash_name = kname; + hash_name += "_align_Q_"; + hash_name += (align_Q ? 't' : 'n'); + hash_name += "_align_K_"; + hash_name += (align_K ? 't' : 'n'); + hash_name += "_causal_"; + hash_name += (do_causal ? 't' : 'n'); + hash_name += "_bmask_"; + hash_name += (has_block_mask_flag ? 't' : 'n'); + + float scale_log2 = static_cast(scale * M_LOG2E); + + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_steel_attention_vjp_dkv_kernel( + d, + kname, + hash_name, + func_consts, + q, + bq, + bk, + D, + wm, + wn, + gqa_factor, + scale, + scale_log2, + align_Q, + align_K, + do_causal, + has_block_mask_flag); + compute_encoder.set_compute_pipeline_state(kernel); + + int64_t lse_str_head = qL; + int64_t lse_str_qpos = 1; + + // qL_off: causal mask offset. When caller pads Q/K to different block sizes, + // the padded kL - qL may differ from the original. Use override if provided. + int qL_off = (qL_off_override >= 0) ? qL_off_override : (kL - qL); + + AttnVJPParams params{ + /* int B = */ B, + /* int H = */ H, + /* int D = */ D, + + /* int qL = */ qL, + /* int kL = */ kL, + + /* int gqa_factor = */ gqa_factor, + /* float scale = */ scale, + /* float scale_log2 = */ scale_log2, + + /* int NQ = */ NQ, + /* int NK = */ NK, + + /* int NQ_aligned = */ NQ_aligned, + /* int NK_aligned = */ NK_aligned, + + /* int qL_rem = */ (qL - NQ_aligned * bq), + /* int kL_rem = */ (kL - NK_aligned * bk), + /* int qL_off = */ qL_off, + + /* int64_t Q_strides[3] = */ + {q.strides(0), q.strides(1), q.strides(2)}, + /* int64_t K_strides[3] = */ + {k.strides(0), k.strides(1), k.strides(2)}, + /* int64_t V_strides[3] = */ + {v.strides(0), v.strides(1), v.strides(2)}, + /* int64_t dO_strides[3] = */ + {d_out.strides(0), d_out.strides(1), d_out.strides(2)}, + /* int64_t LSE_strides[2] = */ {lse_str_head, lse_str_qpos}, + /* int64_t delta_strides[2] = */ {qL, 1}, + + /* int64_t dQ_strides[3] = */ {0, 0, 0}, + /* int64_t dK_strides[3] = */ + {d_k.strides(0), d_k.strides(1), d_k.strides(2)}, + /* int64_t dV_strides[3] = */ + {d_v.strides(0), d_v.strides(1), d_v.strides(2)}, + + /* int NK_tiles = */ NK, + }; + + // Set buffers (must match kernel signature in steel_attention_vjp_dkv.h) + compute_encoder.set_input_array(q, 0); + compute_encoder.set_input_array(k, 1); + compute_encoder.set_input_array(v, 2); + compute_encoder.set_input_array(delta, 3); + compute_encoder.set_input_array(d_out, 4); + compute_encoder.set_input_array(logsumexp, 5); + compute_encoder.set_output_array(d_k, 6); + compute_encoder.set_output_array(d_v, 7); + compute_encoder.set_bytes(params, 8); + if (has_block_mask_flag) { + compute_encoder.set_input_array(*block_mask, 9); + } + + // Grid: [NK, n_kv_heads, B] - one threadgroup per (kv_block, kv_head, batch) + MTL::Size grid_dims = MTL::Size(NK, n_kv_heads, B); + // Group: WM * WN * 32 threads (D=64: 32, D>=96: 64) + MTL::Size group_dims = MTL::Size(wm * wn * 32, 1, 1); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +} // namespace + void ScaledDotProductAttentionVJP::eval_gpu( const std::vector& inputs, std::vector& outputs) { - throw std::runtime_error("NYI"); + auto& s = stream(); + auto& d = metal::device(s.device); + + // Parse inputs: + // inputs = [Q, K, V, (optional mask), (optional sinks), O, logsumexp, dO, + // delta] The last 4 are always O, logsumexp, dO, delta + const auto& q_pre = inputs[0]; + const auto& k_pre = inputs[1]; + const auto& v_pre = inputs[2]; + + // Determine indices based on optional inputs + // primals can have mask and/or sinks appended + size_t num_primals = inputs.size() - 4; // Subtract O, logsumexp, dO, delta + const auto& out = inputs[num_primals]; + const auto& logsumexp = inputs[num_primals + 1]; + const auto& d_out = inputs[num_primals + 2]; + const auto& delta_pre = inputs[num_primals + 3]; + + auto& d_q = outputs[0]; + auto& d_k = outputs[1]; + auto& d_v = outputs[2]; + + std::vector copies; + copies.reserve(inputs.size()); + + auto copy_unless = [&copies, &s]( + auto predicate, const array& arr) -> const array& { + if (!predicate(arr)) { + array arr_copy = contiguous_copy_gpu(arr, s); + copies.push_back(std::move(arr_copy)); + return copies.back(); + } else { + return arr; + } + }; + + // Handle optional sinks + std::optional sinks = std::nullopt; + if (has_sinks_) { + sinks = copy_unless(is_matrix_contiguous, inputs[num_primals - 1]); + } + + // Determine if we have a mask + bool has_arr_mask = num_primals > (3 + has_sinks_); + + // STEEL VJP: re-enabled behind policy control. On Apple Silicon with + // NAX-optimized matmuls, unfused is faster for typical L. Fused VJP + // avoids materializing O(L^2) attention matrix (84-96% memory savings + // at L>=1024). Dispatch controlled by MLX_SDPA_VJP_MODE env var. + // See use_fallback() for policy details. + const int query_head_dim_pre = q_pre.shape(-1); + bool use_steel_vjp = steel_vjp_eligible(query_head_dim_pre, q_pre.dtype()) && + (q_pre.shape(2) > 8) && !has_arr_mask && !has_sinks_; + + auto is_row_contiguous = [](const array& arr) { + return arr.flags().row_contiguous; + }; + + // STEEL VJP requires row-contiguous Q and dO for kernel pointer arithmetic. + const auto& q = use_steel_vjp ? copy_unless(is_row_contiguous, q_pre) + : copy_unless(q_is_vector_compatible, q_pre); + const auto& k = copy_unless(is_matrix_contiguous, k_pre); + const auto& v = copy_unless(is_matrix_contiguous, v_pre); + const auto& dO = copy_unless(is_row_contiguous, d_out); + const auto& lse = copy_unless(is_matrix_contiguous, logsumexp); + + // Handle mask + auto mask_pred = [&q](const array& arr) { + return mask_is_compatible(q, arr); + }; + std::optional mask = std::nullopt; + if (has_arr_mask) { + mask = copy_unless(mask_pred, inputs[3]); + } + + bool do_causal = do_causal_ && q.shape(2) > 1; + + // Dispatch to appropriate kernel + if (use_steel_vjp) { + // delta = sum(dO * O, axis=-1) was precomputed in fast.cpp's VJP function + // as a lazy graph op. By the time eval_gpu runs, it's already evaluated. + const auto& delta_arr = copy_unless(is_row_contiguous, delta_pre); + d_q.set_data(allocator::malloc(d_q.nbytes())); + d_k.set_data(allocator::malloc(d_k.nbytes())); + d_v.set_data(allocator::malloc(d_v.nbytes())); + + { + auto& enc = d.get_command_encoder(s.index); + auto concurrent = enc.start_concurrent(); + sdpa_steel_vjp_dq_dispatch( + s, d, q, k, v, delta_arr, dO, lse, d_q, scale_, do_causal); + sdpa_steel_vjp_dkv_dispatch( + s, d, q, k, v, delta_arr, dO, lse, d_k, d_v, scale_, do_causal); + } + } else { + throw std::runtime_error( + "[ScaledDotProductAttentionVJP::eval_gpu] " + "use_steel_vjp is false but eval_gpu was called. " + "This indicates a mismatch between use_fallback() and eval_gpu() " + "eligibility checks."); + } + + d.add_temporaries(std::move(copies), s.index); } } // namespace mlx::core::fast diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 4819ed2724..d4ebe3935f 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -30,7 +30,6 @@ bool fast::ScaledDotProductAttention::use_fallback( bool has_mask, bool has_arr_mask, bool do_causal, - bool is_training, bool output_logsumexp, Stream s) { return true; @@ -42,7 +41,12 @@ bool fast::ScaledDotProductAttention::supports_bool_mask() { bool fast::ScaledDotProductAttentionVJP::use_fallback( const array& q, - Stream s) { + const array& /* k */, + Stream s, + bool /* do_causal */, + bool /* has_mask */, + bool /* has_sinks */, + int /* n_kv_heads */) { return true; } diff --git a/mlx/fast.cpp b/mlx/fast.cpp index a668fe9abd..73427c44aa 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -822,9 +822,19 @@ array scaled_dot_product_attention( inputs.push_back(astype(*sinks, final_type, stream)); } - bool is_training = detail::in_grad_tracing(); - bool has_fast_vjp = !ScaledDotProductAttentionVJP::use_fallback(q, stream); - bool output_logsumexp = is_training && has_fast_vjp; + // Note: pass has_arr_mask (not has_mask) because the STEEL VJP kernels + // handle causal masking natively via function constants. Only array masks + // are unsupported by the fused backward path. + bool has_fast_vjp = !ScaledDotProductAttentionVJP::use_fallback( + q, + k, + stream, + do_causal, + has_arr_mask, + has_sinks, + static_cast(n_kv_heads)); + bool output_logsumexp = detail::in_grad_tracing() && has_fast_vjp; + if (!ScaledDotProductAttention::use_fallback( q, k, @@ -832,7 +842,6 @@ array scaled_dot_product_attention( has_mask, has_arr_mask, do_causal, - is_training, output_logsumexp, stream)) { if (has_bool_mask && !ScaledDotProductAttention::supports_bool_mask()) { @@ -870,14 +879,41 @@ std::vector ScaledDotProductAttention::vjp( assert(cotangents.size() == outputs.size()); auto s = stream(); - if (ScaledDotProductAttentionVJP::use_fallback(primals[0], s)) { - assert(outputs.size() == 1); + + // Determine if mask is present: primals = [Q, K, V, (mask), (sinks)] + bool has_mask = primals.size() > static_cast(3 + has_sinks_); + int n_kv_heads = primals[1].shape(1); // K is at index 1 + + // Check if we can use Flash Attention VJP + if (ScaledDotProductAttentionVJP::use_fallback( + primals[0], + primals[1], + s, + do_causal_, + has_mask, + has_sinks_, + n_kv_heads) || + !output_logsumexp_) { return Custom::vjp(primals, cotangents, argnums, outputs); } + // When output_logsumexp_ is true, the forward pass creates 2 sibling arrays: + // outputs[0] = attention output, outputs[1] = logsumexp + // Even though only outputs[0] is returned to the user, the tape tracks both + // siblings. + assert( + outputs.size() >= 2 && + "Expected logsumexp in outputs[1] when output_logsumexp_ is true"); + + // Fallback for higher-order gradients (e.g., Hessian-vector products). + // inputs = [Q, K, V, (mask), (sinks), O, LSE, dO, delta] + // The last 4 arrays (O, LSE, dO, delta) were appended below for eval_gpu; + // strip them to recover the original primals, and use dO as the cotangent. auto fallback = [sdpa = fallback_, s](const std::vector& inputs) { - std::vector primals(inputs.begin(), std::prev(inputs.end())); - auto [_, vjps] = mlx::core::vjp(sdpa, primals, {inputs.back()}); + constexpr int n_extra = 4; // O, LSE, dO, delta + std::vector primals(inputs.begin(), inputs.end() - n_extra); + auto& dO = inputs[inputs.size() - 2]; // dO is before delta + auto [_, vjps] = mlx::core::vjp(sdpa, primals, {dO}); return vjps; }; @@ -890,9 +926,16 @@ std::vector ScaledDotProductAttention::vjp( auto primitive = std::make_shared( s, fallback, scale_, do_causal_, has_sinks_); std::vector inputs = primals; - inputs.push_back(outputs[0]); - inputs.push_back(outputs[1]); + inputs.push_back(outputs[0]); // Attention output + inputs.push_back(outputs[1]); // Logsumexp inputs.push_back(cotangents[0]); + // Precompute delta = sum(dO * O, axis=-1) as a lazy graph op. + // This must be done here (not in eval_gpu) because eval_gpu cannot + // create new lazy ops — it runs during the evaluation pass. + auto O_f32 = astype(outputs[0], float32, s); + auto dO_f32 = astype(cotangents[0], float32, s); + auto delta = sum(multiply(dO_f32, O_f32, s), std::vector{3}, false, s); + inputs.push_back(delta); // delta = sum(dO * O, axis=-1), shape [B, H, qL] auto vjps = array::make_arrays(std::move(shapes), dtypes, primitive, inputs); std::vector returned_vjps; diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4434830875..6071b11eba 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -225,7 +225,6 @@ class ScaledDotProductAttention : public Custom { bool has_mask, bool has_arr_mask, bool do_causal, - bool is_training, bool output_logsumexp, Stream s); static bool supports_bool_mask(); @@ -273,7 +272,14 @@ class ScaledDotProductAttentionVJP : public Custom { do_causal_(do_causal), has_sinks_(has_sinks) {} - static bool use_fallback(const array& q, Stream s); + static bool use_fallback( + const array& q, + const array& k, + Stream s, + bool do_causal = false, + bool has_mask = false, + bool has_sinks = false, + int n_kv_heads = -1); void eval_cpu(const std::vector& inputs, std::vector& outputs) override { diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 7606373ce4..b93340321d 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -1,10 +1,12 @@ import math +import os import unittest from itertools import product import mlx.core as mx import mlx_tests import numpy as np +import pytest def mlx_ref_attn(q, k, v, scale=1.0, mask=None, sinks=None): @@ -642,6 +644,461 @@ def test_sdpa_sliced(self): tolerance = {"rtol": 1e-2, "atol": 1e-2} self.assertTrue(mx.allclose(ref, out, **tolerance)) + def test_sdpa_steel_vjp_grad(self): + """Test STEEL VJP correctness for D=64/96/128 with L>8 sequences. + + The STEEL kernel path is used for longer sequences (L>8) where the + vector path is not applicable. Tests MHA and GQA configurations + with causal and non-causal masks across fp32 and fp16. + """ + mx.random.seed(42) + + B = 1 + + # fmt: off + configs = [ + # (qL, kL, D, n_q_heads, n_kv_heads) + # D=64 configs + ( 16, 16, 64, 8, 8), + ( 32, 32, 64, 8, 8), + (128, 128, 64, 8, 8), + # D=96 configs + ( 16, 16, 96, 8, 8), + ( 32, 32, 96, 8, 8), + (128, 128, 96, 8, 8), + # D=128 configs + ( 16, 16, 128, 8, 8), + ( 32, 32, 128, 8, 8), + (128, 128, 128, 8, 8), + # GQA configs (heads=8, kv_heads=2) + ( 32, 32, 64, 8, 2), + (128, 128, 64, 8, 2), + ( 32, 32, 128, 8, 2), + (128, 128, 128, 8, 2), + # Longer sequences (skip 8192 - too slow for unit tests) + (1024, 1024, 64, 8, 8), + (1024, 1024, 64, 8, 2), + ] + # fmt: on + + dtypes = [mx.float32] + if mx.metal.is_available(): + dtypes.append(mx.float16) + + for dtype in dtypes: + for qL, kL, D, n_q, n_kv in configs: + for mask_type in (None, "causal"): + with self.subTest( + dtype=dtype, + qL=qL, + kL=kL, + D=D, + n_q_heads=n_q, + n_kv_heads=n_kv, + mask=mask_type, + ): + scale = D**-0.5 + + q = mx.random.normal(shape=(B, n_q, qL, D), dtype=dtype) + k = mx.random.normal(shape=(B, n_kv, kL, D), dtype=dtype) + v = mx.random.normal(shape=(B, n_kv, kL, D), dtype=dtype) + + mask = mask_type # None or "causal" + + def ref_fn(q, k, v): + return mlx_ref_attn(q, k, v, scale=scale, mask=mask) + + def fused_fn(q, k, v): + return mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale, mask=mask + ) + + primals = [q, k, v] + out_ref = ref_fn(q, k, v) + cotan = mx.random.normal(shape=out_ref.shape, dtype=dtype) + + _, vjp_ref = mx.vjp(ref_fn, primals, [cotan]) + _, vjp_fused = mx.vjp(fused_fn, primals, [cotan]) + + atol = 1e-4 if dtype == mx.float32 else 5e-2 + rtol = 1e-4 if dtype == mx.float32 else 5e-2 + tol = {"atol": atol, "rtol": rtol} + + for i, name in enumerate(["dQ", "dK", "dV"]): + self.assertTrue( + mx.allclose(vjp_ref[i], vjp_fused[i], **tol), + msg=( + f"{name} mismatch: dtype={dtype}, qL={qL}, " + f"kL={kL}, D={D}, n_q={n_q}, n_kv={n_kv}, " + f"mask={mask_type}, " + f"max_diff={mx.max(mx.abs(vjp_ref[i] - vjp_fused[i])).item()}" + ), + ) + + def test_sdpa_steel_vjp_masks(self): + """Test STEEL VJP with explicit masks (bool and additive). + + Since STEEL VJP may fall back to unfused for explicit masks, this + verifies that the fallback path still produces correct gradients. + """ + mx.random.seed(88) + + B = 1 + D = 64 + L = 32 + n_q, n_kv = 8, 8 + scale = D**-0.5 + + dtypes = [mx.float32] + if mx.metal.is_available(): + dtypes.append(mx.float16) + + for dtype in dtypes: + for mask_kind in ("bool", "additive"): + with self.subTest(dtype=dtype, mask_kind=mask_kind): + q = mx.random.normal(shape=(B, n_q, L, D), dtype=dtype) + k = mx.random.normal(shape=(B, n_kv, L, D), dtype=dtype) + v = mx.random.normal(shape=(B, n_kv, L, D), dtype=dtype) + + if mask_kind == "bool": + mask = mx.random.uniform(shape=(1, 1, L, L)) > 0.3 + else: + bool_mask = mx.random.uniform(shape=(1, 1, L, L)) > 0.3 + mask = mx.where( + bool_mask, + mx.zeros((1, 1, L, L), dtype=dtype), + mx.full((1, 1, L, L), -1e9, dtype=dtype), + ) + + def ref_fn(q, k, v): + return mlx_ref_attn(q, k, v, scale=scale, mask=mask) + + def fused_fn(q, k, v): + return mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale, mask=mask + ) + + primals = [q, k, v] + out_ref = ref_fn(q, k, v) + cotan = mx.random.normal(shape=out_ref.shape, dtype=dtype) + + _, vjp_ref = mx.vjp(ref_fn, primals, [cotan]) + _, vjp_fused = mx.vjp(fused_fn, primals, [cotan]) + + atol = 1e-4 if dtype == mx.float32 else 5e-2 + rtol = 1e-4 if dtype == mx.float32 else 5e-2 + tol = {"atol": atol, "rtol": rtol} + + for i, name in enumerate(["dQ", "dK", "dV"]): + self.assertTrue( + mx.allclose(vjp_ref[i], vjp_fused[i], **tol), + msg=( + f"{name} mismatch: dtype={dtype}, " + f"mask_kind={mask_kind}, " + f"max_diff={mx.max(mx.abs(vjp_ref[i] - vjp_fused[i])).item()}" + ), + ) + + def test_sdpa_vector_vjp_d256(self): + """Test D=256 two-stage tiling in the vector VJP path. + + D=256 exceeds the single-tile head dimension limit and requires + two-stage tiling in the vector kernel. Tests with small query + sequence lengths (vector path: qL <= 8). + """ + mx.random.seed(256) + + B = 1 + D = 256 + n_q, n_kv = 8, 8 + scale = D**-0.5 + + dtypes = [mx.float32] + if mx.metal.is_available(): + dtypes.append(mx.float16) + + for dtype in dtypes: + for qL in (1, 2, 4, 8): + for kL in (32, 128): + with self.subTest(dtype=dtype, qL=qL, kL=kL): + q = mx.random.normal( + shape=(B, n_q, qL, D), dtype=dtype + ) + k = mx.random.normal( + shape=(B, n_kv, kL, D), dtype=dtype + ) + v = mx.random.normal( + shape=(B, n_kv, kL, D), dtype=dtype + ) + + def ref_fn(q, k, v): + return mlx_ref_attn(q, k, v, scale=scale) + + def fused_fn(q, k, v): + return mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale + ) + + primals = [q, k, v] + out_ref = ref_fn(q, k, v) + cotan = mx.random.normal( + shape=out_ref.shape, dtype=dtype + ) + + _, vjp_ref = mx.vjp(ref_fn, primals, [cotan]) + _, vjp_fused = mx.vjp(fused_fn, primals, [cotan]) + + atol = 1e-4 if dtype == mx.float32 else 5e-2 + rtol = 1e-4 if dtype == mx.float32 else 5e-2 + tol = {"atol": atol, "rtol": rtol} + + for i, name in enumerate(["dQ", "dK", "dV"]): + self.assertTrue( + mx.allclose(vjp_ref[i], vjp_fused[i], **tol), + msg=( + f"{name} mismatch: dtype={dtype}, qL={qL}, " + f"kL={kL}, " + f"max_diff={mx.max(mx.abs(vjp_ref[i] - vjp_fused[i])).item()}" + ), + ) + + def test_sdpa_steel_vjp_unaligned(self): + """Test STEEL VJP with unaligned sequence lengths. + + Exercises the sequence padding logic by using sequence lengths that + are not multiples of 16 or 32 (the tile sizes used by STEEL kernels). + """ + mx.random.seed(17) + + B = 1 + + # fmt: off + configs = [ + # (qL, kL, D) + # Not multiples of 32 for queries + (17, 17, 64), + (33, 33, 64), + (63, 63, 64), + (100, 100, 64), + # Asymmetric lengths + (17, 33, 64), + (33, 63, 64), + (63, 100, 64), + (100, 17, 64), + # D=96 + (17, 17, 96), + (33, 33, 96), + (63, 63, 96), + (100, 100, 96), + # D=128 (exercises O/dO aliasing + padding) + (17, 17, 128), + (33, 63, 128), + (63, 100, 128), + ] + # fmt: on + + dtypes = [mx.float32] + if mx.metal.is_available(): + dtypes.append(mx.float16) + + for dtype in dtypes: + for qL, kL, D in configs: + n_q, n_kv = 8, 8 + for mask_type in (None, "causal"): + with self.subTest( + dtype=dtype, + qL=qL, + kL=kL, + D=D, + mask=mask_type, + ): + scale = D**-0.5 + + q = mx.random.normal( + shape=(B, n_q, qL, D), dtype=dtype + ) + k = mx.random.normal( + shape=(B, n_kv, kL, D), dtype=dtype + ) + v = mx.random.normal( + shape=(B, n_kv, kL, D), dtype=dtype + ) + + mask = mask_type + + def ref_fn(q, k, v): + return mlx_ref_attn(q, k, v, scale=scale, mask=mask) + + def fused_fn(q, k, v): + return mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale, mask=mask + ) + + primals = [q, k, v] + out_ref = ref_fn(q, k, v) + cotan = mx.random.normal( + shape=out_ref.shape, dtype=dtype + ) + + _, vjp_ref = mx.vjp(ref_fn, primals, [cotan]) + _, vjp_fused = mx.vjp(fused_fn, primals, [cotan]) + + atol = 1e-4 if dtype == mx.float32 else 5e-2 + rtol = 1e-4 if dtype == mx.float32 else 5e-2 + tol = {"atol": atol, "rtol": rtol} + + # For causal mask when qL > kL, skip first rows + # (they are fully masked and undefined) + if mask_type == "causal" and qL > kL: + offset = qL - kL + for i, name in enumerate(["dQ", "dK", "dV"]): + ref_g = vjp_ref[i] + fused_g = vjp_fused[i] + if name == "dQ": + ref_g = ref_g[:, :, offset:, :] + fused_g = fused_g[:, :, offset:, :] + self.assertTrue( + mx.allclose(ref_g, fused_g, **tol), + msg=( + f"{name} mismatch: dtype={dtype}, qL={qL}, " + f"kL={kL}, D={D}, mask={mask_type}, " + f"max_diff={mx.max(mx.abs(ref_g - fused_g)).item()}" + ), + ) + else: + for i, name in enumerate(["dQ", "dK", "dV"]): + self.assertTrue( + mx.allclose( + vjp_ref[i], vjp_fused[i], **tol + ), + msg=( + f"{name} mismatch: dtype={dtype}, qL={qL}, " + f"kL={kL}, D={D}, mask={mask_type}, " + f"max_diff={mx.max(mx.abs(vjp_ref[i] - vjp_fused[i])).item()}" + ), + ) + + +@pytest.mark.skipif( + not hasattr(mx, "metal"), + reason="Metal GPU required", +) +class TestSDPALongSequenceVJP(unittest.TestCase): + """Tests demonstrating fused VJP value for long sequences. + + These tests show that fused backward avoids materializing the O(L^2) + attention matrix, which matters for long sequences where memory is + the constraint rather than compute speed. + + Run with: pytest -m slow python/tests/test_fast_sdpa.py + Or: python -m pytest python/tests/test_fast_sdpa.py -k "LongSequence" -v + """ + + def _run_vjp_memory_test(self, L, D=64, H=4, B=1, dtype=mx.float16): + """Run fused vs unfused VJP and compare memory + correctness.""" + scale = 1.0 / math.sqrt(D) + + q = mx.random.normal(shape=(B, H, L, D)).astype(dtype) + k = mx.random.normal(shape=(B, H, L, D)).astype(dtype) + v = mx.random.normal(shape=(B, H, L, D)).astype(dtype) + mx.eval(q, k, v) + + def loss_fused(q, k, v): + return mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale + ).sum() + + def loss_unfused(q, k, v): + s = (q * scale) @ k.swapaxes(-1, -2) + s = mx.softmax(s, axis=-1, precise=True) + return (s @ v).sum() + + grad_fused = mx.grad(loss_fused, argnums=(0, 1, 2)) + grad_unfused = mx.grad(loss_unfused, argnums=(0, 1, 2)) + + # Measure fused memory + mx.clear_cache() + mx.eval(q, k, v) + mx.reset_peak_memory() + grads_f = grad_fused(q, k, v) + mx.eval(grads_f) + mem_fused = mx.get_peak_memory() + + # Measure unfused memory + mx.clear_cache() + mx.eval(q, k, v) + mx.reset_peak_memory() + grads_u = grad_unfused(q, k, v) + mx.eval(grads_u) + mem_unfused = mx.get_peak_memory() + + # Check no NaN + for i, name in enumerate(["dQ", "dK", "dV"]): + self.assertFalse( + mx.any(mx.isnan(grads_f[i])).item(), + f"NaN in fused {name} at L={L}", + ) + + # Check correctness (fused matches unfused) + atol = 1e-2 # float16 tolerance + for i, name in enumerate(["dQ", "dK", "dV"]): + max_diff = mx.max(mx.abs(grads_f[i] - grads_u[i])).item() + self.assertTrue( + mx.allclose(grads_f[i], grads_u[i], atol=atol, rtol=atol).item(), + f"{name} mismatch at L={L}: max|diff|={max_diff:.2e}", + ) + + # Report memory + attn_matrix_bytes = B * H * L * L * 2 # float16 + savings = 1.0 - mem_fused / mem_unfused if mem_unfused > 0 else 0.0 + print( + f"\n L={L}: fused={mem_fused/1e6:.1f}MB, unfused={mem_unfused/1e6:.1f}MB, " + f"savings={100*savings:.1f}%, " + f"theoretical_attn_matrix={attn_matrix_bytes/1e6:.1f}MB" + ) + + return mem_fused, mem_unfused + + @pytest.mark.slow + def test_long_sequence_L8192(self): + """L=8192: attention matrix would be 4GB (B=1,H=4). Fused avoids this.""" + os.environ["MLX_SDPA_VJP_MODE"] = "fused" + try: + mem_fused, mem_unfused = self._run_vjp_memory_test(L=8192, H=4) + # Fused should use significantly less memory + self.assertLess( + mem_fused, mem_unfused, + f"Fused ({mem_fused/1e6:.1f}MB) should use less memory than " + f"unfused ({mem_unfused/1e6:.1f}MB) at L=8192", + ) + finally: + os.environ.pop("MLX_SDPA_VJP_MODE", None) + + @pytest.mark.slow + def test_long_sequence_L16384(self): + """L=16384: attention matrix would be 16GB (B=1,H=4). Fused avoids this.""" + os.environ["MLX_SDPA_VJP_MODE"] = "fused" + try: + mem_fused, mem_unfused = self._run_vjp_memory_test(L=16384, H=4) + self.assertLess( + mem_fused, mem_unfused, + f"Fused ({mem_fused/1e6:.1f}MB) should use less memory than " + f"unfused ({mem_unfused/1e6:.1f}MB) at L=16384", + ) + finally: + os.environ.pop("MLX_SDPA_VJP_MODE", None) + + @pytest.mark.slow + def test_fused_correctness_sweep(self): + """Verify fused VJP correctness across multiple L values.""" + os.environ["MLX_SDPA_VJP_MODE"] = "fused" + try: + for L in [512, 1024, 2048, 4096]: + with self.subTest(L=L): + self._run_vjp_memory_test(L=L, H=4) + finally: + os.environ.pop("MLX_SDPA_VJP_MODE", None) + if __name__ == "__main__": mlx_tests.MLXTestRunner(failfast=True)