Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
300 changes: 299 additions & 1 deletion convert_hf_to_gguf.py

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ extern "C" {
GGML_OP_FLASH_ATTN_BACK,
GGML_OP_SSM_CONV,
GGML_OP_SSM_SCAN,
GGML_OP_KDA_SCAN,
GGML_OP_WIN_PART,
GGML_OP_WIN_UNPART,
GGML_OP_GET_REL_POS,
Expand Down Expand Up @@ -2336,6 +2337,28 @@ extern "C" {
struct ggml_tensor * C,
struct ggml_tensor * ids);

// KDA (Kimi Delta Attention) scan
// Delta attention recurrence:
// h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t]))
// o[t] = q[t]^T @ h[t]
// Parameters:
// h: hidden state {head_dim, head_dim, n_head, n_seqs+}
// q: query {head_dim, n_head, n_seq_tokens, n_seqs}
// k: key {head_dim, n_head, n_seq_tokens, n_seqs}
// v: value {head_dim, n_head, n_seq_tokens, n_seqs}
// g: gate {head_dim, n_head, n_seq_tokens, n_seqs}
// beta: mixing {n_head, n_seq_tokens, n_seqs}
// ids: seq indices {n_seqs}
GGML_API struct ggml_tensor * ggml_kda_scan(
struct ggml_context * ctx,
struct ggml_tensor * h,
struct ggml_tensor * q,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * g,
struct ggml_tensor * beta,
struct ggml_tensor * ids);

// partition into non-overlapping windows with padding if needed
// example:
// a: 768 64 64 1
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -1962,6 +1962,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_ssm_scan(params, tensor);
} break;
case GGML_OP_KDA_SCAN:
{
ggml_compute_forward_kda_scan(params, tensor);
} break;
case GGML_OP_WIN_PART:
{
ggml_compute_forward_win_part(params, tensor);
Expand Down Expand Up @@ -2320,6 +2324,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_FLASH_ATTN_BACK:
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
case GGML_OP_KDA_SCAN:
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_RWKV_WKV7:
Expand Down
196 changes: 196 additions & 0 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8627,6 +8627,9 @@ static void ggml_compute_forward_ssm_conv_f32(
const int ir1 = MIN(ir0 + dr, nr);
const int ir = ir1 - ir0;

static int conv_debug_count = 0;
bool do_conv_debug = false; // (ith == 0 && conv_debug_count++ < 3);

for (int i3 = 0; i3 < n_s; ++i3) {
for (int i2 = 0; i2 < n_t; ++i2) {
// {d_conv - 1 + n_t, d_inner, n_seqs}
Expand All @@ -8647,6 +8650,13 @@ static void ggml_compute_forward_ssm_conv_f32(
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
}
x[i1] = sumf;

// Debug output
if (do_conv_debug && i1 == 0 && i2 == 0 && i3 == 0) {
fprintf(stderr, "DEBUG SSM_CONV: nc=%d, nr=%d, n_t=%d, n_s=%d\n", nc, nr, n_t, n_s);
fprintf(stderr, "DEBUG SSM_CONV: s[0..3]=%f,%f,%f,%f, c[0..3]=%f,%f,%f,%f, x[0]=%f\n",
s[0], s[1], s[2], s[3], c[0], c[1], c[2], c[3], x[0]);
}
}
}
}
Expand Down Expand Up @@ -8897,6 +8907,192 @@ void ggml_compute_forward_ssm_scan(
}
}

// ggml_compute_forward_kda_scan
// KDA (Kimi Delta Attention) recurrence:
// h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t]))
// o[t] = q[t]^T @ h[t]

static void ggml_compute_forward_kda_scan_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // h {head_dim, head_dim, n_head, n_seqs+}
const ggml_tensor * src1 = dst->src[1]; // q {head_dim, n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src2 = dst->src[2]; // k {head_dim, n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src3 = dst->src[3]; // v {head_dim, n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src4 = dst->src[4]; // g {head_dim, n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src5 = dst->src[5]; // beta {n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}

