Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 46 additions & 36 deletions training/Makefile
Original file line number Diff line number Diff line change
@@ -1,36 +1,46 @@
CC = xcrun clang
CFLAGS = -O2 -Wall -Wno-deprecated-declarations -fobjc-arc
FRAMEWORKS = -framework Foundation -framework CoreML -framework IOSurface
LDFLAGS = $(FRAMEWORKS) -ldl

HEADERS_LARGE = stories_config.h stories_io.h stories_mil.h stories_cpu_ops.h

train: train.m ane_runtime.h ane_mil_gen.h model.h forward.h backward.h
$(CC) $(CFLAGS) -o $@ train.m $(LDFLAGS)

train_large: train_large.m $(HEADERS_LARGE)
$(CC) $(CFLAGS) -o $@ train_large.m $(LDFLAGS) -framework Accelerate

PROBES = test_weight_reload test_perf_stats test_qos_sweep test_ane_advanced

test_weight_reload: test_weight_reload.m
$(CC) $(CFLAGS) -o $@ $< $(LDFLAGS)

test_perf_stats: test_perf_stats.m
$(CC) $(CFLAGS) -o $@ $< $(LDFLAGS)

test_qos_sweep: test_qos_sweep.m
$(CC) $(CFLAGS) -o $@ $< $(LDFLAGS)

test_ane_advanced: test_ane_advanced.m
$(CC) $(CFLAGS) -o $@ $< $(LDFLAGS)

probes: $(PROBES)

tokenize:
python3 tokenize.py

clean:
rm -f train train_large $(PROBES)

.PHONY: clean tokenize probes
CC = xcrun clang
CFLAGS = -O2 -Wall -Wno-deprecated-declarations -fobjc-arc
FRAMEWORKS = -framework Foundation -framework CoreML -framework IOSurface
LDFLAGS = $(FRAMEWORKS) -ldl

HEADERS_LARGE = stories_config.h stories_io.h stories_mil.h stories_cpu_ops.h
HEADERS_PIPELINE = model_config.h pipeline.h gradient_checkpoint.h

train: train.m ane_runtime.h ane_mil_gen.h model.h forward.h backward.h
$(CC) $(CFLAGS) -o $@ train.m $(LDFLAGS)

train_large: train_large.m $(HEADERS_LARGE)
$(CC) $(CFLAGS) -o $@ train_large.m $(LDFLAGS) -framework Accelerate

train_pipeline: train_pipeline.m $(HEADERS_PIPELINE)
$(CC) $(CFLAGS) -o $@ train_pipeline.m $(LDFLAGS) -framework Accelerate

train_pipeline_live: train_pipeline.m $(HEADERS_PIPELINE) $(HEADERS_LARGE)
$(CC) $(CFLAGS) -DANE_LIVE -o train_pipeline train_pipeline.m $(LDFLAGS) -framework Accelerate

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition of the train_pipeline targets! This will make it easier to experiment with pipelining. Consider adding a brief comment in the Makefile explaining the difference between train_pipeline and train_pipeline_live.

PROBES = test_weight_reload test_perf_stats test_qos_sweep test_ane_advanced

test_weight_reload: test_weight_reload.m
$(CC) $(CFLAGS) -o $@ $< $(LDFLAGS)

test_perf_stats: test_perf_stats.m
$(CC) $(CFLAGS) -o $@ $< $(LDFLAGS)

test_qos_sweep: test_qos_sweep.m
$(CC) $(CFLAGS) -o $@ $< $(LDFLAGS)

test_ane_advanced: test_ane_advanced.m
$(CC) $(CFLAGS) -o $@ $< $(LDFLAGS)

probes: $(PROBES)

tokenize:
python3 tokenize.py

test_pipeline_unit: test_pipeline_unit.c $(HEADERS_PIPELINE)
cc -O2 -Wall -o $@ $< -lm

clean:
rm -f train train_large train_pipeline test_pipeline_unit $(PROBES)

.PHONY: clean tokenize probes
165 changes: 165 additions & 0 deletions training/gradient_checkpoint.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
// gradient_checkpoint.h — Activation checkpointing for deep models
// Trades compute for memory: recompute forward activations during backward
// instead of storing all layers' activations simultaneously
#pragma once
#include "model_config.h"

