Skip to content

head_dim=192 crashes at 128K context (unfused SDPA fallback) #3312

@hnshah

Description

@hnshah

head_dim=192 crashes at 128K context (unfused SDPA fallback)

Summary

SDPA with head_dim=192 crashes at 128K+ context with a Metal allocation error. This is similar to the head_dim=256 issue fixed in PR #3293, but head_dim=192 was not included in that PR's routing logic.

Environment

  • Hardware: Mac Studio M3 Ultra (256GB)
  • OS: macOS 25.3.0 (Darwin 25.3.0)
  • MLX: 0.31.2.dev20260324+63b73e7a (main branch, latest)

Reproduction

import mlx.core as mx

B, H, D = 1, 8, 192
context_len = 128 * 1024  # 128K tokens

q = mx.random.normal((B, H, context_len, D))
k = mx.random.normal((B, H, context_len, D))
v = mx.random.normal((B, H, context_len, D))

out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0 / (D ** 0.5))
mx.eval(out)

Result:

[metal::malloc] Attempting to allocate 549755813888 bytes which is greater 
than the maximum allowed buffer size of 179357827072 bytes.

Analysis

Working head dimensions (128K context):

Root Cause

scaled_dot_product_attention.cpp only includes 64, 128, and 256 in sdpa_full_supported_head_dim:

inline bool sdpa_full_supported_head_dim(int query_head_dim) {
  return query_head_dim == 64 || 
         query_head_dim == 128 || 
         query_head_dim == 256;  // Added in PR #3293
}

head_dim=192 is NOT included, so it falls back to the unfused naive attention path which materializes the full score matrix. At 128K context, this creates a ~512 GB allocation that exceeds Metal's buffer limit.

Proposed Solution

Add head_dim=192 to the supported dimensions and instantiate the corresponding kernel:

File: mlx/ops.cpp (or wherever sdpa_full_supported_head_dim is defined)

inline bool sdpa_full_supported_head_dim(int query_head_dim) {
  return query_head_dim == 64 || 
         query_head_dim == 128 || 
         query_head_dim == 192 ||  // ADD THIS
         query_head_dim == 256;
}

File: mlx/backend/metal/kernels/steel/attention/steel_attention.metal

// Add instantiation for head_dim=192
instantiate_attn(iname, itype, 32, 16, 192, 4, 1, mname, mtype)

This follows the same pattern as PR #3293's fix for head_dim=256.

Impact

Affected use cases:

  • Models with 192-dimensional heads (less common than 128/256, but exists)
  • Long-context inference (128K+ tokens)
  • Users hitting this will see cryptic Metal allocation errors

Workaround:
None currently - models with head_dim=192 cannot use long contexts.

Related

Testing

Tested on M3 Ultra 256GB:

Discovery Method

Found via automated edge-case testing across multiple head dimensions at long context lengths.


Note: This bug was discovered while validating PR #3293. The fix for head_dim=256 should also cover head_dim=192 using the same approach.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions