Skip to content

metalKernel: same-named kernels with different weights produce wrong results in single eval graph #3263

@pHequals7

Description

@pHequals7

Bug Description

When multiple metalKernel() calls with the same kernel name but different weight buffers are evaluated in a single graph (e.g., 12 transformer layers all using a fused LN+GEMV kernel named fused_ln_gemv_768_768), the Metal pipeline state appears to confuse buffer bindings between invocations, producing deterministic but incorrect results.

This is potentially related to the family of stale-value issues in compiled execution reported in #3201.

Minimal Reproduction

Two chained attention layers using the same custom kernel name with different weights:

// Layer 1: kernel "fused_768" with weights W1
let out1 = metalKernel("fused_768")(x, W1)  // correct

// Layer 2: kernel "fused_768" with weights W2, consuming layer 1's output
let out2 = metalKernel("fused_768")(out1, W2)  // WRONG — uses W1's buffer bindings

Expected: Each invocation binds its own weight buffers via setBuffer().
Actual: The second invocation reads from the first invocation's weight buffer. Result is deterministic but incorrect (diff ≈ 370 for fp16 values near 1.0).

The issue reproduces both inside and outside compile(). It is NOT related to compile(shapeless=True) — it occurs in standard lazy evaluation.

Workaround

Give each call site a unique kernel name:

let out1 = metalKernel("fused_768_layer0")(x, W1)   // unique name → correct
let out2 = metalKernel("fused_768_layer1")(out1, W2) // unique name → correct

Analysis

The CustomKernel::eval_gpu function in backend/metal/custom_kernel.cpp looks up the Metal library and pipeline state by name_. When two dispatches share the same name_ and source_, they reuse the same pipeline state (correct) and the same library cache entry (correct). The buffer bindings should be set independently per dispatch via compute_encoder.set_input_array() / set_output_array().

Suspicion: the CommandEncoder's DispatchTypeConcurrent mode combined with the barrier tracking (prev_outputs_, needs_barrier_) may incorrectly handle the case where two dispatches of the same pipeline state write to different output buffers but the second reads from a buffer that shares the same Metal buffer().ptr() as the first's input (due to copy_shared_buffer in Slice views).

However, forcing a memoryBarrier(MTL::BarrierScopeBuffers) before every dispatch does NOT fix the issue, ruling out simple barrier timing.

Environment

  • mlx commit 185b06d9 (vendored in mlx-swift)
  • macOS 15.5, Apple Silicon (M-series)
  • Metal with MTL::DispatchTypeConcurrent command encoder
  • Reproduces with both compile() and standard lazy eval

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions