Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions training/stories_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,18 @@
#define SEQ 256
#define NLAYERS 12
#define VOCAB 32000
#define ACCUM_STEPS 10
#define MAX_COMPILES 100
#define ACCUM_STEPS_DEFAULT 10
#define MAX_COMPILES_DEFAULT 100

static inline int get_accum_steps(void) {
const char *env = getenv("ANE_ACCUM_STEPS");
return env ? atoi(env) : ACCUM_STEPS_DEFAULT;
}

static inline int get_max_compiles(void) {
const char *env = getenv("ANE_MAX_COMPILES");
return env ? atoi(env) : MAX_COMPILES_DEFAULT;
}

// Per compile: 5 weight-bearing kernels per layer + 1 classifier = 5*12+1 = 61
// Plus 1 static (sdpaBwd2 per layer, no weights) = 12 more but those are weight-free
Expand Down Expand Up @@ -86,7 +96,7 @@ typedef struct {
} LayerGrads;

// ANE kernels per layer
typedef struct { void *model; IOSurfaceRef ioIn, ioOut; void *request; void *tmpDir; } Kern;
typedef struct { void *model; IOSurfaceRef ioIn, ioOut; void *request; void *tmpDir; size_t inBytes, outBytes; } Kern;
typedef struct {
Kern *fwdAttn, *fwdFFN, *ffnBwd, *sdpaBwd1, *sdpaBwd2, *qkvBwd;
} LayerKernels;
Expand Down
33 changes: 32 additions & 1 deletion training/stories_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,41 @@
#include "stories_config.h"
#include <arm_neon.h>

// IOSurface pool — reuse freed surfaces of the same size
#define IOSURF_POOL_MAX 128
static struct {
IOSurfaceRef surfaces[IOSURF_POOL_MAX];
size_t sizes[IOSURF_POOL_MAX];
int count;
} g_iosurf_pool = { .count = 0 };

static IOSurfaceRef make_surface(size_t bytes) {
// Check pool for matching size
for (int i = 0; i < g_iosurf_pool.count; i++) {
if (g_iosurf_pool.sizes[i] == bytes) {
IOSurfaceRef s = g_iosurf_pool.surfaces[i];
// Swap-remove
g_iosurf_pool.surfaces[i] = g_iosurf_pool.surfaces[--g_iosurf_pool.count];
g_iosurf_pool.sizes[i] = g_iosurf_pool.sizes[g_iosurf_pool.count];
return s;
}
}
return IOSurfaceCreate((__bridge CFDictionaryRef)@{
(id)kIOSurfaceWidth:@(bytes), (id)kIOSurfaceHeight:@1,
(id)kIOSurfaceBytesPerElement:@1, (id)kIOSurfaceBytesPerRow:@(bytes),
(id)kIOSurfaceAllocSize:@(bytes), (id)kIOSurfacePixelFormat:@0});
}

static void pool_return_surface(IOSurfaceRef s, size_t bytes) {
if (g_iosurf_pool.count < IOSURF_POOL_MAX) {
g_iosurf_pool.surfaces[g_iosurf_pool.count] = s;
g_iosurf_pool.sizes[g_iosurf_pool.count] = bytes;
g_iosurf_pool.count++;
} else {
CFRelease(s);
}
}

