Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions scripts/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
---
license: mit
datasets:
- lennart-finke/SimpleStories
language:
- en
tags:
- small-language-model
- story-generation
- text-generation
- efficient-nlp
- distilled-models
---

# SimpleStories Model Family
The SimpleStories models are a tiny model family created for interpretability research, trained on the [SimpleStories dataset](https://huggingface.co/datasets/SimpleStories/SimpleStories). This is the second iteration of the model family.


**Paper:** https://arxiv.org/abs/2504.09184
**Training code:** https://github.com/simple-stories/simple_stories_train
**Traning checkpoints:** https://wandb.ai/finke/simplestories-v2

## Usage

```python
import torch
from transformers import AutoTokenizer, LlamaForCausalLM


MODEL_SIZE = "35M"
model_path = "SimpleStories/SimpleStories-V2-{}".format(MODEL_SIZE)

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = LlamaForCausalLM.from_pretrained(model_path)
model.to("cuda")
model.eval()

prompt = "The curious cat looked at the"

inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
input_ids = inputs.input_ids.to("cuda")

eos_token_id = 1

with torch.no_grad():
output_ids = model.generate(
input_ids=input_ids,
max_new_tokens=400,
temperature=0.7,
do_sample=True,
eos_token_id=eos_token_id
)

output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(f"\nGenerated text:\n{output_text}")

```

## Model Variants

| Model Name | n_params | n_layers | d_model | n_heads | n_ctx | d_vocab |
|------------|----------|----------|---------|---------|-------|---------|
| SimpleStories-35M | 35 million | 12 | 512 | 8 | 512 | 4019 |
| SimpleStories-30M | 30 million | 10 | 512 | 8 | 512 | 4019 |
| SimpleStories-11M | 11 million | 6 | 384 | 6 | 512 | 4019 |
| SimpleStories-5M | 5 million | 6 | 256 | 4 | 512 | 4019 |
| SimpleStories-1.25M | 1.25 million | 4 | 128 | 4 | 512 | 4019 |


## Dataset

The SimpleStories dataset is a collection of short stories generated by state-of-the-art language models. It features:

- Story annotation with high-level concepts: theme, topic, style, etc.
- Higher semantic and syntactic diversity through seeded story generation
- Generated by 2024 models
- Several NLP-metrics pre-computed to aid filtering
- ASCII-only guarantee for the English dataset


## Key improvements from previous version
- Improved evaluation scores due to the increased training epochs
- Pruning and optimization of the tokenizer resulting in vocabulary size from 4096 to 4019
- Model training checkpoints are stored periodically in wandb for further research

132 changes: 37 additions & 95 deletions scripts/push_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@
"""

import argparse
import io
import json
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import torch
import yaml
from huggingface_hub import HfApi
from tokenizers import Tokenizer
from transformers import PreTrainedModel

Expand All @@ -33,6 +30,7 @@
convert_llama_to_llama_for_causal_lm,
)
from simple_stories_train.models.model_configs import MODEL_CONFIGS
from simple_stories_train.tokenizer import convert_to_hf_tokenizer


@dataclass
Expand Down Expand Up @@ -175,111 +173,46 @@ def convert_to_hf_model(custom_model: Llama | GPT2) -> PreTrainedModel:
return hf_model


def _resolve_tokenizer_path(final_cfg_path: Path) -> Path | None:
"""Try to resolve a tokenizer file path from the final_config.yaml next to the checkpoint.
def find_saved_tokenizer(output_dir: Path) -> Path | None:
"""Find the saved tokenizer in the training output directory."""
tokenizer_path = output_dir / "tokenizer.json"
if tokenizer_path.exists():
return tokenizer_path

Returns absolute path to the tokenizer json if it can be found, otherwise None.

TODO: Save the tokenizer when training the model.
"""
try:
with final_cfg_path.open("r") as f:
data: dict[str, Any] = yaml.safe_load(f)
except Exception:
return None

train_ds_cfg = data.get("train_dataset_config", {}) or {}
tokenizer_rel: str | None = train_ds_cfg.get("tokenizer_file_path")
if not tokenizer_rel or not isinstance(tokenizer_rel, str):
return None

# As a last resort, if the file name matches a known tokenizer in the repo, use it
known_default = Path("simple_stories_train/tokenizer/simplestories-tokenizer.json")
if known_default.is_file():
return known_default.resolve()

return None


def upload_tokenizer_to_hub(
def convert_and_upload_tokenizer(
repo_id: str,
token: str | None,
model_max_length: int | None,
checkpoint_path: Path,
model_max_length: int,
output_dir: Path,
) -> None:
"""Upload tokenizer artifacts (minimal set) to the Hub model repo.

Uploads:
- tokenizer.json (raw Tokenizers file)
- tokenizer_config.json (minimal, includes eos/unk tokens and max length if known)
"""
final_cfg_path = checkpoint_path.parent / "final_config.yaml"
tokenizer_path = _resolve_tokenizer_path(final_cfg_path)
if tokenizer_path is None or not tokenizer_path.exists():
# Nothing to upload
"""Convert raw tokenizer to HF format and upload to Hub."""
tokenizer_path = find_saved_tokenizer(output_dir)
if tokenizer_path is None:
print(f"No tokenizer found in {output_dir}, skipping tokenizer upload")
return

api = HfApi()

# Upload tokenizer.json (rename if needed)
api.upload_file(
path_or_fileobj=str(tokenizer_path),
path_in_repo="tokenizer.json",
repo_id=repo_id,
repo_type="model",
token=token,
)

# Build tokenizer_config.json matching desired structure
# Discover IDs for special tokens from the tokenizer file
unk_token = "[UNK]"
eos_token = "[EOS]"
added_tokens_decoder: dict[str, dict[str, Any]] = {}
print(f"Found tokenizer at {tokenizer_path}")

try:
tk: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
unk_id = tk.token_to_id(unk_token)
eos_id = tk.token_to_id(eos_token)
except Exception:
unk_id = None
eos_id = None

def _entry(content: str) -> dict[str, Any]:
return {
"content": content,
"lstrip": False,
"normalized": False,
"rstrip": False,
"single_word": False,
"special": True,
}

if isinstance(unk_id, int):
added_tokens_decoder[str(unk_id)] = _entry(unk_token)
if isinstance(eos_id, int):
added_tokens_decoder[str(eos_id)] = _entry(eos_token)

# Use HF's sentinel for unlimited length to mirror common configs
unlimited_len = int(1e30)

cfg: dict[str, Any] = {
"added_tokens_decoder": added_tokens_decoder,
"clean_up_tokenization_spaces": False,
"eos_token": eos_token,
"extra_special_tokens": {},
"model_max_length": unlimited_len,
"tokenizer_class": "PreTrainedTokenizerFast",
"unk_token": unk_token,
}

cfg_bytes = json.dumps(cfg, indent=2).encode("utf-8")
api.upload_file(
path_or_fileobj=io.BytesIO(cfg_bytes),
path_in_repo="tokenizer_config.json",
raw_tokenizer = Tokenizer.from_file(str(tokenizer_path))
except Exception as e:
print(f"Failed to load tokenizer from {tokenizer_path}: {e}")
return

hf_tokenizer = convert_to_hf_tokenizer(raw_tokenizer, model_max_length)
hf_tokenizer.push_to_hub(
repo_id=repo_id,
repo_type="model",
token=token,
commit_message="Upload tokenizer",
)
print(f"Tokenizer uploaded to {repo_id}")


def push_model_to_hub(
Expand All @@ -306,6 +239,9 @@ def optionally_upload_readme(repo_id: str, token: str | None, readme_path: Path
return
if not readme_path.exists():
raise FileNotFoundError(f"README file not found: {readme_path}")

from huggingface_hub import HfApi

api = HfApi()
api.upload_file(
path_or_fileobj=str(readme_path),
Expand All @@ -325,7 +261,7 @@ def main() -> None:
model_id, config = load_config_from_checkpoint_dir(args.checkpoint_path)
custom_model = load_custom_model(args.checkpoint_path, model_id, config)

# Convert and push
# Convert and push model
hf_model = convert_to_hf_model(custom_model)
push_model_to_hub(
hf_model=hf_model,
Expand All @@ -336,22 +272,28 @@ def main() -> None:
commit_message=args.commit_message,
)

# Upload tokenizer artifacts (minimal set)
model_max_len: int | None = None
# Get model max length
model_max_len = 1024 # default
if isinstance(config, LlamaConfig):
model_max_len = config.n_ctx
elif isinstance(config, GPT2Config):
model_max_len = config.block_size
upload_tokenizer_to_hub(

# Convert and upload tokenizer
# The models are stored inside checkpoints folder and tokenizer is saved outside
output_dir = args.checkpoint_path.parent.parent
convert_and_upload_tokenizer(
repo_id=args.repo_id,
token=args.token,
model_max_length=model_max_len,
checkpoint_path=args.checkpoint_path,
output_dir=output_dir,
)

# Optional README
optionally_upload_readme(args.repo_id, args.token, args.model_card_readme)

print(f"Successfully uploaded model and tokenizer to {args.repo_id}")


if __name__ == "__main__":
torch.set_grad_enabled(False)
Expand Down
13 changes: 13 additions & 0 deletions simple_stories_train/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tokenizers.processors import TemplateProcessing
from tokenizers.trainers import WordPieceTrainer
from tqdm import tqdm
from transformers import PreTrainedTokenizerFast

OUT_DIR = Path("tokenizer")

Expand Down Expand Up @@ -238,6 +239,18 @@ def get_special_token_ids(tokenizer: Tokenizer) -> set[int]:
return special_token_ids


def convert_to_hf_tokenizer(tokenizer: Tokenizer, model_max_length: int):
hf_tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
unk_token="[UNK]",
eos_token="[EOS]",
pad_token="[UNK]", # Using UNK as pad since no dedicated PAD token
model_max_length=model_max_length,
)

return hf_tokenizer


if __name__ == "__main__":
vocab_size = 4096
dataset_name = "SimpleStories/SimpleStories"
Expand Down
4 changes: 4 additions & 0 deletions simple_stories_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,10 @@ def get_lr(it: int) -> float:
if config.intermediate_checkpoints:
save_model(checkpoints_dir, raw_model, step=0, wandb_project=config.wandb_project)

# save the accompanying tokenizer
train_tokenizer.save(str(output_dir / "tokenizer.json"))
print0(f"Tokenizer saved to {output_dir / 'tokenizer.json'}")

if device == "cuda":
torch.cuda.reset_peak_memory_stats()
timings: list[float] = []
Expand Down
16 changes: 15 additions & 1 deletion tests/test_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Simple test for tokenizer pruning functionality."""

from simple_stories_train.tokenizer import prune_tokenizer, train_tokenizer
from simple_stories_train.tokenizer import convert_to_hf_tokenizer, prune_tokenizer, train_tokenizer


def create_test_tokenizer():
Expand Down Expand Up @@ -94,3 +94,17 @@ def test_unk_for_unknown_words():
unk_id_pruned = pruned.token_to_id("[UNK]")
encoded_pruned = pruned.encode("antidisestablishmentarianism")
assert unk_id_pruned in encoded_pruned.ids


def test_convert_to_hf_tokenize():
"""Verify conversion to HF tokenizer produces identical token IDs."""
original_tokenizer = create_test_tokenizer()
hf_tokenizer = convert_to_hf_tokenizer(original_tokenizer, model_max_length=512)

test_strings = ["hello world", "hello there", "antidisestablishmentarianism"]

for test_str in test_strings:
orig_ids = original_tokenizer.encode(test_str).ids
hf_ids = hf_tokenizer.encode(test_str)

assert orig_ids == hf_ids, f"Token IDs differ for '{test_str}': {orig_ids} vs {hf_ids}"