// ===== Checkpoint policies =====

typedef enum {
CKPT_ALL, // save all layers' activations (current behavior)
CKPT_BOUNDARY, // save only group boundary activations, recompute within group
CKPT_SQRT, // save every √N layers (optimal memory/compute tradeoff)
CKPT_EVERY_N, // save every N-th layer (configurable)
CKPT_NONE // save nothing, recompute everything (max memory savings)
} CheckpointPolicy;

typedef struct {
CheckpointPolicy policy;
int interval; // for CKPT_EVERY_N: save every N layers
int n_layers; // total layers in model
int n_groups; // layer groups in pipeline
int layers_per_group; // layers per group (from pipeline plan)
// Derived
int n_checkpointed; // how many layers have saved activations
bool *is_saved; // per-layer: true if activation is saved (not recomputed)
} CheckpointManager;

// ===== Initialization =====

// custom_interval: used for CKPT_EVERY_N (pass 0 for default=4, ignored for other policies)
static CheckpointManager checkpoint_init(CheckpointPolicy policy, const ModelConfig *cfg,
const PipelinePlan *plan, int custom_interval) {
CheckpointManager cm = {0};
cm.policy = policy;
cm.n_layers = cfg->dims.n_layers;
cm.n_groups = plan->n_groups;
cm.layers_per_group = (plan->n_groups > 0) ? plan->groups[0].n_layers : cfg->dims.n_layers;
cm.is_saved = (bool *)calloc(cfg->dims.n_layers, sizeof(bool));

switch (policy) {
case CKPT_ALL:
for (int i = 0; i < cm.n_layers; i++) cm.is_saved[i] = true;
break;

case CKPT_BOUNDARY:
for (int g = 0; g < plan->n_groups; g++) {
cm.is_saved[plan->groups[g].start_layer] = true;
}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solid implementation of checkpoint policies. The use of sqrt for interval in CKPT_SQRT is clever for balancing memory and compute. Consider adding a reference or comment explaining the optimality proof for this strategy.

cm.is_saved[cm.n_layers - 1] = true;
break;

case CKPT_SQRT: {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: n_checkpointed = plan->n_groups + 1 overcounts when the last layer IS a group start layer (e.g., n_layers=1). In that case is_saved[0] gets set twice but n_checkpointed = 2 when only 1 layer exists.

Fix: count the actual set bits, or guard with if (cm.n_layers - 1 != plan->groups[plan->n_groups - 1].start_layer) before adding 1.

int interval = (int)sqrtf((float)cm.n_layers);
if (interval < 1) interval = 1;
cm.interval = interval;
for (int i = 0; i < cm.n_layers; i += interval) cm.is_saved[i] = true;
cm.is_saved[cm.n_layers - 1] = true;
break;
}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: n_checkpointed is calculated as (n_layers + interval - 1) / interval which counts the loop iterations — but the is_saved[cm.n_layers - 1] = true on the next line may add one more if the last layer isn't on an interval boundary.

Example with 32 layers, interval=5: loop saves 0,5,10,15,20,25,30 → 7 layers. Then layer 31 is saved → 8 total. But (32+4)/5 = 7.

Same issue applies to CKPT_EVERY_N below.

Simplest fix: after the switch, count the actual saved layers:

cm.n_checkpointed = 0;
for (int i = 0; i < cm.n_layers; i++) 
    if (cm.is_saved[i]) cm.n_checkpointed++;

This would also fix the CKPT_BOUNDARY issue.


case CKPT_EVERY_N:
cm.interval = (custom_interval > 0) ? custom_interval : 4;
for (int i = 0; i < cm.n_layers; i += cm.interval) cm.is_saved[i] = true;
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says "Caller should set cm.interval before using" but the function immediately hardcodes cm.interval = 4 and uses it in the loop — so there's no way for the caller to set a custom interval. The API needs one of:

  1. Accept interval as a parameter to checkpoint_init
  2. Split into checkpoint_init (sets policy) + checkpoint_configure (sets interval and builds the is_saved array)
  3. Just document that 4 is the fixed interval for CKPT_EVERY_N and update the comment

cm.is_saved[cm.n_layers - 1] = true;
break;

case CKPT_NONE:
cm.is_saved[0] = true;
break;
}

// Count actual saved layers — single source of truth, no fragile arithmetic
cm.n_checkpointed = 0;
for (int i = 0; i < cm.n_layers; i++) {
if (cm.is_saved[i]) cm.n_checkpointed++;
}

return cm;
}