static NSData *build_blob(const float *w, int rows, int cols) {
int ws=rows*cols*2, tot=128+ws;
uint8_t *b=(uint8_t*)calloc(tot,1);
Expand Down Expand Up @@ -110,6 +138,8 @@ static Kern *compile_kern_mil_w(NSString *mil, NSDictionary *weights, int ic_byt
k->model = (void*)CFBridgingRetain(mdl);
k->ioIn = make_surface(ic_bytes);
k->ioOut = make_surface(oc_bytes);
k->inBytes = ic_bytes;
k->outBytes = oc_bytes;
id wI = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioIn);
id wO = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut);
k->request = (void*)CFBridgingRetain(((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR,
Expand All @@ -123,7 +153,8 @@ static void free_kern(Kern *k) {
if (!k) return;
id mdl = (__bridge id)k->model; NSError *e = nil;
((BOOL(*)(id,SEL,unsigned int,NSError**))objc_msgSend)(mdl, @selector(unloadWithQoS:error:), 21, &e);
CFRelease(k->ioIn); CFRelease(k->ioOut);
pool_return_surface(k->ioIn, k->inBytes);
pool_return_surface(k->ioOut, k->outBytes);
[[NSFileManager defaultManager] removeItemAtPath:(__bridge id)k->tmpDir error:nil];
CFRelease(k->model); CFRelease(k->request); CFRelease(k->tmpDir);
free(k);
Expand Down
30 changes: 19 additions & 11 deletions training/tiny_train.m
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,15 @@ static bool load_checkpoint(const char *path, CkptHeader *hdr,
return true;
}

#define MAX_COMPILES 100
static inline int get_max_compiles_tiny(void) {
const char *env = getenv("ANE_MAX_COMPILES");
return env ? atoi(env) : 100;
}
static inline int get_accum_steps_tiny(void) {
const char *env = getenv("ANE_ACCUM_STEPS");
return env ? atoi(env) : 10;
}
#define KERNELS_PER_STEP 4
#define ACCUM_STEPS 10

// === Pipeline: background compile via GCD ===
typedef struct {
Expand Down Expand Up @@ -231,6 +237,8 @@ int main(int argc, char *argv[]) {
float lr = 1.0f;
int start_step = 0;
bool resuming = false;
int accum_steps = get_accum_steps_tiny();
int max_compiles = get_max_compiles_tiny();

float *W1 = (float*)malloc(H * D * sizeof(float));
float *W2 = (float*)malloc(D * H * sizeof(float));
Expand Down Expand Up @@ -278,12 +286,12 @@ int main(int argc, char *argv[]) {
for (int i = 0; i < D*H; i++) W2[i] = 0.01f * cosf(i * 0.9f + 1.1f);
printf("=== ANE Training: Pipeline Parallel + Grad Accumulation ===\n");
printf("x:[%d,%d] -> W1:[%d,%d] -> ReLU -> W2:[%d,%d] -> y:[%d,%d]\n", S,D, H,D, D,H, S,D);
printf("Accum %d steps per recompile | Pipeline: compile overlaps ANE eval\n", ACCUM_STEPS);
printf("Accum %d steps per recompile | Pipeline: compile overlaps ANE eval\n", accum_steps);
printf("ANE FP16 peak: 15.8 TFLOPS (M4) | Weights: %.1f KB\n\n", weight_bytes/1024.0);
printf("FLOPs/step: ANE=%.0f (fwd+bwd) CPU=%.0f (dW) Total=%.0f\n",
ane_flops_per_step, cpu_flops_per_step, total_flops_per_step);
printf("Steps: %d, LR: %.4f, exec() budget: %d compiles\n\n",
total_steps, lr, MAX_COMPILES);
total_steps, lr, max_compiles);
}

float *x = (float*)calloc(S * D, sizeof(float));
Expand Down Expand Up @@ -332,7 +340,7 @@ int main(int argc, char *argv[]) {
int step = start_step;
while (step < total_steps) {
// Check compile budget
if (g_compile_count + KERNELS_PER_STEP > MAX_COMPILES) {
if (g_compile_count + KERNELS_PER_STEP > max_compiles) {
free_kern(k1_fwd); free_kern(k2_fwd);
free_kern(k1_bwd); free_kern(k2_bwd);
save_checkpoint(CKPT_PATH, step, last_loss, D, H, S, total_steps, lr, W1, W2,
Expand All @@ -358,7 +366,7 @@ int main(int argc, char *argv[]) {
// So we need to update weights BEFORE launching background compile

uint64_t t_batch = mach_absolute_time();
for (int a = 0; a < ACCUM_STEPS && step < total_steps; a++, step++) {
for (int a = 0; a < accum_steps && step < total_steps; a++, step++) {
ane_eval_k(k1_fwd, x, h, D, H, S);
for (int i = 0; i < S*H; i++) h_relu[i] = h[i] > 0 ? h[i] : 0;
ane_eval_k(k2_fwd, h_relu, y, H, D, S);
Expand Down Expand Up @@ -412,7 +420,7 @@ int main(int argc, char *argv[]) {
// Pipeline: launch background compile with updated weights,
// then immediately start NEXT batch's ANE evals with OLD kernels
// while compile runs concurrently on GCD queue
bool can_pipeline = (step < total_steps) && (g_compile_count + KERNELS_PER_STEP <= MAX_COMPILES);
bool can_pipeline = (step < total_steps) && (g_compile_count + KERNELS_PER_STEP <= max_compiles);

if (can_pipeline) {
// Snapshot weights for background compile
Expand Down Expand Up @@ -445,7 +453,7 @@ int main(int argc, char *argv[]) {
int steps_overlap = 0;
uint64_t t_overlap = mach_absolute_time();

for (int a = 0; a < ACCUM_STEPS && step < total_steps; a++, step++) {
for (int a = 0; a < accum_steps && step < total_steps; a++, step++) {
ane_eval_k(k1_fwd, x, h, D, H, S);
for (int i = 0; i < S*H; i++) h_relu[i] = h[i] > 0 ? h[i] : 0;
ane_eval_k(k2_fwd, h_relu, y, H, D, S);
Expand Down Expand Up @@ -552,7 +560,7 @@ int main(int argc, char *argv[]) {
// === Efficiency Report ===
printf("\n=== Efficiency Report ===\n");
printf("Total steps: %d\n", total_steps_done);
printf("Total batches: %d (accum %d steps each)\n", total_batches, ACCUM_STEPS);
printf("Total batches: %d (accum %d steps each)\n", total_batches, accum_steps);
printf("Wall time: %.0f ms\n", total_wall_ms);
printf("Compile time: %.0f ms (%.1f%%)\n", total_compile_ms, 100.0*total_compile_ms/total_wall_ms);
printf("Train time: %.0f ms (%.1f%%)\n", total_train_ms, 100.0*total_train_ms/total_wall_ms);
Expand All @@ -579,8 +587,8 @@ int main(int argc, char *argv[]) {
printf("Weight params: %d (%.1f KB FP16)\n",
H*D + D*H, weight_bytes / 1024.0);
printf("Compile amortization: %.1f ms compile / %d steps = %.2f ms/step overhead\n",
total_compile_ms / total_batches, ACCUM_STEPS,
total_compile_ms / total_batches / ACCUM_STEPS);
total_compile_ms / total_batches, accum_steps,
total_compile_ms / total_batches / accum_steps);
printf("Compile fraction: %.1f%% of wall time\n", 100.0 * total_compile_ms / total_wall_ms);
printf("Train fraction: %.1f%% of wall time (useful work)\n", 100.0 * total_train_ms / total_wall_ms);

Expand Down
10 changes: 6 additions & 4 deletions training/train_large.m
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ int main(int argc, char *argv[]) {
float lr = 3e-4f;
float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f;
int adam_t = 0, start_step = 0;
int accum_steps = get_accum_steps();
int max_compiles = get_max_compiles();

// Parse args
const char *ckpt_path = CKPT_PATH_DEFAULT;
Expand Down Expand Up @@ -270,7 +272,7 @@ int main(int argc, char *argv[]) {
printf("Params: %.2fM (transformer %.2fM + embed %.2fM)\n", tp/1e6, xfmr_params/1e6, embed_params/1e6);
printf("Kernels: %d (%d weight-bearing + %d static sdpaBwd2)\n",
TOTAL_WEIGHT_KERNELS+NLAYERS, TOTAL_WEIGHT_KERNELS, NLAYERS);
printf("Accum %d steps per recompile | Adam LR=%.1e b1=%.1f b2=%.3f\n", ACCUM_STEPS, lr, adam_b1, adam_b2);
printf("Accum %d steps per recompile | Adam LR=%.1e b1=%.1f b2=%.3f\n", accum_steps, lr, adam_b1, adam_b2);
double fwd_f = NLAYERS*(4.0*2*DIM*DIM*SEQ + 2.0*2*DIM*HIDDEN*SEQ + 2.0*HIDDEN*DIM*SEQ);
double bwd_dx_f = fwd_f, bwd_dw_f = fwd_f;
double sdpa_f = NLAYERS*2.0*HEADS*5*SEQ*SEQ*HD;
Expand Down Expand Up @@ -331,7 +333,7 @@ int main(int argc, char *argv[]) {
int step = start_step;
while (step < total_steps) {
// Check compile budget
if (g_compile_count + TOTAL_WEIGHT_KERNELS > MAX_COMPILES) {
if (g_compile_count + TOTAL_WEIGHT_KERNELS > max_compiles) {
for (int L=0; L<NLAYERS; L++) { free_layer_kernels(&kern[L]); free_kern(sdpaBwd2[L]); }
double wall = tb_ms(mach_absolute_time() - t_wall_start);
save_checkpoint(ckpt_path, step, total_steps, lr, last_loss,
Expand All @@ -357,7 +359,7 @@ int main(int argc, char *argv[]) {
compile_ok = false; break;
}
}
if (!compile_ok) { g_compile_count = MAX_COMPILES; continue; }
if (!compile_ok) { g_compile_count = max_compiles; continue; }

// Re-compile sdpaBwd2 if needed (after exec restart)
for (int L=0; L<NLAYERS; L++) {
Expand All @@ -380,7 +382,7 @@ int main(int argc, char *argv[]) {
uint64_t tt = mach_absolute_time();
double t_ane=0,t_io=0,t_elem=0,t_rms=0,t_cblas_wait=0,t_cls=0;

for (int a=0; a<ACCUM_STEPS && step<total_steps; a++, step++) {
for (int a=0; a<accum_steps && step<total_steps; a++, step++) {
uint64_t t0,t1;
// Sample random position in token data
size_t max_pos = n_tokens - SEQ - 1;
Expand Down