Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 37 additions & 30 deletions mlx_lm/models/gated_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,37 +59,44 @@ def _make_gated_delta_kernel(has_mask=False, vectorized=False):
{g_setup}
auto beta_ = beta + b_idx * T * Hv;

for (int t = 0; t < T; ++t) {{
if ({mask_source}) {{
float kv_mem = 0.0f;
for (int i = 0; i < n_per_t; ++i) {{
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] * {g_access};
kv_mem += state[i] * k_[s_idx];
}}
kv_mem = simd_sum(kv_mem);

auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx];

float out = 0.0f;
for (int i = 0; i < n_per_t; ++i) {{
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] + k_[s_idx] * delta;
out += state[i] * q_[s_idx];
}}
out = simd_sum(out);
if (thread_index_in_simdgroup == 0) {{
y[dv_idx] = static_cast<InT>(out);
}}
}}
// Increment data pointers to next time step
q_ += Hk * Dk;
k_ += Hk * Dk;
v_ += Hv * Dv;
y += Hv * Dv;
{g_advance}
beta_ += Hv;
#define BODY() {{ \
float kv_mem = 0.0f, kv_c = 0.0f; \
for (int i = 0; i < n_per_t; ++i) {{ \
auto s_idx = n_per_t * dk_idx + i; \
state[i] *= {g_access}; \
auto p = state[i] * k_[s_idx]; \
auto a = p - kv_c; \
auto b = kv_mem + a; \
kv_c = (b - kv_mem) - a; \
kv_mem = b; \
}} \
kv_mem = simd_sum(kv_mem); \
auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx]; \
float out = 0.0f; \
for (int i = 0; i < n_per_t; ++i) {{ \
auto s_idx = n_per_t * dk_idx + i; \
state[i] += k_[s_idx] * delta; \
out += state[i] * q_[s_idx]; \
}} \
out = simd_sum(out); \
if (thread_index_in_simdgroup == 0) \
y[dv_idx] = static_cast<InT>(out); \
}}
#define ADV() q_ += Hk * Dk; k_ += Hk * Dk; v_ += Hv * Dv; y += Hv * Dv; {g_advance} beta_ += Hv;

int t = 0;
for (; t + 3 < T; t += 4) {{
if ({mask_source}) BODY() ADV()
if ({"mask[b_idx * T + t + 1]" if has_mask else "true"}) BODY() ADV()
if ({"mask[b_idx * T + t + 2]" if has_mask else "true"}) BODY() ADV()
if ({"mask[b_idx * T + t + 3]" if has_mask else "true"}) BODY() ADV()
}}
for (; t < T; ++t) {{
if ({mask_source}) BODY()
ADV()
}}
#undef BODY
#undef ADV
for (int i = 0; i < n_per_t; ++i) {{
auto s_idx = n_per_t * dk_idx + i;
o_state[s_idx] = static_cast<StT>(state[i]);
Expand Down