const int ith = params->ith;
const int nth = params->nth;

const int64_t head_dim = src0->ne[0];
const int64_t n_head = src1->ne[1];
const int64_t n_seq_tokens = src1->ne[2];
const int64_t n_seqs = src1->ne[3];

// Output offset for hidden state
const int64_t y_off = ggml_nelements(src1) * sizeof(float);

GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));

// Parallelize over heads
const int dh = (n_head + nth - 1) / nth;
const int ih0 = dh * ith;
const int ih1 = MIN(ih0 + dh, (int)n_head);

const int32_t * ids = (const int32_t *) src6->data;

// Temporary buffer for h @ k computation
float * hk_buf = (float *) malloc(head_dim * sizeof(float));

static int debug_count = 0;
bool do_debug = false; // (ith == 0 && debug_count++ < 20);

for (int i3 = 0; i3 < n_seqs; ++i3) {
// Get initial hidden state for this sequence
const float * h0 = (const float *) ((const char *) src0->data + ids[i3] * src0->nb[3]);
// Output hidden state location
float * h_out = (float *) ((char *) dst->data + i3 * src0->nb[3] + y_off);

for (int ih = ih0; ih < ih1; ++ih) {
// Per-head hidden state: [head_dim, head_dim]
// Copy initial state to output (will be updated in place)
const float * h_in = h0 + ih * head_dim * head_dim;
float * h = h_out + ih * head_dim * head_dim;

// Copy initial state, but check for invalid values and clear if needed
bool need_clear = false;
for (int i = 0; i < head_dim * head_dim && !need_clear; ++i) {
if (!isfinite(h_in[i]) || fabsf(h_in[i]) > 1e6f) {
need_clear = true;
}
}
for (int i = 0; i < head_dim * head_dim; ++i) {
h[i] = need_clear ? 0.0f : h_in[i];
}

for (int it = 0; it < n_seq_tokens; ++it) {
const float * q_raw = (const float *) ((const char *) src1->data +
it * src1->nb[2] + i3 * src1->nb[3]) + ih * head_dim;
const float * k_raw = (const float *) ((const char *) src2->data +
it * src2->nb[2] + i3 * src2->nb[3]) + ih * head_dim;
const float * v = (const float *) ((const char *) src3->data +
it * src3->nb[2] + i3 * src3->nb[3]) + ih * head_dim;
const float * g = (const float *) ((const char *) src4->data +
it * src4->nb[2] + i3 * src4->nb[3]) + ih * head_dim;
const float beta = ((const float *) ((const char *) src5->data +
it * src5->nb[1] + i3 * src5->nb[2]))[ih];

float * y = (float *) dst->data +
it * n_head * head_dim + i3 * n_seq_tokens * n_head * head_dim + ih * head_dim;

// L2 normalize q and k (critical for KDA stability)
float q_norm = 0.0f, k_norm = 0.0f;
for (int i = 0; i < head_dim; ++i) {
q_norm += q_raw[i] * q_raw[i];
k_norm += k_raw[i] * k_raw[i];
}
q_norm = sqrtf(q_norm + 1e-6f);
k_norm = sqrtf(k_norm + 1e-6f);

// Debug output
if (do_debug && ih == 0 && it == 0 && i3 == 0) {
fprintf(stderr, "DEBUG KDA: q_raw[0]=%f, k_raw[0]=%f, v[0]=%f, g[0]=%f, beta=%f\n",
q_raw[0], k_raw[0], v[0], g[0], beta);
fprintf(stderr, "DEBUG KDA: q_norm=%f, k_norm=%f, exp(g[0])=%f, scale=%f\n",
q_norm, k_norm, expf(g[0]), 1.0f / sqrtf((float)head_dim));
}

// Normalized q and k with scale = 1/sqrt(head_dim)
// Note: scale is applied only to q after L2 normalization
const float scale = 1.0f / sqrtf((float)head_dim);
float q[128], k[128]; // assume head_dim <= 128
for (int i = 0; i < head_dim; ++i) {
// L2 normalize then scale q
q[i] = (q_raw[i] / q_norm) * scale;
// L2 normalize k (no scale)
k[i] = k_raw[i] / k_norm;
}

// KDA recurrence: h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t]))
// Note: Apply decay first, then compute retrieval and update

// Step 1: Apply decay to h first: h = h * exp(g)
for (int i = 0; i < head_dim; ++i) {
const float exp_gi = expf(g[i]);
for (int j = 0; j < head_dim; ++j) {
h[i * head_dim + j] *= exp_gi;
}
}

// Step 2: Compute h^T @ k -> hk_buf [head_dim]
// hk_buf[j] = sum_i (h[i,j] * k[i]) which is column j of h dotted with k
for (int j = 0; j < head_dim; ++j) {
float sum = 0.0f;
for (int i = 0; i < head_dim; ++i) {
sum += h[i * head_dim + j] * k[i];
}
hk_buf[j] = sum;
}

// Step 3: Compute delta = beta * (v - hk) and update h
// h = h + outer(k, delta) where outer(k,delta)[i,j] = k[i] * delta[j]
for (int i = 0; i < head_dim; ++i) {
for (int j = 0; j < head_dim; ++j) {
const float delta_j = beta * (v[j] - hk_buf[j]);
h[i * head_dim + j] += k[i] * delta_j;
}
}

// Step 4: Compute output y = h^T @ q -> [head_dim]
// vLLM: b_o = tl.sum(b_h * b_q[:, None], 0) means o[j] = sum_i(h[i,j] * q[i])
for (int j = 0; j < head_dim; ++j) {
float sum = 0.0f;
for (int i = 0; i < head_dim; ++i) {
sum += h[i * head_dim + j] * q[i];
}
y[j] = sum;
}

// Debug output
if (do_debug && ih == 0 && it == 0 && i3 == 0) {
// Find max abs value in h for stability check
float h_max = 0.0f;
for (int i = 0; i < head_dim * head_dim; i++) {
if (fabsf(h[i]) > h_max) h_max = fabsf(h[i]);
}
fprintf(stderr, "DEBUG KDA: y[0]=%.6f, h_max=%.6f, exp(g[0])=%.6f\n",
y[0], h_max, expf(g[0]));
}
}
}
}

