Skip to content

GPU watchdog kills process during long-context SDPA prefill (65K+ keys) #3302

@Thump604

Description

@Thump604

Summary

steel_attention (full attention, qL > 8) dispatches all keys in a single Metal compute command. At 65K+ key sequence length, this dispatch exceeds the macOS GPU watchdog threshold (~5 seconds), resulting in kIOGPUCommandBufferCallbackErrorImpactingInteractivity and process termination.

This is distinct from #3267 (display-active watchdog during training with short sequences). This issue is about the absolute GPU time of a single SDPA dispatch during inference prefill at long context.

Reproduction

Any model with long context prefill triggers this. Tested on M2 Ultra 128GB, macOS 26.3.1:

  • Qwen3.5-2B at 262K: GPU watchdog kill during prefill
  • Qwen3.5-35B at 262K with YaRN: GPU watchdog kill during prefill
  • Qwen3.5-122B at 128K+: borderline (single dispatch ~4-5s, close to threshold)

The vector SDPA path (qL ≤ 8, decode) is not affected — it already has 2-pass chunking.

Root Cause

sdpa_full_self_attention in scaled_dot_product_attention.cpp dispatches a single steel_attention kernel covering the entire key sequence. The dispatch grid is proportional to kL (number of key tokens). At 65K+ keys with typical head counts (8-64 GQA heads), the single dispatch runs for >5 seconds on Apple Silicon, triggering the macOS GPU watchdog.

The watchdog fires per-command-buffer, not per-dispatch. Multiple dispatches within the same command buffer don't individually reset the timer.

What I Tried

  1. Mid-primitive command buffer commits: Inserting mx.eval() or command buffer commits between SDPA chunks. Triggers Metal assertion: "A command encoder is already encoding to this command buffer". MLX's eval model doesn't support mid-graph command buffer boundaries.

  2. AGX_RELAX_CDM_CTXSTORE_TIMEOUT=1: Works as a workaround (per [BUG] Metal GPU watchdog kills LoRA training when display is active #3267), but it's a system-wide env var that disables the watchdog entirely — not appropriate as a permanent solution.

  3. Adaptive prefill step size (application-level): Reducing prefill_step_size in the LLM framework so each generate() call processes fewer tokens. This creates separate mx.eval() calls and thus separate command buffers. Partially works, but adds significant overhead and doesn't help when a single SDPA layer's key length exceeds 65K (which happens during later prefill steps as KV cache accumulates).

Proposed Solution: Chunked Full Attention

Split the key sequence into chunks (e.g., 65K tokens each), dispatch steel_attention per chunk with partial outputs, then reduce via online softmax. This mirrors the existing sdpa_vector_2pass pattern already used for decode.

Design:

  • Each chunk dispatches steel_attention with write_partial=true → outputs unnormalized O + partial max/sum
  • Reduction merges chunk results: O_final = Σ(O_chunk × exp2(max_chunk - max_global))
  • Routes automatically when kL ≥ threshold
  • Single-pass path unchanged for short sequences

I have a working implementation on a local branch (5 commits, ~575 lines across 3 files). Verified at 512K context with Qwen3.5-2B. Happy to open a PR if this approach aligns with the project's direction.

Key design question: Should this be part of the existing steel_attention infrastructure, or a separate kernel? The existing sdpa_vector_2pass sets a precedent for chunking within the SDPA dispatch.

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions