Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 100 additions & 56 deletions training/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,30 +82,57 @@ static int model_load_weights(Model *m, const char *path) {
fprintf(stderr, "ERROR: failed to read config from %s\n", path);
fclose(f); return -1;
}

if (m->cfg.n_layers < 1 || m->cfg.n_layers > N_LAYERS) {
fprintf(stderr, "ERROR: n_layers (%d) exceeds maximum allowed (%d)\n", m->cfg.n_layers, N_LAYERS);
fclose(f); return -1;
}

if (m->cfg.dim < 1 || m->cfg.dim > 8192 ||
m->cfg.hidden_dim < 1 || m->cfg.hidden_dim > 32768) {
fprintf(stderr, "ERROR: model dimensions out of safe bounds\n");
fclose(f); return -1;
}

bool shared = m->cfg.vocab_size > 0;
if (m->cfg.vocab_size < 0) m->cfg.vocab_size = -m->cfg.vocab_size;

if (m->cfg.vocab_size == 0 || m->cfg.vocab_size > 256000) {
fprintf(stderr, "ERROR: vocab_size out of safe bounds\n");
fclose(f); return -1;
}

printf("Model: dim=%d hidden=%d layers=%d heads=%d vocab=%d seq=%d\n",
m->cfg.dim, m->cfg.hidden_dim, m->cfg.n_layers, m->cfg.n_heads,
m->cfg.vocab_size, m->cfg.seq_len);

int d = m->cfg.dim, hd = m->cfg.hidden_dim, nl = m->cfg.n_layers, vs = m->cfg.vocab_size;
size_t d = (size_t)m->cfg.dim, hd = (size_t)m->cfg.hidden_dim, nl = (size_t)m->cfg.n_layers, vs = (size_t)m->cfg.vocab_size;

