From 726c9a0ade5dd279dd3eed8785ac91a76256a42c Mon Sep 17 00:00:00 2001 From: Thump604 Date: Sat, 21 Mar 2026 23:27:48 -0500 Subject: [PATCH 1/3] fix: add head_dim=256 to fused SDPA full attention kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit sdpa_full_supported_head_dim only included {64, 80, 128}. Models with head_dim=256 (Qwen3.5 family) fell back to the unfused naive attention path which materializes the full score matrix as a single matmul. At 32K+ context this creates 8+ GB single allocations that crash Metal's buffer allocator. Add head_dim=256 to the dispatch gate and instantiate steel_attention kernel with bd=256. The Metal kernel template handles arbitrary BD via template parameter — no kernel code changes needed. Verified: 32K, 64K, 128K context on M2 Ultra with Qwen3.5-122B-A10B. --- .../metal/kernels/steel/attn/kernels/steel_attention.metal | 1 + mlx/backend/metal/scaled_dot_product_attention.cpp | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal index 0ff9d91b00..4a67826951 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal @@ -12,6 +12,7 @@ attention, dtype, bq, bk, bd, wm, wn, mtype, float) #define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \ + instantiate_attn(iname, itype, 32, 16, 256, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 37e554f183..e76e86d959 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -620,7 +620,8 @@ bool ScaledDotProductAttention::use_fallback( (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || query_head_dim == 256); const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && - (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); + (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128 || + query_head_dim == 256); const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || (query_sequence_length <= key_sequence_length && do_causal); From 73974355f10d9da3d2536ace1fdd6c6d5c29fb9b Mon Sep 17 00:00:00 2001 From: Thump604 Date: Tue, 24 Mar 2026 17:26:29 -0500 Subject: [PATCH 2/3] perf: route head_dim=256 to unfused SDPA for short sequences MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The fused steel_attention kernel with bd=256 is ~30% slower than the unfused (matmul + softmax + matmul) path. Route head_dim=256 to unfused by default and only use the fused kernel when key_sequence_length > 16384, where unfused would exceed Metal buffer limits. Benchmark (M2 Ultra, H=64, qL=2048, float16): kL=16384: unfused 124ms vs fused 249ms (2.0x faster with routing) kL=32768: fused only (unfused crashes) Vector path (qL<=8, decode) is unaffected — already supports head_dim=256. --- mlx/backend/metal/scaled_dot_product_attention.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index e76e86d959..9d657b4305 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -619,9 +619,14 @@ bool ScaledDotProductAttention::use_fallback( query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || query_head_dim == 256); + // For head_dim=256, the fused full-attention kernel is ~30% slower than + // unfused. Only route to fused when kL is large enough that unfused would + // exceed Metal buffer limits (the fused kernel tiles K/V so it scales). + const bool sdpa_full_256_ok = + query_head_dim == 256 && key_sequence_length > 16384; const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128 || - query_head_dim == 256); + sdpa_full_256_ok); const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || (query_sequence_length <= key_sequence_length && do_causal); From d08d2dab74882c0b2eb62cc99abedf698dfeb627 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Tue, 24 Mar 2026 19:55:48 -0500 Subject: [PATCH 3/3] feat: add head_dim=192 to fused SDPA kernel support Same pattern as head_dim=256: unfused by default for short sequences, fused when kL > 16384 (where unfused would exceed Metal buffer limits). Adds vector kernel instantiations for decode path. Fixes #3312. --- .../kernels/scaled_dot_product_attention.metal | 2 ++ .../metal/scaled_dot_product_attention.cpp | 16 +++++++++------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index c668d9d8c5..ca2fd358ef 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -32,10 +32,12 @@ using namespace metal; instantiate_sdpa_vector(type, 64, 64) \ instantiate_sdpa_vector(type, 96, 96) \ instantiate_sdpa_vector(type, 128, 128) \ + instantiate_sdpa_vector(type, 192, 192) \ instantiate_sdpa_vector(type, 256, 256) \ instantiate_sdpa_vector_aggregation(type, 64) \ instantiate_sdpa_vector_aggregation(type, 96) \ instantiate_sdpa_vector_aggregation(type, 128) \ + instantiate_sdpa_vector_aggregation(type, 192) \ instantiate_sdpa_vector_aggregation(type, 256) instantiate_sdpa_vector_heads(float) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 9d657b4305..4e6ea61880 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -618,15 +618,17 @@ bool ScaledDotProductAttention::use_fallback( const bool sdpa_vector_supported_head_dim = query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || - query_head_dim == 256); - // For head_dim=256, the fused full-attention kernel is ~30% slower than - // unfused. Only route to fused when kL is large enough that unfused would - // exceed Metal buffer limits (the fused kernel tiles K/V so it scales). - const bool sdpa_full_256_ok = - query_head_dim == 256 && key_sequence_length > 16384; + query_head_dim == 192 || query_head_dim == 256); + // For head_dim >= 192, the fused full-attention kernel is slower than + // unfused for short sequences. Only route to fused when kL is large enough + // that the unfused path would exceed Metal buffer limits (the fused kernel + // tiles K/V so it scales to arbitrary sequence lengths). + const bool sdpa_full_large_hd_ok = + (query_head_dim == 192 || query_head_dim == 256) && + key_sequence_length > 16384; const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128 || - sdpa_full_256_ok); + sdpa_full_large_hd_ok); const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || (query_sequence_length <= key_sequence_length && do_causal);