free(hk_buf);
}

void ggml_compute_forward_kda_scan(
const ggml_compute_params * params,
ggml_tensor * dst) {
switch (dst->src[0]->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_kda_scan_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}

// ggml_compute_forward_win_part

static void ggml_compute_forward_win_part_f32(
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cpu/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ void ggml_compute_forward_flash_attn_back(
struct ggml_tensor * dst);
void ggml_compute_forward_ssm_conv(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_kda_scan(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);
Expand Down
9 changes: 9 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "ggml-cuda/softmax.cuh"
#include "ggml-cuda/ssm-conv.cuh"
#include "ggml-cuda/ssm-scan.cuh"
#include "ggml-cuda/kda-scan.cuh"
#include "ggml-cuda/sum.cuh"
#include "ggml-cuda/sumrows.cuh"
#include "ggml-cuda/mean.cuh"
Expand Down Expand Up @@ -2692,6 +2693,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SSM_SCAN:
ggml_cuda_op_ssm_scan(ctx, dst);
break;
case GGML_OP_KDA_SCAN:
ggml_cuda_op_kda_scan(ctx, dst);
break;
case GGML_OP_ARGSORT:
ggml_cuda_op_argsort(ctx, dst);
break;
Expand Down Expand Up @@ -4503,6 +4507,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
}
}
case GGML_OP_KDA_SCAN: {
// KDA scan kernel supports head_dim 64 or 128
const int64_t head_dim = op->src[0]->ne[0];
return head_dim == 64 || head_dim == 128;
}
case GGML_OP_SSM_CONV: {
// assumes d_inner % threads == 0
return op->src[0]->ne[1] % 128 == 0;
Expand Down
Loading