From cd40ad246aa8aad310e45b11fd4d21091340368a Mon Sep 17 00:00:00 2001 From: Tarjei Mandt Date: Sat, 28 Mar 2026 00:52:14 +1100 Subject: [PATCH] Fix gated delta kernel precision --- mlx_lm/models/gated_delta.py | 67 ++++++++++++++++++++---------------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/mlx_lm/models/gated_delta.py b/mlx_lm/models/gated_delta.py index af983f6d3..f4126acfb 100644 --- a/mlx_lm/models/gated_delta.py +++ b/mlx_lm/models/gated_delta.py @@ -59,37 +59,44 @@ def _make_gated_delta_kernel(has_mask=False, vectorized=False): {g_setup} auto beta_ = beta + b_idx * T * Hv; - for (int t = 0; t < T; ++t) {{ - if ({mask_source}) {{ - float kv_mem = 0.0f; - for (int i = 0; i < n_per_t; ++i) {{ - auto s_idx = n_per_t * dk_idx + i; - state[i] = state[i] * {g_access}; - kv_mem += state[i] * k_[s_idx]; - }} - kv_mem = simd_sum(kv_mem); - - auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx]; - - float out = 0.0f; - for (int i = 0; i < n_per_t; ++i) {{ - auto s_idx = n_per_t * dk_idx + i; - state[i] = state[i] + k_[s_idx] * delta; - out += state[i] * q_[s_idx]; - }} - out = simd_sum(out); - if (thread_index_in_simdgroup == 0) {{ - y[dv_idx] = static_cast(out); - }} - }} - // Increment data pointers to next time step - q_ += Hk * Dk; - k_ += Hk * Dk; - v_ += Hv * Dv; - y += Hv * Dv; - {g_advance} - beta_ += Hv; + #define BODY() {{ \ + float kv_mem = 0.0f, kv_c = 0.0f; \ + for (int i = 0; i < n_per_t; ++i) {{ \ + auto s_idx = n_per_t * dk_idx + i; \ + state[i] *= {g_access}; \ + auto p = state[i] * k_[s_idx]; \ + auto a = p - kv_c; \ + auto b = kv_mem + a; \ + kv_c = (b - kv_mem) - a; \ + kv_mem = b; \ + }} \ + kv_mem = simd_sum(kv_mem); \ + auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx]; \ + float out = 0.0f; \ + for (int i = 0; i < n_per_t; ++i) {{ \ + auto s_idx = n_per_t * dk_idx + i; \ + state[i] += k_[s_idx] * delta; \ + out += state[i] * q_[s_idx]; \ + }} \ + out = simd_sum(out); \ + if (thread_index_in_simdgroup == 0) \ + y[dv_idx] = static_cast(out); \ }} + #define ADV() q_ += Hk * Dk; k_ += Hk * Dk; v_ += Hv * Dv; y += Hv * Dv; {g_advance} beta_ += Hv; + + int t = 0; + for (; t + 3 < T; t += 4) {{ + if ({mask_source}) BODY() ADV() + if ({"mask[b_idx * T + t + 1]" if has_mask else "true"}) BODY() ADV() + if ({"mask[b_idx * T + t + 2]" if has_mask else "true"}) BODY() ADV() + if ({"mask[b_idx * T + t + 3]" if has_mask else "true"}) BODY() ADV() + }} + for (; t < T; ++t) {{ + if ({mask_source}) BODY() + ADV() + }} + #undef BODY + #undef ADV for (int i = 0; i < n_per_t; ++i) {{ auto s_idx = n_per_t * dk_idx + i; o_state[s_idx] = static_cast(state[i]);