From 752a3be81a4b152e6ca505f119f3a881a14a34ce Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 3 Mar 2026 00:54:46 +0000 Subject: [PATCH 01/21] Add Project Scope & Intent notice to README Weave in scope notice near the top covering project intent, what it is/isn't, hype clarification, maintenance expectations, and fork encouragement. Consolidate private API disclaimer with existing disclaimer section to avoid duplication. https://claude.ai/code/session_01NNL4MVEY1aKp19eGHTYJUv --- README.md | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index d2c7bb2..ce3df1f 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,56 @@ Training neural networks directly on Apple's Neural Engine (ANE) via reverse-engineered private APIs. No CoreML training APIs, no Metal, no GPU — pure ANE compute. +## Project Scope & Intent + +I'm genuinely grateful for all the attention this project has received — I never expected a weekend research hack to blow up like this. Thank you to everyone who starred, forked, ran benchmarks on their own hardware, and shared the work. It means a lot. + +That said, I want to set clear expectations about what this project is and isn't. + +This is a **research project**, not a production framework. + +The goal was to demonstrate that **training on the Apple Neural Engine — and potentially other NPUs — is possible**, and that the barrier has always been software support, not hardware capability. The ANE is a remarkably capable piece of silicon that Apple restricts to inference-only use through CoreML. This project bypasses that restriction using reverse-engineered private APIs to show what's possible when you give the hardware a chance. + +### What this project is + +- A proof of concept for ANE training via `_ANEClient` and `_ANECompiler` private APIs +- A set of benchmarks documenting real ANE performance characteristics (throughput, power, SRAM behavior) +- A reference for anyone exploring direct ANE access outside CoreML +- Research code that I update when I find something interesting + +### What this project is not + +- A maintained framework or library +- A replacement for CoreML, MLX, llama.cpp, or any production inference stack +- A path to training large models on consumer hardware (yet) + +### On the hype + +Some coverage of this project has overstated its implications. To be clear: + +- Training works, but utilization is low (~2-3% of peak) with significant engineering challenges remaining +- Many element-wise operations still fall back to CPU +- This does **not** replace GPU training for anything beyond small research models today + +The honest results — including all limitations — are documented in the accompanying articles: +- [Part 1: Reverse Engineering](https://maderix.substack.com/p/inside-the-m4-apple-neural-engine) +- [Part 2: Benchmarks](https://maderix.substack.com/p/inside-the-m4-apple-neural-engine-615) + +### On maintenance + +I don't intend to grow this into a large community project. My focus is on original research (compiler infrastructure for edge AI optimization), and maintaining an open-source framework takes time away from that. + +That said: +- I'll keep pushing updates when I discover something interesting +- Bug fixes and benchmark contributions (especially on hardware I don't own) are welcome +- Feature requests will likely go unaddressed — but feel free to fork + +### Fork it, build on it + +This is MIT licensed for a reason. Everyone now has access to AI-assisted development tools that can adapt and extend code in hours. If this project is useful to you — take it, modify it, build something better. If you do something cool with it, I'd love to hear about it. + +--- + ## What This Is A from-scratch implementation of transformer training (forward + backward pass) running on the ANE in Apple Silicon. The ANE is a 15.8 TFLOPS (M4) inference accelerator that Apple does not expose for training. This project reverse-engineers the `_ANEClient` / `_ANECompiler` private APIs and the MIL (Model Intermediate Language) format to run custom compute graphs — including backpropagation — directly on ANE hardware. @@ -104,8 +154,12 @@ No external dependencies. Uses only system frameworks + private ANE APIs resolve ## Disclaimer -This project is independent research into Apple Neural Engine architecture. It uses undocumented APIs discovered through runtime introspection for research and educational purposes under fair use and interoperability provisions (see *Sega v. Accolade*, 1992; DMCA §1201(f)). No Apple proprietary code or binaries are included in this repository. This project is not affiliated with or endorsed by Apple Inc. Use at your own risk. +This project uses Apple's private, undocumented APIs (`_ANEClient`, `_ANECompiler`, `_ANEInMemoryModelDescriptor`). These APIs are not covered by any public stability guarantee and may change or break with any macOS update. This is independent research into Apple Neural Engine architecture, using APIs discovered through runtime introspection for research and educational purposes under fair use and interoperability provisions (see *Sega v. Accolade*, 1992; DMCA §1201(f)). No Apple proprietary code or binaries are included in this repository. This project is not affiliated with or endorsed by Apple Inc. Use at your own risk. ## License MIT — see [LICENSE](LICENSE) + +--- + +*Built by a human + Claude, one weekend at a time.* From 2b3b7ae5ccf072774b9b8f5a2036b89fed75aa39 Mon Sep 17 00:00:00 2001 From: tastyheadphones Date: Tue, 3 Mar 2026 11:42:42 +0900 Subject: [PATCH 02/21] Fix token sampling underflow on short datasets --- training/train_large.m | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/training/train_large.m b/training/train_large.m index e58ce08..e33f2eb 100644 --- a/training/train_large.m +++ b/training/train_large.m @@ -274,11 +274,17 @@ int main(int argc, char *argv[]) { int data_fd = open(DATA_PATH, O_RDONLY); if (data_fd < 0) { printf("Cannot open %s\n", DATA_PATH); return 1; } struct stat st; fstat(data_fd, &st); - size_t data_len = st.st_size; - uint16_t *token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0); - if (token_data == MAP_FAILED) { printf("mmap failed\n"); return 1; } - size_t n_tokens = data_len / 2; - printf("Token data: %zu tokens (%.1f MB)\n", n_tokens, data_len/1e6); + size_t data_len = st.st_size; + uint16_t *token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0); + if (token_data == MAP_FAILED) { printf("mmap failed\n"); return 1; } + size_t n_tokens = data_len / 2; + if (n_tokens <= (size_t)(SEQ + 1)) { + printf("Token data too short: need at least %d tokens, got %zu\n", SEQ + 2, n_tokens); + munmap(token_data, data_len); + close(data_fd); + return 1; + } + printf("Token data: %zu tokens (%.1f MB)\n", n_tokens, data_len/1e6); // Gradient buffers shared across layers (reused each step) float *dy = (float*)malloc(SEQ*DIM*4); // gradient flowing backward From ebac5dd73f9f3f9f59df6dc3a735f3ba1171c095 Mon Sep 17 00:00:00 2001 From: Vipul Date: Tue, 3 Mar 2026 02:04:36 -0500 Subject: [PATCH 03/21] Python Bridge+Memory leak fix+More functions --- bridge/Makefile | 17 + bridge/ane_bridge.h | 87 +++++ bridge/ane_bridge.m | 328 +++++++++++++++++ bridge/libane_bridge.dylib | Bin 0 -> 54480 bytes training/Makefile | 14 +- training/README.md | 64 +++- training/ane_classifier.h | 102 ++++++ training/ane_rmsnorm_bwd.h | 78 ++++ training/download_data.sh | 91 +++++ training/test_classifier.m | 255 +++++++++++++ training/test_rmsnorm_bwd.m | 123 +++++++ training/train_large_ane.m | 695 ++++++++++++++++++++++++++++++++++++ 12 files changed, 1847 insertions(+), 7 deletions(-) create mode 100644 bridge/Makefile create mode 100644 bridge/ane_bridge.h create mode 100644 bridge/ane_bridge.m create mode 100755 bridge/libane_bridge.dylib create mode 100644 training/ane_classifier.h create mode 100644 training/ane_rmsnorm_bwd.h create mode 100755 training/download_data.sh create mode 100644 training/test_classifier.m create mode 100644 training/test_rmsnorm_bwd.m create mode 100644 training/train_large_ane.m diff --git a/bridge/Makefile b/bridge/Makefile new file mode 100644 index 0000000..753d749 --- /dev/null +++ b/bridge/Makefile @@ -0,0 +1,17 @@ +CC = xcrun clang +CFLAGS = -O2 -Wall -Wno-deprecated-declarations -fobjc-arc -fPIC +FRAMEWORKS = -framework Foundation -framework IOSurface -ldl +TARGET = libane_bridge.dylib + +all: $(TARGET) + +$(TARGET): ane_bridge.m ane_bridge.h + $(CC) $(CFLAGS) -dynamiclib -o $@ ane_bridge.m $(FRAMEWORKS) + +test: test_bridge.m ane_bridge.h $(TARGET) + $(CC) $(CFLAGS) -o test_bridge test_bridge.m -L. -lane_bridge $(FRAMEWORKS) + +clean: + rm -f $(TARGET) test_bridge + +.PHONY: all clean test diff --git a/bridge/ane_bridge.h b/bridge/ane_bridge.h new file mode 100644 index 0000000..3e8ff47 --- /dev/null +++ b/bridge/ane_bridge.h @@ -0,0 +1,87 @@ +// ane_bridge.h — C-callable bridge to ANE private APIs for Python ctypes +// Wraps _ANEInMemoryModel via private AppleNeuralEngine.framework + +#ifndef ANE_BRIDGE_H +#define ANE_BRIDGE_H + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// Opaque kernel handle +typedef struct ANEKernelHandle ANEKernelHandle; + +// Initialize ANE runtime (load private framework, resolve classes) +// Returns 0 on success, -1 on failure +int ane_bridge_init(void); + +// Compile a MIL program with weight blobs into an ANE kernel +// mil_text: UTF-8 MIL program text +// mil_len: length of MIL text +// weight_data: raw weight blob (can be NULL) +// weight_len: length of weight blob +// n_inputs: number of input tensors +// input_sizes: array of byte sizes for each input +// n_outputs: number of output tensors +// output_sizes: array of byte sizes for each output +// Returns kernel handle or NULL on failure +ANEKernelHandle *ane_bridge_compile(const char *mil_text, size_t mil_len, + const uint8_t *weight_data, size_t weight_len, + int n_inputs, const size_t *input_sizes, + int n_outputs, const size_t *output_sizes); + +// Compile with multiple named weight files (for transformer kernels) +// weight_names: array of weight file paths (e.g. "@model_path/weights/wq.bin") +// weight_datas: array of weight data pointers +// weight_lens: array of weight data lengths +// n_weights: number of weight files +ANEKernelHandle *ane_bridge_compile_multi_weights( + const char *mil_text, size_t mil_len, + const char **weight_names, const uint8_t **weight_datas, + const size_t *weight_lens, int n_weights, + int n_inputs, const size_t *input_sizes, + int n_outputs, const size_t *output_sizes); + +// Evaluate (run) a compiled kernel on ANE +// Returns true on success +bool ane_bridge_eval(ANEKernelHandle *kernel); + +// Write data to kernel input tensor +void ane_bridge_write_input(ANEKernelHandle *kernel, int idx, + const void *data, size_t bytes); + +// Read data from kernel output tensor +void ane_bridge_read_output(ANEKernelHandle *kernel, int idx, + void *data, size_t bytes); + +// Free a compiled kernel and all associated resources +void ane_bridge_free(ANEKernelHandle *kernel); + +// Get compile count (for exec() restart budgeting) +int ane_bridge_get_compile_count(void); + +// Reset compile count +void ane_bridge_reset_compile_count(void); + +// Build a weight blob in ANE format (128-byte header + fp16 data) +// src: float32 weights [rows x cols] +// Returns allocated buffer and sets out_len. Caller must free(). +uint8_t *ane_bridge_build_weight_blob(const float *src, int rows, int cols, + size_t *out_len); + +// Build a transposed weight blob in ANE format +uint8_t *ane_bridge_build_weight_blob_transposed(const float *src, int rows, int cols, + size_t *out_len); + +// Free a blob allocated by ane_bridge_build_weight_blob* +void ane_bridge_free_blob(void *ptr); + +#ifdef __cplusplus +} +#endif + +#endif // ANE_BRIDGE_H diff --git a/bridge/ane_bridge.m b/bridge/ane_bridge.m new file mode 100644 index 0000000..2b27ddc --- /dev/null +++ b/bridge/ane_bridge.m @@ -0,0 +1,328 @@ +// ane_bridge.m — Objective-C implementation of ANE bridge for Python ctypes +// Wraps _ANEInMemoryModel private APIs into C-callable functions + +#import +#import +#import +#import +#import +#include "ane_bridge.h" + +// --- Private class references --- +static Class g_ANEDesc = nil; +static Class g_ANEInMem = nil; +static Class g_ANEReq = nil; +static Class g_ANEIO = nil; +static bool g_initialized = false; +static int g_compile_count = 0; + +// --- Kernel handle struct --- +struct ANEKernelHandle { + id model; // _ANEInMemoryModel + IOSurfaceRef *ioInputs; + IOSurfaceRef *ioOutputs; + id request; // _ANERequest + NSString *tmpDir; + int nInputs, nOutputs; + size_t *inputBytes; + size_t *outputBytes; +}; + +// --- Public API --- + +int ane_bridge_init(void) { + if (g_initialized) return 0; + + void *handle = dlopen( + "/System/Library/PrivateFrameworks/AppleNeuralEngine.framework/AppleNeuralEngine", + RTLD_NOW); + if (!handle) { + fprintf(stderr, "ane_bridge: Failed to load AppleNeuralEngine.framework\n"); + return -1; + } + + g_ANEDesc = NSClassFromString(@"_ANEInMemoryModelDescriptor"); + g_ANEInMem = NSClassFromString(@"_ANEInMemoryModel"); + g_ANEReq = NSClassFromString(@"_ANERequest"); + g_ANEIO = NSClassFromString(@"_ANEIOSurfaceObject"); + + if (!g_ANEDesc || !g_ANEInMem || !g_ANEReq || !g_ANEIO) { + fprintf(stderr, "ane_bridge: Failed to resolve ANE private classes\n"); + return -1; + } + + g_initialized = true; + g_compile_count = 0; + return 0; +} + +static IOSurfaceRef create_surface(size_t bytes) { + return IOSurfaceCreate((__bridge CFDictionaryRef)@{ + (id)kIOSurfaceWidth: @(bytes), + (id)kIOSurfaceHeight: @1, + (id)kIOSurfaceBytesPerElement: @1, + (id)kIOSurfaceBytesPerRow: @(bytes), + (id)kIOSurfaceAllocSize: @(bytes), + (id)kIOSurfacePixelFormat: @0 + }); +} + +ANEKernelHandle *ane_bridge_compile_multi_weights( + const char *mil_text, size_t mil_len, + const char **weight_names, const uint8_t **weight_datas, + const size_t *weight_lens, int n_weights, + int n_inputs, const size_t *input_sizes, + int n_outputs, const size_t *output_sizes) +{ + @autoreleasepool { + if (!g_initialized) { + fprintf(stderr, "ane_bridge: Not initialized\n"); + return NULL; + } + + NSData *milData = [NSData dataWithBytes:mil_text length:mil_len]; + NSError *e = nil; + + // Build weight dictionary + NSMutableDictionary *wdict = [NSMutableDictionary dictionary]; + for (int i = 0; i < n_weights; i++) { + NSString *name = [NSString stringWithUTF8String:weight_names[i]]; + NSData *data = [NSData dataWithBytes:weight_datas[i] length:weight_lens[i]]; + wdict[name] = @{@"offset": @0, @"data": data}; + } + + id desc = ((id(*)(Class,SEL,id,id,id))objc_msgSend)( + g_ANEDesc, @selector(modelWithMILText:weights:optionsPlist:), + milData, wdict.count > 0 ? wdict : nil, nil); + if (!desc) { + fprintf(stderr, "ane_bridge: modelWithMILText failed\n"); + return NULL; + } + + id mdl = ((id(*)(Class,SEL,id))objc_msgSend)( + g_ANEInMem, @selector(inMemoryModelWithDescriptor:), desc); + if (!mdl) { + fprintf(stderr, "ane_bridge: inMemoryModelWithDescriptor failed\n"); + return NULL; + } + + // Pre-populate temp dir + id hx = ((id(*)(id,SEL))objc_msgSend)(mdl, @selector(hexStringIdentifier)); + NSString *td = [NSTemporaryDirectory() stringByAppendingPathComponent:hx]; + NSFileManager *fm = [NSFileManager defaultManager]; + [fm createDirectoryAtPath:[td stringByAppendingPathComponent:@"weights"] + withIntermediateDirectories:YES attributes:nil error:nil]; + [milData writeToFile:[td stringByAppendingPathComponent:@"model.mil"] atomically:YES]; + + for (int i = 0; i < n_weights; i++) { + NSString *name = [NSString stringWithUTF8String:weight_names[i]]; + // Extract filename from path like "@model_path/weights/wq.bin" -> "weights/wq.bin" + NSString *relPath = name; + if ([name hasPrefix:@"@model_path/"]) { + relPath = [name substringFromIndex:12]; + } + NSString *fullPath = [td stringByAppendingPathComponent:relPath]; + NSString *dir = [fullPath stringByDeletingLastPathComponent]; + [fm createDirectoryAtPath:dir withIntermediateDirectories:YES attributes:nil error:nil]; + NSData *data = [NSData dataWithBytes:weight_datas[i] length:weight_lens[i]]; + [data writeToFile:fullPath atomically:YES]; + } + + // Compile + if (!((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)( + mdl, @selector(compileWithQoS:options:error:), 21, @{}, &e)) { + fprintf(stderr, "ane_bridge: ANE compile failed: %s\n", + e ? [[e description] UTF8String] : "unknown"); + [fm removeItemAtPath:td error:nil]; + return NULL; + } + + // Load (with one retry after a brief pause for ANE slot reclamation) + BOOL loaded = ((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)( + mdl, @selector(loadWithQoS:options:error:), 21, @{}, &e); + if (!loaded) { + fprintf(stderr, "ane_bridge: ANE load failed (retrying in 100ms): %s\n", + e ? [[e description] UTF8String] : "unknown"); + usleep(100000); // 100ms + e = nil; + loaded = ((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)( + mdl, @selector(loadWithQoS:options:error:), 21, @{}, &e); + } + if (!loaded) { + fprintf(stderr, "ane_bridge: ANE load failed after retry: %s\n", + e ? [[e description] UTF8String] : "unknown"); + [fm removeItemAtPath:td error:nil]; + return NULL; + } + + g_compile_count++; + + // Create kernel handle + ANEKernelHandle *k = (ANEKernelHandle *)calloc(1, sizeof(ANEKernelHandle)); + k->model = mdl; + k->tmpDir = td; + k->nInputs = n_inputs; + k->nOutputs = n_outputs; + k->inputBytes = (size_t *)malloc(n_inputs * sizeof(size_t)); + k->outputBytes = (size_t *)malloc(n_outputs * sizeof(size_t)); + memcpy(k->inputBytes, input_sizes, n_inputs * sizeof(size_t)); + memcpy(k->outputBytes, output_sizes, n_outputs * sizeof(size_t)); + + // Create IOSurfaces + k->ioInputs = (IOSurfaceRef *)malloc(n_inputs * sizeof(IOSurfaceRef)); + k->ioOutputs = (IOSurfaceRef *)malloc(n_outputs * sizeof(IOSurfaceRef)); + for (int i = 0; i < n_inputs; i++) + k->ioInputs[i] = create_surface(input_sizes[i]); + for (int i = 0; i < n_outputs; i++) + k->ioOutputs[i] = create_surface(output_sizes[i]); + + // Build request + NSMutableArray *wIns = [NSMutableArray arrayWithCapacity:n_inputs]; + NSMutableArray *iIdx = [NSMutableArray arrayWithCapacity:n_inputs]; + for (int i = 0; i < n_inputs; i++) { + [wIns addObject:((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)( + g_ANEIO, @selector(objectWithIOSurface:), k->ioInputs[i])]; + [iIdx addObject:@(i)]; + } + NSMutableArray *wOuts = [NSMutableArray arrayWithCapacity:n_outputs]; + NSMutableArray *oIdx = [NSMutableArray arrayWithCapacity:n_outputs]; + for (int i = 0; i < n_outputs; i++) { + [wOuts addObject:((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)( + g_ANEIO, @selector(objectWithIOSurface:), k->ioOutputs[i])]; + [oIdx addObject:@(i)]; + } + k->request = ((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)( + g_ANEReq, + @selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:), + wIns, iIdx, wOuts, oIdx, nil, nil, @0); + + return k; + } +} + +ANEKernelHandle *ane_bridge_compile(const char *mil_text, size_t mil_len, + const uint8_t *weight_data, size_t weight_len, + int n_inputs, const size_t *input_sizes, + int n_outputs, const size_t *output_sizes) { + if (weight_data && weight_len > 0) { + const char *name = "@model_path/weights/weight.bin"; + return ane_bridge_compile_multi_weights( + mil_text, mil_len, + &name, &weight_data, &weight_len, 1, + n_inputs, input_sizes, + n_outputs, output_sizes); + } else { + return ane_bridge_compile_multi_weights( + mil_text, mil_len, + NULL, NULL, NULL, 0, + n_inputs, input_sizes, + n_outputs, output_sizes); + } +} + +bool ane_bridge_eval(ANEKernelHandle *kernel) { + @autoreleasepool { + if (!kernel || !kernel->model) return false; + NSError *e = nil; + return ((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)( + kernel->model, @selector(evaluateWithQoS:options:request:error:), + 21, @{}, kernel->request, &e); + } +} + +void ane_bridge_write_input(ANEKernelHandle *kernel, int idx, + const void *data, size_t bytes) { + if (!kernel || idx < 0 || idx >= kernel->nInputs) return; + IOSurfaceLock(kernel->ioInputs[idx], 0, NULL); + memcpy(IOSurfaceGetBaseAddress(kernel->ioInputs[idx]), data, bytes); + IOSurfaceUnlock(kernel->ioInputs[idx], 0, NULL); +} + +void ane_bridge_read_output(ANEKernelHandle *kernel, int idx, + void *data, size_t bytes) { + if (!kernel || idx < 0 || idx >= kernel->nOutputs) return; + IOSurfaceLock(kernel->ioOutputs[idx], kIOSurfaceLockReadOnly, NULL); + memcpy(data, IOSurfaceGetBaseAddress(kernel->ioOutputs[idx]), bytes); + IOSurfaceUnlock(kernel->ioOutputs[idx], kIOSurfaceLockReadOnly, NULL); +} + +void ane_bridge_free(ANEKernelHandle *kernel) { + @autoreleasepool { + if (!kernel) return; + NSError *e = nil; + if (kernel->model) { + ((BOOL(*)(id,SEL,unsigned int,NSError**))objc_msgSend)( + kernel->model, @selector(unloadWithQoS:error:), 21, &e); + } + for (int i = 0; i < kernel->nInputs; i++) + if (kernel->ioInputs[i]) CFRelease(kernel->ioInputs[i]); + for (int i = 0; i < kernel->nOutputs; i++) + if (kernel->ioOutputs[i]) CFRelease(kernel->ioOutputs[i]); + if (kernel->tmpDir) { + [[NSFileManager defaultManager] removeItemAtPath:kernel->tmpDir error:nil]; + } + free(kernel->ioInputs); + free(kernel->ioOutputs); + free(kernel->inputBytes); + free(kernel->outputBytes); + + // Explicitly nil Objective-C objects to trigger ARC release before freeing struct + kernel->model = nil; + kernel->request = nil; + kernel->tmpDir = nil; + + free(kernel); + } +} + +int ane_bridge_get_compile_count(void) { + return g_compile_count; +} + +void ane_bridge_reset_compile_count(void) { + g_compile_count = 0; +} + +uint8_t *ane_bridge_build_weight_blob(const float *src, int rows, int cols, + size_t *out_len) { + int wsize = rows * cols * 2; // fp16 + int total = 128 + wsize; + uint8_t *buf = (uint8_t *)calloc(total, 1); + + // ANE blob header + buf[0] = 0x01; buf[4] = 0x02; + buf[64] = 0xEF; buf[65] = 0xBE; buf[66] = 0xAD; buf[67] = 0xDE; + buf[68] = 0x01; + *(uint32_t*)(buf + 72) = wsize; + *(uint32_t*)(buf + 80) = 128; + + // Convert float32 -> float16 + _Float16 *fp16 = (_Float16 *)(buf + 128); + for (int i = 0; i < rows * cols; i++) { + fp16[i] = (_Float16)src[i]; + } + + *out_len = total; + return buf; +} + +uint8_t *ane_bridge_build_weight_blob_transposed(const float *src, int rows, int cols, + size_t *out_len) { + int wsize = rows * cols * 2; + int total = 128 + wsize; + uint8_t *buf = (uint8_t *)calloc(total, 1); + + buf[0] = 0x01; buf[4] = 0x02; + buf[64] = 0xEF; buf[65] = 0xBE; buf[66] = 0xAD; buf[67] = 0xDE; + buf[68] = 0x01; + *(uint32_t*)(buf + 72) = wsize; + *(uint32_t*)(buf + 80) = 128; + + _Float16 *fp16 = (_Float16 *)(buf + 128); + for (int i = 0; i < rows; i++) + for (int j = 0; j < cols; j++) + fp16[j * rows + i] = (_Float16)src[i * cols + j]; + + *out_len = total; + return buf; +} diff --git a/bridge/libane_bridge.dylib b/bridge/libane_bridge.dylib new file mode 100755 index 0000000000000000000000000000000000000000..72acc32e285cd688d74b74f5735e8b6cfcae02e6 GIT binary patch literal 54480 zcmeHw4RjR8m2S+rVKtc$F&0rt~OaTAH2E{QwlK2t7iX_1CIv!dxEve0DM)dRm z0t~~CW5+mN#`Y$Dc1|E*zeT;PEKW!^tnD?ioh%zCYvW{{-F>kUHft-vFToMUGyGWZ zyWKUTRtwO}%Xw#a&sHDq+`4sd-MaPN>guX#&EfTP|NGNnjJX)Dd{8MUUBK7_teC3U z7|;ecW5HnKyt^9Bs(~NUK#_4d57k6)WnRIcu58e`0Os_!&}Zq|ODs+@y6wNbEZHwm z`?#KHis!1F{(e|tsY=a}F-Ef8pO5>)*v%^#j5XaC>UZk&*FDa%OL8Y9)A(}Np!h1gqtzykS3l;^J+IC9N{`WnRgFap7LG79_3?2Aksj zcmH62J!ntYbc>mv_Hky@omJB~7lMMpCAG^I&#S4mW*cKixP`r$EKW1z>m0!bp zj79Y-l5FSt-Ha+f z$Qt7{2xzQXHnDX5lXn_)_QcS7x*9~jtlX3*40>aMsv1y=1N72qb5GH+NOuk!TzlbT)}_s@`i-gEQgd!xev_Dso~x7 zKO4M1{w0`tE;cg}*JdJVMAlGdRN2yK#x=0rNq8C9j-D_x0`^u`f2{-hBfP%Ot}veJ zn5JAAv%t)OHtEQIYKhcE4r=)k-Ym&#-!{`ON#%?>+Gr*JI^e4bza=$ z?!35zr5fO>bzk)yjO{i|>H6EPt;IpM_Y0{20)acXA7_WBI9Ls9S-$z1y>P`Kh0w4)^YEY+5o$H%ggSH#_vu7AZ3Y{J?f8 z)8k@?j({kRDWKMUo;hSceuV4r2*fo7W#2q0<8etD4EG@GzB6?qKUt64%&B{WrCBNf z-SzPE3i^f5kry{rVeEmu?w-QJI}QY?@R};L47gwBY@3FVB<|? zJO#!lo_yoHC(k(N@fy7zk8#H1Ha_;ajI$nQnDGRpj1Tb~@Qh7=AN{0JPci$x4NS+e z+j>tGbY_Cy&h$){sYJc>=ThcfPnq$aXOuzMUi4gR`~bZB7M3A@?0-m^i@g3I_yzbo z-@p!UEnsUG=d;G12I=s}9^bT2Jf%kZ_X|4BHY4EkD^mLQJj>rcw4?Z_FZY8|Is|%y z`+OQa!(%UW6{O5DY%550PMmG}@_oB7b-**n@|%(}`(AM!+wXB7s{xM3xsQjzdp%=; z2{XDqEf2IBdh$CB+-#rpr9S>V z*^V-v<36TyAKkzt+y1A#-8gWX8{_lX+Vo5*^I9QmynuOv=W1#9=Lwq%V-NQ;7Myao zpUY+7gjM(LWH)9&T$RvuKZ<*EF6Pn$Qo3g@%k)&^9;9+L<}~+_2Ts_~T%d71L+ubd z#tc3N_YLu}#ve>xVVpL3nejVt3^?`80``&L%Nk8gF`u`)rOXcQn`}M+5q=i7g(<@N zJn-^o`ferMw!4z`Te? z#D8uv%VgW0taf8AjKQ2PF{tf#U~kT&#Zu-ulz|Om9c%m*)x*ab;ClWBDT96C;B~Ft zo!aKL;{F=!DQ?R5XA%3LF?v0z6M2^HcG!-EZPxx;n|%rH!Ly^K%u`tBDUN@JEzS4c z;1t^~@Fzjt9+tsV{EZ2c7vm{0TFkmq<9=`)vr6&o*!{_T_(JX|ezb+ZQsYO^lfTX2 zH174_J3%>rr^#QL(O}k%GE{KvT&d23;l%(_3iUL; z5n?*)Mn3``KRYZph?g33%TETh1m9%vk_>%6InyDN}mI>IVBmSvta>&lH%)X_bVyv|w|G83f?t&=qVV(19N6Q0qK%Z*{C zTw)Yt+vHQ5a>E0e#&HJsak7YI&Uq@P(X-_$7i&C&+M3#}PWJAIa-K zMm;_2=Ad0Eh+=xp!x|}{nAfMvM@bpPm3evNrt}Gn<89cJzYo9(Cu>c_dIOo(oqqvO z4yQgU49XqZ{yJw%&n!HDx$V!vHx;l{&{J{h#D%`b-ig2p@%19Fgzayjr}jSsC;UlA z_3!cepMcYP^U3Qqom=y*xXXFmV|m=~Zrp&l#~Odb+Z+bB;x02|ISQHLHo&d8%Z(Bq z_jhu(6{hWr+;$ImIpW^xE;r*QoY*$(r@&ofurC3p=Q-U!y_hRz%y`zGM2`O*xa1?o z0?@oul7ak+3K)&TaG>4vd#;!rJdN*rD$*0i#3SWv`XwPX-(S; ze>BEF1E+g&C3pzL*Pw$FUd=Ptw42SkQez?bC)j^{@M+&@;CRrs4&`ib;5P5TrUK=@ z9qwaiR}~;GANG)iofon0Q(tBtFa~pFF#&n?+Nl%p%~@bEB*CH#aaV-=F1vv3JS({3`c34&0n`H*(v_ z+;()o?G|o30-WZ$+5hLbT?KE4In}*Cuak5vg?aTJOT9OdrC!B)N4WmPm0j=3jI$`y zygp;{65|ix2X=gUP+m_k%4;d|S_}LZBCiWNfv1_vSJ9@=<2rV6lc)3ID=hUJ?8kes zAK!n~+|KPdFN8qA{M2u;MtuNG-}kV??}k`v|3r4|nQC@yGGe21LRRO-yr*Y| z7li#+{ri~fA9xCl^B4NQadA_1C;Bt{?t6tDv!9nJC(k3E2F$Bg%rQE%;EZ;9!nIN+ zj2!l%Z=AVLhp=b(cVJ?z`#$rze-Q1cJ&n_iSW3|Ev#6uBD|Dsab(x*QbV@s zsGT#G&+}Ml4xC-(!u-J5mCpsja~0bEqL3Y;`MMf8S&f`bK@MCKtbEYiqq1eakoyh_RYKS3>*a+U+4w`?(rd(#^fzy0H zv%CuTTp%~^2|tFt49x*5(;1w~>EH{nbLPN`=I>#K}-Y?^Ou+pQ=n;t7X;!YWESLzh*-+*gNS9Rgmu6ZM*t_5QPT@zX-40p9o zShgBm0^bL|9sE7;SHLH1yR+;1ug&e6>T;zPx!kEW+t^1ZF7`E!ca@#;W1b%`+VTf{ zAJIs@k7`>U>nyrzSCRLsU8b$e_nJTNM|Uvt{Q$eLW(0eV_DR%MLYqRbZ`V}#G>Uw? zj=FuXN%N!^*pZ3X`6si@zC3GBr6XT}7wLFh!nhQ9)kr%9ao4lPDu^wf= zfA|FVIpIBJv%8|aV|;_p;~r54yyuvBmzj8P1m5_r^pbdw0p8^{y!(LnDV!Gx=R)RA z{nk@DZO?GFw)^40##5vBj9sE^dTeRomdBQ@-SXHq!22NZPHoxpixZyVuf4lv>o4B( zjGl&kmSS9~-xsn}v3ta>?}MhIJl^Fwg_lc>d(g*Kz~`^_I!+rNmQU0#UEPvN$^73~jy;cs7>|__7Ew*xa(lgvRj&{jRU*o*Z z-lz9&3|qNVCvV>AWvOSGdsooqJr%%sUcoxB&dg&Ca#f9-RUvl)>_>iV<`Uo4oi2VD z>z;f3TV}b;R>nP?tH*tuD_?Ibk8@bEmDm$W~rvE7#k~D{bYwZRLAy7(y=%p|2T2zipM`Ci=KghWMy+;uX2A0jMjV@K%Vp_-2SXhbFDDjY{w(Bu%u+FThSJoyJyva6O zW?!W&Z@Nzj>6eK@Q{u76I>nE!{O#5_{h^2)k1O$_0Zo?0biW!^bybe28pK2c%8TSt^XTM%iTUzxvGab%oQzJ|= zi@|nTZ=E@yPl`VjYikF(ym_U6GKP?du8GFhM+e!O=v#65r)Y|MwH2I#+O)~vSp zRI9N)anymmk=(2+n%`_`8QVkSfzPz~j3zb8tQXN$9mu#64WqL&`v4+Z6JH3sIS1Oto(@2|$>MW!w zn6NdfhQwp?1a#`Gay^>VMiF+K5>{<`Rf$*1I{I!(=qPeelve7bs)H&cO0)$oSXgP6 z6A^u>9K}m;4Wcz85dsUVAqoX=3~fd7U=1c-*&NdrD;@DlGsa4$tdk=Ngy9^CW=-O3 z9IbM^4)3AW4VA1_*dX(H!E6YyP`EW=rWqNYq5|LPG|}Y6<~yWl20BBZYA7+zr8V+ z4JgeCR@G6XL=+u!aETl@r$%lbatjKre?p=iT$oMgYq2)oQ6g$bmk({drs_&#Y(8#7 zOsQC#8j>RsOj?HTj2Zn|4skz-Ex`4@71A1v@s372=(JzFCbk zkn0}2NdMt5=ALym^Blkn{VB!FD;2Z6XGXAmyi39{j(O21aW~HXUc42jvbY2S0s(=5 zKtLcM5D*9m1Ox&C0fB%(Kp-Fx5C{ka1Ofs9fq+0jARrJB2nYlO0{=-6C{EHlH2eY( ze|~LodM`$@^GX5^dd!R08CE5Wh6cQewX-I&gu88BwIh* z@8|S8N2XHfj~<4lU*%J*RZ z9uvS>c&FNbNhnS+#p2{&_^DuiP{C1ey~{Nf;zJ-H5D*9m1Ox&C0fB%(Kp-Fx5C{ka z1Ofs9fq+0jARrJB2nYlO0s;YnfIvVXAP^7;2m}NI0s(=5KtLcM5D*9m1Ox&C0fB%( zKp-Fx5C{ka1Ofs9fq+0jARrJB2nYlO0s;YnfIvVXAP^7;{5L~jsrduS_z`t0Xqbz! zso-F_%gamU;IneIjFHSF{4euTePE*#g=JItqFR^gGa}pkfd6i~~&v-2qzSVcvT{kAhwXy#qQ2 zD)lmNjhE#;0@?@q8OWW-^6mg7LEp_|uCsZJ1(!8cN91^Xz7}h1&^0yMLRF1QTYF5C zwT>E9Q$l)7>tMmF)k^{m4Z%q;nXg8arE*kmQM5t2Wr?eiRzb9)yLM`s~0M2OX~o`I(35*nIF^I&+-=3PNvQuc-t`8#ilH+CuHXJ`s;7 z3Yy27?h6Iw1QMb|6aZBhi$&^Ums2F-h|p}(7H?@#q6o#5c;lR;%PK(K~ted*5$#KYDiaOQOvJFs-+2CZi*-YO_MLHtg$r@Vk-8_ zvnU^0B7{^05ma?7iu6QvYvP&nyk2QmG$k5B(2fb6HLi{73MMt5!*ol#f^t*|Hfd_O zMG3Np=$`km4{yQ8LpQmCx+X{C?XkELW~CnX%aN!~PICp@5)oYut~cjXoZaYQ>1ojH zoPm$iv+!}W7$52B__%;Y#w~|~v4q~9(Ak&biXIHb+S+l0fhRC;??sy%8S^zI)JT~7 z3^qk#O>7J|WV4iYa)iw@r_rL8K^~T?QB`M+n!-Zn6bq_R>T3s$E7ZZ6-?Wp7d&|m7vdq?f`uObSG#YXg+8GXd#I1u|=T8pe3NCpk<(}AKFje0=gCC>}xLSBv1fU z1*!(sfT-`;pgEu`#X$YoNB+pjC_>i^iSnPY!7#Gzp5{n{0_djDTG;r-``|DeN#^Azx(b@!h=orKKe zTrh4O?)b@pA3Y|mXq-PgdjPbbNhZnR@baenaLTN7w2K_(H1y#gc??wW$gEJ=ivbZ3@h=Frnb|G$JxwwkQvw{r=XrKc;c1KF>SHZ5wA4I z?F^w+j@M~Qv%0~lh?<9S@&g1euUrw0t1VF_yd)NFaaw6gTWpqZ)JDU~Al%HQzcEHDBIZb}O{KFt&A}ENo>X0xBlxxfN9vYfVA;GH zCBBy3ceCT zT?BL&DRB*F?9JVNHydV&Lw%OCXDuRXQ~yJBMz{kqOPqXgy8P6iYiDjb`;BK$jotI~ z8_)gImpvoj|E(fV|4jMf>2sGS-g>UDefYC)&oA2Y{6F5jxcKPXPp{kcy}56eY<{e+ z@yTlsfA7GZ$?uiEw(h3J*n{b{Cr|Fuo?P_<(tensor x) {\n", DIM, SEQ]; + [m appendString:@CONV_CONST]; + [m appendFormat:@" tensor We = const()[name=string(\"We\"), " + "val=tensor(BLOBFILE(path=string(\"@model_path/weights/embed.bin\"), offset=uint64(64)))];\n", + VOCAB, DIM, VOCAB, DIM]; + [m appendFormat:@" tensor out = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=We,x=x)[name=string(\"cls\")];\n", VOCAB, SEQ]; + [m appendString:@" } -> (out);\n}\n"]; + return m; +} + +// ============================================================ +// Classifier backward: dx = embed^T @ dlogits +// ANE rejects conv with 32000 input channels. +// Use matmul instead: reshape dlogits to [1, VOCAB, SEQ], +// bake embed^T as [1, DIM, VOCAB], matmul → [1, DIM, SEQ], +// reshape back to [1, DIM, 1, SEQ]. +// ============================================================ +static NSString *gen_classifier_bwd(void) { + NSMutableString *m = [NSMutableString string]; + [m appendString:MIL_HDR]; + [m appendFormat:@" func main(tensor dl) {\n", VOCAB, SEQ]; + // Reshape dlogits from [1, VOCAB, 1, SEQ] to [1, VOCAB, SEQ] + [m appendFormat:@" tensor sh3 = const()[name=string(\"sh3\"), val=tensor([1,%d,%d])];\n", VOCAB, SEQ]; + [m appendFormat:@" tensor dl3 = reshape(shape=sh3,x=dl)[name=string(\"rdl\")];\n", VOCAB, SEQ]; + // embed_t as baked constant [1, DIM, VOCAB] + [m appendFormat:@" tensor Wet = const()[name=string(\"Wet\"), " + "val=tensor(BLOBFILE(path=string(\"@model_path/weights/embed_t.bin\"), offset=uint64(64)))];\n", + DIM, VOCAB, DIM, VOCAB]; + // matmul: [1, DIM, VOCAB] @ [1, VOCAB, SEQ] -> [1, DIM, SEQ] + [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; + [m appendFormat:@" tensor dx3 = matmul(transpose_x=bF,transpose_y=bF,x=Wet,y=dl3)[name=string(\"mm\")];\n", DIM, SEQ]; + // Reshape back to [1, DIM, 1, SEQ] + [m appendFormat:@" tensor sh4 = const()[name=string(\"sh4\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor out = reshape(shape=sh4,x=dx3)[name=string(\"out\")];\n", DIM, SEQ]; + [m appendString:@" } -> (out);\n}\n"]; + return m; +} + +// ============================================================ +// Softmax over VOCAB dimension (channel axis) for cross-entropy +// Input: logits [1, VOCAB, 1, SEQ] +// Output: probs [1, VOCAB, 1, SEQ] +// +// softmax(x, axis=1) = exp(x - max(x)) / sum(exp(x - max(x))) +// +// Note: After getting probs from ANE, the NLL loss + gradient +// (prob[target] -= 1.0) are done on CPU since they need target indexing. +// ============================================================ +static NSString *gen_softmax_vocab(void) { + NSMutableString *m = [NSMutableString string]; + [m appendString:MIL_HDR]; + [m appendFormat:@" func main(tensor x) {\n", VOCAB, SEQ]; + [m appendString:@" int32 ax = const()[name=string(\"ax\"), val=int32(1)];\n"]; + [m appendFormat:@" tensor out = softmax(axis=ax,x=x)[name=string(\"sm\")];\n", VOCAB, SEQ]; + [m appendString:@" } -> (out);\n}\n"]; + return m; +} + +// ============================================================ +// Final RMSNorm on ANE (replaces CPU rmsnorm for final layer) +// Input: x [1, DIM, 1, SEQ] +// Baked: rms_final weights [DIM] +// Output: xn [1, DIM, 1, SEQ] +// ============================================================ +static NSString *gen_final_rmsnorm(void) { + float invd = 1.0f/(float)DIM; + NSMutableString *m = [NSMutableString string]; + [m appendString:MIL_HDR]; + [m appendFormat:@" func main(tensor x) {\n", DIM, SEQ]; + [m appendFormat:@" tensor sq = mul(x=x,y=x)[name=string(\"sq\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor rax = const()[name=string(\"rax\"), val=tensor([1])];\n"]; + [m appendFormat:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"]; + [m appendFormat:@" tensor ss = reduce_sum(x=sq,axes=rax,keep_dims=kd)[name=string(\"ss\")];\n", SEQ]; + [m appendFormat:@" fp16 invd = const()[name=string(\"invd\"), val=fp16(%f)];\n", invd]; + [m appendFormat:@" tensor ss2 = mul(x=ss,y=invd)[name=string(\"ss2\")];\n", SEQ]; + [m appendFormat:@" fp16 eps = const()[name=string(\"eps\"), val=fp16(0.00001)];\n"]; + [m appendFormat:@" tensor ss3 = add(x=ss2,y=eps)[name=string(\"ss3\")];\n", SEQ]; + [m appendFormat:@" fp16 nhalf = const()[name=string(\"nhalf\"), val=fp16(-0.5)];\n"]; + [m appendFormat:@" tensor rrms = pow(x=ss3,y=nhalf)[name=string(\"rrms\")];\n", SEQ]; + [m appendFormat:@" tensor xr = mul(x=x,y=rrms)[name=string(\"xr\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor rw = const()[name=string(\"rw\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/rms_w.bin\"), offset=uint64(64)))];\n", DIM, DIM]; + [m appendFormat:@" tensor out = mul(x=xr,y=rw)[name=string(\"out\")];\n", DIM, SEQ]; + [m appendString:@" } -> (out);\n}\n"]; + return m; +} diff --git a/training/ane_rmsnorm_bwd.h b/training/ane_rmsnorm_bwd.h new file mode 100644 index 0000000..eb51896 --- /dev/null +++ b/training/ane_rmsnorm_bwd.h @@ -0,0 +1,78 @@ +// ane_rmsnorm_bwd.h — MIL generator for RMSNorm backward on ANE +// Replaces CPU rmsnorm_bwd() from stories_cpu_ops.h +// +// RMSNorm forward: xn = x * rrms * w, where rrms = 1/sqrt(mean(x²) + eps) +// RMSNorm backward: dx = w * rrms * (dy - x * sum(dy*w*x) * invd * rrms²) +// +// Input: concat(dy, x) as [1, 2*DIM, 1, SEQ] +// Baked: RMSNorm weights w [1, DIM, 1, 1] as BLOBFILE +// Output: dx [1, DIM, 1, SEQ] +// +// Note: dw (weight gradient) stays on CPU — it requires reduce_sum over SEQ +// and accumulation across steps, which is cheap and better done on CPU. +#pragma once +#include "stories_mil.h" + +// Generate MIL for RMSNorm backward +// Input: concat(dy, x) [1, 2*DIM, 1, SEQ] +// Baked weights: rms_w [DIM] — the RMSNorm scale weights +// Output: dx [1, DIM, 1, SEQ] +static NSString *gen_rmsnorm_bwd(void) { + float invd = 1.0f / (float)DIM; + NSMutableString *m = [NSMutableString string]; + [m appendString:MIL_HDR]; + + // Input: concat of dy and x along channel dimension + [m appendFormat:@" func main(tensor inp) {\n", 2*DIM, SEQ]; + + // Slice out dy [1, DIM, 1, SEQ] and x [1, DIM, 1, SEQ] + [m appendFormat:@" tensor sz = const()[name=string(\"sz\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; + [m appendFormat:@" tensor dy = slice_by_size(x=inp,begin=b0,size=sz)[name=string(\"sdy\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", DIM]; + [m appendFormat:@" tensor x = slice_by_size(x=inp,begin=b1,size=sz)[name=string(\"sx\")];\n", DIM, SEQ]; + + // Step 1: Compute rrms = 1/sqrt(mean(x²) + eps) + // sq = x * x + [m appendFormat:@" tensor sq = mul(x=x,y=x)[name=string(\"sq\")];\n", DIM, SEQ]; + // ss = sum(sq, axis=1, keepdims=true) → [1,1,1,SEQ] + [m appendFormat:@" tensor rax = const()[name=string(\"rax\"), val=tensor([1])];\n"]; + [m appendFormat:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"]; + [m appendFormat:@" tensor ss = reduce_sum(x=sq,axes=rax,keep_dims=kd)[name=string(\"ss\")];\n", SEQ]; + // ss2 = ss * invd + eps + [m appendFormat:@" fp16 invd = const()[name=string(\"invd\"), val=fp16(%f)];\n", invd]; + [m appendFormat:@" tensor ss2 = mul(x=ss,y=invd)[name=string(\"ss2\")];\n", SEQ]; + [m appendFormat:@" fp16 eps = const()[name=string(\"eps\"), val=fp16(0.00001)];\n"]; + [m appendFormat:@" tensor ss3 = add(x=ss2,y=eps)[name=string(\"ss3\")];\n", SEQ]; + // rrms = pow(ss3, -0.5) → [1,1,1,SEQ] + [m appendFormat:@" fp16 nhalf = const()[name=string(\"nhalf\"), val=fp16(-0.5)];\n"]; + [m appendFormat:@" tensor rrms = pow(x=ss3,y=nhalf)[name=string(\"rrms\")];\n", SEQ]; + + // Step 2: Load RMSNorm weights w [1, DIM, 1, 1] + [m appendFormat:@" tensor w = const()[name=string(\"w\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/rms_w.bin\"), offset=uint64(64)))];\n", DIM, DIM]; + + // Step 3: Compute dot = sum(dy * w * x, axis=1) * invd * rrms² + // dyw = dy * w → [1, DIM, 1, SEQ] + [m appendFormat:@" tensor dyw = mul(x=dy,y=w)[name=string(\"dyw\")];\n", DIM, SEQ]; + // dywx = dyw * x → [1, DIM, 1, SEQ] + [m appendFormat:@" tensor dywx = mul(x=dyw,y=x)[name=string(\"dywx\")];\n", DIM, SEQ]; + // dot_sum = sum(dywx, axis=1, keepdims=true) → [1,1,1,SEQ] + [m appendFormat:@" tensor dot_sum = reduce_sum(x=dywx,axes=rax,keep_dims=kd)[name=string(\"ds\")];\n", SEQ]; + // dot_scaled = dot_sum * invd → [1,1,1,SEQ] + [m appendFormat:@" tensor dot_sc = mul(x=dot_sum,y=invd)[name=string(\"dsc\")];\n", SEQ]; + // rrms_sq = rrms * rrms → [1,1,1,SEQ] + [m appendFormat:@" tensor rrms2 = mul(x=rrms,y=rrms)[name=string(\"rr2\")];\n", SEQ]; + // coeff = dot_scaled * rrms_sq → [1,1,1,SEQ] + [m appendFormat:@" tensor coeff = mul(x=dot_sc,y=rrms2)[name=string(\"cof\")];\n", SEQ]; + + // Step 4: dx = (dy * w - x * coeff) * rrms + // x_coeff = x * coeff → [1, DIM, 1, SEQ] + [m appendFormat:@" tensor xc = mul(x=x,y=coeff)[name=string(\"xc\")];\n", DIM, SEQ]; + // diff = dyw - xc → [1, DIM, 1, SEQ] + [m appendFormat:@" tensor diff = sub(x=dyw,y=xc)[name=string(\"dif\")];\n", DIM, SEQ]; + // dx = diff * rrms → [1, DIM, 1, SEQ] + [m appendFormat:@" tensor out = mul(x=diff,y=rrms)[name=string(\"out\")];\n", DIM, SEQ]; + + [m appendString:@" } -> (out);\n}\n"]; + return m; +} diff --git a/training/download_data.sh b/training/download_data.sh new file mode 100755 index 0000000..2d27d96 --- /dev/null +++ b/training/download_data.sh @@ -0,0 +1,91 @@ +#!/bin/bash +# Download pretokenized TinyStories data for ANE training +# Format: flat uint16 token IDs (Llama2 BPE, 32K vocab) +# Source: enio/TinyStories on HuggingFace (pretokenized with karpathy/llama2.c) +# +# The tar.gz contains data00.bin..data49.bin (50 shards). +# We extract only data00.bin and rename it to tinystories_data00.bin. + +set -e + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +OUTPUT="$SCRIPT_DIR/tinystories_data00.bin" + +if [ -f "$OUTPUT" ]; then + SIZE=$(stat -f%z "$OUTPUT" 2>/dev/null || stat -c%s "$OUTPUT" 2>/dev/null) + TOKENS=$((SIZE / 2)) + echo "$OUTPUT already exists ($TOKENS tokens, $(echo "scale=1; $SIZE/1000000" | bc) MB)" + exit 0 +fi + +TAR_URL="https://huggingface.co/datasets/enio/TinyStories/resolve/main/tok32000/TinyStories_tok32000.tar.gz?download=true" +TAR_FILE="$SCRIPT_DIR/TinyStories_tok32000.tar.gz" + +echo "=== TinyStories Data Download ===" +echo "Downloading pretokenized TinyStories (32K vocab, ~993 MB)..." +echo " Source: enio/TinyStories on HuggingFace" +echo " This will take a few minutes depending on your connection." +echo "" + +# Download the tar.gz +if [ ! -f "$TAR_FILE" ]; then + if command -v curl &>/dev/null; then + curl -L --progress-bar -o "$TAR_FILE" "$TAR_URL" + elif command -v wget &>/dev/null; then + wget --show-progress -O "$TAR_FILE" "$TAR_URL" + else + echo "Error: need curl or wget" + exit 1 + fi +else + echo "Tar file already downloaded, skipping..." +fi + +# Verify it's actually a gzip file (not an error page) +if ! file "$TAR_FILE" | grep -q "gzip"; then + echo "Error: Downloaded file is not a valid gzip archive." + echo "Content: $(head -c 100 "$TAR_FILE")" + rm -f "$TAR_FILE" + exit 1 +fi + +echo "" +echo "Extracting data00.bin from archive..." + +# List what's in the archive to find the right path +DATA_FILE=$(tar tzf "$TAR_FILE" 2>/dev/null | grep 'data00\.bin' | head -1) +if [ -z "$DATA_FILE" ]; then + echo "Error: data00.bin not found in archive. Contents:" + tar tzf "$TAR_FILE" | head -20 + exit 1 +fi +echo " Found: $DATA_FILE" + +# Extract just data00.bin +tar xzf "$TAR_FILE" -C "$SCRIPT_DIR" "$DATA_FILE" + +# Move to expected location (might be in a subdirectory) +EXTRACTED="$SCRIPT_DIR/$DATA_FILE" +if [ "$EXTRACTED" != "$OUTPUT" ]; then + mv "$EXTRACTED" "$OUTPUT" + # Clean up any extracted subdirectories + rmdir "$(dirname "$EXTRACTED")" 2>/dev/null || true +fi + +# Clean up tar.gz to save disk space +echo "Cleaning up archive..." +rm -f "$TAR_FILE" + +SIZE=$(stat -f%z "$OUTPUT" 2>/dev/null || stat -c%s "$OUTPUT" 2>/dev/null) +TOKENS=$((SIZE / 2)) +echo "" +echo "Done: $OUTPUT" +echo " $TOKENS tokens ($(echo "scale=1; $SIZE/1000000" | bc) MB)" + +# Sanity check +python3 -c " +import struct +with open('$OUTPUT', 'rb') as f: + tokens = struct.unpack('<10H', f.read(20)) + print(f'First 10 tokens: {tokens}') +" 2>/dev/null || true diff --git a/training/test_classifier.m b/training/test_classifier.m new file mode 100644 index 0000000..363e46e --- /dev/null +++ b/training/test_classifier.m @@ -0,0 +1,255 @@ +// test_classifier.m — Test classifier matmul (32000 channels) and softmax on ANE +// This tests the riskiest operations: VOCAB-sized conv and softmax +// Build: xcrun clang -O2 -framework Foundation -framework IOSurface \ +// -framework CoreML -framework Accelerate -ldl -lobjc \ +// -o test_classifier test_classifier.m +#include "ane_classifier.h" +#include "stories_cpu_ops.h" + +int main(void) { + @autoreleasepool { + setbuf(stdout, NULL); + ane_init(); + mach_timebase_info(&g_tb); + + printf("=== Test: Classifier + Softmax on ANE ===\n"); + printf("DIM=%d SEQ=%d VOCAB=%d\n\n", DIM, SEQ, VOCAB); + + // ======== Test 1: Final RMSNorm ======== + printf("--- Test 1: Final RMSNorm on ANE ---\n"); + { + float *x = (float*)malloc(DIM * SEQ * 4); + float *w = (float*)malloc(DIM * 4); + float *out_cpu = (float*)malloc(DIM * SEQ * 4); + float *out_ane = (float*)malloc(DIM * SEQ * 4); + srand48(42); + for (int i = 0; i < DIM * SEQ; i++) x[i] = (float)(drand48() * 2 - 1); + for (int i = 0; i < DIM; i++) w[i] = (float)(drand48() * 0.5 + 0.75); + + rmsnorm(out_cpu, x, w, DIM, SEQ); + + Kern *kern = compile_kern_mil_w(gen_final_rmsnorm(), (@{ + @"@model_path/weights/rms_w.bin": @{@"offset":@0, @"data":build_blob(w, 1, DIM)}, + }), DIM*SEQ*2, DIM*SEQ*2); + + if (!kern) { printf("FAIL: Final RMSNorm compile failed\n"); return 1; } + printf("Compile OK\n"); + + io_write_fp16(kern->ioIn, x, DIM, SEQ); + ane_eval(kern); + io_read_fp16(kern->ioOut, out_ane, 0, DIM, SEQ); + + float max_err = 0; + for (int i = 0; i < DIM*SEQ; i++) { + float e = fabsf(out_cpu[i] - out_ane[i]); + if (e > max_err) max_err = e; + } + printf("Max error: %.6f %s\n\n", max_err, max_err < 0.05 ? "PASS ✅" : "FAIL ❌"); + free_kern(kern); + free(x); free(w); free(out_cpu); free(out_ane); + } + + // ======== Test 2: Classifier forward (32000-channel conv) ======== + printf("--- Test 2: Classifier Forward (VOCAB=%d channel conv) ---\n", VOCAB); + { + float *x_final = (float*)malloc(DIM * SEQ * 4); + float *embed = (float*)malloc((size_t)VOCAB * DIM * 4); + float *logits_cpu = (float*)malloc((size_t)VOCAB * SEQ * 4); + float *logits_ane = (float*)malloc((size_t)VOCAB * SEQ * 4); + + srand48(123); + for (int i = 0; i < DIM * SEQ; i++) x_final[i] = (float)(drand48() * 2 - 1) * 0.1f; + for (size_t i = 0; i < (size_t)VOCAB * DIM; i++) embed[i] = (float)(drand48() * 2 - 1) * 0.02f; + + // CPU reference: logits = embed @ x_final + // logits[v, t] = sum_d embed[v,d] * x_final[d,t] + // embed is [VOCAB, DIM] row-major, x_final is [DIM, SEQ] channel-first + uint64_t t0 = mach_absolute_time(); + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, + VOCAB, SEQ, DIM, 1.0f, + embed, DIM, x_final, SEQ, 0.0f, logits_cpu, SEQ); + uint64_t t1 = mach_absolute_time(); + printf("CPU cblas_sgemm: %.2f ms\n", tb_ms(t1-t0)); + + // ANE: build weight blob for embed [VOCAB, DIM] + printf("Building embed blob (%.1f MB fp16)...\n", (float)VOCAB*DIM*2/1e6); + NSData *embed_blob = build_blob(embed, VOCAB, DIM); + + printf("Compiling classifier kernel...\n"); + t0 = mach_absolute_time(); + Kern *cls = compile_kern_mil_w(gen_classifier_fwd(), (@{ + @"@model_path/weights/embed.bin": @{@"offset":@0, @"data":embed_blob}, + }), DIM*SEQ*2, VOCAB*SEQ*2); + t1 = mach_absolute_time(); + + if (!cls) { + printf("FAIL: Classifier compile failed (32000 channels too large for ANE)\n"); + printf("This confirms tiling is needed.\n\n"); + } else { + printf("Compile OK in %.0f ms (compiles=%d)\n", tb_ms(t1-t0), g_compile_count); + + io_write_fp16(cls->ioIn, x_final, DIM, SEQ); + t0 = mach_absolute_time(); + ane_eval(cls); + t1 = mach_absolute_time(); + printf("ANE eval: %.2f ms\n", tb_ms(t1-t0)); + + // Read back and compare (sample — full read would be 32000*256*4 = 32MB) + io_read_fp16(cls->ioOut, logits_ane, 0, VOCAB, SEQ); + + float max_err = 0, sum_err = 0; + int cnt = 0; + for (int v = 0; v < VOCAB; v++) { + for (int t = 0; t < SEQ; t++) { + int idx = v*SEQ + t; + float e = fabsf(logits_cpu[idx] - logits_ane[idx]); + sum_err += e; + cnt++; + if (e > max_err) max_err = e; + } + } + printf("Max error: %.6f Mean error: %.6f %s\n", + max_err, sum_err/cnt, max_err < 1.0 ? "PASS ✅" : "FAIL ❌"); + + // Benchmark + int N = 10; + t0 = mach_absolute_time(); + for (int i = 0; i < N; i++) ane_eval(cls); + t1 = mach_absolute_time(); + printf("Benchmark: %d evals in %.2f ms (%.2f ms/eval)\n\n", N, tb_ms(t1-t0), tb_ms(t1-t0)/N); + free_kern(cls); + } + free(x_final); free(embed); free(logits_cpu); free(logits_ane); + } + + // ======== Test 3: Softmax over VOCAB dimension ======== + printf("--- Test 3: Softmax over VOCAB=%d ---\n", VOCAB); + { + float *logits = (float*)malloc((size_t)VOCAB * SEQ * 4); + float *probs_cpu = (float*)malloc((size_t)VOCAB * SEQ * 4); + float *probs_ane = (float*)malloc((size_t)VOCAB * SEQ * 4); + + srand48(999); + for (size_t i = 0; i < (size_t)VOCAB * SEQ; i++) + logits[i] = (float)(drand48() * 10 - 5); + + // CPU reference softmax (per position, over vocab) + // logits is [VOCAB, SEQ] channel-first + uint64_t t0 = mach_absolute_time(); + for (int t = 0; t < SEQ; t++) { + float maxv = -1e30f; + for (int v = 0; v < VOCAB; v++) { + float val = logits[v*SEQ+t]; + if (val > maxv) maxv = val; + } + float sum = 0; + for (int v = 0; v < VOCAB; v++) { + probs_cpu[v*SEQ+t] = expf(logits[v*SEQ+t] - maxv); + sum += probs_cpu[v*SEQ+t]; + } + for (int v = 0; v < VOCAB; v++) probs_cpu[v*SEQ+t] /= sum; + } + uint64_t t1 = mach_absolute_time(); + printf("CPU softmax: %.2f ms\n", tb_ms(t1-t0)); + + printf("Compiling softmax kernel...\n"); + int sm_bytes = VOCAB * SEQ * 2; + Kern *sm = compile_kern_mil_w(gen_softmax_vocab(), @{}, sm_bytes, sm_bytes); + + if (!sm) { + printf("FAIL: Softmax compile failed\n\n"); + } else { + printf("Compile OK\n"); + + io_write_fp16(sm->ioIn, logits, VOCAB, SEQ); + t0 = mach_absolute_time(); + ane_eval(sm); + t1 = mach_absolute_time(); + printf("ANE eval: %.2f ms\n", tb_ms(t1-t0)); + + io_read_fp16(sm->ioOut, probs_ane, 0, VOCAB, SEQ); + + // Check: probs should sum to ~1.0 per position + float max_err = 0; + for (int t = 0; t < 4; t++) { + float sum_cpu = 0, sum_ane = 0; + for (int v = 0; v < VOCAB; v++) { + sum_cpu += probs_cpu[v*SEQ+t]; + sum_ane += probs_ane[v*SEQ+t]; + float e = fabsf(probs_cpu[v*SEQ+t] - probs_ane[v*SEQ+t]); + if (e > max_err) max_err = e; + } + printf(" pos %d: CPU sum=%.4f ANE sum=%.4f\n", t, sum_cpu, sum_ane); + } + printf("Max error (first 4 positions): %.6f %s\n", + max_err, max_err < 0.01 ? "PASS ✅" : "FAIL ❌"); + + int N = 10; + t0 = mach_absolute_time(); + for (int i = 0; i < N; i++) ane_eval(sm); + t1 = mach_absolute_time(); + printf("Benchmark: %d evals in %.2f ms (%.2f ms/eval)\n\n", N, tb_ms(t1-t0), tb_ms(t1-t0)/N); + free_kern(sm); + } + free(logits); free(probs_cpu); free(probs_ane); + } + + // ======== Test 4: Classifier backward ======== + printf("--- Test 4: Classifier Backward (DIM=%d from VOCAB=%d) ---\n", DIM, VOCAB); + { + float *dlogits = (float*)malloc((size_t)VOCAB * SEQ * 4); + float *embed = (float*)malloc((size_t)VOCAB * DIM * 4); + float *dx_cpu = (float*)malloc(DIM * SEQ * 4); + float *dx_ane = (float*)malloc(DIM * SEQ * 4); + + srand48(456); + for (size_t i = 0; i < (size_t)VOCAB * SEQ; i++) dlogits[i] = (float)(drand48() * 2 - 1) * 0.01f; + for (size_t i = 0; i < (size_t)VOCAB * DIM; i++) embed[i] = (float)(drand48() * 2 - 1) * 0.02f; + + // CPU: dx = embed^T @ dlogits + uint64_t t0 = mach_absolute_time(); + cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, + DIM, SEQ, VOCAB, 1.0f, + embed, DIM, dlogits, SEQ, 0.0f, dx_cpu, SEQ); + uint64_t t1 = mach_absolute_time(); + printf("CPU cblas_sgemm: %.2f ms\n", tb_ms(t1-t0)); + + // Build transposed embed blob + NSData *embed_t_blob = build_blob_t(embed, VOCAB, DIM); + + printf("Compiling classifier backward...\n"); + Kern *clsb = compile_kern_mil_w(gen_classifier_bwd(), (@{ + @"@model_path/weights/embed_t.bin": @{@"offset":@0, @"data":embed_t_blob}, + }), VOCAB*SEQ*2, DIM*SEQ*2); + + if (!clsb) { + printf("FAIL: Classifier backward compile failed\n\n"); + } else { + printf("Compile OK\n"); + + io_write_fp16(clsb->ioIn, dlogits, VOCAB, SEQ); + t0 = mach_absolute_time(); + ane_eval(clsb); + t1 = mach_absolute_time(); + printf("ANE eval: %.2f ms\n", tb_ms(t1-t0)); + + io_read_fp16(clsb->ioOut, dx_ane, 0, DIM, SEQ); + + float max_err = 0, sum_err = 0; + for (int i = 0; i < DIM*SEQ; i++) { + float e = fabsf(dx_cpu[i] - dx_ane[i]); + sum_err += e; + if (e > max_err) max_err = e; + } + printf("Max error: %.6f Mean error: %.6f %s\n\n", + max_err, sum_err/(DIM*SEQ), max_err < 1.0 ? "PASS ✅" : "FAIL ❌"); + free_kern(clsb); + } + free(dlogits); free(embed); free(dx_cpu); free(dx_ane); + } + + printf("=== All tests complete ===\n"); + printf("Total ANE compiles used: %d\n", g_compile_count); + return 0; + } +} diff --git a/training/test_rmsnorm_bwd.m b/training/test_rmsnorm_bwd.m new file mode 100644 index 0000000..9014e53 --- /dev/null +++ b/training/test_rmsnorm_bwd.m @@ -0,0 +1,123 @@ +// test_rmsnorm_bwd.m — Test RMSNorm backward ANE kernel vs CPU reference +// Build: xcrun clang -O2 -framework Foundation -framework IOSurface \ +// -framework CoreML -framework Accelerate -ldl -lobjc \ +// -o test_rmsnorm_bwd test_rmsnorm_bwd.m +#include "ane_rmsnorm_bwd.h" +#include "stories_cpu_ops.h" + +int main(void) { + @autoreleasepool { + setbuf(stdout, NULL); + ane_init(); + mach_timebase_info(&g_tb); + + printf("=== Test: RMSNorm Backward on ANE ===\n"); + printf("DIM=%d SEQ=%d\n\n", DIM, SEQ); + + // Allocate test data + float *x = (float*)malloc(DIM * SEQ * 4); + float *dy = (float*)malloc(DIM * SEQ * 4); + float *w = (float*)malloc(DIM * 4); + float *dx_cpu = (float*)calloc(DIM * SEQ, 4); + float *dw_cpu = (float*)calloc(DIM, 4); + float *dx_ane = (float*)malloc(DIM * SEQ * 4); + + // Random init (channel-first [DIM, SEQ]) + srand48(42); + for (int i = 0; i < DIM * SEQ; i++) { + x[i] = (float)(drand48() * 2 - 1) * 0.5f; + dy[i] = (float)(drand48() * 2 - 1) * 0.1f; + } + for (int i = 0; i < DIM; i++) { + w[i] = (float)(drand48() * 0.5 + 0.75); // close to 1.0 + } + + // === CPU Reference === + uint64_t t0 = mach_absolute_time(); + rmsnorm_bwd(dx_cpu, dw_cpu, dy, x, w, DIM, SEQ); + uint64_t t1 = mach_absolute_time(); + printf("CPU rmsnorm_bwd: %.2f ms\n", tb_ms(t1 - t0)); + + // === ANE Kernel === + printf("Compiling ANE rmsnorm_bwd kernel...\n"); + NSString *mil = gen_rmsnorm_bwd(); + + // Build weight blob for RMSNorm weights + NSData *rms_blob = build_blob(w, 1, DIM); + + int in_bytes = 2 * DIM * SEQ * 2; // concat(dy, x) in fp16 + int out_bytes = DIM * SEQ * 2; // dx in fp16 + + Kern *kern = compile_kern_mil_w(mil, (@{ + @"@model_path/weights/rms_w.bin": @{@"offset":@0, @"data":rms_blob}, + }), in_bytes, out_bytes); + + if (!kern) { + printf("FAIL: ANE kernel compilation failed!\n"); + return 1; + } + printf("Compile OK (compiles=%d)\n", g_compile_count); + + // Write input: concat(dy, x) into ioIn + // dy goes at channel offset 0, x goes at channel offset DIM + io_write_fp16_at(kern->ioIn, 0, dy, DIM, SEQ); + io_write_fp16_at(kern->ioIn, DIM, x, DIM, SEQ); + + // Evaluate + t0 = mach_absolute_time(); + ane_eval(kern); + t1 = mach_absolute_time(); + printf("ANE eval: %.3f ms\n", tb_ms(t1 - t0)); + + // Read output + io_read_fp16(kern->ioOut, dx_ane, 0, DIM, SEQ); + + // === Compare === + float max_err = 0, sum_err = 0; + int max_i = 0, max_j = 0; + for (int i = 0; i < DIM; i++) { + for (int j = 0; j < SEQ; j++) { + int idx = i * SEQ + j; + float err = fabsf(dx_cpu[idx] - dx_ane[idx]); + sum_err += err; + if (err > max_err) { + max_err = err; + max_i = i; max_j = j; + } + } + } + float mean_err = sum_err / (DIM * SEQ); + + printf("\n=== Results ===\n"); + printf("Max absolute error: %.6f at [%d,%d] (CPU=%.6f ANE=%.6f)\n", + max_err, max_i, max_j, dx_cpu[max_i*SEQ+max_j], dx_ane[max_i*SEQ+max_j]); + printf("Mean absolute error: %.6f\n", mean_err); + + // Sample outputs + printf("\nSample dx values (first 4 channels, first 4 positions):\n"); + printf("%-6s %-12s %-12s %-10s\n", "Idx", "CPU", "ANE", "Error"); + for (int i = 0; i < 4 && i < DIM; i++) { + for (int j = 0; j < 4 && j < SEQ; j++) { + int idx = i * SEQ + j; + printf("[%d,%d] %-12.6f %-12.6f %-10.6f\n", + i, j, dx_cpu[idx], dx_ane[idx], fabsf(dx_cpu[idx] - dx_ane[idx])); + } + } + + // Benchmark: multiple evals + int N = 100; + t0 = mach_absolute_time(); + for (int i = 0; i < N; i++) ane_eval(kern); + t1 = mach_absolute_time(); + printf("\nBenchmark: %d evals in %.2f ms (%.3f ms/eval)\n", + N, tb_ms(t1-t0), tb_ms(t1-t0)/N); + + // Pass/fail + bool pass = max_err < 0.05f && mean_err < 0.01f; + printf("\n%s (threshold: max<0.05, mean<0.01)\n", pass ? "PASS ✅" : "FAIL ❌"); + + free_kern(kern); + free(x); free(dy); free(w); free(dx_cpu); free(dw_cpu); free(dx_ane); + return pass ? 0 : 1; + } +} diff --git a/training/train_large_ane.m b/training/train_large_ane.m new file mode 100644 index 0000000..d7a99ef --- /dev/null +++ b/training/train_large_ane.m @@ -0,0 +1,695 @@ +// train_large_ane.m — Stories110M training with CPU ops offloaded to ANE +// Based on train_large.m but moves these operations from CPU to ANE: +// 1. Final RMSNorm (was CPU vDSP) → ANE kernel +// 2. Classifier forward embed@x (was CPU cblas) → ANE 32000-ch conv +// 3. Cross-entropy softmax (was CPU vDSP) → ANE softmax kernel +// 4. RMSNorm backward (was CPU vDSP) → ANE kernel +// Still on CPU: dW gradients (parallel via GCD), Adam optimizer (needs weight mutation), +// classifier backward (ANE matmul slower than cblas for this shape), +// NLL loss + gradient (needs target indexing) +// +// Build: make train_large_ane +// Run: ./train_large_ane [--resume] [--steps N] [--lr F] +#include "stories_io.h" +#include "stories_mil.h" +#include "stories_cpu_ops.h" +#include "ane_rmsnorm_bwd.h" +#include "ane_classifier.h" + +#define CKPT_PATH "ane_stories110M_ckpt.bin" +#define MODEL_PATH "../../assets/models/stories110M.bin" +#define DATA_PATH "tinystories_data00.bin" + +// ===== Weight loading from llama2.c format ===== +static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) { + FILE *f = fopen(path, "rb"); + if (!f) { printf("Cannot open %s\n", path); return false; } + Llama2Config cfg; + fread(&cfg, sizeof(cfg), 1, f); + printf(" Model config: dim=%d hidden=%d layers=%d heads=%d vocab=%d seq=%d\n", + cfg.dim, cfg.hidden_dim, cfg.n_layers, cfg.n_heads, abs(cfg.vocab_size), cfg.seq_len); + if (cfg.dim != DIM || cfg.hidden_dim != HIDDEN || cfg.n_layers != NLAYERS) { + printf(" ERROR: Config mismatch!\n"); fclose(f); return false; + } + int V = abs(cfg.vocab_size); + bool shared = cfg.vocab_size > 0; + fread(embed, 4, V * DIM, f); + for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_att, 4, DIM, f); + for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wq, 4, WQ_SZ, f); + for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wk, 4, WQ_SZ, f); + for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wv, 4, WQ_SZ, f); + for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wo, 4, WO_SZ, f); + for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_ffn, 4, DIM, f); + for (int L = 0; L < NLAYERS; L++) fread(lw[L].W1, 4, W1_SZ, f); + for (int L = 0; L < NLAYERS; L++) fread(lw[L].W2, 4, W2_SZ, f); + for (int L = 0; L < NLAYERS; L++) fread(lw[L].W3, 4, W3_SZ, f); + fread(rms_final, 4, DIM, f); + fclose(f); + printf(" Loaded pretrained weights (%s)\n", shared ? "shared embed/cls" : "separate cls"); + return true; +} + +// ===== Compile one layer's kernels ===== +static bool compile_layer_kernels(LayerKernels *lk, LayerWeights *w) { + lk->fwdAttn = compile_kern_mil_w(gen_sdpa_fwd_taps(), (@{ + @"@model_path/weights/rms1.bin": @{@"offset":@0, @"data":build_blob(w->rms_att,1,DIM)}, + @"@model_path/weights/wq.bin": @{@"offset":@0, @"data":build_blob(w->Wq,DIM,DIM)}, + @"@model_path/weights/wk.bin": @{@"offset":@0, @"data":build_blob(w->Wk,DIM,DIM)}, + @"@model_path/weights/wv.bin": @{@"offset":@0, @"data":build_blob(w->Wv,DIM,DIM)}, + @"@model_path/weights/wo.bin": @{@"offset":@0, @"data":build_blob(w->Wo,DIM,DIM)}, + @"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()}, + }), DIM*SEQ*2, 6*DIM*SEQ*2); + lk->fwdFFN = compile_kern_mil_w(gen_ffn_fwd_taps(), (@{ + @"@model_path/weights/rms2.bin": @{@"offset":@0, @"data":build_blob(w->rms_ffn,1,DIM)}, + @"@model_path/weights/w1.bin": @{@"offset":@0, @"data":build_blob(w->W1,HIDDEN,DIM)}, + @"@model_path/weights/w3.bin": @{@"offset":@0, @"data":build_blob(w->W3,HIDDEN,DIM)}, + @"@model_path/weights/w2.bin": @{@"offset":@0, @"data":build_blob(w->W2,DIM,HIDDEN)}, + }), DIM*SEQ*2, (2*DIM+3*HIDDEN)*SEQ*2); + lk->ffnBwd = compile_kern_mil_w(gen_ffn_bwd(), (@{ + @"@model_path/weights/w2t.bin": @{@"offset":@0, @"data":build_blob_t(w->W2,DIM,HIDDEN)}, + @"@model_path/weights/w1t.bin": @{@"offset":@0, @"data":build_blob_t(w->W1,HIDDEN,DIM)}, + @"@model_path/weights/w3t.bin": @{@"offset":@0, @"data":build_blob_t(w->W3,HIDDEN,DIM)}, + }), (DIM+2*HIDDEN)*SEQ*2, (DIM+2*HIDDEN)*SEQ*2); + lk->sdpaBwd1 = compile_kern_mil_w(gen_sdpa_bwd1(), (@{ + @"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()}, + @"@model_path/weights/wot.bin": @{@"offset":@0, @"data":build_blob_t(w->Wo,DIM,DIM)}, + }), 4*DIM*SEQ*2, (DIM+2*SCORE_CH)*SEQ*2); + lk->qkvBwd = compile_kern_mil_w(gen_qkvb(), (@{ + @"@model_path/weights/wqt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wq,DIM,DIM)}, + @"@model_path/weights/wkt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wk,DIM,DIM)}, + @"@model_path/weights/wvt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wv,DIM,DIM)}, + }), 3*DIM*SEQ*2, DIM*SEQ*2); + return lk->fwdAttn && lk->fwdFFN && lk->ffnBwd && lk->sdpaBwd1 && lk->qkvBwd; +} + +static Kern *compile_sdpa_bwd2(void) { + return compile_kern_mil_w(gen_sdpa_bwd2(), @{}, + (2*SCORE_CH+2*DIM)*SEQ*2, 2*DIM*SEQ*2); +} + +// NEW: Compile RMSNorm backward kernels (one per layer pair: attn + ffn) +static Kern *compile_rmsnorm_bwd_kern(const float *rms_w) { + return compile_kern_mil_w(gen_rmsnorm_bwd(), (@{ + @"@model_path/weights/rms_w.bin": @{@"offset":@0, @"data":build_blob(rms_w, 1, DIM)}, + }), 2*DIM*SEQ*2, DIM*SEQ*2); +} + +// NEW: Compile classifier forward kernel +static Kern *compile_classifier_fwd(const float *embed) { + return compile_kern_mil_w(gen_classifier_fwd(), (@{ + @"@model_path/weights/embed.bin": @{@"offset":@0, @"data":build_blob(embed, VOCAB, DIM)}, + }), DIM*SEQ*2, VOCAB*SEQ*2); +} + +// NEW: Compile final RMSNorm kernel +static Kern *compile_final_rmsnorm_kern(const float *rms_w) { + return compile_kern_mil_w(gen_final_rmsnorm(), (@{ + @"@model_path/weights/rms_w.bin": @{@"offset":@0, @"data":build_blob(rms_w, 1, DIM)}, + }), DIM*SEQ*2, DIM*SEQ*2); +} + +// NEW: Compile softmax kernel (no weights) +static Kern *compile_softmax_kern(void) { + return compile_kern_mil_w(gen_softmax_vocab(), @{}, VOCAB*SEQ*2, VOCAB*SEQ*2); +} + +static void free_layer_kernels(LayerKernels *lk) { + free_kern(lk->fwdAttn); free_kern(lk->fwdFFN); free_kern(lk->ffnBwd); + free_kern(lk->sdpaBwd1); free_kern(lk->qkvBwd); + lk->fwdAttn = lk->fwdFFN = lk->ffnBwd = lk->sdpaBwd1 = lk->qkvBwd = NULL; +} + +// ===== Checkpoint save/load (same as train_large.m) ===== +static void save_checkpoint(const char *path, int step, int total_steps, float lr, float loss, + double cc, double ct, double cw, int cs, int cb, int adam_t, + LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final, + float *embed, AdamState *aembed) { + FILE *f = fopen(path, "wb"); + CkptHdr h = {0}; + h.magic = 0x424C5A54; h.version = 2; + h.step = step; h.total_steps = total_steps; + h.n_layers = NLAYERS; h.vocab_size = VOCAB; h.dim = DIM; + h.hidden_dim = HIDDEN; h.n_heads = HEADS; h.seq_len = SEQ; + h.lr = lr; h.loss = loss; + h.cum_compile = cc; h.cum_train = ct; h.cum_wall = cw; + h.cum_steps = cs; h.cum_batches = cb; h.adam_t = adam_t; + fwrite(&h, sizeof(h), 1, f); + for (int L = 0; L < NLAYERS; L++) { + fwrite(lw[L].Wq,4,WQ_SZ,f); fwrite(lw[L].Wk,4,WQ_SZ,f); + fwrite(lw[L].Wv,4,WQ_SZ,f); fwrite(lw[L].Wo,4,WO_SZ,f); + fwrite(lw[L].W1,4,W1_SZ,f); fwrite(lw[L].W2,4,W2_SZ,f); fwrite(lw[L].W3,4,W3_SZ,f); + fwrite(lw[L].rms_att,4,DIM,f); fwrite(lw[L].rms_ffn,4,DIM,f); + fwrite(la[L].Wq.m,4,WQ_SZ,f); fwrite(la[L].Wq.v,4,WQ_SZ,f); + fwrite(la[L].Wk.m,4,WQ_SZ,f); fwrite(la[L].Wk.v,4,WQ_SZ,f); + fwrite(la[L].Wv.m,4,WQ_SZ,f); fwrite(la[L].Wv.v,4,WQ_SZ,f); + fwrite(la[L].Wo.m,4,WO_SZ,f); fwrite(la[L].Wo.v,4,WO_SZ,f); + fwrite(la[L].W1.m,4,W1_SZ,f); fwrite(la[L].W1.v,4,W1_SZ,f); + fwrite(la[L].W2.m,4,W2_SZ,f); fwrite(la[L].W2.v,4,W2_SZ,f); + fwrite(la[L].W3.m,4,W3_SZ,f); fwrite(la[L].W3.v,4,W3_SZ,f); + fwrite(la[L].rms_att.m,4,DIM,f); fwrite(la[L].rms_att.v,4,DIM,f); + fwrite(la[L].rms_ffn.m,4,DIM,f); fwrite(la[L].rms_ffn.v,4,DIM,f); + } + fwrite(rms_final,4,DIM,f); + fwrite(arms_final->m,4,DIM,f); fwrite(arms_final->v,4,DIM,f); + fwrite(embed,4,VOCAB*DIM,f); + fwrite(aembed->m,4,VOCAB*DIM,f); fwrite(aembed->v,4,VOCAB*DIM,f); + fclose(f); +} + +static bool load_checkpoint(const char *path, int *step, int *total_steps, float *lr, float *loss, + double *cc, double *ct, double *cw, int *cs, int *cb, int *adam_t, + LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final, + float *embed, AdamState *aembed) { + FILE *f = fopen(path, "rb"); + if (!f) return false; + CkptHdr h; + fread(&h, sizeof(h), 1, f); + if (h.magic != 0x424C5A54 || h.version != 2) { fclose(f); return false; } + *step = h.step; *total_steps = h.total_steps; *lr = h.lr; *loss = h.loss; + *cc = h.cum_compile; *ct = h.cum_train; *cw = h.cum_wall; + *cs = h.cum_steps; *cb = h.cum_batches; *adam_t = h.adam_t; + for (int L = 0; L < NLAYERS; L++) { + fread(lw[L].Wq,4,WQ_SZ,f); fread(lw[L].Wk,4,WQ_SZ,f); + fread(lw[L].Wv,4,WQ_SZ,f); fread(lw[L].Wo,4,WO_SZ,f); + fread(lw[L].W1,4,W1_SZ,f); fread(lw[L].W2,4,W2_SZ,f); fread(lw[L].W3,4,W3_SZ,f); + fread(lw[L].rms_att,4,DIM,f); fread(lw[L].rms_ffn,4,DIM,f); + fread(la[L].Wq.m,4,WQ_SZ,f); fread(la[L].Wq.v,4,WQ_SZ,f); + fread(la[L].Wk.m,4,WQ_SZ,f); fread(la[L].Wk.v,4,WQ_SZ,f); + fread(la[L].Wv.m,4,WQ_SZ,f); fread(la[L].Wv.v,4,WQ_SZ,f); + fread(la[L].Wo.m,4,WO_SZ,f); fread(la[L].Wo.v,4,WO_SZ,f); + fread(la[L].W1.m,4,W1_SZ,f); fread(la[L].W1.v,4,W1_SZ,f); + fread(la[L].W2.m,4,W2_SZ,f); fread(la[L].W2.v,4,W2_SZ,f); + fread(la[L].W3.m,4,W3_SZ,f); fread(la[L].W3.v,4,W3_SZ,f); + fread(la[L].rms_att.m,4,DIM,f); fread(la[L].rms_att.v,4,DIM,f); + fread(la[L].rms_ffn.m,4,DIM,f); fread(la[L].rms_ffn.v,4,DIM,f); + } + fread(rms_final,4,DIM,f); + fread(arms_final->m,4,DIM,f); fread(arms_final->v,4,DIM,f); + fread(embed,4,VOCAB*DIM,f); + fread(aembed->m,4,VOCAB*DIM,f); fread(aembed->v,4,VOCAB*DIM,f); + fclose(f); + return true; +} + +// ===== Main ===== +int main(int argc, char *argv[]) { + @autoreleasepool { + setbuf(stdout, NULL); + ane_init(); + mach_timebase_info(&g_tb); + + int total_steps = 10000; + float lr = 3e-4f; + float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f; + int adam_t = 0, start_step = 0; + bool do_resume = false; + for (int i=1; i