static void checkpoint_free(CheckpointManager *cm) {
free(cm->is_saved);
cm->is_saved = NULL;
}

// ===== Query functions =====

// Should we save this layer's activations during forward pass?
static bool checkpoint_should_save(const CheckpointManager *cm, int layer_idx) {
if (layer_idx < 0 || layer_idx >= cm->n_layers) return false;
return cm->is_saved[layer_idx];
}

// Does this layer need forward recompute during backward pass?
static bool checkpoint_needs_recompute(const CheckpointManager *cm, int layer_idx) {
return !checkpoint_should_save(cm, layer_idx);
}

// Find the nearest saved checkpoint before this layer (for recompute starting point)
static int checkpoint_nearest_saved_before(const CheckpointManager *cm, int layer_idx) {
for (int i = layer_idx; i >= 0; i--) {
if (cm->is_saved[i]) return i;
}
return 0; // fallback to first layer
}

// How many layers need recompute between the nearest checkpoint and this layer?
static int checkpoint_recompute_depth(const CheckpointManager *cm, int layer_idx) {
int saved = checkpoint_nearest_saved_before(cm, layer_idx);
return layer_idx - saved;
}

// ===== Memory estimation =====

// Memory for saved activations only (bytes)
static size_t checkpoint_saved_memory(const CheckpointManager *cm, const ModelDims *d) {
return (size_t)cm->n_checkpointed * layer_activation_bytes(d);
}

// Memory savings vs. saving all layers (bytes)
static size_t checkpoint_memory_saved(const CheckpointManager *cm, const ModelDims *d) {
size_t all = (size_t)cm->n_layers * layer_activation_bytes(d);
size_t used = checkpoint_saved_memory(cm, d);
return all - used;
}

// Extra forward FLOPs due to recompute (fraction of 1.0)
static double checkpoint_recompute_overhead(const CheckpointManager *cm) {
int recomputed = cm->n_layers - cm->n_checkpointed;
return (double)recomputed / (double)cm->n_layers;
}

// ===== Pretty-print =====

static const char *checkpoint_policy_name(CheckpointPolicy p) {
switch (p) {
case CKPT_ALL: return "ALL";
case CKPT_BOUNDARY: return "BOUNDARY";
case CKPT_SQRT: return "SQRT";
case CKPT_EVERY_N: return "EVERY_N";
case CKPT_NONE: return "NONE";
default: return "UNKNOWN";
}
}

static void checkpoint_print(const CheckpointManager *cm, const ModelDims *d) {
printf("=== Checkpoint Policy: %s ===\n", checkpoint_policy_name(cm->policy));
printf(" %d/%d layers checkpointed", cm->n_checkpointed, cm->n_layers);
if (cm->policy == CKPT_SQRT || cm->policy == CKPT_EVERY_N)
printf(" (interval=%d)", cm->interval);
printf("\n");
printf(" Activation memory: %.1fMB (saved) / %.1fMB (all)\n",
checkpoint_saved_memory(cm, d) / 1e6,
(double)cm->n_layers * layer_activation_bytes(d) / 1e6);
printf(" Memory savings: %.1fMB (%.0f%%)\n",
checkpoint_memory_saved(cm, d) / 1e6,
100.0 * checkpoint_memory_saved(cm, d) / ((double)cm->n_layers * layer_activation_bytes(d)));
printf(" Recompute overhead: %.0f%% extra forward FLOPs\n",
100.0 * checkpoint_recompute_overhead(cm));
printf(" Saved layers: ");
for (int i = 0; i < cm->n_layers; i++) {
if (cm->is_saved[i]) printf("%d ", i);
}
printf("\n");
}
Loading