-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Description
Summary
steel_attention (full attention, qL > 8) dispatches all keys in a single Metal compute command. At 65K+ key sequence length, this dispatch exceeds the macOS GPU watchdog threshold (~5 seconds), resulting in kIOGPUCommandBufferCallbackErrorImpactingInteractivity and process termination.
This is distinct from #3267 (display-active watchdog during training with short sequences). This issue is about the absolute GPU time of a single SDPA dispatch during inference prefill at long context.
Reproduction
Any model with long context prefill triggers this. Tested on M2 Ultra 128GB, macOS 26.3.1:
- Qwen3.5-2B at 262K: GPU watchdog kill during prefill
- Qwen3.5-35B at 262K with YaRN: GPU watchdog kill during prefill
- Qwen3.5-122B at 128K+: borderline (single dispatch ~4-5s, close to threshold)
The vector SDPA path (qL ≤ 8, decode) is not affected — it already has 2-pass chunking.
Root Cause
sdpa_full_self_attention in scaled_dot_product_attention.cpp dispatches a single steel_attention kernel covering the entire key sequence. The dispatch grid is proportional to kL (number of key tokens). At 65K+ keys with typical head counts (8-64 GQA heads), the single dispatch runs for >5 seconds on Apple Silicon, triggering the macOS GPU watchdog.
The watchdog fires per-command-buffer, not per-dispatch. Multiple dispatches within the same command buffer don't individually reset the timer.
What I Tried
-
Mid-primitive command buffer commits: Inserting
mx.eval()or command buffer commits between SDPA chunks. Triggers Metal assertion: "A command encoder is already encoding to this command buffer". MLX's eval model doesn't support mid-graph command buffer boundaries. -
AGX_RELAX_CDM_CTXSTORE_TIMEOUT=1: Works as a workaround (per [BUG] Metal GPU watchdog kills LoRA training when display is active #3267), but it's a system-wide env var that disables the watchdog entirely — not appropriate as a permanent solution. -
Adaptive prefill step size (application-level): Reducing
prefill_step_sizein the LLM framework so eachgenerate()call processes fewer tokens. This creates separatemx.eval()calls and thus separate command buffers. Partially works, but adds significant overhead and doesn't help when a single SDPA layer's key length exceeds 65K (which happens during later prefill steps as KV cache accumulates).
Proposed Solution: Chunked Full Attention
Split the key sequence into chunks (e.g., 65K tokens each), dispatch steel_attention per chunk with partial outputs, then reduce via online softmax. This mirrors the existing sdpa_vector_2pass pattern already used for decode.
Design:
- Each chunk dispatches
steel_attentionwithwrite_partial=true→ outputs unnormalizedO+ partialmax/sum - Reduction merges chunk results:
O_final = Σ(O_chunk × exp2(max_chunk - max_global)) - Routes automatically when
kL ≥ threshold - Single-pass path unchanged for short sequences
I have a working implementation on a local branch (5 commits, ~575 lines across 3 files). Verified at 512K context with Qwen3.5-2B. Happy to open a PR if this approach aligns with the project's direction.
Key design question: Should this be part of the existing steel_attention infrastructure, or a separate kernel? The existing sdpa_vector_2pass sets a precedent for chunking within the SDPA dispatch.
Related
- [BUG] Metal GPU watchdog kills LoRA training when display is active #3267 — GPU watchdog during training (display-active, different trigger)
- fix: add head_dim=256 to fused SDPA full attention kernel #3293 —
head_dim=256fused SDPA (fixes crash at 32K, but 65K+ still hits watchdog) - SIGSEGV in QuantizedMatmul::eval_gpu during long token generation on Mac Studio M2 Ultra #3216 — thread-safety crashes (fixed by Make each thread have its own default stream #3281)