From e0b1b27a3717e5183b953178347f9d791d0ab522 Mon Sep 17 00:00:00 2001 From: "codegen-sh[bot]" <131295404+codegen-sh[bot]@users.noreply.github.com> Date: Mon, 2 Mar 2026 21:26:28 +0000 Subject: [PATCH 1/3] Add pipeline scaffolding for multi-group ANE training New files: - model_config.h: Parameterized model config with presets (Stories42M/110M, LLaMA-1B/7B), pipeline planning (compute_pipeline_plan), memory/FLOP estimation - pipeline.h: Layer-group scheduler (PipelineScheduler state machine), compile budget tracking, mmap-based cross-exec() shared tensor state, exec() restart with automatic resume - gradient_checkpoint.h: Activation checkpointing policies (ALL/BOUNDARY/SQRT/NONE), recompute tracking, memory savings estimation - train_pipeline.m: Entry point with dry-run simulation mode -- prints full execution plan for any model config, simulates scheduler state machine - Makefile: train_pipeline and train_pipeline_live targets All additive -- existing train_large.m untouched. Co-authored-by: dermitchell1993 --- training/Makefile | 80 +++--- training/gradient_checkpoint.h | 170 +++++++++++ training/model_config.h | 310 ++++++++++++++++++++ training/pipeline.h | 497 +++++++++++++++++++++++++++++++++ training/train_pipeline.m | 258 +++++++++++++++++ 5 files changed, 1279 insertions(+), 36 deletions(-) create mode 100644 training/gradient_checkpoint.h create mode 100644 training/model_config.h create mode 100644 training/pipeline.h create mode 100644 training/train_pipeline.m diff --git a/training/Makefile b/training/Makefile index 9cc9e34..b0ff4bc 100644 --- a/training/Makefile +++ b/training/Makefile @@ -1,36 +1,44 @@ -CC = xcrun clang -CFLAGS = -O2 -Wall -Wno-deprecated-declarations -fobjc-arc -FRAMEWORKS = -framework Foundation -framework CoreML -framework IOSurface -LDFLAGS = $(FRAMEWORKS) -ldl - -HEADERS_LARGE = stories_config.h stories_io.h stories_mil.h stories_cpu_ops.h - -train: train.m ane_runtime.h ane_mil_gen.h model.h forward.h backward.h - $(CC) $(CFLAGS) -o $@ train.m $(LDFLAGS) - -train_large: train_large.m $(HEADERS_LARGE) - $(CC) $(CFLAGS) -o $@ train_large.m $(LDFLAGS) -framework Accelerate - -PROBES = test_weight_reload test_perf_stats test_qos_sweep test_ane_advanced - -test_weight_reload: test_weight_reload.m - $(CC) $(CFLAGS) -o $@ $< $(LDFLAGS) - -test_perf_stats: test_perf_stats.m - $(CC) $(CFLAGS) -o $@ $< $(LDFLAGS) - -test_qos_sweep: test_qos_sweep.m - $(CC) $(CFLAGS) -o $@ $< $(LDFLAGS) - -test_ane_advanced: test_ane_advanced.m - $(CC) $(CFLAGS) -o $@ $< $(LDFLAGS) - -probes: $(PROBES) - -tokenize: - python3 tokenize.py - -clean: - rm -f train train_large $(PROBES) - -.PHONY: clean tokenize probes +CC = xcrun clang +CFLAGS = -O2 -Wall -Wno-deprecated-declarations -fobjc-arc +FRAMEWORKS = -framework Foundation -framework CoreML -framework IOSurface +LDFLAGS = $(FRAMEWORKS) -ldl + +HEADERS_LARGE = stories_config.h stories_io.h stories_mil.h stories_cpu_ops.h +HEADERS_PIPELINE = model_config.h pipeline.h gradient_checkpoint.h + +train: train.m ane_runtime.h ane_mil_gen.h model.h forward.h backward.h + $(CC) $(CFLAGS) -o $@ train.m $(LDFLAGS) + +train_large: train_large.m $(HEADERS_LARGE) + $(CC) $(CFLAGS) -o $@ train_large.m $(LDFLAGS) -framework Accelerate + +train_pipeline: train_pipeline.m $(HEADERS_PIPELINE) + $(CC) $(CFLAGS) -o $@ train_pipeline.m $(LDFLAGS) -framework Accelerate + +train_pipeline_live: train_pipeline.m $(HEADERS_PIPELINE) $(HEADERS_LARGE) + $(CC) $(CFLAGS) -DANE_LIVE -o train_pipeline train_pipeline.m $(LDFLAGS) -framework Accelerate + +PROBES = test_weight_reload test_perf_stats test_qos_sweep test_ane_advanced + +test_weight_reload: test_weight_reload.m + $(CC) $(CFLAGS) -o $@ $< $(LDFLAGS) + +test_perf_stats: test_perf_stats.m + $(CC) $(CFLAGS) -o $@ $< $(LDFLAGS) + +test_qos_sweep: test_qos_sweep.m + $(CC) $(CFLAGS) -o $@ $< $(LDFLAGS) + +test_ane_advanced: test_ane_advanced.m + $(CC) $(CFLAGS) -o $@ $< $(LDFLAGS) + +probes: $(PROBES) + +tokenize: + python3 tokenize.py + +clean: + rm -f train train_large train_pipeline $(PROBES) + +.PHONY: clean tokenize probes + diff --git a/training/gradient_checkpoint.h b/training/gradient_checkpoint.h new file mode 100644 index 0000000..4065c04 --- /dev/null +++ b/training/gradient_checkpoint.h @@ -0,0 +1,170 @@ +// gradient_checkpoint.h — Activation checkpointing for deep models +// Trades compute for memory: recompute forward activations during backward +// instead of storing all layers' activations simultaneously +#pragma once +#include "model_config.h" + +// ===== Checkpoint policies ===== + +typedef enum { + CKPT_ALL, // save all layers' activations (current behavior) + CKPT_BOUNDARY, // save only group boundary activations, recompute within group + CKPT_SQRT, // save every √N layers (optimal memory/compute tradeoff) + CKPT_EVERY_N, // save every N-th layer (configurable) + CKPT_NONE // save nothing, recompute everything (max memory savings) +} CheckpointPolicy; + +typedef struct { + CheckpointPolicy policy; + int interval; // for CKPT_EVERY_N: save every N layers + int n_layers; // total layers in model + int n_groups; // layer groups in pipeline + int layers_per_group; // layers per group (from pipeline plan) + // Derived + int n_checkpointed; // how many layers have saved activations + bool *is_saved; // per-layer: true if activation is saved (not recomputed) +} CheckpointManager; + +// ===== Initialization ===== + +static CheckpointManager checkpoint_init(CheckpointPolicy policy, const ModelConfig *cfg, + const PipelinePlan *plan) { + CheckpointManager cm = {0}; + cm.policy = policy; + cm.n_layers = cfg->dims.n_layers; + cm.n_groups = plan->n_groups; + cm.layers_per_group = (plan->n_groups > 0) ? plan->groups[0].n_layers : cfg->dims.n_layers; + cm.is_saved = (bool *)calloc(cfg->dims.n_layers, sizeof(bool)); + + switch (policy) { + case CKPT_ALL: + // Save everything — no recompute needed + for (int i = 0; i < cm.n_layers; i++) cm.is_saved[i] = true; + cm.n_checkpointed = cm.n_layers; + break; + + case CKPT_BOUNDARY: + // Save only the input to each layer group + for (int g = 0; g < plan->n_groups; g++) { + cm.is_saved[plan->groups[g].start_layer] = true; + } + // Always save the last layer's output (needed for loss backward) + cm.is_saved[cm.n_layers - 1] = true; + cm.n_checkpointed = plan->n_groups + 1; + break; + + case CKPT_SQRT: { + // Save every √N layers — optimal memory/compute balance + int interval = (int)sqrtf((float)cm.n_layers); + if (interval < 1) interval = 1; + cm.interval = interval; + for (int i = 0; i < cm.n_layers; i += interval) cm.is_saved[i] = true; + cm.is_saved[cm.n_layers - 1] = true; + cm.n_checkpointed = (cm.n_layers + interval - 1) / interval; + break; + } + + case CKPT_EVERY_N: + // Caller should set cm.interval before using + cm.interval = 4; // default + for (int i = 0; i < cm.n_layers; i += cm.interval) cm.is_saved[i] = true; + cm.is_saved[cm.n_layers - 1] = true; + cm.n_checkpointed = (cm.n_layers + cm.interval - 1) / cm.interval; + break; + + case CKPT_NONE: + // Save nothing except layer 0 input (needed as recompute starting point) + cm.is_saved[0] = true; + cm.n_checkpointed = 1; + break; + } + + return cm; +} + +static void checkpoint_free(CheckpointManager *cm) { + free(cm->is_saved); + cm->is_saved = NULL; +} + +// ===== Query functions ===== + +// Should we save this layer's activations during forward pass? +static bool checkpoint_should_save(const CheckpointManager *cm, int layer_idx) { + if (layer_idx < 0 || layer_idx >= cm->n_layers) return false; + return cm->is_saved[layer_idx]; +} + +// Does this layer need forward recompute during backward pass? +static bool checkpoint_needs_recompute(const CheckpointManager *cm, int layer_idx) { + return !checkpoint_should_save(cm, layer_idx); +} + +// Find the nearest saved checkpoint before this layer (for recompute starting point) +static int checkpoint_nearest_saved_before(const CheckpointManager *cm, int layer_idx) { + for (int i = layer_idx; i >= 0; i--) { + if (cm->is_saved[i]) return i; + } + return 0; // fallback to first layer +} + +// How many layers need recompute between the nearest checkpoint and this layer? +static int checkpoint_recompute_depth(const CheckpointManager *cm, int layer_idx) { + int saved = checkpoint_nearest_saved_before(cm, layer_idx); + return layer_idx - saved; +} + +// ===== Memory estimation ===== + +// Memory for saved activations only (bytes) +static size_t checkpoint_saved_memory(const CheckpointManager *cm, const ModelDims *d) { + return (size_t)cm->n_checkpointed * layer_activation_bytes(d); +} + +// Memory savings vs. saving all layers (bytes) +static size_t checkpoint_memory_saved(const CheckpointManager *cm, const ModelDims *d) { + size_t all = (size_t)cm->n_layers * layer_activation_bytes(d); + size_t used = checkpoint_saved_memory(cm, d); + return all - used; +} + +// Extra forward FLOPs due to recompute (fraction of 1.0) +static double checkpoint_recompute_overhead(const CheckpointManager *cm) { + int recomputed = cm->n_layers - cm->n_checkpointed; + return (double)recomputed / (double)cm->n_layers; +} + +// ===== Pretty-print ===== + +static const char *checkpoint_policy_name(CheckpointPolicy p) { + switch (p) { + case CKPT_ALL: return "ALL"; + case CKPT_BOUNDARY: return "BOUNDARY"; + case CKPT_SQRT: return "SQRT"; + case CKPT_EVERY_N: return "EVERY_N"; + case CKPT_NONE: return "NONE"; + default: return "UNKNOWN"; + } +} + +static void checkpoint_print(const CheckpointManager *cm, const ModelDims *d) { + printf("=== Checkpoint Policy: %s ===\n", checkpoint_policy_name(cm->policy)); + printf(" %d/%d layers checkpointed", cm->n_checkpointed, cm->n_layers); + if (cm->policy == CKPT_SQRT || cm->policy == CKPT_EVERY_N) + printf(" (interval=%d)", cm->interval); + printf("\n"); + printf(" Activation memory: %.1fMB (saved) / %.1fMB (all)\n", + checkpoint_saved_memory(cm, d) / 1e6, + (double)cm->n_layers * layer_activation_bytes(d) / 1e6); + printf(" Memory savings: %.1fMB (%.0f%%)\n", + checkpoint_memory_saved(cm, d) / 1e6, + 100.0 * checkpoint_memory_saved(cm, d) / ((double)cm->n_layers * layer_activation_bytes(d))); + printf(" Recompute overhead: %.0f%% extra forward FLOPs\n", + 100.0 * checkpoint_recompute_overhead(cm)); + printf(" Saved layers: "); + for (int i = 0; i < cm->n_layers; i++) { + if (cm->is_saved[i]) printf("%d ", i); + } + printf("\n"); +} + diff --git a/training/model_config.h b/training/model_config.h new file mode 100644 index 0000000..8e99090 --- /dev/null +++ b/training/model_config.h @@ -0,0 +1,310 @@ +// model_config.h — Parameterized model configuration for pipeline training +// Replaces hardcoded #defines with portable structs + preset configs +#pragma once +#include +#include +#include +#include + +// ===== Model configuration ===== + +typedef struct { + int dim; // model dimension (embedding/residual width) + int hidden_dim; // FFN hidden dimension + int n_heads; // number of attention heads + int n_kv_heads; // number of KV heads (for GQA; == n_heads for MHA) + int n_layers; // total transformer layers + int vocab_size; // vocabulary size + int seq_len; // maximum sequence length + // Derived (computed by model_config_init) + int head_dim; // dim / n_heads + int kv_dim; // head_dim * n_kv_heads + int score_ch; // n_heads * seq_len (attention score channels for SDPA bwd) +} ModelDims; + +typedef struct { + int compile_budget; // max ANE compilations per process (~119) + int kernels_per_layer; // weight-bearing kernels per layer (currently 5) + int static_per_layer; // weight-free kernels per layer (sdpaBwd2 = 1) + int accum_steps; // gradient accumulation steps per compile batch +} CompileConfig; + +typedef struct { + ModelDims dims; + CompileConfig compile; + const char *name; // human-readable model name +} ModelConfig; + +// ===== Layer group for pipeline scheduling ===== + +typedef struct { + int start_layer; // first layer index (inclusive) + int end_layer; // last layer index (exclusive) + int n_layers; // end_layer - start_layer + int weight_kernels; // weight-bearing kernels in this group + int static_kernels; // weight-free kernels in this group + int total_kernels; // weight_kernels + static_kernels +} LayerGroup; + +typedef struct { + LayerGroup *groups; + int n_groups; + int total_exec_restarts; // estimated exec() restarts per training step +} PipelinePlan; + +// ===== Derived dimension helpers ===== + +static void model_dims_init(ModelDims *d) { + d->head_dim = d->dim / d->n_heads; + d->kv_dim = d->head_dim * d->n_kv_heads; + d->score_ch = d->n_heads * d->seq_len; +} + +// ===== Per-layer memory sizes (bytes) ===== + +// Weight sizes in floats (fp32) +static inline size_t wq_size(const ModelDims *d) { return (size_t)d->dim * d->dim; } +static inline size_t wo_size(const ModelDims *d) { return (size_t)d->dim * d->dim; } +static inline size_t w1_size(const ModelDims *d) { return (size_t)d->hidden_dim * d->dim; } +static inline size_t w2_size(const ModelDims *d) { return (size_t)d->dim * d->hidden_dim; } +static inline size_t w3_size(const ModelDims *d) { return (size_t)d->hidden_dim * d->dim; } + +static inline size_t layer_weight_floats(const ModelDims *d) { + return 4 * wq_size(d) // Wq, Wk, Wv, Wo + + w1_size(d) + w2_size(d) + w3_size(d) // W1, W2, W3 + + 2 * (size_t)d->dim; // rms_att, rms_ffn +} + +static inline size_t layer_weight_bytes(const ModelDims *d) { + return layer_weight_floats(d) * sizeof(float); +} + +// Adam state: 2x weight size (m + v vectors) +static inline size_t layer_adam_bytes(const ModelDims *d) { + return 2 * layer_weight_bytes(d); +} + +// Activation buffers per layer (saved for backward) +static inline size_t layer_activation_floats(const ModelDims *d) { + int S = d->seq_len, D = d->dim, H = d->hidden_dim; + // layer_in, xnorm, Q, K, V, attn_out, o_out, x2, x2norm = 9 * D*S + // h1, h3, silu_out = 3 * H*S + // ffn_out = D*S + return (size_t)(10 * D * S + 3 * H * S); +} + +static inline size_t layer_activation_bytes(const ModelDims *d) { + return layer_activation_floats(d) * sizeof(float); +} + +// Gradient accumulators per layer +static inline size_t layer_gradient_bytes(const ModelDims *d) { + return layer_weight_bytes(d); // same layout as weights +} + +// Total model memory (weights + adam + activations + gradients for all layers) +static inline size_t total_model_bytes(const ModelConfig *cfg) { + const ModelDims *d = &cfg->dims; + size_t per_layer = layer_weight_bytes(d) + layer_adam_bytes(d) + + layer_activation_bytes(d) + layer_gradient_bytes(d); + size_t global = (size_t)d->dim * sizeof(float) // rms_final + + (size_t)d->vocab_size * d->dim * sizeof(float) // embed + + (size_t)d->dim * 2 * sizeof(float) // rms_final adam + + (size_t)d->vocab_size * d->dim * 2 * sizeof(float); // embed adam + return per_layer * d->n_layers + global; +} + +// ===== Pipeline planning ===== + +// Compute how many layers can fit in one compile batch +static int max_layers_per_compile(const CompileConfig *cc) { + // Reserve some headroom (90% of budget) for safety + int usable = (int)(cc->compile_budget * 0.9); + int per_layer = cc->kernels_per_layer + cc->static_per_layer; + if (per_layer <= 0) return 1; + return usable / per_layer; +} + +// Compute optimal layer groups for a model given compile budget +// Returns a PipelinePlan (caller must free plan.groups) +static PipelinePlan compute_pipeline_plan(const ModelConfig *cfg) { + PipelinePlan plan = {0}; + int max_per = max_layers_per_compile(&cfg->compile); + if (max_per <= 0) max_per = 1; + + // Clamp to actual layer count + int group_size = (max_per < cfg->dims.n_layers) ? max_per : cfg->dims.n_layers; + + plan.n_groups = (cfg->dims.n_layers + group_size - 1) / group_size; + plan.groups = (LayerGroup *)calloc(plan.n_groups, sizeof(LayerGroup)); + + int kpl = cfg->compile.kernels_per_layer; + int spl = cfg->compile.static_per_layer; + + for (int g = 0; g < plan.n_groups; g++) { + LayerGroup *lg = &plan.groups[g]; + lg->start_layer = g * group_size; + lg->end_layer = lg->start_layer + group_size; + if (lg->end_layer > cfg->dims.n_layers) + lg->end_layer = cfg->dims.n_layers; + lg->n_layers = lg->end_layer - lg->start_layer; + lg->weight_kernels = lg->n_layers * kpl; + lg->static_kernels = lg->n_layers * spl; + lg->total_kernels = lg->weight_kernels + lg->static_kernels; + } + + // Each compile batch needs one exec() restart (except possibly the last) + // Forward: n_groups compiles. Backward: n_groups compiles. + // Per training step: forward + backward = 2 * n_groups compile batches + // Each batch may need exec() restart. Worst case: + plan.total_exec_restarts = 2 * plan.n_groups; + + return plan; +} + +static void pipeline_plan_free(PipelinePlan *plan) { + free(plan->groups); + plan->groups = NULL; + plan->n_groups = 0; +} + +// ===== Pretty-print plan ===== + +static void pipeline_plan_print(const ModelConfig *cfg, const PipelinePlan *plan) { + printf("=== Pipeline Plan: %s ===\n", cfg->name); + printf(" %d layers | dim=%d hidden=%d heads=%d seq=%d vocab=%d\n", + cfg->dims.n_layers, cfg->dims.dim, cfg->dims.hidden_dim, + cfg->dims.n_heads, cfg->dims.seq_len, cfg->dims.vocab_size); + printf(" Compile budget: %d | %d weight-kernels/layer + %d static/layer\n", + cfg->compile.compile_budget, cfg->compile.kernels_per_layer, + cfg->compile.static_per_layer); + printf(" %d layer group(s):\n", plan->n_groups); + for (int g = 0; g < plan->n_groups; g++) { + const LayerGroup *lg = &plan->groups[g]; + printf(" Group %d: layers [%d..%d) — %d layers, %d kernels (%d weight + %d static)\n", + g, lg->start_layer, lg->end_layer, lg->n_layers, + lg->total_kernels, lg->weight_kernels, lg->static_kernels); + } + printf(" Est. exec() restarts per step: %d\n", plan->total_exec_restarts); + printf(" Memory per layer: weights=%.1fMB adam=%.1fMB acts=%.1fMB grads=%.1fMB\n", + layer_weight_bytes(&cfg->dims)/1e6, layer_adam_bytes(&cfg->dims)/1e6, + layer_activation_bytes(&cfg->dims)/1e6, layer_gradient_bytes(&cfg->dims)/1e6); + printf(" Total model state: %.1fMB\n", total_model_bytes(cfg)/1e6); +} + +// ===== FLOP estimation per step ===== + +static inline double flops_per_step(const ModelConfig *cfg) { + const ModelDims *d = &cfg->dims; + int N = d->n_layers, D = d->dim, H = d->hidden_dim, S = d->seq_len; + int HD = d->head_dim, NH = d->n_heads; + // Forward: 4 linear projections (QKV+O) + 3 FFN projections per layer + double fwd = N * (4.0*2*D*D*S + 2.0*2*D*H*S + 2.0*H*D*S); + // Backward dx same flops as forward + double bwd_dx = fwd; + // Backward dW same flops as forward + double bwd_dw = fwd; + // SDPA backward (attention score computation) + double sdpa = N * 2.0 * NH * 5 * S * S * HD; + // Classifier (forward + backward) + double cls = 3.0 * 2.0 * d->vocab_size * D * S; + return fwd + bwd_dx + bwd_dw + sdpa + cls; +} + +static inline double ane_flops_per_step(const ModelConfig *cfg) { + const ModelDims *d = &cfg->dims; + int N = d->n_layers, D = d->dim, H = d->hidden_dim, S = d->seq_len; + int HD = d->head_dim, NH = d->n_heads; + double fwd = N * (4.0*2*D*D*S + 2.0*2*D*H*S + 2.0*H*D*S); + double bwd_dx = fwd; + double sdpa = N * 2.0 * NH * 5 * S * S * HD; + return fwd + bwd_dx + sdpa; // dW is on CPU (cblas) +} + +// ===== Model presets ===== + +static ModelConfig model_config_stories110m(void) { + ModelConfig cfg = {0}; + cfg.name = "Stories110M"; + cfg.dims = (ModelDims){ + .dim = 768, .hidden_dim = 2048, .n_heads = 12, + .n_kv_heads = 12, .n_layers = 12, .vocab_size = 32000, .seq_len = 256 + }; + cfg.compile = (CompileConfig){ + .compile_budget = 119, .kernels_per_layer = 5, + .static_per_layer = 1, .accum_steps = 10 + }; + model_dims_init(&cfg.dims); + return cfg; +} + +static ModelConfig model_config_stories42m(void) { + ModelConfig cfg = {0}; + cfg.name = "Stories42M"; + cfg.dims = (ModelDims){ + .dim = 512, .hidden_dim = 1376, .n_heads = 8, + .n_kv_heads = 8, .n_layers = 8, .vocab_size = 32000, .seq_len = 256 + }; + cfg.compile = (CompileConfig){ + .compile_budget = 119, .kernels_per_layer = 5, + .static_per_layer = 1, .accum_steps = 10 + }; + model_dims_init(&cfg.dims); + return cfg; +} + +static ModelConfig model_config_llama_1b(void) { + ModelConfig cfg = {0}; + cfg.name = "LLaMA-1.1B"; + cfg.dims = (ModelDims){ + .dim = 2048, .hidden_dim = 5504, .n_heads = 16, + .n_kv_heads = 16, .n_layers = 22, .vocab_size = 32000, .seq_len = 512 + }; + cfg.compile = (CompileConfig){ + .compile_budget = 119, .kernels_per_layer = 5, + .static_per_layer = 1, .accum_steps = 4 + }; + model_dims_init(&cfg.dims); + return cfg; +} + +static ModelConfig model_config_llama_7b(void) { + ModelConfig cfg = {0}; + cfg.name = "LLaMA-7B"; + cfg.dims = (ModelDims){ + .dim = 4096, .hidden_dim = 11008, .n_heads = 32, + .n_kv_heads = 32, .n_layers = 32, .vocab_size = 32000, .seq_len = 512 + }; + cfg.compile = (CompileConfig){ + .compile_budget = 119, .kernels_per_layer = 5, + .static_per_layer = 1, .accum_steps = 2 + }; + model_dims_init(&cfg.dims); + return cfg; +} + +// Parse config from command-line (returns preset, caller can override) +static ModelConfig model_config_from_args(int argc, char *argv[]) { + ModelConfig cfg = model_config_stories110m(); // default + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "--model") == 0 && i+1 < argc) { + const char *name = argv[++i]; + if (strcmp(name, "stories42m") == 0) cfg = model_config_stories42m(); + else if (strcmp(name, "stories110m") == 0) cfg = model_config_stories110m(); + else if (strcmp(name, "llama1b") == 0) cfg = model_config_llama_1b(); + else if (strcmp(name, "llama7b") == 0) cfg = model_config_llama_7b(); + else fprintf(stderr, "Unknown model: %s (using stories110m)\n", name); + } + else if (strcmp(argv[i], "--dim") == 0 && i+1 < argc) cfg.dims.dim = atoi(argv[++i]); + else if (strcmp(argv[i], "--hidden") == 0 && i+1 < argc) cfg.dims.hidden_dim = atoi(argv[++i]); + else if (strcmp(argv[i], "--heads") == 0 && i+1 < argc) cfg.dims.n_heads = atoi(argv[++i]); + else if (strcmp(argv[i], "--layers") == 0 && i+1 < argc) cfg.dims.n_layers = atoi(argv[++i]); + else if (strcmp(argv[i], "--seq") == 0 && i+1 < argc) cfg.dims.seq_len = atoi(argv[++i]); + else if (strcmp(argv[i], "--vocab") == 0 && i+1 < argc) cfg.dims.vocab_size = atoi(argv[++i]); + else if (strcmp(argv[i], "--budget") == 0 && i+1 < argc) cfg.compile.compile_budget = atoi(argv[++i]); + else if (strcmp(argv[i], "--accum") == 0 && i+1 < argc) cfg.compile.accum_steps = atoi(argv[++i]); + } + model_dims_init(&cfg.dims); + return cfg; +} + diff --git a/training/pipeline.h b/training/pipeline.h new file mode 100644 index 0000000..c04eb0c --- /dev/null +++ b/training/pipeline.h @@ -0,0 +1,497 @@ +// pipeline.h — Layer-group scheduling and mmap state for multi-group ANE training +// Manages compile budgets, exec() restarts, and cross-exec shared tensor state +#pragma once +#include "model_config.h" +#include +#include +#include +#include +#include + +// ===== Compile budget tracker ===== + +typedef struct { + int budget; // max compilations allowed + int used; // compilations consumed so far + int headroom; // safety margin (budget * 0.1) +} CompileBudget; + +static CompileBudget budget_init(int max_compiles) { + CompileBudget b; + b.budget = max_compiles; + b.used = 0; + b.headroom = max_compiles / 10; + return b; +} + +static bool budget_can_fit(const CompileBudget *b, int n_kernels) { + return (b->used + n_kernels) <= (b->budget - b->headroom); +} + +static void budget_consume(CompileBudget *b, int n_kernels) { + b->used += n_kernels; +} + +static bool budget_needs_restart(const CompileBudget *b) { + return b->used >= (b->budget - b->headroom); +} + +static int budget_remaining(const CompileBudget *b) { + int r = b->budget - b->headroom - b->used; + return (r > 0) ? r : 0; +} + +// ===== Pipeline execution phases ===== + +typedef enum { + PHASE_INIT = 0, + PHASE_FORWARD, // running forward pass through layer groups + PHASE_BACKWARD, // running backward pass through layer groups (reverse) + PHASE_WEIGHT_UPDATE, // Adam step on accumulated gradients + PHASE_DONE // training step complete +} PipelinePhase; + +typedef enum { + ACTION_COMPILE_GROUP, // compile kernels for current layer group + ACTION_RUN_FORWARD_GROUP, // execute forward pass for compiled group + ACTION_RUN_BACKWARD_GROUP, // execute backward pass for compiled group + ACTION_EXEC_RESTART, // save state and exec() to reset compile budget + ACTION_WEIGHT_UPDATE, // run optimizer on all layers + ACTION_STEP_DONE, // training step complete + ACTION_ERROR // something went wrong +} PipelineAction; + +// ===== Scheduler state ===== + +typedef struct { + ModelConfig config; + PipelinePlan plan; + CompileBudget budget; + + PipelinePhase phase; + int current_group; // index into plan.groups + int current_step; // training step number + int accum_step; // gradient accumulation step within batch + int total_steps; // total training steps requested + float learning_rate; + float last_loss; + + // Flags + bool group_compiled; // whether current group's kernels are compiled + bool needs_restart; // whether we need exec() before next group +} PipelineScheduler; + +static PipelineScheduler pipeline_scheduler_init(ModelConfig config, int total_steps, float lr) { + PipelineScheduler s = {0}; + s.config = config; + s.plan = compute_pipeline_plan(&config); + s.budget = budget_init(config.compile.compile_budget); + s.phase = PHASE_FORWARD; + s.current_group = 0; + s.current_step = 0; + s.accum_step = 0; + s.total_steps = total_steps; + s.learning_rate = lr; + s.last_loss = 999.0f; + s.group_compiled = false; + s.needs_restart = false; + return s; +} + +// Get the next action the training loop should take +static PipelineAction pipeline_next_action(PipelineScheduler *s) { + if (s->current_step >= s->total_steps) + return ACTION_STEP_DONE; + + switch (s->phase) { + case PHASE_FORWARD: + if (s->current_group >= s->plan.n_groups) { + // Forward pass complete for all groups — start backward + s->phase = PHASE_BACKWARD; + s->current_group = s->plan.n_groups - 1; + s->group_compiled = false; + return pipeline_next_action(s); + } + if (!s->group_compiled) { + // Check if we have compile budget for this group + LayerGroup *lg = &s->plan.groups[s->current_group]; + if (!budget_can_fit(&s->budget, lg->total_kernels)) { + s->needs_restart = true; + return ACTION_EXEC_RESTART; + } + return ACTION_COMPILE_GROUP; + } + return ACTION_RUN_FORWARD_GROUP; + + case PHASE_BACKWARD: + if (s->current_group < 0) { + // Backward complete — weight update + s->phase = PHASE_WEIGHT_UPDATE; + return ACTION_WEIGHT_UPDATE; + } + if (!s->group_compiled) { + LayerGroup *lg = &s->plan.groups[s->current_group]; + if (!budget_can_fit(&s->budget, lg->total_kernels)) { + s->needs_restart = true; + return ACTION_EXEC_RESTART; + } + return ACTION_COMPILE_GROUP; + } + return ACTION_RUN_BACKWARD_GROUP; + + case PHASE_WEIGHT_UPDATE: + return ACTION_WEIGHT_UPDATE; + + case PHASE_DONE: + return ACTION_STEP_DONE; + + default: + return ACTION_ERROR; + } +} + +// Called after successfully compiling a layer group's kernels +static void pipeline_group_compiled(PipelineScheduler *s) { + LayerGroup *lg = &s->plan.groups[s->current_group]; + budget_consume(&s->budget, lg->total_kernels); + s->group_compiled = true; +} + +// Called after successfully running forward for current group +static void pipeline_forward_group_done(PipelineScheduler *s) { + s->current_group++; + s->group_compiled = false; +} + +// Called after successfully running backward for current group +static void pipeline_backward_group_done(PipelineScheduler *s) { + s->current_group--; + s->group_compiled = false; +} + +// Called after weight update completes +static void pipeline_weight_update_done(PipelineScheduler *s) { + s->accum_step++; + if (s->accum_step >= s->config.compile.accum_steps) { + s->accum_step = 0; + s->current_step++; + } + // Reset for next forward pass + s->phase = PHASE_FORWARD; + s->current_group = 0; + s->group_compiled = false; +} + +// ===== mmap-based cross-exec state ===== +// +// Layout: [Header][Layer 0 weights][Layer 0 adam][Layer 0 grads]...[Global state] +// All tensors stored as fp32. The mmap file persists across exec() restarts. + +#define MMAP_SENTINEL 0x414E4550 // "ANEP" — file format identifier +#define MMAP_VERSION 1 + +typedef struct { + int sentinel; // MMAP_SENTINEL for file identification + int version; + int n_layers; + int dim; + int hidden_dim; + int n_heads; + int vocab_size; + int seq_len; + // Scheduler state (for exec restart) + int phase; + int current_group; + int current_step; + int accum_step; + int total_steps; + int compile_count; // compiles used in current process + int adam_t; // Adam timestep + float learning_rate; + float last_loss; + // Offsets into mmap (bytes from base) + size_t layer_weights_offset; // start of per-layer weight data + size_t layer_adam_offset; // start of per-layer adam state + size_t layer_grads_offset; // start of per-layer gradient accumulators + size_t layer_acts_offset; // start of per-layer activation checkpoints + size_t global_offset; // start of global state (rms_final, embed, etc.) + size_t total_size; // total mmap size + int pad[4]; // alignment +} MmapHeader; + +typedef struct { + int fd; + void *base; + size_t size; + MmapHeader *header; + const char *path; +} MmapState; + +// Compute mmap layout for a given config +static size_t mmap_compute_size(const ModelConfig *cfg) { + const ModelDims *d = &cfg->dims; + size_t header = sizeof(MmapHeader); + // Round up to page boundary + header = (header + 4095) & ~(size_t)4095; + + size_t per_layer_weights = layer_weight_bytes(d); + size_t per_layer_adam = layer_adam_bytes(d); + size_t per_layer_grads = layer_gradient_bytes(d); + size_t per_layer_acts = layer_activation_bytes(d); + + size_t all_layers = (size_t)d->n_layers * (per_layer_weights + per_layer_adam + per_layer_grads + per_layer_acts); + + // Global: rms_final + embed + their adam states + embed gradients + size_t global = (size_t)d->dim * 4 // rms_final + + (size_t)d->vocab_size * d->dim * 4 // embed + + (size_t)d->dim * 2 * 4 // rms_final adam (m+v) + + (size_t)d->vocab_size * d->dim * 2 * 4 // embed adam + + (size_t)d->dim * 4 // rms_final grad + + (size_t)d->vocab_size * d->dim * 4; // embed grad + + return header + all_layers + global; +} + +// Create a new mmap state file +static MmapState *mmap_state_create(const char *path, const ModelConfig *cfg) { + size_t total = mmap_compute_size(cfg); + int fd = open(path, O_RDWR | O_CREAT | O_TRUNC, 0644); + if (fd < 0) { perror("mmap_state_create: open"); return NULL; } + if (ftruncate(fd, total) < 0) { perror("mmap_state_create: ftruncate"); close(fd); return NULL; } + + void *base = mmap(NULL, total, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (base == MAP_FAILED) { perror("mmap_state_create: mmap"); close(fd); return NULL; } + + MmapState *ms = (MmapState *)calloc(1, sizeof(MmapState)); + ms->fd = fd; + ms->base = base; + ms->size = total; + ms->path = path; + ms->header = (MmapHeader *)base; + + // Initialize header + MmapHeader *h = ms->header; + h->sentinel = MMAP_SENTINEL; + h->version = MMAP_VERSION; + h->n_layers = cfg->dims.n_layers; + h->dim = cfg->dims.dim; + h->hidden_dim = cfg->dims.hidden_dim; + h->n_heads = cfg->dims.n_heads; + h->vocab_size = cfg->dims.vocab_size; + h->seq_len = cfg->dims.seq_len; + + // Compute offsets + size_t header_end = (sizeof(MmapHeader) + 4095) & ~(size_t)4095; + const ModelDims *d = &cfg->dims; + size_t pw = layer_weight_bytes(d); + size_t pa = layer_adam_bytes(d); + size_t pg = layer_gradient_bytes(d); + size_t pact = layer_activation_bytes(d); + + h->layer_weights_offset = header_end; + h->layer_adam_offset = h->layer_weights_offset + (size_t)d->n_layers * pw; + h->layer_grads_offset = h->layer_adam_offset + (size_t)d->n_layers * pa; + h->layer_acts_offset = h->layer_grads_offset + (size_t)d->n_layers * pg; + h->global_offset = h->layer_acts_offset + (size_t)d->n_layers * pact; + h->total_size = total; + + return ms; +} + +// Reopen existing mmap state (after exec() restart) +static MmapState *mmap_state_open(const char *path) { + int fd = open(path, O_RDWR); + if (fd < 0) { perror("mmap_state_open: open"); return NULL; } + struct stat st; + if (fstat(fd, &st) < 0) { perror("mmap_state_open: fstat"); close(fd); return NULL; } + + void *base = mmap(NULL, st.st_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (base == MAP_FAILED) { perror("mmap_state_open: mmap"); close(fd); return NULL; } + + MmapHeader *h = (MmapHeader *)base; + if (h->sentinel != MMAP_SENTINEL || h->version != MMAP_VERSION) { + fprintf(stderr, "mmap_state_open: invalid header\n"); + munmap(base, st.st_size); + close(fd); + return NULL; + } + + MmapState *ms = (MmapState *)calloc(1, sizeof(MmapState)); + ms->fd = fd; + ms->base = base; + ms->size = st.st_size; + ms->path = path; + ms->header = h; + return ms; +} + +// Close and unmap (does NOT delete the file) +static void mmap_state_close(MmapState *ms) { + if (!ms) return; + msync(ms->base, ms->size, MS_SYNC); + munmap(ms->base, ms->size); + close(ms->fd); + free(ms); +} + +// Delete the mmap file (call after training completes) +static void mmap_state_destroy(MmapState *ms) { + if (!ms) return; + const char *p = ms->path; + mmap_state_close(ms); + unlink(p); +} + +// ===== Typed accessors into mmap regions ===== + +// Get pointer to layer L's weights in mmap +static float *mmap_layer_weights(MmapState *ms, int layer) { + return (float *)((char *)ms->base + ms->header->layer_weights_offset + + (size_t)layer * layer_weight_bytes(&(ModelDims){ + .dim = ms->header->dim, + .hidden_dim = ms->header->hidden_dim, + .n_heads = ms->header->n_heads, + .vocab_size = ms->header->vocab_size, + .seq_len = ms->header->seq_len + })); +} + +// Get pointer to layer L's adam state in mmap +static float *mmap_layer_adam(MmapState *ms, int layer) { + ModelDims d = { + .dim = ms->header->dim, .hidden_dim = ms->header->hidden_dim, + .n_heads = ms->header->n_heads, .vocab_size = ms->header->vocab_size, + .seq_len = ms->header->seq_len + }; + return (float *)((char *)ms->base + ms->header->layer_adam_offset + + (size_t)layer * layer_adam_bytes(&d)); +} + +// Get pointer to layer L's gradient accumulators in mmap +static float *mmap_layer_grads(MmapState *ms, int layer) { + ModelDims d = { + .dim = ms->header->dim, .hidden_dim = ms->header->hidden_dim, + .n_heads = ms->header->n_heads, .vocab_size = ms->header->vocab_size, + .seq_len = ms->header->seq_len + }; + return (float *)((char *)ms->base + ms->header->layer_grads_offset + + (size_t)layer * layer_gradient_bytes(&d)); +} + +// Get pointer to layer L's activation checkpoint in mmap +static float *mmap_layer_acts(MmapState *ms, int layer) { + ModelDims d = { + .dim = ms->header->dim, .hidden_dim = ms->header->hidden_dim, + .n_heads = ms->header->n_heads, .vocab_size = ms->header->vocab_size, + .seq_len = ms->header->seq_len + }; + return (float *)((char *)ms->base + ms->header->layer_acts_offset + + (size_t)layer * layer_activation_bytes(&d)); +} + +// Get pointer to global state region (rms_final, embed, etc.) +static float *mmap_global(MmapState *ms) { + return (float *)((char *)ms->base + ms->header->global_offset); +} + +// ===== Save/restore scheduler state to/from mmap header ===== + +static void pipeline_save_to_mmap(const PipelineScheduler *s, MmapState *ms) { + MmapHeader *h = ms->header; + h->phase = (int)s->phase; + h->current_group = s->current_group; + h->current_step = s->current_step; + h->accum_step = s->accum_step; + h->total_steps = s->total_steps; + h->learning_rate = s->learning_rate; + h->last_loss = s->last_loss; + msync(ms->base, sizeof(MmapHeader), MS_SYNC); +} + +static void pipeline_restore_from_mmap(PipelineScheduler *s, const MmapState *ms) { + const MmapHeader *h = ms->header; + s->phase = (PipelinePhase)h->phase; + s->current_group = h->current_group; + s->current_step = h->current_step; + s->accum_step = h->accum_step; + s->total_steps = h->total_steps; + s->learning_rate = h->learning_rate; + s->last_loss = h->last_loss; + // Reset compile budget (new process after exec) + s->budget = budget_init(s->config.compile.compile_budget); + s->group_compiled = false; + s->needs_restart = false; +} + +// ===== exec() restart with mmap persistence ===== + +// Call this when ACTION_EXEC_RESTART is returned. +// Saves scheduler state to mmap, syncs, and exec()s. +// Does not return on success. +static void pipeline_exec_restart(PipelineScheduler *s, MmapState *ms, char *argv[]) { + pipeline_save_to_mmap(s, ms); + printf("[pipeline] exec() restart: step=%d phase=%d group=%d compiles=%d\n", + s->current_step, s->phase, s->current_group, s->budget.used); + fflush(stdout); + + // Sync all mmap data before exec + msync(ms->base, ms->size, MS_SYNC); + + // exec with --pipeline-resume flag + execl(argv[0], argv[0], "--pipeline-resume", ms->path, NULL); + perror("pipeline_exec_restart: execl"); +} + +// Resume from exec() restart. Returns true if this is a resume. +static bool pipeline_check_resume(int argc, char *argv[], PipelineScheduler *s, MmapState **ms_out) { + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "--pipeline-resume") == 0 && i+1 < argc) { + const char *mmap_path = argv[i+1]; + MmapState *ms = mmap_state_open(mmap_path); + if (!ms) { + fprintf(stderr, "[pipeline] Failed to reopen mmap at %s\n", mmap_path); + return false; + } + pipeline_restore_from_mmap(s, ms); + *ms_out = ms; + printf("[pipeline] Resumed: step=%d phase=%d group=%d\n", + s->current_step, s->phase, s->current_group); + return true; + } + } + return false; +} + +// ===== Pipeline pretty-print helpers ===== + +static const char *phase_name(PipelinePhase p) { + switch (p) { + case PHASE_INIT: return "INIT"; + case PHASE_FORWARD: return "FORWARD"; + case PHASE_BACKWARD: return "BACKWARD"; + case PHASE_WEIGHT_UPDATE: return "WEIGHT_UPDATE"; + case PHASE_DONE: return "DONE"; + default: return "UNKNOWN"; + } +} + +static const char *action_name(PipelineAction a) { + switch (a) { + case ACTION_COMPILE_GROUP: return "COMPILE_GROUP"; + case ACTION_RUN_FORWARD_GROUP: return "RUN_FORWARD_GROUP"; + case ACTION_RUN_BACKWARD_GROUP: return "RUN_BACKWARD_GROUP"; + case ACTION_EXEC_RESTART: return "EXEC_RESTART"; + case ACTION_WEIGHT_UPDATE: return "WEIGHT_UPDATE"; + case ACTION_STEP_DONE: return "STEP_DONE"; + case ACTION_ERROR: return "ERROR"; + default: return "UNKNOWN"; + } +} + +static void pipeline_print_status(const PipelineScheduler *s) { + printf("[pipeline] step=%d/%d accum=%d/%d phase=%s group=%d/%d budget=%d/%d\n", + s->current_step, s->total_steps, + s->accum_step, s->config.compile.accum_steps, + phase_name(s->phase), s->current_group, s->plan.n_groups, + s->budget.used, s->budget.budget); +} diff --git a/training/train_pipeline.m b/training/train_pipeline.m new file mode 100644 index 0000000..ef055b8 --- /dev/null +++ b/training/train_pipeline.m @@ -0,0 +1,258 @@ +// train_pipeline.m — Pipeline-scheduled multi-group ANE training +// +// Entry point that uses the pipeline scaffolding to train models +// beyond the single-compile-batch limit. +// +// Architecture: +// ModelConfig → what the model looks like +// PipelinePlan → which layers go in which compile groups +// PipelineScheduler → state machine driving forward/backward/restart +// MmapState → cross-exec() shared memory for all tensor state +// CheckpointManager → activation save/recompute policy +// +// Usage: +// ./train_pipeline --model stories110m --steps 100 --lr 3e-4 +// ./train_pipeline --model llama1b --steps 50 --lr 1e-4 --checkpoint sqrt +// ./train_pipeline --pipeline-resume /tmp/ane_pipeline.mmap (auto after exec restart) +// +// Build: +// make train_pipeline +// +// Currently runs in planning/dry-run mode — prints the full execution +// plan and simulates the scheduler state machine without compiling +// actual MIL programs. Enable ANE_LIVE for real kernels. + +#import +#import +#import +#import +#import + +#include "model_config.h" +#include "pipeline.h" +#include "gradient_checkpoint.h" + +#define MMAP_PATH "/tmp/ane_pipeline.mmap" + +// ===== Forward declarations for ANE kernel operations ===== +// These would call into stories_io.h / stories_mil.h for real execution. +// Stubbed here for planning mode. + +#ifdef ANE_LIVE +#include "stories_io.h" +#include "stories_mil.h" +#include "stories_cpu_ops.h" +// Real ANE kernel compilation and execution would go here +#endif + +// ===== Dry-run simulation ===== + +static void simulate_compile_group(const PipelineScheduler *s, const LayerGroup *lg) { + printf(" [compile] Layers [%d..%d): %d weight-bearing + %d static kernels\n", + lg->start_layer, lg->end_layer, lg->weight_kernels, lg->static_kernels); + printf(" Budget: %d/%d used → %d/%d after\n", + s->budget.used, s->budget.budget, + s->budget.used + lg->total_kernels, s->budget.budget); +} + +static void simulate_forward_group(const PipelineScheduler *s, const LayerGroup *lg, + const CheckpointManager *cm) { + printf(" [forward] Layers [%d..%d)\n", lg->start_layer, lg->end_layer); + for (int L = lg->start_layer; L < lg->end_layer; L++) { + bool save = checkpoint_should_save(cm, L); + printf(" L%02d: fwdAttn → residual → fwdFFN → residual %s\n", + L, save ? "[SAVE acts]" : "[skip acts]"); + } +} + +static void simulate_backward_group(const PipelineScheduler *s, const LayerGroup *lg, + const CheckpointManager *cm) { + printf(" [backward] Layers [%d..%d) (reverse)\n", lg->start_layer, lg->end_layer); + for (int L = lg->end_layer - 1; L >= lg->start_layer; L--) { + bool recompute = checkpoint_needs_recompute(cm, L); + if (recompute) { + int from = checkpoint_nearest_saved_before(cm, L); + printf(" L%02d: [RECOMPUTE from L%02d] → ffnBwd → rmsnorm2_bwd → sdpaBwd1 → sdpaBwd2 → qkvBwd → rmsnorm1_bwd\n", + L, from); + } else { + printf(" L%02d: ffnBwd → rmsnorm2_bwd → sdpaBwd1 → sdpaBwd2 → qkvBwd → rmsnorm1_bwd\n", L); + } + } +} + +// ===== Main ===== + +int main(int argc, char *argv[]) { + @autoreleasepool { + // Parse model config from command line + ModelConfig cfg = model_config_from_args(argc, argv); + + // Parse additional training args + int total_steps = 100; + float lr = 3e-4f; + bool dry_run = true; + CheckpointPolicy ckpt_policy = CKPT_ALL; + + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "--steps") == 0 && i+1 < argc) total_steps = atoi(argv[++i]); + else if (strcmp(argv[i], "--lr") == 0 && i+1 < argc) lr = atof(argv[++i]); + else if (strcmp(argv[i], "--live") == 0) dry_run = false; + else if (strcmp(argv[i], "--checkpoint") == 0 && i+1 < argc) { + const char *p = argv[++i]; + if (strcmp(p, "all") == 0) ckpt_policy = CKPT_ALL; + else if (strcmp(p, "boundary") == 0) ckpt_policy = CKPT_BOUNDARY; + else if (strcmp(p, "sqrt") == 0) ckpt_policy = CKPT_SQRT; + else if (strcmp(p, "none") == 0) ckpt_policy = CKPT_NONE; + else fprintf(stderr, "Unknown checkpoint policy: %s\n", p); + } + } + + // Check for exec() resume + PipelineScheduler sched = pipeline_scheduler_init(cfg, total_steps, lr); + MmapState *ms = NULL; + + if (pipeline_check_resume(argc, argv, &sched, &ms)) { + printf("[pipeline] Resumed from exec() restart\n"); + } else { + // Fresh start + printf("=== ANE Pipeline Training ===\n"); + if (dry_run) printf(" ** DRY RUN MODE — no ANE kernels compiled **\n\n"); + + // Print model config and pipeline plan + PipelinePlan plan = compute_pipeline_plan(&cfg); + pipeline_plan_print(&cfg, &plan); + printf("\n"); + + // Print checkpoint policy + CheckpointManager cm = checkpoint_init(ckpt_policy, &cfg, &plan); + checkpoint_print(&cm, &cfg.dims); + printf("\n"); + + // Print FLOP estimates + double total_flops = flops_per_step(&cfg); + double ane_flops = ane_flops_per_step(&cfg); + printf("=== Compute Estimate ===\n"); + printf(" FLOPs/step: %.0fM total, %.0fM ANE (%.0f%% on-engine)\n", + total_flops/1e6, ane_flops/1e6, 100.0*ane_flops/total_flops); + printf(" At 15.8 TFLOPS ANE: %.1f ms/step theoretical minimum\n", + ane_flops / 15.8e9); + printf(" Training: %d steps × %d accum = %d optimizer updates\n", + total_steps, cfg.compile.accum_steps, total_steps / cfg.compile.accum_steps); + printf("\n"); + + // Print mmap state size + size_t mmap_sz = mmap_compute_size(&cfg); + printf("=== State Management ===\n"); + printf(" mmap file: %s (%.1fMB)\n", MMAP_PATH, mmap_sz/1e6); + printf(" Per-layer: weights=%.1fMB adam=%.1fMB grads=%.1fMB acts=%.1fMB\n", + layer_weight_bytes(&cfg.dims)/1e6, layer_adam_bytes(&cfg.dims)/1e6, + layer_gradient_bytes(&cfg.dims)/1e6, layer_activation_bytes(&cfg.dims)/1e6); + printf("\n"); + + // Create mmap state + ms = mmap_state_create(MMAP_PATH, &cfg); + if (!ms) { + fprintf(stderr, "Failed to create mmap state\n"); + checkpoint_free(&cm); + pipeline_plan_free(&plan); + return 1; + } + + if (dry_run) { + // ===== Simulate the full scheduler state machine ===== + printf("=== Execution Trace (1 training step) ===\n"); + int max_actions = 200; // safety limit + int action_count = 0; + + while (action_count < max_actions) { + PipelineAction action = pipeline_next_action(&sched); + action_count++; + + printf("\n [%d] %s (phase=%s group=%d)\n", + action_count, action_name(action), + phase_name(sched.phase), sched.current_group); + + switch (action) { + case ACTION_COMPILE_GROUP: { + LayerGroup *lg = &sched.plan.groups[sched.current_group]; + simulate_compile_group(&sched, lg); + pipeline_group_compiled(&sched); + break; + } + case ACTION_RUN_FORWARD_GROUP: { + LayerGroup *lg = &sched.plan.groups[sched.current_group]; + simulate_forward_group(&sched, lg, &cm); + pipeline_forward_group_done(&sched); + break; + } + case ACTION_RUN_BACKWARD_GROUP: { + LayerGroup *lg = &sched.plan.groups[sched.current_group]; + simulate_backward_group(&sched, lg, &cm); + pipeline_backward_group_done(&sched); + break; + } + case ACTION_EXEC_RESTART: + printf(" [exec] Would restart process to reset compile budget\n"); + printf(" Saving scheduler state to mmap, calling exec()\n"); + // In dry-run, just reset the budget and continue + sched.budget = budget_init(cfg.compile.compile_budget); + sched.needs_restart = false; + break; + + case ACTION_WEIGHT_UPDATE: + printf(" [adam] Optimizer step on all %d layers + global params\n", + cfg.dims.n_layers); + printf(" LR=%.1e adam_t=%d\n", sched.learning_rate, sched.current_step+1); + pipeline_weight_update_done(&sched); + break; + + case ACTION_STEP_DONE: + printf("\n=== Training step complete ===\n"); + goto done_trace; + + case ACTION_ERROR: + printf(" ERROR in scheduler\n"); + goto done_trace; + } + } + done_trace: + + printf("\nTotal actions simulated: %d\n", action_count); + printf("Compile budget consumed: %d/%d\n", sched.budget.used, sched.budget.budget); + + // Summary for multi-group models + if (plan.n_groups > 1) { + printf("\n=== Multi-Group Pipeline Summary ===\n"); + printf(" This model requires %d layer groups per training step\n", plan.n_groups); + printf(" Forward pass: %d compile batches (left to right)\n", plan.n_groups); + printf(" Backward pass: %d compile batches (right to left)\n", plan.n_groups); + printf(" Each compile batch may need exec() restart\n"); + printf(" All tensor state survives restarts via mmap (%s)\n", MMAP_PATH); + printf(" Checkpoint policy '%s' saves %d/%d layer activations (%.0f%% memory reduction)\n", + checkpoint_policy_name(ckpt_policy), cm.n_checkpointed, cm.n_layers, + 100.0 * checkpoint_memory_saved(&cm, &cfg.dims) / + ((double)cm.n_layers * layer_activation_bytes(&cfg.dims))); + } + + checkpoint_free(&cm); + } else { + // ===== Live training mode ===== +#ifdef ANE_LIVE + printf("Live training not yet implemented — use train_large.m for Stories110M\n"); + printf("This entry point will be wired up once the scaffolding is validated.\n"); +#else + printf("Compiled without ANE_LIVE — use --live with ANE_LIVE defined.\n"); + printf("Build with: xcrun clang -DANE_LIVE -O2 ... train_pipeline.m\n"); +#endif + checkpoint_free(&cm); + } + + pipeline_plan_free(&plan); + } + + // Cleanup + if (ms) mmap_state_destroy(ms); + } + return 0; +} + From 32b5c72a8526c3c79b8738f7ed97373a3f66ac40 Mon Sep 17 00:00:00 2001 From: "codegen-sh[bot]" <131295404+codegen-sh[bot]@users.noreply.github.com> Date: Mon, 2 Mar 2026 23:44:36 +0000 Subject: [PATCH 2/3] Address review feedback: configurable headroom, mmap hardening, unit tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - model_config.h: Added headroom_pct field to CompileConfig, used in max_layers_per_compile() with validation (falls back to 10% for invalid values). All presets include default. --headroom CLI flag added. - pipeline.h: Tightened mmap error handling — calloc checks, size validation in mmap_state_open (file size vs header, truncation detection), sentinel/version in error message, msync/munmap return checks in close. - test_pipeline_unit.c: 23 unit tests for model_config, pipeline planning, gradient checkpoint, and FLOP estimation. Pure C, no ANE dependency. All passing. Co-authored-by: dermitchell1993 --- training/Makefile | 6 +- training/model_config.h | 16 +- training/pipeline.h | 28 ++- training/test_pipeline_unit.c | 397 ++++++++++++++++++++++++++++++++++ 4 files changed, 434 insertions(+), 13 deletions(-) create mode 100644 training/test_pipeline_unit.c diff --git a/training/Makefile b/training/Makefile index b0ff4bc..adba2c6 100644 --- a/training/Makefile +++ b/training/Makefile @@ -37,8 +37,10 @@ probes: $(PROBES) tokenize: python3 tokenize.py +test_pipeline_unit: test_pipeline_unit.c $(HEADERS_PIPELINE) + cc -O2 -Wall -o $@ $< -lm + clean: - rm -f train train_large train_pipeline $(PROBES) + rm -f train train_large train_pipeline test_pipeline_unit $(PROBES) .PHONY: clean tokenize probes - diff --git a/training/model_config.h b/training/model_config.h index 8e99090..ddd9229 100644 --- a/training/model_config.h +++ b/training/model_config.h @@ -27,6 +27,7 @@ typedef struct { int kernels_per_layer; // weight-bearing kernels per layer (currently 5) int static_per_layer; // weight-free kernels per layer (sdpaBwd2 = 1) int accum_steps; // gradient accumulation steps per compile batch + float headroom_pct; // safety margin as fraction of budget (0.0-1.0, default 0.10) } CompileConfig; typedef struct { @@ -118,8 +119,9 @@ static inline size_t total_model_bytes(const ModelConfig *cfg) { // Compute how many layers can fit in one compile batch static int max_layers_per_compile(const CompileConfig *cc) { - // Reserve some headroom (90% of budget) for safety - int usable = (int)(cc->compile_budget * 0.9); + float headroom = (cc->headroom_pct > 0.0f && cc->headroom_pct < 1.0f) + ? cc->headroom_pct : 0.10f; + int usable = (int)(cc->compile_budget * (1.0f - headroom)); int per_layer = cc->kernels_per_layer + cc->static_per_layer; if (per_layer <= 0) return 1; return usable / per_layer; @@ -232,7 +234,7 @@ static ModelConfig model_config_stories110m(void) { }; cfg.compile = (CompileConfig){ .compile_budget = 119, .kernels_per_layer = 5, - .static_per_layer = 1, .accum_steps = 10 + .static_per_layer = 1, .accum_steps = 10, .headroom_pct = 0.10f }; model_dims_init(&cfg.dims); return cfg; @@ -247,7 +249,7 @@ static ModelConfig model_config_stories42m(void) { }; cfg.compile = (CompileConfig){ .compile_budget = 119, .kernels_per_layer = 5, - .static_per_layer = 1, .accum_steps = 10 + .static_per_layer = 1, .accum_steps = 10, .headroom_pct = 0.10f }; model_dims_init(&cfg.dims); return cfg; @@ -262,7 +264,7 @@ static ModelConfig model_config_llama_1b(void) { }; cfg.compile = (CompileConfig){ .compile_budget = 119, .kernels_per_layer = 5, - .static_per_layer = 1, .accum_steps = 4 + .static_per_layer = 1, .accum_steps = 4, .headroom_pct = 0.10f }; model_dims_init(&cfg.dims); return cfg; @@ -277,7 +279,7 @@ static ModelConfig model_config_llama_7b(void) { }; cfg.compile = (CompileConfig){ .compile_budget = 119, .kernels_per_layer = 5, - .static_per_layer = 1, .accum_steps = 2 + .static_per_layer = 1, .accum_steps = 2, .headroom_pct = 0.10f }; model_dims_init(&cfg.dims); return cfg; @@ -303,8 +305,8 @@ static ModelConfig model_config_from_args(int argc, char *argv[]) { else if (strcmp(argv[i], "--vocab") == 0 && i+1 < argc) cfg.dims.vocab_size = atoi(argv[++i]); else if (strcmp(argv[i], "--budget") == 0 && i+1 < argc) cfg.compile.compile_budget = atoi(argv[++i]); else if (strcmp(argv[i], "--accum") == 0 && i+1 < argc) cfg.compile.accum_steps = atoi(argv[++i]); + else if (strcmp(argv[i], "--headroom") == 0 && i+1 < argc) cfg.compile.headroom_pct = atof(argv[++i]); } model_dims_init(&cfg.dims); return cfg; } - diff --git a/training/pipeline.h b/training/pipeline.h index c04eb0c..798f6d6 100644 --- a/training/pipeline.h +++ b/training/pipeline.h @@ -263,6 +263,7 @@ static MmapState *mmap_state_create(const char *path, const ModelConfig *cfg) { if (base == MAP_FAILED) { perror("mmap_state_create: mmap"); close(fd); return NULL; } MmapState *ms = (MmapState *)calloc(1, sizeof(MmapState)); + if (!ms) { perror("mmap_state_create: calloc"); munmap(base, total); close(fd); return NULL; } ms->fd = fd; ms->base = base; ms->size = total; @@ -308,15 +309,32 @@ static MmapState *mmap_state_open(const char *path) { void *base = mmap(NULL, st.st_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); if (base == MAP_FAILED) { perror("mmap_state_open: mmap"); close(fd); return NULL; } + if ((size_t)st.st_size < sizeof(MmapHeader)) { + fprintf(stderr, "mmap_state_open: file too small (%lld bytes)\n", (long long)st.st_size); + munmap(base, st.st_size); + close(fd); + return NULL; + } + MmapHeader *h = (MmapHeader *)base; if (h->sentinel != MMAP_SENTINEL || h->version != MMAP_VERSION) { - fprintf(stderr, "mmap_state_open: invalid header\n"); + fprintf(stderr, "mmap_state_open: invalid header (sentinel=0x%08x version=%d)\n", + h->sentinel, h->version); + munmap(base, st.st_size); + close(fd); + return NULL; + } + + if (h->total_size != 0 && (size_t)st.st_size < h->total_size) { + fprintf(stderr, "mmap_state_open: file truncated (expected %zu, got %lld)\n", + h->total_size, (long long)st.st_size); munmap(base, st.st_size); close(fd); return NULL; } MmapState *ms = (MmapState *)calloc(1, sizeof(MmapState)); + if (!ms) { perror("mmap_state_open: calloc"); munmap(base, st.st_size); close(fd); return NULL; } ms->fd = fd; ms->base = base; ms->size = st.st_size; @@ -328,9 +346,11 @@ static MmapState *mmap_state_open(const char *path) { // Close and unmap (does NOT delete the file) static void mmap_state_close(MmapState *ms) { if (!ms) return; - msync(ms->base, ms->size, MS_SYNC); - munmap(ms->base, ms->size); - close(ms->fd); + if (ms->base && ms->base != MAP_FAILED) { + if (msync(ms->base, ms->size, MS_SYNC) < 0) perror("mmap_state_close: msync"); + if (munmap(ms->base, ms->size) < 0) perror("mmap_state_close: munmap"); + } + if (ms->fd >= 0) close(ms->fd); free(ms); } diff --git a/training/test_pipeline_unit.c b/training/test_pipeline_unit.c new file mode 100644 index 0000000..b641911 --- /dev/null +++ b/training/test_pipeline_unit.c @@ -0,0 +1,397 @@ +// test_pipeline_unit.c — Unit tests for pipeline scheduler + checkpoint manager +// Pure C, no ANE dependency. Validates state machine transitions and checkpoint logic. +// Build: cc -O2 -o test_pipeline_unit test_pipeline_unit.c -lm +// Run: ./test_pipeline_unit +#include +#include +#include +#include +#include +#include + +// Stub out mmap/exec dependencies — we only test the pure logic +#define _PIPELINE_SKIP_MMAP 1 + +#include "model_config.h" +#include "gradient_checkpoint.h" + +// ===== Test helpers ===== + +static int tests_run = 0; +static int tests_passed = 0; + +#define TEST(name) do { \ + tests_run++; \ + printf(" %-50s", name); \ +} while(0) + +#define PASS() do { tests_passed++; printf("PASS\n"); } while(0) +#define FAIL(msg) do { printf("FAIL: %s\n", msg); } while(0) + +#define ASSERT_EQ(a, b, msg) do { \ + if ((a) != (b)) { FAIL(msg); printf(" got %d, expected %d\n", (int)(a), (int)(b)); return; } \ +} while(0) + +#define ASSERT_TRUE(cond, msg) do { \ + if (!(cond)) { FAIL(msg); return; } \ +} while(0) + +// ===== model_config.h tests ===== + +static void test_dims_init(void) { + TEST("model_dims_init computes derived fields"); + ModelDims d = {.dim = 768, .n_heads = 12, .n_kv_heads = 12, .seq_len = 256}; + model_dims_init(&d); + ASSERT_EQ(d.head_dim, 64, "head_dim = dim / n_heads"); + ASSERT_EQ(d.kv_dim, 768, "kv_dim = head_dim * n_kv_heads"); + ASSERT_EQ(d.score_ch, 12 * 256, "score_ch = n_heads * seq_len"); + PASS(); +} + +static void test_stories110m_preset(void) { + TEST("Stories110M preset"); + ModelConfig cfg = model_config_stories110m(); + ASSERT_EQ(cfg.dims.dim, 768, "dim"); + ASSERT_EQ(cfg.dims.n_layers, 12, "n_layers"); + ASSERT_EQ(cfg.dims.n_heads, 12, "n_heads"); + ASSERT_EQ(cfg.compile.compile_budget, 119, "compile_budget"); + ASSERT_TRUE(cfg.compile.headroom_pct > 0.0f, "headroom > 0"); + PASS(); +} + +static void test_llama7b_preset(void) { + TEST("LLaMA-7B preset"); + ModelConfig cfg = model_config_llama_7b(); + ASSERT_EQ(cfg.dims.dim, 4096, "dim"); + ASSERT_EQ(cfg.dims.n_layers, 32, "n_layers"); + ASSERT_EQ(cfg.dims.hidden_dim, 11008, "hidden_dim"); + PASS(); +} + +static void test_layer_memory_nonzero(void) { + TEST("Per-layer memory sizes are nonzero"); + ModelConfig cfg = model_config_stories110m(); + ASSERT_TRUE(layer_weight_bytes(&cfg.dims) > 0, "weight bytes"); + ASSERT_TRUE(layer_adam_bytes(&cfg.dims) > 0, "adam bytes"); + ASSERT_TRUE(layer_activation_bytes(&cfg.dims) > 0, "activation bytes"); + ASSERT_TRUE(layer_gradient_bytes(&cfg.dims) > 0, "gradient bytes"); + ASSERT_TRUE(total_model_bytes(&cfg) > 0, "total model bytes"); + PASS(); +} + +static void test_adam_is_2x_weights(void) { + TEST("Adam state = 2x weight size"); + ModelConfig cfg = model_config_stories110m(); + ASSERT_EQ(layer_adam_bytes(&cfg.dims), 2 * layer_weight_bytes(&cfg.dims), "adam = 2 * weights"); + PASS(); +} + +// ===== Pipeline planning tests ===== + +static void test_max_layers_per_compile(void) { + TEST("max_layers_per_compile respects budget"); + CompileConfig cc = {.compile_budget = 119, .kernels_per_layer = 5, + .static_per_layer = 1, .headroom_pct = 0.10f}; + int max = max_layers_per_compile(&cc); + // usable = floor(119 * 0.9) = 107, per_layer = 6, max = 107/6 = 17 + ASSERT_EQ(max, 17, "max layers = 17 for budget=119, 6 kernels/layer, 10% headroom"); + PASS(); +} + +static void test_configurable_headroom(void) { + TEST("Configurable headroom changes max layers"); + CompileConfig cc5 = {.compile_budget = 119, .kernels_per_layer = 5, + .static_per_layer = 1, .headroom_pct = 0.05f}; + CompileConfig cc20 = {.compile_budget = 119, .kernels_per_layer = 5, + .static_per_layer = 1, .headroom_pct = 0.20f}; + int max5 = max_layers_per_compile(&cc5); // floor(119*0.95/6) = 18 + int max20 = max_layers_per_compile(&cc20); // floor(119*0.80/6) = 15 + ASSERT_TRUE(max5 > max20, "5% headroom fits more layers than 20%"); + ASSERT_EQ(max5, 18, "5% headroom: 18 layers"); + ASSERT_EQ(max20, 15, "20% headroom: 15 layers"); + PASS(); +} + +static void test_invalid_headroom_defaults(void) { + TEST("Invalid headroom falls back to 10%"); + CompileConfig cc_neg = {.compile_budget = 119, .kernels_per_layer = 5, + .static_per_layer = 1, .headroom_pct = -0.5f}; + CompileConfig cc_over = {.compile_budget = 119, .kernels_per_layer = 5, + .static_per_layer = 1, .headroom_pct = 1.5f}; + CompileConfig cc_def = {.compile_budget = 119, .kernels_per_layer = 5, + .static_per_layer = 1, .headroom_pct = 0.10f}; + ASSERT_EQ(max_layers_per_compile(&cc_neg), max_layers_per_compile(&cc_def), + "negative headroom -> default"); + ASSERT_EQ(max_layers_per_compile(&cc_over), max_layers_per_compile(&cc_def), + "headroom > 1.0 -> default"); + PASS(); +} + +static void test_plan_stories110m(void) { + TEST("Stories110M fits in 1 group"); + ModelConfig cfg = model_config_stories110m(); + PipelinePlan plan = compute_pipeline_plan(&cfg); + ASSERT_EQ(plan.n_groups, 1, "1 group"); + ASSERT_EQ(plan.groups[0].start_layer, 0, "starts at 0"); + ASSERT_EQ(plan.groups[0].end_layer, 12, "ends at 12"); + ASSERT_EQ(plan.groups[0].n_layers, 12, "12 layers"); + ASSERT_EQ(plan.groups[0].total_kernels, 72, "72 total kernels"); + pipeline_plan_free(&plan); + PASS(); +} + +static void test_plan_llama7b_multiple_groups(void) { + TEST("LLaMA-7B needs multiple groups"); + ModelConfig cfg = model_config_llama_7b(); + PipelinePlan plan = compute_pipeline_plan(&cfg); + ASSERT_TRUE(plan.n_groups >= 2, "at least 2 groups for 32 layers"); + // Verify all layers covered + int total_layers = 0; + for (int g = 0; g < plan.n_groups; g++) { + total_layers += plan.groups[g].n_layers; + ASSERT_TRUE(plan.groups[g].n_layers > 0, "no empty groups"); + } + ASSERT_EQ(total_layers, 32, "all 32 layers covered"); + // Verify contiguous + for (int g = 1; g < plan.n_groups; g++) { + ASSERT_EQ(plan.groups[g].start_layer, plan.groups[g-1].end_layer, "contiguous groups"); + } + pipeline_plan_free(&plan); + PASS(); +} + +static void test_plan_kernel_budget(void) { + TEST("No group exceeds compile budget"); + ModelConfig cfg = model_config_llama_7b(); + PipelinePlan plan = compute_pipeline_plan(&cfg); + int usable = (int)(cfg.compile.compile_budget * (1.0f - cfg.compile.headroom_pct)); + for (int g = 0; g < plan.n_groups; g++) { + ASSERT_TRUE(plan.groups[g].total_kernels <= usable, + "group kernel count <= usable budget"); + } + pipeline_plan_free(&plan); + PASS(); +} + +// ===== Gradient checkpoint tests ===== + +static void test_ckpt_all_saves_everything(void) { + TEST("CKPT_ALL saves all layers"); + ModelConfig cfg = model_config_stories110m(); + PipelinePlan plan = compute_pipeline_plan(&cfg); + CheckpointManager cm = checkpoint_init(CKPT_ALL, &cfg, &plan); + ASSERT_EQ(cm.n_checkpointed, 12, "12 layers saved"); + for (int i = 0; i < 12; i++) { + ASSERT_TRUE(checkpoint_should_save(&cm, i), "every layer saved"); + ASSERT_TRUE(!checkpoint_needs_recompute(&cm, i), "no recompute needed"); + } + ASSERT_TRUE(checkpoint_recompute_overhead(&cm) < 0.001, "zero overhead"); + checkpoint_free(&cm); + pipeline_plan_free(&plan); + PASS(); +} + +static void test_ckpt_none_saves_minimum(void) { + TEST("CKPT_NONE saves only layer 0"); + ModelConfig cfg = model_config_stories110m(); + PipelinePlan plan = compute_pipeline_plan(&cfg); + CheckpointManager cm = checkpoint_init(CKPT_NONE, &cfg, &plan); + ASSERT_EQ(cm.n_checkpointed, 1, "only 1 layer saved"); + ASSERT_TRUE(checkpoint_should_save(&cm, 0), "layer 0 saved"); + ASSERT_TRUE(checkpoint_needs_recompute(&cm, 5), "layer 5 needs recompute"); + checkpoint_free(&cm); + pipeline_plan_free(&plan); + PASS(); +} + +static void test_ckpt_sqrt_interval(void) { + TEST("CKPT_SQRT uses sqrt(N) interval"); + ModelConfig cfg = model_config_llama_7b(); + PipelinePlan plan = compute_pipeline_plan(&cfg); + CheckpointManager cm = checkpoint_init(CKPT_SQRT, &cfg, &plan); + int expected_interval = (int)sqrtf(32.0f); // 5 + ASSERT_EQ(cm.interval, expected_interval, "interval = sqrt(32) = 5"); + // Layer 0 always saved, then 5, 10, 15, 20, 25, 30, 31 + ASSERT_TRUE(checkpoint_should_save(&cm, 0), "layer 0 saved"); + ASSERT_TRUE(checkpoint_should_save(&cm, 5), "layer 5 saved"); + ASSERT_TRUE(!checkpoint_should_save(&cm, 3), "layer 3 not saved"); + ASSERT_TRUE(checkpoint_should_save(&cm, 31), "last layer saved"); + checkpoint_free(&cm); + pipeline_plan_free(&plan); + PASS(); +} + +static void test_ckpt_boundary(void) { + TEST("CKPT_BOUNDARY saves group edges"); + ModelConfig cfg = model_config_llama_7b(); + PipelinePlan plan = compute_pipeline_plan(&cfg); + CheckpointManager cm = checkpoint_init(CKPT_BOUNDARY, &cfg, &plan); + // First layer of each group + last layer overall + for (int g = 0; g < plan.n_groups; g++) { + ASSERT_TRUE(checkpoint_should_save(&cm, plan.groups[g].start_layer), + "group start layer saved"); + } + ASSERT_TRUE(checkpoint_should_save(&cm, 31), "last layer saved"); + // Middle of first group should not be saved + if (plan.groups[0].n_layers > 2) { + int mid = plan.groups[0].start_layer + plan.groups[0].n_layers / 2; + ASSERT_TRUE(checkpoint_needs_recompute(&cm, mid), "mid-group needs recompute"); + } + checkpoint_free(&cm); + pipeline_plan_free(&plan); + PASS(); +} + +static void test_ckpt_memory_savings(void) { + TEST("Checkpoint memory savings are positive for non-ALL policies"); + ModelConfig cfg = model_config_llama_7b(); + PipelinePlan plan = compute_pipeline_plan(&cfg); + + CheckpointManager cm_all = checkpoint_init(CKPT_ALL, &cfg, &plan); + CheckpointManager cm_sqrt = checkpoint_init(CKPT_SQRT, &cfg, &plan); + CheckpointManager cm_none = checkpoint_init(CKPT_NONE, &cfg, &plan); + + size_t saved_sqrt = checkpoint_memory_saved(&cm_sqrt, &cfg.dims); + size_t saved_none = checkpoint_memory_saved(&cm_none, &cfg.dims); + + ASSERT_TRUE(saved_sqrt > 0, "SQRT saves memory"); + ASSERT_TRUE(saved_none > saved_sqrt, "NONE saves more than SQRT"); + ASSERT_EQ(checkpoint_memory_saved(&cm_all, &cfg.dims), 0, "ALL saves nothing"); + + checkpoint_free(&cm_all); + checkpoint_free(&cm_sqrt); + checkpoint_free(&cm_none); + pipeline_plan_free(&plan); + PASS(); +} + +static void test_ckpt_recompute_depth(void) { + TEST("Recompute depth counts layers from nearest checkpoint"); + ModelConfig cfg = model_config_llama_7b(); + PipelinePlan plan = compute_pipeline_plan(&cfg); + CheckpointManager cm = checkpoint_init(CKPT_SQRT, &cfg, &plan); + // With interval=5: checkpoints at 0, 5, 10, 15, 20, 25, 30, 31 + // Layer 3: nearest saved before = 0, depth = 3 + ASSERT_EQ(checkpoint_recompute_depth(&cm, 3), 3, "depth from layer 0 to 3"); + // Layer 7: nearest saved before = 5, depth = 2 + ASSERT_EQ(checkpoint_recompute_depth(&cm, 7), 2, "depth from layer 5 to 7"); + // Layer 5: nearest saved = 5, depth = 0 + ASSERT_EQ(checkpoint_recompute_depth(&cm, 5), 0, "checkpointed layer = 0 depth"); + checkpoint_free(&cm); + pipeline_plan_free(&plan); + PASS(); +} + +static void test_ckpt_out_of_bounds(void) { + TEST("Checkpoint queries handle out-of-bounds gracefully"); + ModelConfig cfg = model_config_stories110m(); + PipelinePlan plan = compute_pipeline_plan(&cfg); + CheckpointManager cm = checkpoint_init(CKPT_ALL, &cfg, &plan); + ASSERT_TRUE(!checkpoint_should_save(&cm, -1), "negative index returns false"); + ASSERT_TRUE(!checkpoint_should_save(&cm, 100), "over-max index returns false"); + checkpoint_free(&cm); + pipeline_plan_free(&plan); + PASS(); +} + +// ===== FLOP estimation tests ===== + +static void test_flops_nonzero(void) { + TEST("FLOP estimates are nonzero and ANE < total"); + ModelConfig cfg = model_config_stories110m(); + double total = flops_per_step(&cfg); + double ane = ane_flops_per_step(&cfg); + ASSERT_TRUE(total > 0, "total FLOPs > 0"); + ASSERT_TRUE(ane > 0, "ANE FLOPs > 0"); + ASSERT_TRUE(ane < total, "ANE FLOPs < total (dW is on CPU)"); + PASS(); +} + +static void test_flops_scale_with_layers(void) { + TEST("FLOPs scale roughly linearly with layer count"); + ModelConfig cfg12 = model_config_stories110m(); + ModelConfig cfg8 = model_config_stories42m(); + double f12 = flops_per_step(&cfg12); + double f8 = flops_per_step(&cfg8); + // Not exact linear due to different dims, but 12-layer should be >8-layer + ASSERT_TRUE(f12 > f8, "12 layers > 8 layers"); + PASS(); +} + +// ===== Pipeline plan edge cases ===== + +static void test_plan_single_layer(void) { + TEST("Single-layer model = 1 group"); + ModelConfig cfg = model_config_stories110m(); + cfg.dims.n_layers = 1; + PipelinePlan plan = compute_pipeline_plan(&cfg); + ASSERT_EQ(plan.n_groups, 1, "1 group"); + ASSERT_EQ(plan.groups[0].n_layers, 1, "1 layer in group"); + pipeline_plan_free(&plan); + PASS(); +} + +static void test_plan_exact_budget_fit(void) { + TEST("Layers that exactly fill budget = 1 group"); + ModelConfig cfg = model_config_stories110m(); + // 17 layers * 6 kernels = 102 <= 107 usable (10% headroom on 119) + cfg.dims.n_layers = 17; + PipelinePlan plan = compute_pipeline_plan(&cfg); + ASSERT_EQ(plan.n_groups, 1, "17 layers fit in 1 group"); + pipeline_plan_free(&plan); + PASS(); +} + +static void test_plan_one_over_budget(void) { + TEST("One layer over budget = 2 groups"); + ModelConfig cfg = model_config_stories110m(); + // 18 layers * 6 kernels = 108 > 107 usable -> 2 groups + cfg.dims.n_layers = 18; + PipelinePlan plan = compute_pipeline_plan(&cfg); + ASSERT_EQ(plan.n_groups, 2, "18 layers = 2 groups"); + int total = plan.groups[0].n_layers + plan.groups[1].n_layers; + ASSERT_EQ(total, 18, "all layers covered"); + pipeline_plan_free(&plan); + PASS(); +} + +// ===== Main ===== + +int main(void) { + printf("=== Pipeline Unit Tests ===\n\n"); + + printf("[model_config.h]\n"); + test_dims_init(); + test_stories110m_preset(); + test_llama7b_preset(); + test_layer_memory_nonzero(); + test_adam_is_2x_weights(); + + printf("\n[pipeline planning]\n"); + test_max_layers_per_compile(); + test_configurable_headroom(); + test_invalid_headroom_defaults(); + test_plan_stories110m(); + test_plan_llama7b_multiple_groups(); + test_plan_kernel_budget(); + test_plan_single_layer(); + test_plan_exact_budget_fit(); + test_plan_one_over_budget(); + + printf("\n[gradient_checkpoint.h]\n"); + test_ckpt_all_saves_everything(); + test_ckpt_none_saves_minimum(); + test_ckpt_sqrt_interval(); + test_ckpt_boundary(); + test_ckpt_memory_savings(); + test_ckpt_recompute_depth(); + test_ckpt_out_of_bounds(); + + printf("\n[FLOP estimation]\n"); + test_flops_nonzero(); + test_flops_scale_with_layers(); + + printf("\n=== Results: %d/%d passed ===\n", tests_passed, tests_run); + return (tests_passed == tests_run) ? 0 : 1; +} + From 1ccf13e97510af128c3bc70b5b279fb5b4c945fe Mon Sep 17 00:00:00 2001 From: "codegen-sh[bot]" <131295404+codegen-sh[bot]@users.noreply.github.com> Date: Tue, 3 Mar 2026 06:54:28 +0000 Subject: [PATCH 3/3] Fix 6 review issues: checkpoint counting bug, headroom/memory consistency, safety guards Bug fix: n_checkpointed counting wrong in CKPT_BOUNDARY/SQRT/EVERY_N - Replaced per-policy arithmetic with single post-switch loop that counts actual is_saved bits. Eliminates edge-case miscounts when last layer falls on an interval boundary. Inconsistency: headroom mismatch between planner and runtime budget - budget_init() now takes CompileConfig* and uses the same headroom_pct validation as max_layers_per_compile(). Both paths yield identical usable-budget calculations. Inconsistency: total_model_bytes() omitted global gradients - Added rms_final_grad and embed_grad terms to match mmap_compute_size(). Diagnostic output now agrees with actual allocation. Design: divide-by-zero in model_dims_init() if n_heads=0 - Guarded head_dim = dim / n_heads with n_heads > 0 check. Design: no bounds checking in mmap typed accessors - All four mmap_layer_* accessors now validate layer index and return NULL on out-of-bounds. Extracted shared mmap_dims() helper to deduplicate ModelDims reconstruction. Design: CKPT_EVERY_N interval hardcoded despite caller should set - Added custom_interval parameter to checkpoint_init(). Pass 0 for default (4), or any positive int for custom spacing. Tests: 26/26 passing (3 new: custom interval, n_checkpointed accuracy, zero-heads guard). Co-authored-by: dermitchell1993 --- training/gradient_checkpoint.h | 23 +++++------ training/model_config.h | 6 ++- training/pipeline.h | 60 ++++++++++++++-------------- training/test_pipeline_unit.c | 71 +++++++++++++++++++++++++++++----- training/train_pipeline.m | 5 +-- 5 files changed, 105 insertions(+), 60 deletions(-) diff --git a/training/gradient_checkpoint.h b/training/gradient_checkpoint.h index 4065c04..29a5aa0 100644 --- a/training/gradient_checkpoint.h +++ b/training/gradient_checkpoint.h @@ -27,8 +27,9 @@ typedef struct { // ===== Initialization ===== +// custom_interval: used for CKPT_EVERY_N (pass 0 for default=4, ignored for other policies) static CheckpointManager checkpoint_init(CheckpointPolicy policy, const ModelConfig *cfg, - const PipelinePlan *plan) { + const PipelinePlan *plan, int custom_interval) { CheckpointManager cm = {0}; cm.policy = policy; cm.n_layers = cfg->dims.n_layers; @@ -38,47 +39,42 @@ static CheckpointManager checkpoint_init(CheckpointPolicy policy, const ModelCon switch (policy) { case CKPT_ALL: - // Save everything — no recompute needed for (int i = 0; i < cm.n_layers; i++) cm.is_saved[i] = true; - cm.n_checkpointed = cm.n_layers; break; case CKPT_BOUNDARY: - // Save only the input to each layer group for (int g = 0; g < plan->n_groups; g++) { cm.is_saved[plan->groups[g].start_layer] = true; } - // Always save the last layer's output (needed for loss backward) cm.is_saved[cm.n_layers - 1] = true; - cm.n_checkpointed = plan->n_groups + 1; break; case CKPT_SQRT: { - // Save every √N layers — optimal memory/compute balance int interval = (int)sqrtf((float)cm.n_layers); if (interval < 1) interval = 1; cm.interval = interval; for (int i = 0; i < cm.n_layers; i += interval) cm.is_saved[i] = true; cm.is_saved[cm.n_layers - 1] = true; - cm.n_checkpointed = (cm.n_layers + interval - 1) / interval; break; } case CKPT_EVERY_N: - // Caller should set cm.interval before using - cm.interval = 4; // default + cm.interval = (custom_interval > 0) ? custom_interval : 4; for (int i = 0; i < cm.n_layers; i += cm.interval) cm.is_saved[i] = true; cm.is_saved[cm.n_layers - 1] = true; - cm.n_checkpointed = (cm.n_layers + cm.interval - 1) / cm.interval; break; case CKPT_NONE: - // Save nothing except layer 0 input (needed as recompute starting point) cm.is_saved[0] = true; - cm.n_checkpointed = 1; break; } + // Count actual saved layers — single source of truth, no fragile arithmetic + cm.n_checkpointed = 0; + for (int i = 0; i < cm.n_layers; i++) { + if (cm.is_saved[i]) cm.n_checkpointed++; + } + return cm; } @@ -167,4 +163,3 @@ static void checkpoint_print(const CheckpointManager *cm, const ModelDims *d) { } printf("\n"); } - diff --git a/training/model_config.h b/training/model_config.h index ddd9229..fa8be1c 100644 --- a/training/model_config.h +++ b/training/model_config.h @@ -56,7 +56,7 @@ typedef struct { // ===== Derived dimension helpers ===== static void model_dims_init(ModelDims *d) { - d->head_dim = d->dim / d->n_heads; + d->head_dim = (d->n_heads > 0) ? d->dim / d->n_heads : 0; d->kv_dim = d->head_dim * d->n_kv_heads; d->score_ch = d->n_heads * d->seq_len; } @@ -111,7 +111,9 @@ static inline size_t total_model_bytes(const ModelConfig *cfg) { size_t global = (size_t)d->dim * sizeof(float) // rms_final + (size_t)d->vocab_size * d->dim * sizeof(float) // embed + (size_t)d->dim * 2 * sizeof(float) // rms_final adam - + (size_t)d->vocab_size * d->dim * 2 * sizeof(float); // embed adam + + (size_t)d->vocab_size * d->dim * 2 * sizeof(float) // embed adam + + (size_t)d->dim * sizeof(float) // rms_final grad + + (size_t)d->vocab_size * d->dim * sizeof(float); // embed grad return per_layer * d->n_layers + global; } diff --git a/training/pipeline.h b/training/pipeline.h index 798f6d6..d3ceb95 100644 --- a/training/pipeline.h +++ b/training/pipeline.h @@ -16,11 +16,13 @@ typedef struct { int headroom; // safety margin (budget * 0.1) } CompileBudget; -static CompileBudget budget_init(int max_compiles) { +static CompileBudget budget_init(const CompileConfig *cc) { CompileBudget b; - b.budget = max_compiles; + b.budget = cc->compile_budget; b.used = 0; - b.headroom = max_compiles / 10; + float pct = (cc->headroom_pct > 0.0f && cc->headroom_pct < 1.0f) + ? cc->headroom_pct : 0.10f; + b.headroom = (int)(cc->compile_budget * pct); return b; } @@ -85,7 +87,7 @@ static PipelineScheduler pipeline_scheduler_init(ModelConfig config, int total_s PipelineScheduler s = {0}; s.config = config; s.plan = compute_pipeline_plan(&config); - s.budget = budget_init(config.compile.compile_budget); + s.budget = budget_init(&config.compile); s.phase = PHASE_FORWARD; s.current_group = 0; s.current_step = 0; @@ -364,47 +366,43 @@ static void mmap_state_destroy(MmapState *ms) { // ===== Typed accessors into mmap regions ===== -// Get pointer to layer L's weights in mmap +// Reconstruct ModelDims from mmap header (avoids repeating in each accessor) +static inline ModelDims mmap_dims(const MmapState *ms) { + return (ModelDims){ + .dim = ms->header->dim, .hidden_dim = ms->header->hidden_dim, + .n_heads = ms->header->n_heads, .vocab_size = ms->header->vocab_size, + .seq_len = ms->header->seq_len + }; +} + +// Get pointer to layer L's weights in mmap (NULL if out of bounds) static float *mmap_layer_weights(MmapState *ms, int layer) { + if (!ms || layer < 0 || layer >= ms->header->n_layers) return NULL; + ModelDims d = mmap_dims(ms); return (float *)((char *)ms->base + ms->header->layer_weights_offset - + (size_t)layer * layer_weight_bytes(&(ModelDims){ - .dim = ms->header->dim, - .hidden_dim = ms->header->hidden_dim, - .n_heads = ms->header->n_heads, - .vocab_size = ms->header->vocab_size, - .seq_len = ms->header->seq_len - })); + + (size_t)layer * layer_weight_bytes(&d)); } -// Get pointer to layer L's adam state in mmap +// Get pointer to layer L's adam state in mmap (NULL if out of bounds) static float *mmap_layer_adam(MmapState *ms, int layer) { - ModelDims d = { - .dim = ms->header->dim, .hidden_dim = ms->header->hidden_dim, - .n_heads = ms->header->n_heads, .vocab_size = ms->header->vocab_size, - .seq_len = ms->header->seq_len - }; + if (!ms || layer < 0 || layer >= ms->header->n_layers) return NULL; + ModelDims d = mmap_dims(ms); return (float *)((char *)ms->base + ms->header->layer_adam_offset + (size_t)layer * layer_adam_bytes(&d)); } -// Get pointer to layer L's gradient accumulators in mmap +// Get pointer to layer L's gradient accumulators in mmap (NULL if out of bounds) static float *mmap_layer_grads(MmapState *ms, int layer) { - ModelDims d = { - .dim = ms->header->dim, .hidden_dim = ms->header->hidden_dim, - .n_heads = ms->header->n_heads, .vocab_size = ms->header->vocab_size, - .seq_len = ms->header->seq_len - }; + if (!ms || layer < 0 || layer >= ms->header->n_layers) return NULL; + ModelDims d = mmap_dims(ms); return (float *)((char *)ms->base + ms->header->layer_grads_offset + (size_t)layer * layer_gradient_bytes(&d)); } -// Get pointer to layer L's activation checkpoint in mmap +// Get pointer to layer L's activation checkpoint in mmap (NULL if out of bounds) static float *mmap_layer_acts(MmapState *ms, int layer) { - ModelDims d = { - .dim = ms->header->dim, .hidden_dim = ms->header->hidden_dim, - .n_heads = ms->header->n_heads, .vocab_size = ms->header->vocab_size, - .seq_len = ms->header->seq_len - }; + if (!ms || layer < 0 || layer >= ms->header->n_layers) return NULL; + ModelDims d = mmap_dims(ms); return (float *)((char *)ms->base + ms->header->layer_acts_offset + (size_t)layer * layer_activation_bytes(&d)); } @@ -438,7 +436,7 @@ static void pipeline_restore_from_mmap(PipelineScheduler *s, const MmapState *ms s->learning_rate = h->learning_rate; s->last_loss = h->last_loss; // Reset compile budget (new process after exec) - s->budget = budget_init(s->config.compile.compile_budget); + s->budget = budget_init(&s->config.compile); s->group_compiled = false; s->needs_restart = false; } diff --git a/training/test_pipeline_unit.c b/training/test_pipeline_unit.c index b641911..dcdd9e5 100644 --- a/training/test_pipeline_unit.c +++ b/training/test_pipeline_unit.c @@ -179,7 +179,7 @@ static void test_ckpt_all_saves_everything(void) { TEST("CKPT_ALL saves all layers"); ModelConfig cfg = model_config_stories110m(); PipelinePlan plan = compute_pipeline_plan(&cfg); - CheckpointManager cm = checkpoint_init(CKPT_ALL, &cfg, &plan); + CheckpointManager cm = checkpoint_init(CKPT_ALL, &cfg, &plan, 0); ASSERT_EQ(cm.n_checkpointed, 12, "12 layers saved"); for (int i = 0; i < 12; i++) { ASSERT_TRUE(checkpoint_should_save(&cm, i), "every layer saved"); @@ -195,7 +195,7 @@ static void test_ckpt_none_saves_minimum(void) { TEST("CKPT_NONE saves only layer 0"); ModelConfig cfg = model_config_stories110m(); PipelinePlan plan = compute_pipeline_plan(&cfg); - CheckpointManager cm = checkpoint_init(CKPT_NONE, &cfg, &plan); + CheckpointManager cm = checkpoint_init(CKPT_NONE, &cfg, &plan, 0); ASSERT_EQ(cm.n_checkpointed, 1, "only 1 layer saved"); ASSERT_TRUE(checkpoint_should_save(&cm, 0), "layer 0 saved"); ASSERT_TRUE(checkpoint_needs_recompute(&cm, 5), "layer 5 needs recompute"); @@ -208,7 +208,7 @@ static void test_ckpt_sqrt_interval(void) { TEST("CKPT_SQRT uses sqrt(N) interval"); ModelConfig cfg = model_config_llama_7b(); PipelinePlan plan = compute_pipeline_plan(&cfg); - CheckpointManager cm = checkpoint_init(CKPT_SQRT, &cfg, &plan); + CheckpointManager cm = checkpoint_init(CKPT_SQRT, &cfg, &plan, 0); int expected_interval = (int)sqrtf(32.0f); // 5 ASSERT_EQ(cm.interval, expected_interval, "interval = sqrt(32) = 5"); // Layer 0 always saved, then 5, 10, 15, 20, 25, 30, 31 @@ -225,7 +225,7 @@ static void test_ckpt_boundary(void) { TEST("CKPT_BOUNDARY saves group edges"); ModelConfig cfg = model_config_llama_7b(); PipelinePlan plan = compute_pipeline_plan(&cfg); - CheckpointManager cm = checkpoint_init(CKPT_BOUNDARY, &cfg, &plan); + CheckpointManager cm = checkpoint_init(CKPT_BOUNDARY, &cfg, &plan, 0); // First layer of each group + last layer overall for (int g = 0; g < plan.n_groups; g++) { ASSERT_TRUE(checkpoint_should_save(&cm, plan.groups[g].start_layer), @@ -247,9 +247,9 @@ static void test_ckpt_memory_savings(void) { ModelConfig cfg = model_config_llama_7b(); PipelinePlan plan = compute_pipeline_plan(&cfg); - CheckpointManager cm_all = checkpoint_init(CKPT_ALL, &cfg, &plan); - CheckpointManager cm_sqrt = checkpoint_init(CKPT_SQRT, &cfg, &plan); - CheckpointManager cm_none = checkpoint_init(CKPT_NONE, &cfg, &plan); + CheckpointManager cm_all = checkpoint_init(CKPT_ALL, &cfg, &plan, 0); + CheckpointManager cm_sqrt = checkpoint_init(CKPT_SQRT, &cfg, &plan, 0); + CheckpointManager cm_none = checkpoint_init(CKPT_NONE, &cfg, &plan, 0); size_t saved_sqrt = checkpoint_memory_saved(&cm_sqrt, &cfg.dims); size_t saved_none = checkpoint_memory_saved(&cm_none, &cfg.dims); @@ -269,7 +269,7 @@ static void test_ckpt_recompute_depth(void) { TEST("Recompute depth counts layers from nearest checkpoint"); ModelConfig cfg = model_config_llama_7b(); PipelinePlan plan = compute_pipeline_plan(&cfg); - CheckpointManager cm = checkpoint_init(CKPT_SQRT, &cfg, &plan); + CheckpointManager cm = checkpoint_init(CKPT_SQRT, &cfg, &plan, 0); // With interval=5: checkpoints at 0, 5, 10, 15, 20, 25, 30, 31 // Layer 3: nearest saved before = 0, depth = 3 ASSERT_EQ(checkpoint_recompute_depth(&cm, 3), 3, "depth from layer 0 to 3"); @@ -286,7 +286,7 @@ static void test_ckpt_out_of_bounds(void) { TEST("Checkpoint queries handle out-of-bounds gracefully"); ModelConfig cfg = model_config_stories110m(); PipelinePlan plan = compute_pipeline_plan(&cfg); - CheckpointManager cm = checkpoint_init(CKPT_ALL, &cfg, &plan); + CheckpointManager cm = checkpoint_init(CKPT_ALL, &cfg, &plan, 0); ASSERT_TRUE(!checkpoint_should_save(&cm, -1), "negative index returns false"); ASSERT_TRUE(!checkpoint_should_save(&cm, 100), "over-max index returns false"); checkpoint_free(&cm); @@ -294,6 +294,55 @@ static void test_ckpt_out_of_bounds(void) { PASS(); } +static void test_ckpt_every_n_custom_interval(void) { + TEST("CKPT_EVERY_N respects custom_interval parameter"); + ModelConfig cfg = model_config_llama_7b(); + PipelinePlan plan = compute_pipeline_plan(&cfg); + CheckpointManager cm3 = checkpoint_init(CKPT_EVERY_N, &cfg, &plan, 3); + CheckpointManager cm8 = checkpoint_init(CKPT_EVERY_N, &cfg, &plan, 8); + ASSERT_EQ(cm3.interval, 3, "interval=3 when custom_interval=3"); + ASSERT_EQ(cm8.interval, 8, "interval=8 when custom_interval=8"); + ASSERT_TRUE(cm3.n_checkpointed > cm8.n_checkpointed, + "shorter interval = more checkpoints"); + // Verify layer 0 and last layer always saved + ASSERT_TRUE(checkpoint_should_save(&cm3, 0), "layer 0 saved (interval=3)"); + ASSERT_TRUE(checkpoint_should_save(&cm3, 31), "last layer saved (interval=3)"); + ASSERT_TRUE(checkpoint_should_save(&cm8, 0), "layer 0 saved (interval=8)"); + ASSERT_TRUE(checkpoint_should_save(&cm8, 31), "last layer saved (interval=8)"); + checkpoint_free(&cm3); + checkpoint_free(&cm8); + pipeline_plan_free(&plan); + PASS(); +} + +static void test_ckpt_n_checkpointed_accuracy(void) { + TEST("n_checkpointed matches actual is_saved bit count"); + ModelConfig cfg = model_config_llama_7b(); + PipelinePlan plan = compute_pipeline_plan(&cfg); + CheckpointPolicy policies[] = {CKPT_ALL, CKPT_BOUNDARY, CKPT_SQRT, CKPT_EVERY_N, CKPT_NONE}; + for (int p = 0; p < 5; p++) { + CheckpointManager cm = checkpoint_init(policies[p], &cfg, &plan, 0); + int actual = 0; + for (int i = 0; i < cm.n_layers; i++) { + if (cm.is_saved[i]) actual++; + } + ASSERT_EQ(cm.n_checkpointed, actual, "n_checkpointed matches is_saved count"); + checkpoint_free(&cm); + } + pipeline_plan_free(&plan); + PASS(); +} + +static void test_dims_init_zero_heads(void) { + TEST("model_dims_init guards divide-by-zero on n_heads=0"); + ModelDims d = {.dim = 768, .n_heads = 0, .n_kv_heads = 0, .seq_len = 256}; + model_dims_init(&d); + ASSERT_EQ(d.head_dim, 0, "head_dim=0 when n_heads=0"); + ASSERT_EQ(d.kv_dim, 0, "kv_dim=0 when n_heads=0"); + ASSERT_EQ(d.score_ch, 0, "score_ch=0 when n_heads=0"); + PASS(); +} + // ===== FLOP estimation tests ===== static void test_flops_nonzero(void) { @@ -386,6 +435,9 @@ int main(void) { test_ckpt_memory_savings(); test_ckpt_recompute_depth(); test_ckpt_out_of_bounds(); + test_ckpt_every_n_custom_interval(); + test_ckpt_n_checkpointed_accuracy(); + test_dims_init_zero_heads(); printf("\n[FLOP estimation]\n"); test_flops_nonzero(); @@ -394,4 +446,3 @@ int main(void) { printf("\n=== Results: %d/%d passed ===\n", tests_passed, tests_run); return (tests_passed == tests_run) ? 0 : 1; } - diff --git a/training/train_pipeline.m b/training/train_pipeline.m index ef055b8..b0fc3a4 100644 --- a/training/train_pipeline.m +++ b/training/train_pipeline.m @@ -124,7 +124,7 @@ int main(int argc, char *argv[]) { printf("\n"); // Print checkpoint policy - CheckpointManager cm = checkpoint_init(ckpt_policy, &cfg, &plan); + CheckpointManager cm = checkpoint_init(ckpt_policy, &cfg, &plan, 0); checkpoint_print(&cm, &cfg.dims); printf("\n"); @@ -195,7 +195,7 @@ int main(int argc, char *argv[]) { printf(" [exec] Would restart process to reset compile budget\n"); printf(" Saving scheduler state to mmap, calling exec()\n"); // In dry-run, just reset the budget and continue - sched.budget = budget_init(cfg.compile.compile_budget); + sched.budget = budget_init(&cfg.compile); sched.needs_restart = false; break; @@ -255,4 +255,3 @@ int main(int argc, char *argv[]) { } return 0; } -