Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion training/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ HEADERS_LARGE = stories_config.h stories_io.h stories_mil.h stories_cpu_ops.h
HEADERS_ANE = $(HEADERS_LARGE) ane_rmsnorm_bwd.h ane_classifier.h

train: train.m ane_runtime.h ane_mil_gen.h model.h forward.h backward.h
$(CC) $(CFLAGS) -o $@ train.m $(LDFLAGS)
$(CC) $(CFLAGS) -o $@ train.m $(LDFLAGS) -framework Accelerate

train_large: train_large.m $(HEADERS_LARGE)
$(CC) $(CFLAGS) -o $@ train_large.m $(LDFLAGS) -framework Accelerate
Expand Down
50 changes: 29 additions & 21 deletions training/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,21 @@
#include "forward.h"
#include <math.h>
#include <string.h>
#include <Accelerate/Accelerate.h>

// dW += dy @ x^T — dy: [S, out_dim], x: [S, in_dim], dW: [out_dim, in_dim]
// dW += dy^T @ x — dy: [S, out_dim], x: [S, in_dim], dW: [out_dim, in_dim]
static void cpu_accum_dW(float *dW, const float *dy, const float *x, int S, int out_dim, int in_dim) {
for (int t = 0; t < S; t++)
for (int i = 0; i < out_dim; i++)
for (int j = 0; j < in_dim; j++)
dW[i*in_dim+j] += dy[t*out_dim+i] * x[t*in_dim+j];
cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans,
out_dim, in_dim, S, 1.0f,
dy, out_dim, x, in_dim, 1.0f, dW, in_dim);
}

// dx = W^T @ dy — W: [out_dim, in_dim], dy: [S, out_dim] → dx: [S, in_dim]
static void cpu_matmul_backward_dx(const float *W, const float *dy, float *dx,
int S, int out_dim, int in_dim) {
for (int t = 0; t < S; t++)
for (int j = 0; j < in_dim; j++) {
float sum = 0;
for (int i = 0; i < out_dim; i++)
sum += W[i*in_dim+j] * dy[t*out_dim+i];
dx[t*in_dim+j] = sum;
}
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
S, in_dim, out_dim, 1.0f,
dy, out_dim, W, in_dim, 0.0f, dx, in_dim);
}

static void cpu_rmsnorm_backward(float *dx, const float *dy, const float *x, const float *w,
Expand Down Expand Up @@ -278,18 +274,30 @@ static void model_adam_step(Model *m, float lr, float beta1, float beta2, float
m->adam_step++;
float bc1 = 1.0f - powf(beta1, m->adam_step);
float bc2 = 1.0f - powf(beta2, m->adam_step);
float neg_lr_over_bc1 = -lr / bc1;
float inv_bc2 = 1.0f / bc2;
float one_minus_b1 = 1.0f - beta1;
float one_minus_b2 = 1.0f - beta2;
size_t idx = 0;

// Vectorized Adam update for a contiguous chunk
#define ADAM_UPDATE(param, grad, size) do { \
for (size_t _i = 0; _i < (size_t)(size); _i++) { \
float g = (grad)[_i]; \
m->adam_m[idx] = beta1 * m->adam_m[idx] + (1-beta1) * g; \
m->adam_v[idx] = beta2 * m->adam_v[idx] + (1-beta2) * g * g; \
float m_hat = m->adam_m[idx] / bc1; \
float v_hat = m->adam_v[idx] / bc2; \
(param)[_i] -= lr * m_hat / (sqrtf(v_hat) + eps); \
idx++; \
} \
size_t _n = (size_t)(size); \
float *_m = m->adam_m + idx; \
float *_v = m->adam_v + idx; \
float *_tmp = (float*)malloc(_n * sizeof(float)); \
vDSP_vsmul(_m, 1, &beta1, _m, 1, _n); \
vDSP_vsma((grad), 1, &one_minus_b1, _m, 1, _m, 1, _n); \
vDSP_vsq((grad), 1, _tmp, 1, _n); \
vDSP_vsmul(_v, 1, &beta2, _v, 1, _n); \
vDSP_vsma(_tmp, 1, &one_minus_b2, _v, 1, _v, 1, _n); \
vDSP_vsmul(_v, 1, &inv_bc2, _tmp, 1, _n); \
int _nn = (int)_n; vvsqrtf(_tmp, _tmp, &_nn); \
vDSP_vsadd(_tmp, 1, &eps, _tmp, 1, _n); \
vDSP_vdiv(_tmp, 1, _m, 1, _tmp, 1, _n); \
vDSP_vsma(_tmp, 1, &neg_lr_over_bc1, (param), 1, (param), 1, _n); \
free(_tmp); \
idx += _n; \
} while(0)

int d = m->cfg.dim, hd = m->cfg.hidden_dim, vs = m->cfg.vocab_size;
Expand Down
31 changes: 25 additions & 6 deletions training/stories_cpu_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,31 @@ static void rmsnorm_bwd(float *dx, float *dw, const float *dy, const float *x, c

static void adam_update(float *w, const float *g, AdamState *s, int t, float lr, float b1, float b2, float eps) {
float bc1 = 1.0f - powf(b1, t), bc2 = 1.0f - powf(b2, t);
for (size_t i=0; i<s->n; i++) {
s->m[i] = b1*s->m[i] + (1-b1)*g[i];
s->v[i] = b2*s->v[i] + (1-b2)*g[i]*g[i];
float mh = s->m[i]/bc1, vh = s->v[i]/bc2;
w[i] -= lr * mh / (sqrtf(vh) + eps);
}
size_t n = s->n;
float one_minus_b1 = 1.0f - b1;
float one_minus_b2 = 1.0f - b2;
float neg_lr_over_bc1 = -lr / bc1;
float inv_bc2 = 1.0f / bc2;

// m = b1*m + (1-b1)*g
vDSP_vsmul(s->m, 1, &b1, s->m, 1, n);
vDSP_vsma(g, 1, &one_minus_b1, s->m, 1, s->m, 1, n);

// v = b2*v + (1-b2)*g^2
float *tmp = (float*)malloc(n * sizeof(float));
vDSP_vsq(g, 1, tmp, 1, n);
vDSP_vsmul(s->v, 1, &b2, s->v, 1, n);
vDSP_vsma(tmp, 1, &one_minus_b2, s->v, 1, s->v, 1, n);

// update = m / (sqrt(v/bc2) + eps), then w -= (lr/bc1) * update
vDSP_vsmul(s->v, 1, &inv_bc2, tmp, 1, n);
int nn = (int)n;
vvsqrtf(tmp, tmp, &nn);
vDSP_vsadd(tmp, 1, &eps, tmp, 1, n);
vDSP_vdiv(tmp, 1, s->m, 1, tmp, 1, n);
vDSP_vsma(tmp, 1, &neg_lr_over_bc1, w, 1, w, 1, n);

free(tmp);
}

// Cross-entropy loss + gradient for logits (column-major: [VOCAB, SEQ])
Expand Down