From e030ffb2137ed4b0423c5285cb367f3809197b7c Mon Sep 17 00:00:00 2001 From: tastyheadphones Date: Wed, 4 Mar 2026 22:21:44 +0900 Subject: [PATCH] Guard short token datasets in ANE and dynamic trainers --- training/train_large_ane.m | 6 ++++++ training/training_dynamic/train.m | 18 ++++++++++++------ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/training/train_large_ane.m b/training/train_large_ane.m index 25e9160..64325e5 100644 --- a/training/train_large_ane.m +++ b/training/train_large_ane.m @@ -284,6 +284,12 @@ int main(int argc, char *argv[]) { uint16_t *token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0); if (token_data == MAP_FAILED) { printf("mmap failed\n"); return 1; } size_t n_tokens = data_len / 2; + if (n_tokens <= (size_t)(SEQ + 1)) { + printf("Token data too short: need at least %d tokens, got %zu\n", SEQ + 2, n_tokens); + munmap(token_data, data_len); + close(data_fd); + return 1; + } printf("Token data: %zu tokens (%.1f MB)\n", n_tokens, data_len/1e6); // Gradient buffers diff --git a/training/training_dynamic/train.m b/training/training_dynamic/train.m index 412c4d8..afe9686 100644 --- a/training/training_dynamic/train.m +++ b/training/training_dynamic/train.m @@ -335,12 +335,18 @@ int main(int argc, char *argv[]) { // mmap token data int data_fd = open(DATA_PATH, O_RDONLY); if (data_fd < 0) { printf("Cannot open %s\n", DATA_PATH); return 1; } - struct stat st; fstat(data_fd, &st); - size_t data_len = st.st_size; - uint16_t *token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0); - if (token_data == MAP_FAILED) { printf("mmap failed\n"); return 1; } - size_t n_tokens = data_len / 2; - printf("Token data: %zu tokens (%.1f MB)\n", n_tokens, data_len/1e6); + struct stat st; fstat(data_fd, &st); + size_t data_len = st.st_size; + uint16_t *token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0); + if (token_data == MAP_FAILED) { printf("mmap failed\n"); return 1; } + size_t n_tokens = data_len / 2; + if (n_tokens <= (size_t)(SEQ + 1)) { + printf("Token data too short: need at least %d tokens, got %zu\n", SEQ + 2, n_tokens); + munmap(token_data, data_len); + close(data_fd); + return 1; + } + printf("Token data: %zu tokens (%.1f MB)\n", n_tokens, data_len/1e6); // Vocab compaction: map 32K sparse vocab → ~9K compact VocabMap vm = vocab_map_build(token_data, n_tokens, VOCAB);