Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlx/backend/metal/kernels/scaled_dot_product_attention.metal
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ using namespace metal;
instantiate_sdpa_vector(type, 64, 64) \
instantiate_sdpa_vector(type, 96, 96) \
instantiate_sdpa_vector(type, 128, 128) \
instantiate_sdpa_vector(type, 192, 192) \
instantiate_sdpa_vector(type, 256, 256) \
instantiate_sdpa_vector_aggregation(type, 64) \
instantiate_sdpa_vector_aggregation(type, 96) \
instantiate_sdpa_vector_aggregation(type, 128) \
instantiate_sdpa_vector_aggregation(type, 192) \
instantiate_sdpa_vector_aggregation(type, 256)

instantiate_sdpa_vector_heads(float)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
attention, dtype, bq, bk, bd, wm, wn, mtype, float)

#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \
instantiate_attn(iname, itype, 32, 16, 256, 4, 1, mname, mtype) \
instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \
instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \
instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype)
Expand Down
12 changes: 10 additions & 2 deletions mlx/backend/metal/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,9 +618,17 @@ bool ScaledDotProductAttention::use_fallback(
const bool sdpa_vector_supported_head_dim =
query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
query_head_dim == 256);
query_head_dim == 192 || query_head_dim == 256);
// For head_dim >= 192, the fused full-attention kernel is slower than
// unfused for short sequences. Only route to fused when kL is large enough
// that the unfused path would exceed Metal buffer limits (the fused kernel
// tiles K/V so it scales to arbitrary sequence lengths).
const bool sdpa_full_large_hd_ok =
(query_head_dim == 192 || query_head_dim == 256) &&
key_sequence_length > 16384;
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128 ||
sdpa_full_large_hd_ok);

const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
(query_sequence_length <= key_sequence_length && do_causal);
Expand Down
Loading