-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Description
Bug Description
When multiple metalKernel() calls with the same kernel name but different weight buffers are evaluated in a single graph (e.g., 12 transformer layers all using a fused LN+GEMV kernel named fused_ln_gemv_768_768), the Metal pipeline state appears to confuse buffer bindings between invocations, producing deterministic but incorrect results.
This is potentially related to the family of stale-value issues in compiled execution reported in #3201.
Minimal Reproduction
Two chained attention layers using the same custom kernel name with different weights:
// Layer 1: kernel "fused_768" with weights W1
let out1 = metalKernel("fused_768")(x, W1) // correct
// Layer 2: kernel "fused_768" with weights W2, consuming layer 1's output
let out2 = metalKernel("fused_768")(out1, W2) // WRONG — uses W1's buffer bindingsExpected: Each invocation binds its own weight buffers via setBuffer().
Actual: The second invocation reads from the first invocation's weight buffer. Result is deterministic but incorrect (diff ≈ 370 for fp16 values near 1.0).
The issue reproduces both inside and outside compile(). It is NOT related to compile(shapeless=True) — it occurs in standard lazy evaluation.
Workaround
Give each call site a unique kernel name:
let out1 = metalKernel("fused_768_layer0")(x, W1) // unique name → correct
let out2 = metalKernel("fused_768_layer1")(out1, W2) // unique name → correctAnalysis
The CustomKernel::eval_gpu function in backend/metal/custom_kernel.cpp looks up the Metal library and pipeline state by name_. When two dispatches share the same name_ and source_, they reuse the same pipeline state (correct) and the same library cache entry (correct). The buffer bindings should be set independently per dispatch via compute_encoder.set_input_array() / set_output_array().
Suspicion: the CommandEncoder's DispatchTypeConcurrent mode combined with the barrier tracking (prev_outputs_, needs_barrier_) may incorrectly handle the case where two dispatches of the same pipeline state write to different output buffers but the second reads from a buffer that shares the same Metal buffer().ptr() as the first's input (due to copy_shared_buffer in Slice views).
However, forcing a memoryBarrier(MTL::BarrierScopeBuffers) before every dispatch does NOT fix the issue, ruling out simple barrier timing.
Environment
- mlx commit
185b06d9(vendored in mlx-swift) - macOS 15.5, Apple Silicon (M-series)
- Metal with
MTL::DispatchTypeConcurrentcommand encoder - Reproduces with both
compile()and standard lazy eval