-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Description
head_dim=192 crashes at 128K context (unfused SDPA fallback)
Summary
SDPA with head_dim=192 crashes at 128K+ context with a Metal allocation error. This is similar to the head_dim=256 issue fixed in PR #3293, but head_dim=192 was not included in that PR's routing logic.
Environment
- Hardware: Mac Studio M3 Ultra (256GB)
- OS: macOS 25.3.0 (Darwin 25.3.0)
- MLX: 0.31.2.dev20260324+63b73e7a (main branch, latest)
Reproduction
import mlx.core as mx
B, H, D = 1, 8, 192
context_len = 128 * 1024 # 128K tokens
q = mx.random.normal((B, H, context_len, D))
k = mx.random.normal((B, H, context_len, D))
v = mx.random.normal((B, H, context_len, D))
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0 / (D ** 0.5))
mx.eval(out)Result:
[metal::malloc] Attempting to allocate 549755813888 bytes which is greater
than the maximum allowed buffer size of 179357827072 bytes.
Analysis
Working head dimensions (128K context):
- ✅
head_dim=64- works - ✅
head_dim=128- works - ❌
head_dim=192- CRASHES (this issue) - ❌
head_dim=256- crashes (PR fix: add head_dim=256 to fused SDPA full attention kernel #3293 adds support)
Root Cause
scaled_dot_product_attention.cpp only includes 64, 128, and 256 in sdpa_full_supported_head_dim:
inline bool sdpa_full_supported_head_dim(int query_head_dim) {
return query_head_dim == 64 ||
query_head_dim == 128 ||
query_head_dim == 256; // Added in PR #3293
}head_dim=192 is NOT included, so it falls back to the unfused naive attention path which materializes the full score matrix. At 128K context, this creates a ~512 GB allocation that exceeds Metal's buffer limit.
Proposed Solution
Add head_dim=192 to the supported dimensions and instantiate the corresponding kernel:
File: mlx/ops.cpp (or wherever sdpa_full_supported_head_dim is defined)
inline bool sdpa_full_supported_head_dim(int query_head_dim) {
return query_head_dim == 64 ||
query_head_dim == 128 ||
query_head_dim == 192 || // ADD THIS
query_head_dim == 256;
}File: mlx/backend/metal/kernels/steel/attention/steel_attention.metal
// Add instantiation for head_dim=192
instantiate_attn(iname, itype, 32, 16, 192, 4, 1, mname, mtype)This follows the same pattern as PR #3293's fix for head_dim=256.
Impact
Affected use cases:
- Models with 192-dimensional heads (less common than 128/256, but exists)
- Long-context inference (128K+ tokens)
- Users hitting this will see cryptic Metal allocation errors
Workaround:
None currently - models with head_dim=192 cannot use long contexts.
Related
- PR fix: add head_dim=256 to fused SDPA full attention kernel #3293 - Adds
head_dim=256support (same pattern needed for 192) - Issue GPU watchdog kills process during long-context SDPA prefill (65K+ keys) #3302 - GPU watchdog at long contexts (different issue, fixed by chunking)
Testing
Tested on M3 Ultra 256GB:
- 128K context with
head_dim=192: ❌ Crashes - 128K context with
head_dim=64, 128: ✅ Works - 128K context with
head_dim=256: ❌ Crashes (expected, needs PR fix: add head_dim=256 to fused SDPA full attention kernel #3293 + Chunked full-attention SDPA for long key sequences #3307)
Discovery Method
Found via automated edge-case testing across multiple head dimensions at long context lengths.
Note: This bug was discovered while validating PR #3293. The fix for head_dim=256 should also cover head_dim=192 using the same approach.