-
Notifications
You must be signed in to change notification settings - Fork 0
Pipeline scaffolding for multi-group ANE training #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e0b1b27
32b5c72
1ccf13e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,36 +1,46 @@ | ||
| CC = xcrun clang | ||
| CFLAGS = -O2 -Wall -Wno-deprecated-declarations -fobjc-arc | ||
| FRAMEWORKS = -framework Foundation -framework CoreML -framework IOSurface | ||
| LDFLAGS = $(FRAMEWORKS) -ldl | ||
|
|
||
| HEADERS_LARGE = stories_config.h stories_io.h stories_mil.h stories_cpu_ops.h | ||
|
|
||
| train: train.m ane_runtime.h ane_mil_gen.h model.h forward.h backward.h | ||
| $(CC) $(CFLAGS) -o $@ train.m $(LDFLAGS) | ||
|
|
||
| train_large: train_large.m $(HEADERS_LARGE) | ||
| $(CC) $(CFLAGS) -o $@ train_large.m $(LDFLAGS) -framework Accelerate | ||
|
|
||
| PROBES = test_weight_reload test_perf_stats test_qos_sweep test_ane_advanced | ||
|
|
||
| test_weight_reload: test_weight_reload.m | ||
| $(CC) $(CFLAGS) -o $@ $< $(LDFLAGS) | ||
|
|
||
| test_perf_stats: test_perf_stats.m | ||
| $(CC) $(CFLAGS) -o $@ $< $(LDFLAGS) | ||
|
|
||
| test_qos_sweep: test_qos_sweep.m | ||
| $(CC) $(CFLAGS) -o $@ $< $(LDFLAGS) | ||
|
|
||
| test_ane_advanced: test_ane_advanced.m | ||
| $(CC) $(CFLAGS) -o $@ $< $(LDFLAGS) | ||
|
|
||
| probes: $(PROBES) | ||
|
|
||
| tokenize: | ||
| python3 tokenize.py | ||
|
|
||
| clean: | ||
| rm -f train train_large $(PROBES) | ||
|
|
||
| .PHONY: clean tokenize probes | ||
| CC = xcrun clang | ||
| CFLAGS = -O2 -Wall -Wno-deprecated-declarations -fobjc-arc | ||
| FRAMEWORKS = -framework Foundation -framework CoreML -framework IOSurface | ||
| LDFLAGS = $(FRAMEWORKS) -ldl | ||
|
|
||
| HEADERS_LARGE = stories_config.h stories_io.h stories_mil.h stories_cpu_ops.h | ||
| HEADERS_PIPELINE = model_config.h pipeline.h gradient_checkpoint.h | ||
|
|
||
| train: train.m ane_runtime.h ane_mil_gen.h model.h forward.h backward.h | ||
| $(CC) $(CFLAGS) -o $@ train.m $(LDFLAGS) | ||
|
|
||
| train_large: train_large.m $(HEADERS_LARGE) | ||
| $(CC) $(CFLAGS) -o $@ train_large.m $(LDFLAGS) -framework Accelerate | ||
|
|
||
| train_pipeline: train_pipeline.m $(HEADERS_PIPELINE) | ||
| $(CC) $(CFLAGS) -o $@ train_pipeline.m $(LDFLAGS) -framework Accelerate | ||
|
|
||
| train_pipeline_live: train_pipeline.m $(HEADERS_PIPELINE) $(HEADERS_LARGE) | ||
| $(CC) $(CFLAGS) -DANE_LIVE -o train_pipeline train_pipeline.m $(LDFLAGS) -framework Accelerate | ||
|
|
||
| PROBES = test_weight_reload test_perf_stats test_qos_sweep test_ane_advanced | ||
|
|
||
| test_weight_reload: test_weight_reload.m | ||
| $(CC) $(CFLAGS) -o $@ $< $(LDFLAGS) | ||
|
|
||
| test_perf_stats: test_perf_stats.m | ||
| $(CC) $(CFLAGS) -o $@ $< $(LDFLAGS) | ||
|
|
||
| test_qos_sweep: test_qos_sweep.m | ||
| $(CC) $(CFLAGS) -o $@ $< $(LDFLAGS) | ||
|
|
||
| test_ane_advanced: test_ane_advanced.m | ||
| $(CC) $(CFLAGS) -o $@ $< $(LDFLAGS) | ||
|
|
||
| probes: $(PROBES) | ||
|
|
||
| tokenize: | ||
| python3 tokenize.py | ||
|
|
||
| test_pipeline_unit: test_pipeline_unit.c $(HEADERS_PIPELINE) | ||
| cc -O2 -Wall -o $@ $< -lm | ||
|
|
||
| clean: | ||
| rm -f train train_large train_pipeline test_pipeline_unit $(PROBES) | ||
|
|
||
| .PHONY: clean tokenize probes | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,165 @@ | ||
| // gradient_checkpoint.h — Activation checkpointing for deep models | ||
| // Trades compute for memory: recompute forward activations during backward | ||
| // instead of storing all layers' activations simultaneously | ||
| #pragma once | ||
| #include "model_config.h" | ||
|
|
||
| // ===== Checkpoint policies ===== | ||
|
|
||
| typedef enum { | ||
| CKPT_ALL, // save all layers' activations (current behavior) | ||
| CKPT_BOUNDARY, // save only group boundary activations, recompute within group | ||
| CKPT_SQRT, // save every √N layers (optimal memory/compute tradeoff) | ||
| CKPT_EVERY_N, // save every N-th layer (configurable) | ||
| CKPT_NONE // save nothing, recompute everything (max memory savings) | ||
| } CheckpointPolicy; | ||
|
|
||
| typedef struct { | ||
| CheckpointPolicy policy; | ||
| int interval; // for CKPT_EVERY_N: save every N layers | ||
| int n_layers; // total layers in model | ||
| int n_groups; // layer groups in pipeline | ||
| int layers_per_group; // layers per group (from pipeline plan) | ||
| // Derived | ||
| int n_checkpointed; // how many layers have saved activations | ||
| bool *is_saved; // per-layer: true if activation is saved (not recomputed) | ||
| } CheckpointManager; | ||
|
|
||
| // ===== Initialization ===== | ||
|
|
||
| // custom_interval: used for CKPT_EVERY_N (pass 0 for default=4, ignored for other policies) | ||
| static CheckpointManager checkpoint_init(CheckpointPolicy policy, const ModelConfig *cfg, | ||
| const PipelinePlan *plan, int custom_interval) { | ||
| CheckpointManager cm = {0}; | ||
| cm.policy = policy; | ||
| cm.n_layers = cfg->dims.n_layers; | ||
| cm.n_groups = plan->n_groups; | ||
| cm.layers_per_group = (plan->n_groups > 0) ? plan->groups[0].n_layers : cfg->dims.n_layers; | ||
| cm.is_saved = (bool *)calloc(cfg->dims.n_layers, sizeof(bool)); | ||
|
|
||
| switch (policy) { | ||
| case CKPT_ALL: | ||
| for (int i = 0; i < cm.n_layers; i++) cm.is_saved[i] = true; | ||
| break; | ||
|
|
||
| case CKPT_BOUNDARY: | ||
| for (int g = 0; g < plan->n_groups; g++) { | ||
| cm.is_saved[plan->groups[g].start_layer] = true; | ||
| } | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Solid implementation of checkpoint policies. The use of sqrt for interval in CKPT_SQRT is clever for balancing memory and compute. Consider adding a reference or comment explaining the optimality proof for this strategy. |
||
| cm.is_saved[cm.n_layers - 1] = true; | ||
| break; | ||
|
|
||
| case CKPT_SQRT: { | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Fix: count the actual set bits, or guard with |
||
| int interval = (int)sqrtf((float)cm.n_layers); | ||
| if (interval < 1) interval = 1; | ||
| cm.interval = interval; | ||
| for (int i = 0; i < cm.n_layers; i += interval) cm.is_saved[i] = true; | ||
| cm.is_saved[cm.n_layers - 1] = true; | ||
| break; | ||
| } | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Example with 32 layers, interval=5: loop saves 0,5,10,15,20,25,30 → 7 layers. Then layer 31 is saved → 8 total. But Same issue applies to Simplest fix: after the switch, count the actual saved layers: cm.n_checkpointed = 0;
for (int i = 0; i < cm.n_layers; i++)
if (cm.is_saved[i]) cm.n_checkpointed++;This would also fix the CKPT_BOUNDARY issue. |
||
|
|
||
| case CKPT_EVERY_N: | ||
| cm.interval = (custom_interval > 0) ? custom_interval : 4; | ||
| for (int i = 0; i < cm.n_layers; i += cm.interval) cm.is_saved[i] = true; | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment says "Caller should set cm.interval before using" but the function immediately hardcodes
|
||
| cm.is_saved[cm.n_layers - 1] = true; | ||
| break; | ||
|
|
||
| case CKPT_NONE: | ||
| cm.is_saved[0] = true; | ||
| break; | ||
| } | ||
|
|
||
| // Count actual saved layers — single source of truth, no fragile arithmetic | ||
| cm.n_checkpointed = 0; | ||
| for (int i = 0; i < cm.n_layers; i++) { | ||
| if (cm.is_saved[i]) cm.n_checkpointed++; | ||
| } | ||
|
|
||
| return cm; | ||
| } | ||
|
|
||
| static void checkpoint_free(CheckpointManager *cm) { | ||
| free(cm->is_saved); | ||
| cm->is_saved = NULL; | ||
| } | ||
|
|
||
| // ===== Query functions ===== | ||
|
|
||
| // Should we save this layer's activations during forward pass? | ||
| static bool checkpoint_should_save(const CheckpointManager *cm, int layer_idx) { | ||
| if (layer_idx < 0 || layer_idx >= cm->n_layers) return false; | ||
| return cm->is_saved[layer_idx]; | ||
| } | ||
|
|
||
| // Does this layer need forward recompute during backward pass? | ||
| static bool checkpoint_needs_recompute(const CheckpointManager *cm, int layer_idx) { | ||
| return !checkpoint_should_save(cm, layer_idx); | ||
| } | ||
|
|
||
| // Find the nearest saved checkpoint before this layer (for recompute starting point) | ||
| static int checkpoint_nearest_saved_before(const CheckpointManager *cm, int layer_idx) { | ||
| for (int i = layer_idx; i >= 0; i--) { | ||
| if (cm->is_saved[i]) return i; | ||
| } | ||
| return 0; // fallback to first layer | ||
| } | ||
|
|
||
| // How many layers need recompute between the nearest checkpoint and this layer? | ||
| static int checkpoint_recompute_depth(const CheckpointManager *cm, int layer_idx) { | ||
| int saved = checkpoint_nearest_saved_before(cm, layer_idx); | ||
| return layer_idx - saved; | ||
| } | ||
|
|
||
| // ===== Memory estimation ===== | ||
|
|
||
| // Memory for saved activations only (bytes) | ||
| static size_t checkpoint_saved_memory(const CheckpointManager *cm, const ModelDims *d) { | ||
| return (size_t)cm->n_checkpointed * layer_activation_bytes(d); | ||
| } | ||
|
|
||
| // Memory savings vs. saving all layers (bytes) | ||
| static size_t checkpoint_memory_saved(const CheckpointManager *cm, const ModelDims *d) { | ||
| size_t all = (size_t)cm->n_layers * layer_activation_bytes(d); | ||
| size_t used = checkpoint_saved_memory(cm, d); | ||
| return all - used; | ||
| } | ||
|
|
||
| // Extra forward FLOPs due to recompute (fraction of 1.0) | ||
| static double checkpoint_recompute_overhead(const CheckpointManager *cm) { | ||
| int recomputed = cm->n_layers - cm->n_checkpointed; | ||
| return (double)recomputed / (double)cm->n_layers; | ||
| } | ||
|
|
||
| // ===== Pretty-print ===== | ||
|
|
||
| static const char *checkpoint_policy_name(CheckpointPolicy p) { | ||
| switch (p) { | ||
| case CKPT_ALL: return "ALL"; | ||
| case CKPT_BOUNDARY: return "BOUNDARY"; | ||
| case CKPT_SQRT: return "SQRT"; | ||
| case CKPT_EVERY_N: return "EVERY_N"; | ||
| case CKPT_NONE: return "NONE"; | ||
| default: return "UNKNOWN"; | ||
| } | ||
| } | ||
|
|
||
| static void checkpoint_print(const CheckpointManager *cm, const ModelDims *d) { | ||
| printf("=== Checkpoint Policy: %s ===\n", checkpoint_policy_name(cm->policy)); | ||
| printf(" %d/%d layers checkpointed", cm->n_checkpointed, cm->n_layers); | ||
| if (cm->policy == CKPT_SQRT || cm->policy == CKPT_EVERY_N) | ||
| printf(" (interval=%d)", cm->interval); | ||
| printf("\n"); | ||
| printf(" Activation memory: %.1fMB (saved) / %.1fMB (all)\n", | ||
| checkpoint_saved_memory(cm, d) / 1e6, | ||
| (double)cm->n_layers * layer_activation_bytes(d) / 1e6); | ||
| printf(" Memory savings: %.1fMB (%.0f%%)\n", | ||
| checkpoint_memory_saved(cm, d) / 1e6, | ||
| 100.0 * checkpoint_memory_saved(cm, d) / ((double)cm->n_layers * layer_activation_bytes(d))); | ||
| printf(" Recompute overhead: %.0f%% extra forward FLOPs\n", | ||
| 100.0 * checkpoint_recompute_overhead(cm)); | ||
| printf(" Saved layers: "); | ||
| for (int i = 0; i < cm->n_layers; i++) { | ||
| if (cm->is_saved[i]) printf("%d ", i); | ||
| } | ||
| printf("\n"); | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice addition of the train_pipeline targets! This will make it easier to experiment with pipelining. Consider adding a brief comment in the Makefile explaining the difference between train_pipeline and train_pipeline_live.