diff --git a/training/Makefile b/training/Makefile index 7f16c1a..a5fbf55 100644 --- a/training/Makefile +++ b/training/Makefile @@ -8,7 +8,7 @@ HEADERS_LARGE = stories_config.h stories_io.h stories_mil.h stories_cpu_ops.h HEADERS_ANE = $(HEADERS_LARGE) ane_rmsnorm_bwd.h ane_classifier.h train: train.m ane_runtime.h ane_mil_gen.h model.h forward.h backward.h - $(CC) $(CFLAGS) -o $@ train.m $(LDFLAGS) + $(CC) $(CFLAGS) -o $@ train.m $(LDFLAGS) -framework Accelerate train_large: train_large.m $(HEADERS_LARGE) $(CC) $(CFLAGS) -o $@ train_large.m $(LDFLAGS) -framework Accelerate diff --git a/training/backward.h b/training/backward.h index 138ea7c..c50ce9f 100644 --- a/training/backward.h +++ b/training/backward.h @@ -4,25 +4,21 @@ #include "forward.h" #include #include +#include -// dW += dy @ x^T — dy: [S, out_dim], x: [S, in_dim], dW: [out_dim, in_dim] +// dW += dy^T @ x — dy: [S, out_dim], x: [S, in_dim], dW: [out_dim, in_dim] static void cpu_accum_dW(float *dW, const float *dy, const float *x, int S, int out_dim, int in_dim) { - for (int t = 0; t < S; t++) - for (int i = 0; i < out_dim; i++) - for (int j = 0; j < in_dim; j++) - dW[i*in_dim+j] += dy[t*out_dim+i] * x[t*in_dim+j]; + cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, + out_dim, in_dim, S, 1.0f, + dy, out_dim, x, in_dim, 1.0f, dW, in_dim); } // dx = W^T @ dy — W: [out_dim, in_dim], dy: [S, out_dim] → dx: [S, in_dim] static void cpu_matmul_backward_dx(const float *W, const float *dy, float *dx, int S, int out_dim, int in_dim) { - for (int t = 0; t < S; t++) - for (int j = 0; j < in_dim; j++) { - float sum = 0; - for (int i = 0; i < out_dim; i++) - sum += W[i*in_dim+j] * dy[t*out_dim+i]; - dx[t*in_dim+j] = sum; - } + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + S, in_dim, out_dim, 1.0f, + dy, out_dim, W, in_dim, 0.0f, dx, in_dim); } static void cpu_rmsnorm_backward(float *dx, const float *dy, const float *x, const float *w, @@ -278,18 +274,30 @@ static void model_adam_step(Model *m, float lr, float beta1, float beta2, float m->adam_step++; float bc1 = 1.0f - powf(beta1, m->adam_step); float bc2 = 1.0f - powf(beta2, m->adam_step); + float neg_lr_over_bc1 = -lr / bc1; + float inv_bc2 = 1.0f / bc2; + float one_minus_b1 = 1.0f - beta1; + float one_minus_b2 = 1.0f - beta2; size_t idx = 0; + // Vectorized Adam update for a contiguous chunk #define ADAM_UPDATE(param, grad, size) do { \ - for (size_t _i = 0; _i < (size_t)(size); _i++) { \ - float g = (grad)[_i]; \ - m->adam_m[idx] = beta1 * m->adam_m[idx] + (1-beta1) * g; \ - m->adam_v[idx] = beta2 * m->adam_v[idx] + (1-beta2) * g * g; \ - float m_hat = m->adam_m[idx] / bc1; \ - float v_hat = m->adam_v[idx] / bc2; \ - (param)[_i] -= lr * m_hat / (sqrtf(v_hat) + eps); \ - idx++; \ - } \ + size_t _n = (size_t)(size); \ + float *_m = m->adam_m + idx; \ + float *_v = m->adam_v + idx; \ + float *_tmp = (float*)malloc(_n * sizeof(float)); \ + vDSP_vsmul(_m, 1, &beta1, _m, 1, _n); \ + vDSP_vsma((grad), 1, &one_minus_b1, _m, 1, _m, 1, _n); \ + vDSP_vsq((grad), 1, _tmp, 1, _n); \ + vDSP_vsmul(_v, 1, &beta2, _v, 1, _n); \ + vDSP_vsma(_tmp, 1, &one_minus_b2, _v, 1, _v, 1, _n); \ + vDSP_vsmul(_v, 1, &inv_bc2, _tmp, 1, _n); \ + int _nn = (int)_n; vvsqrtf(_tmp, _tmp, &_nn); \ + vDSP_vsadd(_tmp, 1, &eps, _tmp, 1, _n); \ + vDSP_vdiv(_tmp, 1, _m, 1, _tmp, 1, _n); \ + vDSP_vsma(_tmp, 1, &neg_lr_over_bc1, (param), 1, (param), 1, _n); \ + free(_tmp); \ + idx += _n; \ } while(0) int d = m->cfg.dim, hd = m->cfg.hidden_dim, vs = m->cfg.vocab_size; diff --git a/training/stories_cpu_ops.h b/training/stories_cpu_ops.h index c9f2cfa..507e6be 100644 --- a/training/stories_cpu_ops.h +++ b/training/stories_cpu_ops.h @@ -55,12 +55,31 @@ static void rmsnorm_bwd(float *dx, float *dw, const float *dy, const float *x, c static void adam_update(float *w, const float *g, AdamState *s, int t, float lr, float b1, float b2, float eps) { float bc1 = 1.0f - powf(b1, t), bc2 = 1.0f - powf(b2, t); - for (size_t i=0; in; i++) { - s->m[i] = b1*s->m[i] + (1-b1)*g[i]; - s->v[i] = b2*s->v[i] + (1-b2)*g[i]*g[i]; - float mh = s->m[i]/bc1, vh = s->v[i]/bc2; - w[i] -= lr * mh / (sqrtf(vh) + eps); - } + size_t n = s->n; + float one_minus_b1 = 1.0f - b1; + float one_minus_b2 = 1.0f - b2; + float neg_lr_over_bc1 = -lr / bc1; + float inv_bc2 = 1.0f / bc2; + + // m = b1*m + (1-b1)*g + vDSP_vsmul(s->m, 1, &b1, s->m, 1, n); + vDSP_vsma(g, 1, &one_minus_b1, s->m, 1, s->m, 1, n); + + // v = b2*v + (1-b2)*g^2 + float *tmp = (float*)malloc(n * sizeof(float)); + vDSP_vsq(g, 1, tmp, 1, n); + vDSP_vsmul(s->v, 1, &b2, s->v, 1, n); + vDSP_vsma(tmp, 1, &one_minus_b2, s->v, 1, s->v, 1, n); + + // update = m / (sqrt(v/bc2) + eps), then w -= (lr/bc1) * update + vDSP_vsmul(s->v, 1, &inv_bc2, tmp, 1, n); + int nn = (int)n; + vvsqrtf(tmp, tmp, &nn); + vDSP_vsadd(tmp, 1, &eps, tmp, 1, n); + vDSP_vdiv(tmp, 1, s->m, 1, tmp, 1, n); + vDSP_vsma(tmp, 1, &neg_lr_over_bc1, w, 1, w, 1, n); + + free(tmp); } // Cross-entropy loss + gradient for logits (column-major: [VOCAB, SEQ])