m->token_embedding = (float*)malloc(vs * d * sizeof(float));
if (fread(m->token_embedding, sizeof(float), vs * d, f) != (size_t)(vs * d)) {
if (!m->token_embedding) {
fprintf(stderr, "ERROR: OOM allocating token_embedding\n");
fclose(f); return -1;
}
if (fread(m->token_embedding, sizeof(float), vs * d, f) != (vs * d)) {
fprintf(stderr, "ERROR: short read on token_embedding (file truncated?)\n");
fclose(f); return -1;
}

float *rms_att_all = (float*)malloc(nl * d * sizeof(float));
float *wq_all = (float*)malloc(nl * d * d * sizeof(float));
float *wk_all = (float*)malloc(nl * d * d * sizeof(float));
float *wv_all = (float*)malloc(nl * d * d * sizeof(float));
float *wo_all = (float*)malloc(nl * d * d * sizeof(float));
float *wq_all = (float*)malloc(nl * d * d * sizeof(float));
float *wk_all = (float*)malloc(nl * d * d * sizeof(float));
float *wv_all = (float*)malloc(nl * d * d * sizeof(float));
float *wo_all = (float*)malloc(nl * d * d * sizeof(float));
float *rms_ffn_all = (float*)malloc(nl * d * sizeof(float));
float *w1_all = (float*)malloc(nl * hd * d * sizeof(float));
float *w2_all = (float*)malloc(nl * d * hd * sizeof(float));
float *w3_all = (float*)malloc(nl * hd * d * sizeof(float));
float *w1_all = (float*)malloc(nl * hd * d * sizeof(float));
float *w2_all = (float*)malloc(nl * d * hd * sizeof(float));
float *w3_all = (float*)malloc(nl * hd * d * sizeof(float));

if (!rms_att_all || !wq_all || !wk_all || !wv_all || !wo_all ||
!rms_ffn_all || !w1_all || !w2_all || !w3_all) {
fprintf(stderr, "ERROR: OOM allocating layer weights\n");
fclose(f); return -1;
}

#define FREAD_CHECK(buf, count, file, label) do { \
size_t _n = fread(buf, sizeof(float), count, file); \
Expand All @@ -126,26 +153,28 @@ static int model_load_weights(Model *m, const char *path) {
FREAD_CHECK(w2_all, nl * d * hd, f, "w2");
FREAD_CHECK(w3_all, nl * hd * d, f, "w3");

#define SAFE_MALLOC_MEMCPY(dest, src, size) do { \
dest = (float*)malloc(size); \
if (!(dest)) { \
fprintf(stderr, "ERROR: memory allocation failed for size %zu\n", (size_t)(size)); \
fclose(f); return -1; \
} \
memcpy(dest, src, size); \
} while(0)

for (int l = 0; l < nl; l++) {
m->rms_att_w[l] = (float*)malloc(d * sizeof(float));
memcpy(m->rms_att_w[l], rms_att_all + l*d, d * sizeof(float));
m->wq[l] = (float*)malloc(d*d*sizeof(float));
memcpy(m->wq[l], wq_all + l*d*d, d*d*sizeof(float));
m->wk[l] = (float*)malloc(d*d*sizeof(float));
memcpy(m->wk[l], wk_all + l*d*d, d*d*sizeof(float));
m->wv[l] = (float*)malloc(d*d*sizeof(float));
memcpy(m->wv[l], wv_all + l*d*d, d*d*sizeof(float));
m->wo[l] = (float*)malloc(d*d*sizeof(float));
memcpy(m->wo[l], wo_all + l*d*d, d*d*sizeof(float));
m->rms_ffn_w[l] = (float*)malloc(d * sizeof(float));
memcpy(m->rms_ffn_w[l], rms_ffn_all + l*d, d * sizeof(float));
m->w1[l] = (float*)malloc(hd*d*sizeof(float));
memcpy(m->w1[l], w1_all + l*hd*d, hd*d*sizeof(float));
m->w2[l] = (float*)malloc(d*hd*sizeof(float));
memcpy(m->w2[l], w2_all + l*d*hd, d*hd*sizeof(float));
m->w3[l] = (float*)malloc(hd*d*sizeof(float));
memcpy(m->w3[l], w3_all + l*hd*d, hd*d*sizeof(float));
SAFE_MALLOC_MEMCPY(m->rms_att_w[l], rms_att_all + l*d, d * sizeof(float));
SAFE_MALLOC_MEMCPY(m->wq[l], wq_all + l*d*d, d*d*sizeof(float));
SAFE_MALLOC_MEMCPY(m->wk[l], wk_all + l*d*d, d*d*sizeof(float));
SAFE_MALLOC_MEMCPY(m->wv[l], wv_all + l*d*d, d*d*sizeof(float));
SAFE_MALLOC_MEMCPY(m->wo[l], wo_all + l*d*d, d*d*sizeof(float));
SAFE_MALLOC_MEMCPY(m->rms_ffn_w[l], rms_ffn_all + l*d, d * sizeof(float));
SAFE_MALLOC_MEMCPY(m->w1[l], w1_all + l*hd*d, hd*d*sizeof(float));
SAFE_MALLOC_MEMCPY(m->w2[l], w2_all + l*d*hd, d*hd*sizeof(float));
SAFE_MALLOC_MEMCPY(m->w3[l], w3_all + l*hd*d, hd*d*sizeof(float));
}

#undef SAFE_MALLOC_MEMCPY
free(rms_att_all); free(wq_all); free(wk_all); free(wv_all); free(wo_all);
free(rms_ffn_all); free(w1_all); free(w2_all); free(w3_all);

Expand Down Expand Up @@ -246,40 +275,55 @@ static int model_recompile_kernels(Model *m) {
return 0;
}

static void model_alloc_training(Model *m) {
int d = m->cfg.dim, hd = m->cfg.hidden_dim, vs = m->cfg.vocab_size, S = m->seq_len;
static int model_alloc_training(Model *m) {

size_t d = (size_t)m->cfg.dim, hd = (size_t)m->cfg.hidden_dim;
size_t vs = (size_t)m->cfg.vocab_size, S = (size_t)m->seq_len;

#define SAFE_CALLOC(dest, count) do { \
dest = (float*)calloc(count, sizeof(float)); \
if (!(dest)) { \
fprintf(stderr, "ERROR: OOM in model_alloc_training for size %zu\n", (size_t)(count)); \
return -1; \
} \
} while(0)

for (int l = 0; l < N_LAYERS; l++) {
m->act_x[l] = (float*)calloc(S * d, sizeof(float));
m->act_xnorm[l] = (float*)calloc(S * d, sizeof(float));
m->act_q[l] = (float*)calloc(S * d, sizeof(float));
m->act_k[l] = (float*)calloc(S * d, sizeof(float));
m->act_v[l] = (float*)calloc(S * d, sizeof(float));
m->act_attn_out[l] = (float*)calloc(S * d, sizeof(float));
m->act_ffn_in[l] = (float*)calloc(S * d, sizeof(float));
m->act_h1[l] = (float*)calloc(S * hd, sizeof(float));
m->act_h3[l] = (float*)calloc(S * hd, sizeof(float));
m->act_silu[l] = (float*)calloc(S * hd, sizeof(float));

m->grad_wq[l] = (float*)calloc(d * d, sizeof(float));
m->grad_wk[l] = (float*)calloc(d * d, sizeof(float));
m->grad_wv[l] = (float*)calloc(d * d, sizeof(float));
m->grad_wo[l] = (float*)calloc(d * d, sizeof(float));
m->grad_w1[l] = (float*)calloc(hd * d, sizeof(float));
m->grad_w2[l] = (float*)calloc(d * hd, sizeof(float));
m->grad_w3[l] = (float*)calloc(hd * d, sizeof(float));
SAFE_CALLOC(m->act_x[l], S * d);
SAFE_CALLOC(m->act_xnorm[l], S * d);
SAFE_CALLOC(m->act_q[l], S * d);
SAFE_CALLOC(m->act_k[l], S * d);
SAFE_CALLOC(m->act_v[l], S * d);
SAFE_CALLOC(m->act_attn_out[l], S * d);
SAFE_CALLOC(m->act_ffn_in[l], S * d);
SAFE_CALLOC(m->act_h1[l], S * hd);
SAFE_CALLOC(m->act_h3[l], S * hd);
SAFE_CALLOC(m->act_silu[l], S * hd);

SAFE_CALLOC(m->grad_wq[l], d * d);
SAFE_CALLOC(m->grad_wk[l], d * d);
SAFE_CALLOC(m->grad_wv[l], d * d);
SAFE_CALLOC(m->grad_wo[l], d * d);
SAFE_CALLOC(m->grad_w1[l], hd * d);
SAFE_CALLOC(m->grad_w2[l], d * hd);
SAFE_CALLOC(m->grad_w3[l], hd * d);
}
m->act_final = (float*)calloc(S * d, sizeof(float));
m->act_pre_final = (float*)calloc(S * d, sizeof(float));
m->logits = (float*)calloc(S * vs, sizeof(float));
m->grad_wcls = (float*)calloc(vs * d, sizeof(float));
m->grad_emb = (float*)calloc(vs * d, sizeof(float));
SAFE_CALLOC(m->act_final, S * d);
SAFE_CALLOC(m->act_pre_final, S * d);
SAFE_CALLOC(m->logits, S * vs);
SAFE_CALLOC(m->grad_wcls, vs * d);
SAFE_CALLOC(m->grad_emb, vs * d);

m->total_params = 0;
for (int l = 0; l < N_LAYERS; l++)
m->total_params += 4*(size_t)d*d + 2*(size_t)hd*d + (size_t)d*hd;
m->total_params += (size_t)vs * d * 2;
m->adam_m = (float*)calloc(m->total_params, sizeof(float));
m->adam_v = (float*)calloc(m->total_params, sizeof(float));
m->total_params += 4*d*d + 2*hd*d + d*hd;
m->total_params += vs * d * 2;
SAFE_CALLOC(m->adam_m, m->total_params);
SAFE_CALLOC(m->adam_v, m->total_params);
m->adam_step = 0;

#undef SAFE_CALLOC

printf("Total trainable params: %zu (%.1f M)\n", m->total_params, m->total_params/1e6);
return 0;
}