From 29ebe0fc90469fa499adcc0b5fef0bca06e649f3 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 13 Aug 2025 11:38:38 +0000 Subject: [PATCH 01/13] Reset train dataloader when depleted --- simple_stories_train/train_config.yaml | 4 +-- simple_stories_train/train_llama.py | 45 ++++++++++++-------------- 2 files changed, 23 insertions(+), 26 deletions(-) diff --git a/simple_stories_train/train_config.yaml b/simple_stories_train/train_config.yaml index 2c506d8..a30ef75 100644 --- a/simple_stories_train/train_config.yaml +++ b/simple_stories_train/train_config.yaml @@ -31,7 +31,7 @@ learning_rate: 1e-4 learning_rate_decay_frac: 0.1 weight_decay: 0.1 grad_clip: 1.0 -val_loss_every: 50 +val_loss_every: 100 val_max_steps: 20 -sample_every: 100 +sample_every: 1000 intermediate_checkpoints: false diff --git a/simple_stories_train/train_llama.py b/simple_stories_train/train_llama.py index 9ee63fc..11fb8e0 100644 --- a/simple_stories_train/train_llama.py +++ b/simple_stories_train/train_llama.py @@ -128,6 +128,7 @@ class Config(BaseModel): val_max_steps: NonNegativeInt = Field( 20, description="Max number of batches to use for validation" ) + train_log_every: NonNegativeInt = Field(100, description="How often to log train loss?") sample_every: NonNegativeInt = Field(0, description="How often to sample from the model?") tensorcores: bool = Field(True, description="Use TensorCores?") device: str | None = Field(None, description="Device to use. If None, will autodetect.") @@ -243,7 +244,7 @@ def main(config_path_or_obj: Path | str | Config | None = None, **kwargs: Any) - ddp_rank=ddp_rank, ddp_world_size=ddp_world_size, ) - train_loader = iter(train_loader) # Is this the right way to sample from a Pytorch DataLoader? + train_iter = iter(train_loader) val_loader, _ = create_data_loader( dataset_config=config.val_dataset_config, @@ -253,7 +254,6 @@ def main(config_path_or_obj: Path | str | Config | None = None, **kwargs: Any) - ddp_rank=ddp_rank, ddp_world_size=ddp_world_size, ) - val_loader = iter(val_loader) # Is this the right way to sample from a Pytorch DataLoader? # ------------------------------------------------------------------------- # main training loop @@ -311,19 +311,15 @@ def get_lr(it: int) -> float: if device == "cuda": torch.cuda.reset_peak_memory_stats() - train_loader_depleted = False timings = [] generations = [] for step in range(1, config.num_iterations + 1): - t0 = time.time() last_step = step == config.num_iterations # once in a while evaluate the validation dataset if config.val_loss_every > 0 and (step % config.val_loss_every == 0 or last_step): model.eval() - val_loader_iter = iter( - val_loader - ) # By creating the iterator anew, we sample the same data each time + val_loader_iter = iter(val_loader) with torch.no_grad(): val_loss = 0.0 for _ in range(config.val_max_steps): @@ -373,7 +369,7 @@ def get_lr(it: int) -> float: # but also after the very last iteration. so we loop for step <= num_iterations # instead of just < num_iterations (one extra due to <=), only to do # the validation/sampling one last time, and then we break right here as we're done. - if last_step or train_loader_depleted: + if last_step: break # --------------- TRAINING SECTION BEGIN ----------------- @@ -384,15 +380,16 @@ def get_lr(it: int) -> float: lossf = Tensor([0.0]).to( device ) # for getting the mean loss (as simple float) over the accumulation steps + t0 = time.time() for micro_step in range(grad_accum_steps): # fetch a batch try: - bat = next(train_loader)["input_ids"].to(torch.int) + bat = next(train_iter)["input_ids"].to(torch.int) except StopIteration: - # No more batches. Break so we can sync existing gradients and exit. - print0("No more batches in train_loader. Ending training now.") - train_loader_depleted = True - break + # reset the train_loader + print0("Depleted train_loader, resetting for next epoch") + train_iter = iter(train_loader) + bat = next(train_iter)["input_ids"].to(torch.int) x = bat.view(B, T)[:, :-1] # inputs y = bat.view(B, T)[:, 1:] # targets @@ -435,15 +432,16 @@ def get_lr(it: int) -> float: torch.mps.synchronize() elif device == "cuda": torch.cuda.synchronize() - # time and print - t1 = time.time() - # the 0th iteration is often an outlier (much slower) => skip logging it - tokens_per_second = grad_accum_steps * ddp_world_size * B * T / (t1 - t0) - norm_str = f"norm {norm:.4f}" if norm is not None else "" - print0( - f"step {step:4d}/{config.num_iterations} | train loss {lossf:.6f} | {norm_str} | " - f"lr {lr:.2e} | ({(t1 - t0) * 1000:.2f} ms | {tokens_per_second:.0f} tok/s)" - ) + if step % config.train_log_every == 0: + # time and print + t1 = time.time() + # the 0th iteration is often an outlier (much slower) => skip logging it + tokens_per_second = grad_accum_steps * ddp_world_size * B * T / (t1 - t0) + norm_str = f"norm {norm:.4f}" if norm is not None else "" + print0( + f"step {step:4d}/{config.num_iterations} | train loss {lossf:.6f} | {norm_str} | " + f"lr {lr:.2e} | ({(t1 - t0) * 1000:.2f} ms | {tokens_per_second:.0f} tok/s)" + ) # log to wandb if config.wandb_project is not None and master_process: log_metrics(step, {"train_loss": lossf, "lr": lr}) @@ -458,13 +456,12 @@ def get_lr(it: int) -> float: and ( (config.intermediate_checkpoints and is_checkpoint_step(step)) or step == config.num_iterations - 1 - or train_loader_depleted ) ): save_model(checkpoints_dir, raw_model, step=step, wandb_project=config.wandb_project) # keep track of smooth timings, last 20 iterations - if step > 1 and (step > config.num_iterations - 20 or train_loader_depleted): + if step > 1 and (step > config.num_iterations - 20): timings.append(t1 - t0) # print the average of the last 20 timings, to get something smooth-ish From 1a5440adbbd00d055b640bd4210c0fb05ad03f56 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 13 Aug 2025 13:58:26 +0000 Subject: [PATCH 02/13] Fix pyright errors --- simple_stories_train/models/llama.py | 41 ++++++++++++++++++---------- simple_stories_train/train_llama.py | 13 ++++----- simple_stories_train/utils.py | 3 +- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/simple_stories_train/models/llama.py b/simple_stories_train/models/llama.py index ac73134..432932e 100644 --- a/simple_stories_train/models/llama.py +++ b/simple_stories_train/models/llama.py @@ -37,6 +37,10 @@ class LlamaConfig(BaseModel): class CausalSelfAttention(nn.Module): + bias: Tensor + rotary_sin: Tensor + rotary_cos: Tensor + def __init__(self, config: LlamaConfig): super().__init__() assert config.n_embd % config.n_head == 0 @@ -103,7 +107,7 @@ def get_offset_position_ids( self, past_kv_pos_offset: int, attention_mask: Int[Tensor, "batch offset_pos"], - ): # Changed return type hint + ) -> Int[Tensor, "batch pos"]: shifted_position_ids = attention_mask.cumsum(dim=1) - 1 position_ids = shifted_position_ids.masked_fill(shifted_position_ids < 0, 0) return position_ids[:, past_kv_pos_offset:].long() # Ensure long type for indexing @@ -283,22 +287,29 @@ class Llama(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.config = config - self.transformer = nn.ModuleDict( - dict( - wte=nn.Embedding(config.vocab_size, config.n_embd), - h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), - rms_f=LlamaRMSNorm(config.n_embd), - ) + self.wte: nn.Embedding = nn.Embedding(config.vocab_size, config.n_embd) + _blocks: list[Block] = [Block(config) for _ in range(config.n_layer)] + self.h_torch: nn.ModuleList = nn.ModuleList(_blocks) + # Keep a typed Python list view for static type checking/iteration + self.h: list[Block] = _blocks + self.rms_f: LlamaRMSNorm = LlamaRMSNorm(config.n_embd) + self.transformer: nn.ModuleDict = nn.ModuleDict( + { + "wte": self.wte, + "h": self.h_torch, + "rms_f": self.rms_f, + } ) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.lm_head.LLMC_SKIP_INIT = True # type: ignore - self.transformer.wte.weight = self.lm_head.weight + # Tie embeddings and lm_head weights + self.wte.weight = self.lm_head.weight # type: ignore[reportAttributeAccessIssue] self.init_rng = torch.Generator() self.init_rng.manual_seed(42) self.apply(self._init_weights) - def _init_weights(self, module: nn.Module): + def _init_weights(self, module: nn.Module) -> None: if isinstance(module, nn.Linear): std = ( 0.02 @@ -317,18 +328,18 @@ def forward( self, idx: Float[Tensor, "batch pos"], targets: Float[Tensor, "batch pos vocab"] | None = None, - return_logits=True, - ) -> tuple[Float[Tensor, "batch pos"] | None, Float[Tensor, ""] | None]: + return_logits: bool = True, + ) -> tuple[Float[Tensor, "batch pos vocab"] | None, Float[Tensor, ""] | None]: device = idx.device b, t = idx.size() assert t <= self.config.block_size, ( f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" ) - tok_emb = self.transformer.wte(idx) + tok_emb = self.wte(idx) x = tok_emb - for block in self.transformer.h: + for block in self.h: x = block(x) - x = self.transformer.rms_f(x) + x = self.rms_f(x) logits = self.lm_head(x) loss = None if targets is not None: @@ -382,7 +393,7 @@ def from_pretrained( model.load_state_dict(state_dict, strict=False) # Regenerate rotary_sin and rotary_cos for each attention layer - for layer_idx, block in enumerate(model.transformer.h): + for block in model.h: attn = block.attn sin, cos = attn.calculate_sin_cos_rotary( rotary_dim=attn.rotary_dim, diff --git a/simple_stories_train/train_llama.py b/simple_stories_train/train_llama.py index 11fb8e0..3f90c2a 100644 --- a/simple_stories_train/train_llama.py +++ b/simple_stories_train/train_llama.py @@ -47,7 +47,6 @@ PositiveInt, model_validator, ) -from torch import Tensor from torch.distributed import destroy_process_group, init_process_group from torch.nn.parallel import DistributedDataParallel as DDP @@ -265,6 +264,8 @@ def main(config_path_or_obj: Path | str | Config | None = None, **kwargs: Any) - model: nn.Module = DDP(model, device_ids=[ddp_local_rank]) raw_model = model.module if ddp else model # always contains the "raw" unwrapped model + assert isinstance(raw_model, Llama) + # init the optimizer optimizer = raw_model.configure_optimizers( weight_decay=config.weight_decay, @@ -377,9 +378,8 @@ def get_lr(it: int) -> float: optimizer.zero_grad(set_to_none=True) # micro-batch loop where we do gradient accumulation to reach desired total batch size - lossf = Tensor([0.0]).to( - device - ) # for getting the mean loss (as simple float) over the accumulation steps + lossf = torch.tensor([0.0], device=device) + # for getting the mean loss (as simple float) over the accumulation steps t0 = time.time() for micro_step in range(grad_accum_steps): # fetch a batch @@ -432,10 +432,9 @@ def get_lr(it: int) -> float: torch.mps.synchronize() elif device == "cuda": torch.cuda.synchronize() + + t1 = time.time() if step % config.train_log_every == 0: - # time and print - t1 = time.time() - # the 0th iteration is often an outlier (much slower) => skip logging it tokens_per_second = grad_accum_steps * ddp_world_size * B * T / (t1 - t0) norm_str = f"norm {norm:.4f}" if norm is not None else "" print0( diff --git a/simple_stories_train/utils.py b/simple_stories_train/utils.py index b5ba2b7..e38cebe 100644 --- a/simple_stories_train/utils.py +++ b/simple_stories_train/utils.py @@ -39,8 +39,7 @@ def save_config(save_dir: Path, config_dict: dict[str, Any]) -> None: def save_model( save_dir: Path, model: nn.Module, step: int, wandb_project: str | None = None ) -> None: - # Get the underlying model if it's DDP-wrapped - state_dict = model.module.state_dict() if hasattr(model, "module") else model.state_dict() + state_dict = model.state_dict() model_file = save_dir / f"model_step_{step}.pt" torch.save(state_dict, model_file) From d0fa3fd683f1dd279633f7eace766d46bd2f3ea8 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 13 Aug 2025 14:17:36 +0000 Subject: [PATCH 03/13] Cast instead of isinstance --- simple_stories_train/train_llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/simple_stories_train/train_llama.py b/simple_stories_train/train_llama.py index 3f90c2a..20c4f49 100644 --- a/simple_stories_train/train_llama.py +++ b/simple_stories_train/train_llama.py @@ -27,7 +27,7 @@ from contextlib import nullcontext from datetime import datetime from pathlib import Path -from typing import Any, Literal, Self +from typing import Any, Literal, Self, cast import fire import numpy as np @@ -264,7 +264,7 @@ def main(config_path_or_obj: Path | str | Config | None = None, **kwargs: Any) - model: nn.Module = DDP(model, device_ids=[ddp_local_rank]) raw_model = model.module if ddp else model # always contains the "raw" unwrapped model - assert isinstance(raw_model, Llama) + raw_model = cast(Llama, raw_model) # init the optimizer optimizer = raw_model.configure_optimizers( From 7407ff3b205bdbda97329a854b0365a18702300b Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 13 Aug 2025 14:18:36 +0000 Subject: [PATCH 04/13] Update pinned torch version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 000a9aa..523d14a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ description = "Training of small LM models for SimpleStories" requires-python = ">=3.11" readme = "README.md" dependencies = [ - "torch<2.6.0", + "torch>=2.6.0", "torchvision", "pydantic", "wandb", From 5b7070372ce440d8962d6dc5b71710ba485439bd Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 13 Aug 2025 16:03:35 +0000 Subject: [PATCH 05/13] Factor out gpt2 and make general train.py --- simple_stories_train/models/gpt2.py | 360 ++++++ simple_stories_train/models/model_configs.py | 35 +- .../{train_llama.py => train.py} | 190 ++- simple_stories_train/train_config.yaml | 4 +- simple_stories_train/train_gpt2.py | 1042 ----------------- 5 files changed, 462 insertions(+), 1169 deletions(-) create mode 100644 simple_stories_train/models/gpt2.py rename simple_stories_train/{train_llama.py => train.py} (69%) delete mode 100644 simple_stories_train/train_gpt2.py diff --git a/simple_stories_train/models/gpt2.py b/simple_stories_train/models/gpt2.py new file mode 100644 index 0000000..e4b5e3f --- /dev/null +++ b/simple_stories_train/models/gpt2.py @@ -0,0 +1,360 @@ +import inspect +import math +from typing import Any +from typing import cast as _cast + +import torch +import torch.nn as nn +from jaxtyping import Float +from pydantic import BaseModel, ConfigDict +from torch import Tensor +from torch.distributed.optim import ZeroRedundancyOptimizer +from torch.nn import functional as F + +from simple_stories_train.utils import print0 + + +class GPT2Config(BaseModel): + model_config = ConfigDict(extra="forbid", frozen=True) + block_size: int = 1024 + vocab_size: int = 50257 + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + flash_attention: bool = True + + +class NewGELU(nn.Module): + def forward(self, input: Float[Tensor, "... dim"]) -> Float[Tensor, "... dim"]: + return ( + 0.5 + * input + * ( + 1.0 + + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))) + ) + ) + + +class CausalSelfAttention(nn.Module): + bias: Tensor + + def __init__(self, config: GPT2Config): + super().__init__() + assert config.n_embd % config.n_head == 0 + self.n_head = config.n_head + self.n_embd = config.n_embd + self.flash_attention = config.flash_attention + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd) + self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = True # type: ignore + # not really a 'bias', more of a mask, but following the OpenAI/HF naming though + self.register_buffer( + "bias", + torch.tril(torch.ones(config.block_size, config.block_size)).view( + 1, 1, config.block_size, config.block_size + ), + persistent=False, + ) + + def forward( + self, + x: Float[Tensor, "batch pos d_model"], + ) -> Float[Tensor, "batch pos d_model"]: + B, T, C = x.size() + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + qkv = self.c_attn(x) + q, k, v = qkv.split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) + if self.flash_attention: + # use PyTorch SDPA + y = F.scaled_dot_product_attention( + q, + k, + v, + is_causal=True, + ) + else: + # manual implementation of attention + # this materializes the large (T,T) matrix for all the queries and keys + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) + att = F.softmax(att, dim=-1) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = ( + y.transpose(1, 2).contiguous().view(B, T, C) + ) # re-assemble all head outputs side by side + y = self.c_proj(y) + return y + + +class MLP(nn.Module): + def __init__(self, config: GPT2Config): + super().__init__() + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) + self.gelu = NewGELU() + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) + self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = True # type: ignore + + def forward(self, x: Float[Tensor, "... dim"]) -> Float[Tensor, "... dim"]: + x = self.c_fc(x) + x = self.gelu(x) + x = self.c_proj(x) + return x + + +class Block(nn.Module): + def __init__(self, config: GPT2Config): + super().__init__() + self.ln_1 = nn.LayerNorm(config.n_embd) + self.attn = CausalSelfAttention(config) + self.ln_2 = nn.LayerNorm(config.n_embd) + self.mlp = MLP(config) + + def forward( + self, + x: Float[Tensor, "batch pos d_model"], + ) -> Float[Tensor, "batch pos d_model"]: + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class GPT2(nn.Module): + def __init__(self, config: GPT2Config): + super().__init__() + self.config = config + + self.wte: nn.Embedding = nn.Embedding(config.vocab_size, config.n_embd) + self.wpe: nn.Embedding = nn.Embedding(config.block_size, config.n_embd) + _blocks: list[Block] = [Block(config) for _ in range(config.n_layer)] + self.h_torch: nn.ModuleList = nn.ModuleList(_blocks) + self.h: list[Block] = _blocks + self.ln_f: nn.LayerNorm = nn.LayerNorm(config.n_embd) + self.transformer: nn.ModuleDict = nn.ModuleDict( + { + "wte": self.wte, + "wpe": self.wpe, + "h": self.h_torch, + "ln_f": self.ln_f, + } + ) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.lm_head.LLMC_SKIP_INIT = True # type: ignore + self.wte.weight = self.lm_head.weight # type: ignore[reportAttributeAccessIssue] + + # init all weights, use a torch rng object to be very careful + self.init_rng = torch.Generator() + self.init_rng.manual_seed(42) + self.apply(self._init_weights) + + def _init_weights(self, module: nn.Module) -> None: + if isinstance(module, nn.Linear): + std = ( + 0.02 + if not hasattr(module, "LLMC_RESIDUAL_SCALE_FLAG") + else 0.02 / math.sqrt(2 * self.config.n_layer) + ) + if not hasattr(module, "LLMC_SKIP_INIT"): + torch.nn.init.normal_(module.weight, mean=0.0, std=std, generator=self.init_rng) + if getattr(module, "bias", None) is not None: + torch.nn.init.zeros_(module.bias) # type: ignore[arg-type] + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02, generator=self.init_rng) + + def forward( + self, + idx: Float[Tensor, "batch pos"], + targets: Float[Tensor, "batch pos vocab"] | None = None, + return_logits: bool = True, + ) -> tuple[ + Float[Tensor, "batch pos vocab"] | None, + Float[Tensor, ""] | None, + ]: + device = idx.device + b, t = idx.size() + assert t <= self.config.block_size, ( + f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + ) + pos = torch.arange(0, t, dtype=torch.long, device=device) + + tok_emb = self.wte(idx) # (b, t, n_embd) + pos_emb = self.wpe(pos) # (t, n_embd) + x = tok_emb + pos_emb + + for block in self.h: + x = block(x) + x = self.ln_f(x) + + logits: Tensor = self.lm_head(x) + loss: Tensor | None + if targets is not None: + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + targets.view(-1), + ignore_index=-1, + ) + else: + loss = None + + out_logits: Tensor | None = logits + if not return_logits: + out_logits = None + + return out_logits, loss + + @classmethod + def from_pretrained(cls, model_type: str) -> "GPT2": + """Loads pretrained GPT-2 model weights from Hugging Face.""" + assert model_type in {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"} + from transformers import GPT2LMHeadModel # type: ignore + + print0(f"loading weights from pretrained gpt: {model_type}") + config_args = { + "gpt2": dict(n_layer=12, n_head=12, n_embd=768), + "gpt2-medium": dict(n_layer=24, n_head=16, n_embd=1024), + "gpt2-large": dict(n_layer=36, n_head=20, n_embd=1280), + "gpt2-xl": dict(n_layer=48, n_head=25, n_embd=1600), + }[model_type] + config_args["vocab_size"] = 50257 + config_args["block_size"] = 1024 + config = GPT2Config(**_cast(dict[str, Any], config_args)) + model = GPT2(config) + + sd = model.state_dict() + sd_keys = [k for k in sd if not k.endswith(".attn.bias")] # discard this mask / buffer + + model_hf = GPT2LMHeadModel.from_pretrained(model_type) + sd_hf = model_hf.state_dict() + sd_keys_hf = [ + k for k in sd_hf if not (k.endswith(".attn.masked_bias") or k.endswith(".attn.bias")) + ] + transposed = [ + "attn.c_attn.weight", + "attn.c_proj.weight", + "mlp.c_fc.weight", + "mlp.c_proj.weight", + ] + # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear + # this means that we have to transpose these weights when we import them + assert len(sd_keys_hf) == len(sd_keys), ( + f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" + ) + for k in sd_keys_hf: + if any(k.endswith(w) for w in transposed): + assert sd_hf[k].shape[::-1] == sd[k].shape + with torch.no_grad(): + sd[k].copy_(sd_hf[k].t()) + else: + assert sd_hf[k].shape == sd[k].shape + with torch.no_grad(): + sd[k].copy_(sd_hf[k]) + + return model + + def configure_optimizers( + self, + weight_decay: float, + learning_rate: float, + betas: tuple[float, float], + device_type: str, + zero_stage: int, + ) -> torch.optim.Optimizer: + # start with all of the candidate parameters + param_dict = {pn: p for pn, p in self.named_parameters()} + # filter out those that do not require grad + param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} + # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. + decay_params = [p for _, p in param_dict.items() if p.dim() >= 2] + nodecay_params = [p for _, p in param_dict.items() if p.dim() < 2] + optim_groups = [ + {"params": decay_params, "weight_decay": weight_decay}, + {"params": nodecay_params, "weight_decay": 0.0}, + ] + num_decay_params = sum(p.numel() for p in decay_params) + num_nodecay_params = sum(p.numel() for p in nodecay_params) + print0( + f"num decayed parameter tensors: {len(decay_params)}, " + f"with {num_decay_params:,} parameters" + ) + print0( + f"num non-decayed parameter tensors: {len(nodecay_params)}, " + f"with {num_nodecay_params:,} parameters" + ) + # Create AdamW optimizer and use the fused version if it is available + fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters + use_fused = fused_available and device_type == "cuda" + print0(f"using fused AdamW: {use_fused}") + if zero_stage == 1: + print0("using ZeroRedundancyOptimizer") + optim_group = optim_groups[0] + optimizer: torch.optim.Optimizer = ZeroRedundancyOptimizer( # type: ignore[assignment] + **optim_group, # type: ignore[arg-type] + optimizer_class=torch.optim.AdamW, + lr=learning_rate, + betas=betas, + fused=use_fused, + ) + optimizer.add_param_group(optim_groups[1]) # type: ignore[arg-type] + else: + print0("using regular AdamW") + optimizer = torch.optim.AdamW( + optim_groups, lr=learning_rate, betas=betas, fused=use_fused + ) + return optimizer + + @torch.no_grad() + def generate( + self, + idx: Float[Tensor, "... pos"], + max_new_tokens: int, + temperature: float = 1.0, + top_k: int | None = None, + eos_token_id: int | None = None, + ) -> Float[Tensor, "... pos"]: + # Keep track of whether input was 1D and ensure input has batch dimension + is_1d = idx.dim() == 1 + if is_1d: + idx = idx.unsqueeze(0) + + batch_size = idx.size(0) + not_completed = torch.ones(batch_size, dtype=torch.bool, device=idx.device) + + for _ in range(max_new_tokens): + if not not_completed.any(): + break + + idx_cond = ( + idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size :] + ) + logits, _ = self(idx_cond) + assert logits is not None + logits = logits[:, -1, :] + if temperature > 0: + logits = logits / temperature + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float("Inf") + probs = F.softmax(logits, dim=-1) + else: + probs = torch.zeros_like(logits) + probs.scatter_(1, logits.argmax(dim=-1, keepdim=True), 1.0) + idx_next = torch.multinomial(probs, num_samples=1) + + if eos_token_id is not None: + not_completed = not_completed & (idx_next[:, -1] != eos_token_id) + update_mask = not_completed.unsqueeze(-1) + idx_next = torch.where( + update_mask, idx_next, torch.full_like(idx_next, eos_token_id) + ) + + idx = torch.cat((idx, idx_next), dim=1) + + if is_1d: + idx = idx.squeeze(0) + + return idx diff --git a/simple_stories_train/models/model_configs.py b/simple_stories_train/models/model_configs.py index c577837..66dd480 100644 --- a/simple_stories_train/models/model_configs.py +++ b/simple_stories_train/models/model_configs.py @@ -1,9 +1,11 @@ +from simple_stories_train.models.gpt2 import GPT2Config from simple_stories_train.models.llama import LlamaConfig MODEL_CONFIGS = { - "d2": LlamaConfig( + # Llama debug/dev sizes + "llama-d2": LlamaConfig( block_size=1024, - vocab_size=50257, # TODO: Make this depend on the tokenizer vocab size + vocab_size=50257, n_layer=2, n_head=2, n_embd=12, @@ -11,7 +13,7 @@ n_key_value_heads=2 // 2, flash_attention=True, ), - "d12": LlamaConfig( + "llama-d12": LlamaConfig( block_size=1024, vocab_size=50257, n_layer=12, @@ -21,7 +23,7 @@ n_key_value_heads=12 // 4, flash_attention=True, ), - "d24": LlamaConfig( + "llama-d24": LlamaConfig( block_size=1024, vocab_size=50257, n_layer=24, @@ -31,7 +33,7 @@ n_key_value_heads=16 // 4, flash_attention=True, ), - "d36": LlamaConfig( + "llama-d36": LlamaConfig( block_size=1024, vocab_size=50257, n_layer=36, @@ -41,7 +43,7 @@ n_key_value_heads=20 // 4, flash_attention=True, ), - "d48": LlamaConfig( + "llama-d48": LlamaConfig( block_size=1024, vocab_size=50257, n_layer=48, @@ -51,8 +53,8 @@ n_key_value_heads=25 // 4, flash_attention=True, ), - # SimpleStories Model Configs - "1.25M": LlamaConfig( + # SimpleStories Llama presets + "llama-1.25M": LlamaConfig( block_size=512, vocab_size=4019, n_layer=4, @@ -64,7 +66,7 @@ n_key_value_heads=2, flash_attention=True, ), - "5M": LlamaConfig( + "llama-5M": LlamaConfig( block_size=512, vocab_size=4019, n_layer=6, @@ -76,7 +78,7 @@ n_key_value_heads=2, flash_attention=True, ), - "11M": LlamaConfig( + "llama-11M": LlamaConfig( block_size=512, vocab_size=4019, n_layer=6, @@ -88,7 +90,7 @@ n_key_value_heads=2, flash_attention=True, ), - "30M": LlamaConfig( + "llama-30M": LlamaConfig( block_size=512, vocab_size=4019, n_layer=10, @@ -100,7 +102,7 @@ n_key_value_heads=2, flash_attention=True, ), - "35M": LlamaConfig( + "llama-35M": LlamaConfig( block_size=512, vocab_size=4019, n_layer=12, @@ -112,4 +114,13 @@ n_key_value_heads=2, flash_attention=True, ), + # GPT-2 presets + "gpt2-1.25M": GPT2Config( + block_size=512, + vocab_size=4019, + n_layer=4, + n_head=4, + n_embd=128, + flash_attention=True, + ), } diff --git a/simple_stories_train/train_llama.py b/simple_stories_train/train.py similarity index 69% rename from simple_stories_train/train_llama.py rename to simple_stories_train/train.py index 20c4f49..6cbc279 100644 --- a/simple_stories_train/train_llama.py +++ b/simple_stories_train/train.py @@ -1,23 +1,18 @@ """ -Training script. Currently only supports models with the Llama architecture. +Unified training script for multiple model families (Llama, GPT-2). Usage: +```bash +python -m simple_stories_train.train [PATH/TO/CONFIG.yaml] [--key1 value1 --key2 value2 ...] ``` -python train_llama.py [PATH/TO/CONFIG.yaml] [--key1 value1 --key2 value2 ...] -``` -where -- `PATH/TO/CONFIG.yaml` contains the training config. If no path is provided, a default config -will be used. -- `--key1 value1 --key2 value2 ...` override values in the config. Note that if you wish to update a -nested value, you must use dotted notation (e.g. `--train_dataset_config.name my_dataset`). - -If running on CPU, you may need to set `--compile=False`. +- PATH/TO/CONFIG.yaml contains the training config. If no path is provided, a default config will be used. +- Override values with dotted notation for nested keys (e.g., --train_dataset_config.name my_dataset). -To run on multiple GPUs, use -``` -torchrun --standalone --nproc_per_node=N train_llama.py ... +To run on multiple GPUs: +```bash +torchrun --standalone --nproc_per_node=N -m simple_stories_train.train ... ``` -where `N` is the number of GPUs to use. +where N is the number of GPUs. """ import math @@ -51,6 +46,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from simple_stories_train.dataloaders import DatasetConfig, create_data_loader +from simple_stories_train.models.gpt2 import GPT2 from simple_stories_train.models.llama import Llama from simple_stories_train.models.model_configs import MODEL_CONFIGS from simple_stories_train.utils import ( @@ -64,6 +60,11 @@ save_model, ) +FAMILY_TO_MODEL: dict[str, type[nn.Module]] = { + "llama": Llama, + "gpt2": GPT2, +} + class Config(BaseModel): model_config = ConfigDict(extra="forbid", frozen=True) @@ -99,18 +100,15 @@ class Config(BaseModel): output_dir: Path = Field( REPO_ROOT / "out", description="Directory to write logs and checkpoints" ) - model_name: str = Field( - "d2", - description=f"Name of the model to train (one of {tuple(MODEL_CONFIGS.keys())}). " - "Currently only supports models with the Llama architecture.", + model_id: str = Field( + "llama-d2", + description=f"Model to train (one of {tuple(MODEL_CONFIGS.keys())}).", ) batch_size: PositiveInt = Field(4, description="Batch size") total_batch_size: PositiveInt = Field( 4096, description="Number of batch_size * sequence_length before updating gradients" - ) # TODO: Rename/reconfigure - num_iterations: PositiveInt = Field( - 50, description="Number of gradient accumulation steps" - ) # TODO: Allow for None and deplete the (streaming) dataset + ) + num_iterations: PositiveInt = Field(50, description="Number of gradient accumulation steps") inference_only: bool = Field(False, description="If True, don't update gradients") learning_rate: PositiveFloat = Field(1e-4, description="Learning rate") warmup_iters: NonNegativeInt = Field( @@ -143,9 +141,8 @@ class Config(BaseModel): @model_validator(mode="after") def validate_model(self) -> Self: - # Check that the model name is valid - if self.model_name not in MODEL_CONFIGS: - raise ValueError(f"Model {self.model_name} not in {tuple(MODEL_CONFIGS.keys())}") + if self.model_id not in MODEL_CONFIGS: + raise ValueError(f"model_id {self.model_id} not in {tuple(MODEL_CONFIGS.keys())}") return self @@ -158,9 +155,8 @@ def main(config_path_or_obj: Path | str | Config | None = None, **kwargs: Any) - T = config.train_dataset_config.n_ctx # set up DDP (distributed data parallel). torchrun sets this env variable - ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? + ddp = int(os.environ.get("RANK", -1)) != -1 if ddp: - # use of DDP atm demands CUDA, we set the device appropriately according to rank assert torch.cuda.is_available(), "for now i think we need CUDA for DDP" init_process_group(backend="nccl") ddp_rank = int(os.environ["RANK"]) @@ -168,7 +164,7 @@ def main(config_path_or_obj: Path | str | Config | None = None, **kwargs: Any) - ddp_world_size = int(os.environ["WORLD_SIZE"]) device = f"cuda:{ddp_local_rank}" torch.cuda.set_device(device) - master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. + master_process = ddp_rank == 0 zero_stage = config.zero_stage else: ddp_rank = 0 @@ -176,12 +172,9 @@ def main(config_path_or_obj: Path | str | Config | None = None, **kwargs: Any) - zero_stage = 0 ddp_world_size = 1 master_process = True - # select the device if config.device: - # provided explicitly by the user device = config.device else: - # attempt to autodetect the device device = "cpu" if torch.cuda.is_available(): device = "cuda" @@ -190,7 +183,7 @@ def main(config_path_or_obj: Path | str | Config | None = None, **kwargs: Any) - print(f"using device: {device}") device_type = "cuda" if "cuda" in device else "cpu" - # calculate gradient accumulation from the desired total batch size and the current run configuration + # gradient accumulation tokens_per_fwdbwd = B * T * ddp_world_size assert config.total_batch_size % tokens_per_fwdbwd == 0, ( f"Mismatch between batch size and tokens {config.total_batch_size} % {tokens_per_fwdbwd} != 0" @@ -199,10 +192,12 @@ def main(config_path_or_obj: Path | str | Config | None = None, **kwargs: Any) - print0(f"total desired batch size: {config.total_batch_size}") print0(f"=> calculated gradient accumulation steps: {grad_accum_steps}") - # set up a context manager following the desired dtype and device - ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[ - config.dtype - ] + # dtype context + ptdtype = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, + }[config.dtype] ctx = ( torch.amp.autocast(device_type=device_type, dtype=ptdtype) # type: ignore if device_type == "cuda" @@ -214,26 +209,30 @@ def main(config_path_or_obj: Path | str | Config | None = None, **kwargs: Any) - if torch.cuda.is_available(): torch.cuda.manual_seed(42) - # set the torch precision mode to use TensorFloat32 (TF32) for matmuls - # docs https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html + # TF32 if config.tensorcores: torch.set_float32_matmul_precision("high") - # init (and write) the tokenizer - # enc: tiktoken.core.Encoding = tiktoken.get_encoding("gpt2") - - model_config = MODEL_CONFIGS[config.model_name] - model = Llama(model_config) + # Instantiate model + model_config = MODEL_CONFIGS[config.model_id] + family = config.model_id.split("-", 1)[0] + if family not in FAMILY_TO_MODEL: + raise ValueError(f"Unknown model family {family} from model_id {config.model_id}") + model_ctor = FAMILY_TO_MODEL[family] + model: nn.Module = model_ctor(model_config) model.train() model.to(device) if config.compile: if device_type == "cpu": - warnings.warn("compile may not be compatible with cpu, use `--compile=False` if issues") + warnings.warn( + "compile may not be compatible with cpu, use `--compile=False` if issues", + stacklevel=1, + ) if hasattr(torch_inductor_config, "coordinate_descent_tuning"): - torch_inductor_config.coordinate_descent_tuning = True # suggested by @Chillee + torch_inductor_config.coordinate_descent_tuning = True print0("compiling the model...") - model: nn.Module = torch.compile(model) # type: ignore[reportArgumentType] + model = cast(nn.Module, torch.compile(model)) # type: ignore[reportArgumentType] train_loader, train_tokenizer = create_data_loader( dataset_config=config.train_dataset_config, @@ -254,20 +253,17 @@ def main(config_path_or_obj: Path | str | Config | None = None, **kwargs: Any) - ddp_world_size=ddp_world_size, ) - # ------------------------------------------------------------------------- - # main training loop + # logging if config.wandb_project is not None and master_process: wandb.init(project=config.wandb_project, config=config.model_dump(mode="json")) - # here we wrap model into DDP container + # DDP wrap if ddp: - model: nn.Module = DDP(model, device_ids=[ddp_local_rank]) - raw_model = model.module if ddp else model # always contains the "raw" unwrapped model - - raw_model = cast(Llama, raw_model) + model = DDP(model, device_ids=[ddp_local_rank]) + raw_model: nn.Module = model.module if ddp else model # type: ignore[attr-defined] - # init the optimizer - optimizer = raw_model.configure_optimizers( + # optimizer + optimizer = raw_model.configure_optimizers( # type: ignore[attr-defined] weight_decay=config.weight_decay, learning_rate=config.learning_rate, betas=(0.9, 0.95), @@ -275,22 +271,19 @@ def main(config_path_or_obj: Path | str | Config | None = None, **kwargs: Any) - zero_stage=zero_stage, ) - # learning rate decay scheduler (cosine with warmup) + # lr schedule def get_lr(it: int) -> float: min_lr = config.learning_rate * config.learning_rate_decay_frac - # 1) linear warmup for warmup_iters steps if it < config.warmup_iters: return config.learning_rate * (it + 1) / config.warmup_iters - # 2) if it > lr_decay_iters, return min learning rate if it > config.num_iterations: return min_lr - # 3) in between, use cosine decay down to min learning rate decay_ratio = (it - config.warmup_iters) / (config.num_iterations - config.warmup_iters) assert 0 <= decay_ratio <= 1 - coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) return min_lr + coeff * (config.learning_rate - min_lr) - # create the logging directory if it does not exist + # IO dirs logfile = None checkpoints_dir = None output_dir = None @@ -299,11 +292,8 @@ def get_lr(it: int) -> float: output_dir = Path(config.output_dir) / f"{timestamp}" output_dir.mkdir(parents=True, exist_ok=True) logfile = output_dir / "main.log" - # create the log file "main.log" inside it, and wipe it clean with open(logfile, "w") as f: pass - - # set our checkpoints directory and save off the initilized model checkpoints_dir = output_dir / "checkpoints" checkpoints_dir.mkdir(parents=True, exist_ok=True) save_config(checkpoints_dir, config_dict=config.model_dump(mode="json")) @@ -312,12 +302,12 @@ def get_lr(it: int) -> float: if device == "cuda": torch.cuda.reset_peak_memory_stats() - timings = [] - generations = [] + timings: list[float] = [] + generations: list[list[Any]] = [] for step in range(1, config.num_iterations + 1): last_step = step == config.num_iterations - # once in a while evaluate the validation dataset + # validation if config.val_loss_every > 0 and (step % config.val_loss_every == 0 or last_step): model.eval() val_loader_iter = iter(val_loader) @@ -327,107 +317,85 @@ def get_lr(it: int) -> float: try: bat = next(val_loader_iter)["input_ids"].to(torch.int) except StopIteration: - # No more batches, end the loop break - x = bat.view(B, T)[:, :-1] # inputs - y = bat.view(B, T)[:, 1:] # targets + x = bat.view(B, T)[:, :-1] + y = bat.view(B, T)[:, 1:] x, y = x.to(device), y.to(device) _, loss = model(x, y, return_logits=False) - val_loss += loss.item() + val_loss += float(loss.item()) if loss is not None else 0.0 val_loss /= config.val_max_steps - # log to wandb if config.wandb_project is not None and master_process: log_metrics(step, {"val_loss": val_loss}) - # log to console and to file print0(f"val loss {val_loss}") if master_process and logfile is not None: with open(logfile, "a") as f: - f.write("s:%d tel:%f\n" % (step, val_loss)) + f.write(f"s:{step} tel:{val_loss}\n") - # once in a while perform model inference on the master process + # sample generations if ( config.sample_every > 0 and (step % config.sample_every == 0 or last_step) ) and master_process: model.eval() - # before we end, let's also do one round of inference - # we'll kick off the generation with "<|endoftext|>", which designates the start of a - # new sequence start_ids = [train_tokenizer.token_to_id("[EOS]")] xg = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...] max_new_tokens = 32 temperature = 1.0 top_k = 40 - yg = raw_model.generate(xg, max_new_tokens, temperature=temperature, top_k=top_k) + yg = cast(Any, raw_model).generate( + xg, max_new_tokens, temperature=temperature, top_k=top_k + ) print0("---------------") print0(train_tokenizer.decode(yg[0].tolist())) print0("---------------") - # log to wandb if config.wandb_project is not None and master_process: generations.append([step, train_tokenizer.decode(yg[0].tolist())]) log_generations(step, generations) - # bit confusing: we want to make sure to eval and sample on 0th iteration - # but also after the very last iteration. so we loop for step <= num_iterations - # instead of just < num_iterations (one extra due to <=), only to do - # the validation/sampling one last time, and then we break right here as we're done. if last_step: break - # --------------- TRAINING SECTION BEGIN ----------------- + # training model.train() optimizer.zero_grad(set_to_none=True) - - # micro-batch loop where we do gradient accumulation to reach desired total batch size lossf = torch.tensor([0.0], device=device) - # for getting the mean loss (as simple float) over the accumulation steps t0 = time.time() for micro_step in range(grad_accum_steps): - # fetch a batch try: bat = next(train_iter)["input_ids"].to(torch.int) except StopIteration: - # reset the train_loader print0("Depleted train_loader, resetting for next epoch") train_iter = iter(train_loader) bat = next(train_iter)["input_ids"].to(torch.int) - x = bat.view(B, T)[:, :-1] # inputs - y = bat.view(B, T)[:, 1:] # targets + x = bat.view(B, T)[:, :-1] + y = bat.view(B, T)[:, 1:] x, y = x.to(device), y.to(device) if ddp: # we want only the last micro-step to sync grads in a DDP model # the official way to do this is with model.no_sync(), but that is a # context manager that bloats the code, so we just toggle this variable - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 # type: ignore - # forward pass + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 # type: ignore[attr-defined] with ctx: _, loss = model(x, y, return_logits=False) # we have to scale the loss to account for gradient accumulation, # because the gradients just add on each successive backward(). # addition of gradients corresponds to a SUM in the objective, but # instead of a SUM we want MEAN, so we scale the loss here - loss = loss / grad_accum_steps - lossf += loss.detach() # keep track of the mean loss - - # backward pass + loss = loss / grad_accum_steps # type: ignore[operator] + lossf += loss.detach() # type: ignore[operator] if not config.inference_only: - loss.backward() + loss.backward() # type: ignore[arg-type] if ddp: dist.all_reduce(lossf, op=dist.ReduceOp.AVG) - lossf = lossf.item() + lossf_value = float(lossf.item()) norm = None if config.grad_clip is not None: norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) - # determine and set the learning rate for this iteration lr = get_lr(step) for param_group in optimizer.param_groups: param_group["lr"] = lr - # step the optimizer optimizer.step() - # --------------- TRAINING SECTION END ------------------- - # everything that follows now is just diagnostics, prints, logging, etc. - # wait on the CPU for all device work to end so we get accurate per-iteration timings below if device == "mps": torch.mps.synchronize() elif device == "cuda": @@ -438,16 +406,14 @@ def get_lr(it: int) -> float: tokens_per_second = grad_accum_steps * ddp_world_size * B * T / (t1 - t0) norm_str = f"norm {norm:.4f}" if norm is not None else "" print0( - f"step {step:4d}/{config.num_iterations} | train loss {lossf:.6f} | {norm_str} | " + f"step {step:4d}/{config.num_iterations} | train loss {lossf_value:.6f} | {norm_str} | " f"lr {lr:.2e} | ({(t1 - t0) * 1000:.2f} ms | {tokens_per_second:.0f} tok/s)" ) - # log to wandb if config.wandb_project is not None and master_process: - log_metrics(step, {"train_loss": lossf, "lr": lr}) - # log to logile + log_metrics(step, {"train_loss": lossf_value, "lr": lr}) if master_process and logfile is not None: with open(logfile, "a") as f: - f.write("step:%d loss:%f\n" % (step, lossf)) + f.write(f"step:{step} loss:{lossf_value}\n") if ( checkpoints_dir is not None @@ -459,17 +425,13 @@ def get_lr(it: int) -> float: ): save_model(checkpoints_dir, raw_model, step=step, wandb_project=config.wandb_project) - # keep track of smooth timings, last 20 iterations if step > 1 and (step > config.num_iterations - 20): timings.append(t1 - t0) - # print the average of the last 20 timings, to get something smooth-ish timings = timings[-20:] print0(f"final {len(timings)} iters avg: {np.mean(timings) * 1000:.3f}ms") print0(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") - # ------------------------------------------------------------------------- - # clean up nice if ddp: destroy_process_group() diff --git a/simple_stories_train/train_config.yaml b/simple_stories_train/train_config.yaml index a30ef75..9fc9ad3 100644 --- a/simple_stories_train/train_config.yaml +++ b/simple_stories_train/train_config.yaml @@ -1,3 +1,4 @@ +# wandb_project: spd train_dataset_config: name: SimpleStories/SimpleStories is_tokenized: false @@ -16,7 +17,8 @@ val_dataset_config: n_ctx: 512 seed: 0 column_name: story -model_name: 1.25M +# model_id: llama-1.25M +model_id: gpt2-1.25M # 1 GPU batch_size: 64 total_batch_size: 32768 # 64 * 512 diff --git a/simple_stories_train/train_gpt2.py b/simple_stories_train/train_gpt2.py deleted file mode 100644 index 4636f5b..0000000 --- a/simple_stories_train/train_gpt2.py +++ /dev/null @@ -1,1042 +0,0 @@ -# type: ignore -# TODO: add type hints -""" -Training code for use with the SimpleStories dataset and model suite. - -References: -1) the official GPT-2 TensorFlow implementation released by OpenAI: -https://github.com/openai/gpt-2/blob/master/src/model.py -2) huggingface/transformers PyTorch implementation: -https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py - -Example launches to only benchmark the speed of bfloat16 compiled GPU training: -1 GPU: -python train_gpt2.py --write_tensors=0 --num_iterations=50 --sequence_length=1024 --compile=1 --tensorcores=1 --dtype=bfloat16 -you can also turn on flash-attention by appending --flash=1 -4 GPU: -torchrun --standalone --nproc_per_node=4 train_gpt2.py --write_tensors=0 --num_iterations=50 --sequence_length=1024 --compile=1 --tensorcores=1 --dtype=bfloat16 - -This implementation is verbatim from -- llm.c, licensed under MIT ((c) 2024 Andrei Karpathy). - -MIT License: -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -""" - -import glob -import inspect -import math -import os -import struct -from contextlib import nullcontext -from dataclasses import dataclass -from pathlib import Path - -import numpy as np -import torch -import torch._inductor.config as config -import torch.distributed as dist -import torch.nn as nn -from torch.distributed import destroy_process_group, init_process_group -from torch.distributed.optim import ZeroRedundancyOptimizer -from torch.nn import functional as F -from torch.nn.parallel import DistributedDataParallel as DDP - -from simple_stories_train.utils import is_checkpoint_step, print0, save_model_and_config - -# ----------------------------------------------------------------------------- -# PyTorch nn.Module definitions for the GPT-2 model - - -class NewGELU(nn.Module): - """Careful there are a few versions of GeLU, this one is the exact one used by OpenAI""" - - def forward(self, input): - return ( - 0.5 - * input - * ( - 1.0 - + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))) - ) - ) - - -# using a global to toggle flash-attention -FLASH = 0 - - -class CausalSelfAttention(nn.Module): - def __init__(self, config): - super().__init__() - assert config.n_embd % config.n_head == 0 - # key, query, value projections for all heads, but in a batch - self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) - # output projection - self.c_proj = nn.Linear(config.n_embd, config.n_embd) - self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 - # regularization - self.n_head = config.n_head - self.n_embd = config.n_embd - # not really a 'bias', more of a mask, but following the OpenAI/HF naming though - self.register_buffer( - "bias", - torch.tril(torch.ones(config.block_size, config.block_size)).view( - 1, 1, config.block_size, config.block_size - ), - ) - - def forward(self, x): - B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) - # calculate query, key, values for all heads in batch and move head forward to be the batch dim - qkv = self.c_attn(x) - q, k, v = qkv.split(self.n_embd, dim=2) - k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - if FLASH: - # flashattention - y = F.scaled_dot_product_attention(q, k, v, is_causal=True) - else: - # manual implementation of attention - # this materializes the large (T,T) matrix for all the queries and keys - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) - att = F.softmax(att, dim=-1) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = ( - y.transpose(1, 2).contiguous().view(B, T, C) - ) # re-assemble all head outputs side by side - # output projection - y = self.c_proj(y) - return y - - -class MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) - self.gelu = NewGELU() - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) - self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 - - def forward(self, x): - x = self.c_fc(x) - x = self.gelu(x) - x = self.c_proj(x) - return x - - -class Block(nn.Module): - def __init__(self, config): - super().__init__() - self.ln_1 = nn.LayerNorm(config.n_embd) - self.attn = CausalSelfAttention(config) - self.ln_2 = nn.LayerNorm(config.n_embd) - self.mlp = MLP(config) - - def forward(self, x): - x = x + self.attn(self.ln_1(x)) - x = x + self.mlp(self.ln_2(x)) - return x - - -# ----------------------------------------------------------------------------- -# The main GPT-2 model - - -@dataclass -class GPTConfig: - block_size: int = 1024 - vocab_size: int = 50257 - n_layer: int = 12 - n_head: int = 12 - n_embd: int = 768 - - -class GPT(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - - self.transformer = nn.ModuleDict( - dict( - wte=nn.Embedding(config.vocab_size, config.n_embd), - wpe=nn.Embedding(config.block_size, config.n_embd), - h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), - ln_f=nn.LayerNorm(config.n_embd), - ) - ) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - self.lm_head.LLMC_SKIP_INIT = 1 # don't init this one, we will tie weights - self.transformer.wte.weight = ( - self.lm_head.weight - ) # https://paperswithcode.com/method/weight-tying - - # init all weights, use a torch rng object to be very careful - self.init_rng = torch.Generator() - self.init_rng.manual_seed(42) - self.apply(self._init_weights) - - def _init_weights(self, module): - if isinstance(module, nn.Linear): - # apply special scaled init to the residual projections, per GPT-2 paper - std = ( - 0.02 - if not hasattr(module, "LLMC_RESIDUAL_SCALE_FLAG") - else 0.02 / math.sqrt(2 * self.config.n_layer) - ) - # we want to skip initializing lm_head, which shares parameters with wte - # and wte was already initialized down below during the Embedding init - if not hasattr(module, "LLMC_SKIP_INIT"): - torch.nn.init.normal_(module.weight, mean=0.0, std=std, generator=self.init_rng) - if module.bias is not None: - torch.nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02, generator=self.init_rng) - - def forward(self, idx, targets=None, return_logits=True): - device = idx.device - b, t = idx.size() - assert t <= self.config.block_size, ( - f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" - ) - pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) - - # forward the GPT model itself - tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) - x = tok_emb + pos_emb - - for block in self.transformer.h: - x = block(x) - x = self.transformer.ln_f(x) - - if targets is not None: - # if we are given some desired targets also calculate the loss - logits = self.lm_head(x) - loss = F.cross_entropy( - logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 - ) - else: - # inference-time mini-optimization: only forward the lm_head on the very last position - logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim - loss = None - - # there are performance reasons why not returning logits is prudent, if not needed - if not return_logits: - logits = None - - return logits, loss - - @classmethod - def from_pretrained(cls, model_type): - """Loads pretrained GPT-2 model weights from huggingface""" - assert model_type in {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"} - from transformers import GPT2LMHeadModel - - print("loading weights from pretrained gpt: %s" % model_type) - - # n_layer, n_head and n_embd are determined from model_type - config_args = { - "gpt2": dict(n_layer=12, n_head=12, n_embd=768), # 124M params - "gpt2-medium": dict(n_layer=24, n_head=16, n_embd=1024), # 350M params - "gpt2-large": dict(n_layer=36, n_head=20, n_embd=1280), # 774M params - "gpt2-xl": dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params - }[model_type] - config_args["vocab_size"] = 50257 # always 50257 for GPT model checkpoints - config_args["block_size"] = 1024 # always 1024 for GPT model checkpoints - # create a from-scratch initialized minGPT model - config = GPTConfig(**config_args) - model = GPT(config) - sd = model.state_dict() - sd_keys = sd.keys() - sd_keys = [ - k for k in sd_keys if not k.endswith(".attn.bias") - ] # discard this mask / buffer, not a param - - # init a huggingface/transformers model - model_hf = GPT2LMHeadModel.from_pretrained(model_type) - sd_hf = model_hf.state_dict() - - # copy while ensuring all of the parameters are aligned and match in names and shapes - sd_keys_hf = sd_hf.keys() - sd_keys_hf = [ - k for k in sd_keys_hf if not k.endswith(".attn.masked_bias") - ] # ignore these, just a buffer - sd_keys_hf = [ - k for k in sd_keys_hf if not k.endswith(".attn.bias") - ] # same, just the mask (buffer) - transposed = [ - "attn.c_attn.weight", - "attn.c_proj.weight", - "mlp.c_fc.weight", - "mlp.c_proj.weight", - ] - # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear - # this means that we have to transpose these weights when we import them - assert len(sd_keys_hf) == len(sd_keys), ( - f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" - ) - for k in sd_keys_hf: - if any(k.endswith(w) for w in transposed): - # special treatment for the Conv1D weights we need to transpose - assert sd_hf[k].shape[::-1] == sd[k].shape - with torch.no_grad(): - sd[k].copy_(sd_hf[k].t()) - else: - # vanilla copy over the other parameters - assert sd_hf[k].shape == sd[k].shape - with torch.no_grad(): - sd[k].copy_(sd_hf[k]) - - return model - - def configure_optimizers(self, weight_decay, learning_rate, betas, device_type, zero_stage): - # start with all of the candidate parameters - param_dict = {pn: p for pn, p in self.named_parameters()} - # filter out those that do not require grad - param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} - # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. - # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. - decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] - nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] - optim_groups = [ - {"params": decay_params, "weight_decay": weight_decay}, - {"params": nodecay_params, "weight_decay": 0.0}, - ] - num_decay_params = sum(p.numel() for p in decay_params) - num_nodecay_params = sum(p.numel() for p in nodecay_params) - print0( - f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters" - ) - print0( - f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters" - ) - # Create AdamW optimizer and use the fused version if it is available - fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters - use_fused = fused_available and device_type == "cuda" - print0(f"using fused AdamW: {use_fused}") - if zero_stage == 1: - print0("using ZeroRedundancyOptimizer") - optimizer = ZeroRedundancyOptimizer( - **optim_groups[0], - optimizer_class=torch.optim.AdamW, - lr=learning_rate, - betas=betas, - fused=use_fused, - ) - optimizer.add_param_group(optim_groups[1]) - else: - print0("using regular AdamW") - optimizer = torch.optim.AdamW( - optim_groups, lr=learning_rate, betas=betas, fused=use_fused - ) - return optimizer - - @torch.no_grad() - def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): - """ - Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete - the sequence max_new_tokens times, feeding the predictions back into the model each time. - Most likely you'll want to make sure to be in model.eval() mode of operation for this. - """ - for _ in range(max_new_tokens): - # if the sequence context is growing too long we must crop it at block_size - idx_cond = ( - idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size :] - ) - # forward the model to get the logits for the index in the sequence - logits, _ = self(idx_cond) - # pluck the logits at the final step and scale by desired temperature - logits = logits[:, -1, :] / temperature - # optionally crop the logits to only the top k options - if top_k is not None: - v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - logits[logits < v[:, [-1]]] = -float("Inf") - # apply softmax to convert logits to (normalized) probabilities - probs = F.softmax(logits, dim=-1) - # sample from the distribution - idx_next = torch.multinomial(probs, num_samples=1) - # append sampled index to the running sequence and continue - idx = torch.cat((idx, idx_next), dim=1) - - return idx - - -# ----------------------------------------------------------------------------- -# Our own simple Distributed Data Loader - - -def _peek_data_shard(filename): - # only reads the header, returns header data - with open(filename, "rb") as f: - # first read the header, which is 256 int32 integers (4 bytes each) - header = np.frombuffer(f.read(256 * 4), dtype=np.int32) - if header[0] != 20240520: - print("ERROR: magic number mismatch in the data .bin file!") - print("---> HINT: Are you passing in a correct file with --input_bin?") - print( - "---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README" - ) - print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try") - exit(1) - assert header[1] == 1, "unsupported version" - ntok = header[2] # number of tokens (claimed) - return ntok # for now just return the number of tokens - - -def _load_data_shard(filename): - with open(filename, "rb") as f: - # first read the header, which is 256 int32 integers (4 bytes each) - header = np.frombuffer(f.read(256 * 4), dtype=np.int32) - assert header[0] == 20240520, "magic number mismatch in the data .bin file" - assert header[1] == 1, "unsupported version" - ntok = header[2] # number of tokens (claimed) - # the rest of it are tokens, stored as uint16 - tokens = np.frombuffer(f.read(), dtype=np.uint16) - assert len(tokens) == ntok, "number of tokens read does not match header?" - return tokens - - -class DistributedDataLoader: - def __init__(self, filename_pattern, B, T, process_rank, num_processes): - self.process_rank = process_rank - self.num_processes = num_processes - self.B = B - self.T = T - - # glob files that match the pattern - self.files = sorted(glob.glob(filename_pattern)) - assert len(self.files) > 0, ( - f"did not find any files that match the pattern {filename_pattern}" - ) - - # load and validate all data shards, count number of tokens in total - ntok_total = 0 - for fname in self.files: - shard_ntok = _peek_data_shard(fname) - assert shard_ntok >= num_processes * B * T + 1 - ntok_total += shard_ntok - self.ntok_total = ntok_total - print0(f"DataLoader: total number of tokens: {ntok_total:,} across {len(self.files)} files") - - # kick things off - self.current_shard = None - self.reset() - - def reset(self): - # we're being a bit clever here: if we already had shard 0 loaded, - # then don't do the work to reload it, just reset the pointer - if self.current_shard != 0: - self.current_shard = 0 - self.tokens = _load_data_shard(self.files[self.current_shard]) - self.current_position = self.process_rank * self.B * self.T - - def advance(self): # advance to next data shard - self.current_shard = (self.current_shard + 1) % len(self.files) - self.current_position = self.process_rank * self.B * self.T - self.tokens = _load_data_shard(self.files[self.current_shard]) - - def next_batch(self): - B = self.B - T = self.T - buf = self.tokens[self.current_position : self.current_position + B * T + 1] - buf = torch.tensor(buf.astype(np.int32), dtype=torch.long) - x = (buf[:-1]).view(B, T) # inputs - y = (buf[1:]).view(B, T) # targets - # advance the start pointer in current shard - self.current_position += B * T * self.num_processes - # if loading the next batch would be out of bounds advance the shard - if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens): - self.advance() - return x, y - - -# ----------------------------------------------------------------------------- -# Python -> C bridge utilities for saving params/grads/activations to .bin files - - -def write_fp32(tensor, file): - t = tensor.detach().cpu().to(torch.float32) - b = t.numpy().tobytes() - file.write(b) - - -def write_bf16(tensor, file): - t = tensor.detach().cpu().to(torch.bfloat16) - # numpy doesn't have bf16 datatype so we have to trick it - t = t.view(torch.int16) # trick: reinterpret as int16 - b = t.numpy().tobytes() - file.write(b) - - -def write_tensors(model_tensors, L, file, dtype): - # writes the GPT-2 model's weights to a binary file - assert dtype in {"float32", "bfloat16"} - write_fun = write_fp32 if dtype == "float32" else write_bf16 - write_fun(model_tensors["transformer.wte.weight"], file) # (V, C) - write_fun(model_tensors["transformer.wpe.weight"], file) # (T, C) - for i in range(L): # (L, C) - write_fun(model_tensors[f"transformer.h.{i}.ln_1.weight"], file) - for i in range(L): # (L, C) - write_fun(model_tensors[f"transformer.h.{i}.ln_1.bias"], file) - for i in range(L): # (L, 3C, C) - write_fun(model_tensors[f"transformer.h.{i}.attn.c_attn.weight"], file) - for i in range(L): # (L, 3C) - write_fun(model_tensors[f"transformer.h.{i}.attn.c_attn.bias"], file) - for i in range(L): # (L, C, C) - write_fun(model_tensors[f"transformer.h.{i}.attn.c_proj.weight"], file) - for i in range(L): # (L, C) - write_fun(model_tensors[f"transformer.h.{i}.attn.c_proj.bias"], file) - for i in range(L): # (L, C) - write_fun(model_tensors[f"transformer.h.{i}.ln_2.weight"], file) - for i in range(L): # (L, C) - write_fun(model_tensors[f"transformer.h.{i}.ln_2.bias"], file) - for i in range(L): # (L, 4C, C) - write_fun(model_tensors[f"transformer.h.{i}.mlp.c_fc.weight"], file) - for i in range(L): # (L, 4C) - write_fun(model_tensors[f"transformer.h.{i}.mlp.c_fc.bias"], file) - for i in range(L): # (L, C, 4C) - write_fun(model_tensors[f"transformer.h.{i}.mlp.c_proj.weight"], file) - for i in range(L): # (L, C) - write_fun(model_tensors[f"transformer.h.{i}.mlp.c_proj.bias"], file) - write_fun(model_tensors["transformer.ln_f.weight"], file) # (C, ) - write_fun(model_tensors["transformer.ln_f.bias"], file) # (C, ) - - -@torch.no_grad() -def pad_vocab(tensor, multiple=128, value=0): - """ - The dimension of the vocab size in GPT-2 is 50,257 - which is unfortunately a very unfriendly number for a lot of - matrix operations on the GPU. So we pad it to the nearest - friendlier multiple, e.g. 50,304 if multiple=128 when we - export the weights into C land. This is a NOOP algorithmically - and is only done to make the tensor operations more efficient. - """ - assert tensor.ndim == 2 - V, C = tensor.shape - assert V == 50257, "just being defensive here" - # calculate padded vocab size by rounding up to nearest multiple - Vp = ((V + multiple - 1) // multiple) * multiple - # pad the tensor - pad_rows = Vp - V - padded = tensor if pad_rows == 0 else F.pad(tensor, (0, 0, 0, pad_rows), value=value) - assert padded.shape == (Vp, C) - return padded - - -def write_model(model, filename, dtype): - # everything we need to instantiate the model - # 1) header is: version int, GPTConfig ints, padding to 1024 bytes - assert dtype in {"float32", "bfloat16"} # float16 todo maybe later - version = { - "float32": 3, # 3: all tensors are fp32, padded vocab - "bfloat16": 5, # 5: all tensors are bf16, padded vocab - }[dtype] - header = torch.zeros(256, dtype=torch.int32) - header[0] = 20240326 # magic - header[1] = version # checkpoint version - header[2] = model.config.block_size - header[3] = model.config.vocab_size - header[4] = model.config.n_layer - header[5] = model.config.n_head - header[6] = model.config.n_embd - # 2) the parameters follow the header - params = {name: param.cpu() for name, param in model.named_parameters()} - # pad the vocab to a multiple of 128 here at export, for efficiency in C - wte = params["transformer.wte.weight"] # (V, C) - wte_padded = pad_vocab(wte) # (Vp, C) - params["transformer.wte.weight"] = wte_padded # (Vp, C) - print(f"padded vocab size from {wte.size(0)} to {wte_padded.size(0)}") - header[7] = wte_padded.size(0) # padded vocab size store in header - # now write to file - with open(filename, "wb") as file: - file.write(header.numpy().tobytes()) # header - write_tensors(params, model.config.n_layer, file, dtype) # params - print(f"wrote {filename}") - - -def write_state(model, x, y, logits, loss, filename): - # the state is used for debugging. - # it contains information about the input, logits, loss, and the parameter gradients - # this can be used for checking the computation correctness in C - header = torch.zeros(256, dtype=torch.int32) - header[0] = 20240327 # magic - header[1] = 2 # run state version = 2 (1 -> 2 for padded vocab changes) - header[2] = x.size(0) # batch size of the batch, B - header[3] = x.size(1) # temporal extent of the batch, T - grads = {name: param.grad.cpu() for name, param in model.named_parameters()} - # pad the vocab grads here as well, to mirror write_model - wte_grad = grads["transformer.wte.weight"] # (V, C) - wte_grad_padded = pad_vocab(wte_grad, value=0) # (Vp, C) # TODO later maybe pad with nan? - grads["transformer.wte.weight"] = wte_grad_padded # (Vp, C) - print( - f"padded vocab size in reference grads from {wte_grad.size(0)} to {wte_grad_padded.size(0)}" - ) - with open(filename, "wb") as file: - # header - file.write(header.numpy().tobytes()) - # input x - file.write(x.cpu().numpy().astype("int32").tobytes()) # (B, T) - # targets y - file.write(y.cpu().numpy().astype("int32").tobytes()) # (B, T) - # logits (result of the model forward pass) - write_fp32(logits.cpu(), file) - # loss (single float, result of the cross entropy loss) - write_fp32(loss.cpu(), file) - # gradients - write_tensors(grads, model.config.n_layer, file, "float32") - print(f"wrote {filename}") - - -def write_tokenizer(enc, filename): - n = enc.max_token_value + 1 - header = torch.zeros(256, dtype=torch.int32) - header[0] = 20240328 # magic - header[1] = 2 # tokenizer version = 2 (1 -> 2: includes EOT token) - header[2] = n # number of tokens - header[3] = enc.eot_token # EOT token - with open(filename, "wb") as file: - file.write(header.numpy().tobytes()) - for i in range(n): - b = enc.decode_bytes([i]) - length = len(b) - assert length < 256, f"Token length exceeds 255: {length}" - file.write(struct.pack(" C bridge - parser.add_argument("--write_tensors", type=int, default=1, help="write tensors to disk") - args = parser.parse_args() - - # args error checking and convenience variables - B, T = args.batch_size, args.sequence_length - assert 1 <= T <= 1024 - assert args.dtype in {"float32", "float16", "bfloat16"} - assert args.model in { - "gpt2", - "gpt2-medium", - "gpt2-large", - "gpt2-xl", - "d12", - "d24", - "d36", - "d48", - } - - # set up DDP (distributed data parallel). torchrun sets this env variable - ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? - if ddp: - # use of DDP atm demands CUDA, we set the device appropriately according to rank - assert torch.cuda.is_available(), "for now i think we need CUDA for DDP" - init_process_group(backend="nccl") - ddp_rank = int(os.environ["RANK"]) - ddp_local_rank = int(os.environ["LOCAL_RANK"]) - ddp_world_size = int(os.environ["WORLD_SIZE"]) - device = f"cuda:{ddp_local_rank}" - torch.cuda.set_device(device) - master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. - seed_offset = 0 # each process gets the exact same seed - zero_stage = args.zero_stage - else: - ddp_rank = 0 - ddp_local_rank = 0 - zero_stage = 0 - ddp_world_size = 1 - master_process = True - seed_offset = 0 - # select the device - if args.device: - # provided explicitly by the user - device = args.device - else: - # attempt to autodetect the device - device = "cpu" - if torch.cuda.is_available(): - device = "cuda" - elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): - device = "mps" - print(f"using device: {device}") - device_type = "cuda" if "cuda" in device else "cpu" - - # calculate gradient accumulation from the desired total batch size and the current run configuration - tokens_per_fwdbwd = B * T * ddp_world_size - assert args.total_batch_size % tokens_per_fwdbwd == 0 - grad_accum_steps = args.total_batch_size // tokens_per_fwdbwd - print0(f"total desired batch size: {args.total_batch_size}") - print0(f"=> calculated gradient accumulation steps: {grad_accum_steps}") - - # set up a context manager following the desired dtype and device - ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[ - args.dtype - ] - ctx = ( - torch.amp.autocast(device_type=device_type, dtype=ptdtype) - if device_type == "cuda" - else nullcontext() - ) - - # rng / reproducibility - torch.manual_seed(42) - if torch.cuda.is_available(): - torch.cuda.manual_seed(42) - - # set the torch precision mode to use TensorFloat32 (TF32) for matmuls - # docs https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html - if args.tensorcores: - torch.set_float32_matmul_precision("high") - - # turn on/off flash attention - assert args.flash in {0, 1} - FLASH = args.flash - - # init (and write) the tokenizer - enc = tiktoken.get_encoding("gpt2") - if master_process and args.write_tensors: # tokenizer is technically not tensors but ok - write_tokenizer(enc, "gpt2_tokenizer.bin") - - # init the model, either from scratch or from OpenAI pretrained checkpoint - if args.model[0] == "d": - # from scratch (random weights) - model_config = { - "d12": GPTConfig(block_size=1024, vocab_size=50257, n_layer=12, n_head=12, n_embd=768), - "d24": GPTConfig(block_size=1024, vocab_size=50257, n_layer=24, n_head=16, n_embd=1024), - "d36": GPTConfig(block_size=1024, vocab_size=50257, n_layer=36, n_head=20, n_embd=1280), - "d48": GPTConfig(block_size=1024, vocab_size=50257, n_layer=48, n_head=25, n_embd=1600), - }[args.model] - model = GPT(model_config) - else: - # load the GPT-2 model weights - model = GPT.from_pretrained(args.model) - model.train() - model.to(device) - if args.compile: - if hasattr(config, "coordinate_descent_tuning"): - config.coordinate_descent_tuning = True # suggested by @Chillee - print0("compiling the model...") - model = torch.compile(model) - - # ------------------------------------------------------------------------- - # Our own version of a simple DistributedDataLoader - - # load tokens - train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size) - val_loader = None - if args.input_val_bin: - val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size) - - # ------------------------------------------------------------------------- - # PyTorch -> C bridge: save some weights and state for C to load later as reference - - # do one forward pass to generate ground truth for our C tests - if master_process and args.write_tensors and (not args.inference_only): - x, y = train_loader.next_batch() - x, y = x.to(device), y.to(device) - logits, loss = model(x, y) - loss.backward() - # save model params, in both float32 and bfloat16 - model_to_size = { - "gpt2": "124M", - "gpt2-medium": "355M", - "gpt2-large": "774M", - "gpt2-xl": "1558M", - } - model_to_size.update({f"d{d}": f"d{d}" for d in [12, 24, 36, 48]}) - model_size_str = model_to_size[args.model] # e.g. "124M", or "d12" - write_model(model, f"gpt2_{model_size_str}.bin", dtype="float32") - write_model(model, f"gpt2_{model_size_str}_bf16.bin", dtype="bfloat16") - # save x, y, logits, loss, and parameter gradients, for debugging C - # always store these in fp32 to have an accurate reference (?) - write_state(model, x, y, logits, loss, f"gpt2_{model_size_str}_debug_state.bin") - # reset the train_loader for the optimization below - train_loader.reset() - - # ------------------------------------------------------------------------- - # main training loop - - # here we wrap model into DDP container - if ddp: - model = DDP(model, device_ids=[ddp_local_rank]) - raw_model = model.module if ddp else model # always contains the "raw" unwrapped model - - # init the optimizer - optimizer = raw_model.configure_optimizers( - weight_decay=args.weight_decay, - learning_rate=args.learning_rate, - betas=(0.9, 0.95), - device_type=device, - zero_stage=zero_stage, - ) - - # learning rate decay scheduler (cosine with warmup) - def get_lr(it): - min_lr = args.learning_rate * args.learning_rate_decay_frac - # 1) linear warmup for warmup_iters steps - if it < args.warmup_iters: - return args.learning_rate * (it + 1) / args.warmup_iters - # 2) if it > lr_decay_iters, return min learning rate - if it > args.num_iterations: - return min_lr - # 3) in between, use cosine decay down to min learning rate - decay_ratio = (it - args.warmup_iters) / (args.num_iterations - args.warmup_iters) - assert 0 <= decay_ratio <= 1 - coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0 - return min_lr + coeff * (args.learning_rate - min_lr) - - # create the logging directory if it does not exist - logfile = None - checkpoints_dir = None - output_dir = None - if args.output_dir: - output_dir = Path(args.output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - logfile = output_dir / "main.log" - # create the log file "main.log" inside it, and wipe it clean - with open(logfile, "w") as f: - pass - - # set our checkpoints directory and save off the initilized model - checkpoints_dir = output_dir / "checkpoints" - checkpoints_dir.mkdir(parents=True, exist_ok=True) - save_model_and_config(checkpoints_dir, raw_model, step=0) - - if device == "cuda": - torch.cuda.reset_peak_memory_stats() - timings = [] - norm = -1.0 # dummy value to print in inference-only mode - for step in range(1, args.num_iterations + 1): - t0 = time.time() - last_step = step == args.num_iterations - - # once in a while evaluate the validation dataset - if (args.val_loss_every > 0 and (step % args.val_loss_every == 0 or last_step)) and ( - val_loader is not None - ): - model.eval() - val_loader.reset() - with torch.no_grad(): - val_loss = 0.0 - for _ in range(args.val_max_steps): - x, y = val_loader.next_batch() - x, y = x.to(device), y.to(device) - _, loss = model(x, y, return_logits=False) - val_loss += loss.item() - val_loss /= args.val_max_steps - # log to console and to file - print0(f"val loss {val_loss}") - if master_process and logfile is not None: - with open(logfile, "a") as f: - f.write("s:%d tel:%f\n" % (step, val_loss)) - - # once in a while perform model inference on the master process - if ( - args.sample_every > 0 and (step % args.sample_every == 0 or last_step) - ) and master_process: - model.eval() - # before we end, let's also do one round of inference - # we'll kick off the generation with "<|endoftext|>", which designates the start of a new sequence - start_ids = [enc.eot_token] - xg = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...] - max_new_tokens = 32 - temperature = 1.0 - top_k = 40 - yg = raw_model.generate(xg, max_new_tokens, temperature=temperature, top_k=top_k) - print0("---------------") - print0(enc.decode(yg[0].tolist())) - print0("---------------") - - # bit confusing: we want to make sure to eval and sample on 0th iteration - # but also after the very last iteration. so we loop for step <= num_iterations - # instead of just < num_iterations (one extra due to <=), only to do - # the validation/sampling one last time, and then we break right here as we're done. - if last_step: - break - - # --------------- TRAINING SECTION BEGIN ----------------- - model.train() - optimizer.zero_grad(set_to_none=True) - # if we are trying to overfit a single batch, we reset the loader here - if args.overfit_single_batch: - train_loader.reset() - # micro-batch loop where we do gradient accumulation to reach desired total batch size - lossf = 0.0 # for getting the mean loss (as simple float) over the accumulation steps - for micro_step in range(grad_accum_steps): - # fetch a batch - x, y = train_loader.next_batch() - x, y = x.to(device), y.to(device) - if ddp: - # we want only the last micro-step to sync grads in a DDP model - # the official way to do this is with model.no_sync(), but that is a - # context manager that bloats the code, so we just toggle this variable - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - # forward pass - with ctx: - _, loss = model(x, y, return_logits=False) - # we have to scale the loss to account for gradient accumulation, - # because the gradients just add on each successive backward(). - # addition of gradients corresponds to a SUM in the objective, but - # instead of a SUM we want MEAN, so we scale the loss here - loss = loss / grad_accum_steps - lossf += loss.detach() # keep track of the mean loss - # backward pass - if not args.inference_only: - loss.backward() - if ddp: - dist.all_reduce(lossf, op=dist.ReduceOp.AVG) - lossf = lossf.item() - norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) - # determine and set the learning rate for this iteration - lr = get_lr(step) - for param_group in optimizer.param_groups: - param_group["lr"] = lr - # step the optimizer - optimizer.step() - # --------------- TRAINING SECTION END ------------------- - # everything that follows now is just diagnostics, prints, logging, etc. - - # wait on the CPU for all device work to end so we get accurate per-iteration timings below - if device == "mps": - torch.mps.synchronize() - elif device == "cuda": - torch.cuda.synchronize() - # time and print - t1 = time.time() - # the 0th iteration is often an outlier (much slower) => skip logging it - tokens_per_second = grad_accum_steps * ddp_world_size * B * T / (t1 - t0) - print0( - f"step {step:4d}/{args.num_iterations} | train loss {lossf:.6f} | norm {norm:.4f} | lr {lr:.2e} | ({(t1 - t0) * 1000:.2f} ms | {tokens_per_second:.0f} tok/s)" - ) - # log to logile - if master_process and logfile is not None: - with open(logfile, "a") as f: - f.write("s:%d trl:%f\n" % (step, lossf)) - - if checkpoints_dir is not None and is_checkpoint_step(step): - save_model_and_config(checkpoints_dir, raw_model, step=step) - - # keep track of smooth timings, last 20 iterations - if step > 1 and step > args.num_iterations - 20: - timings.append(t1 - t0) - - # print the average of the last 20 timings, to get something smooth-ish - timings = timings[-20:] - print0(f"final {len(timings)} iters avg: {np.mean(timings) * 1000:.3f}ms") - print0(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") - - # ------------------------------------------------------------------------- - # clean up nice - if ddp: - destroy_process_group() From 4042ab443d617b7dd233518704040e17f750d9e2 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 13 Aug 2025 16:27:06 +0000 Subject: [PATCH 06/13] Prefix wandb run name with model_id --- simple_stories_train/train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/simple_stories_train/train.py b/simple_stories_train/train.py index 6cbc279..47359b0 100644 --- a/simple_stories_train/train.py +++ b/simple_stories_train/train.py @@ -255,7 +255,8 @@ def main(config_path_or_obj: Path | str | Config | None = None, **kwargs: Any) - # logging if config.wandb_project is not None and master_process: - wandb.init(project=config.wandb_project, config=config.model_dump(mode="json")) + run = wandb.init(project=config.wandb_project, config=config.model_dump(mode="json")) + run.name = f"{config.model_id}-{run.name}" # DDP wrap if ddp: @@ -435,6 +436,9 @@ def get_lr(it: int) -> float: if ddp: destroy_process_group() + if config.wandb_project is not None and master_process: + wandb.finish() + if __name__ == "__main__": fire.Fire(main) From b1c6b50c0be94a7aeaaa0af28dada747126f0577 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 13 Aug 2025 17:31:30 +0000 Subject: [PATCH 07/13] Create gpt2 hf converters --- simple_stories_train/convert_to_hf.py | 110 --------- simple_stories_train/models/gpt2.py | 140 +++++++++++ simple_stories_train/models/llama.py | 217 ++++++++++++------ tests/test_gpt2_hf_converters.py | 73 ++++++ ...ibility.py => test_llama_hf_converters.py} | 2 +- tests/test_llama_implementation.py | 6 +- 6 files changed, 366 insertions(+), 182 deletions(-) delete mode 100644 simple_stories_train/convert_to_hf.py create mode 100644 tests/test_gpt2_hf_converters.py rename tests/{test_hf_compatibility.py => test_llama_hf_converters.py} (98%) diff --git a/simple_stories_train/convert_to_hf.py b/simple_stories_train/convert_to_hf.py deleted file mode 100644 index bfae729..0000000 --- a/simple_stories_train/convert_to_hf.py +++ /dev/null @@ -1,110 +0,0 @@ -""" -This script demonstrates how to convert our custom model to a HuggingFace-compatible model. -""" - -from transformers import LlamaConfig as HFLlamaConfig -from transformers import LlamaForCausalLM - -from simple_stories_train.models.llama import Llama, LlamaConfig -from simple_stories_train.models.model_configs import MODEL_CONFIGS - -# pyright: reportAttributeAccessIssue=false -# pyright: reportIndexIssue=false - - -def convert_llama_to_llama_for_causal_lm(custom_model: Llama) -> LlamaForCausalLM: - """Convert Llama model to HuggingFace format. - - Args: - custom_model: The custom Llama model to convert - - Returns: - The converted HuggingFace model - """ - model_config = custom_model.config - - # Create a matching HuggingFace configuration - hf_config = HFLlamaConfig( - vocab_size=model_config.vocab_size, - hidden_size=model_config.n_embd, - intermediate_size=model_config.n_intermediate, - num_hidden_layers=model_config.n_layer, - num_attention_heads=model_config.n_head, - num_key_value_heads=model_config.n_key_value_heads, - hidden_act="silu", - max_position_embeddings=2048, - rms_norm_eps=model_config.rms_norm_eps, - tie_word_embeddings=True, - ) - - hf_model = LlamaForCausalLM(hf_config) - - hf_model.model.embed_tokens.weight.data = custom_model.transformer.wte.weight.data - - for i in range(model_config.n_layer): - # RMSNorm 1 - hf_model.model.layers[i].input_layernorm.weight.data = custom_model.transformer.h[ - i - ].rms_1.weight.data - - # Attention weights - # Query projection - hf_model.model.layers[i].self_attn.q_proj.weight.data = custom_model.transformer.h[ - i - ].attn.q_attn.weight.data - - # Key and Value are combined in your model but separate in HF model - kv_weight = custom_model.transformer.h[i].attn.kv_attn.weight.data - kv_dim = kv_weight.shape[0] // 2 - - # Split KV weights for HF model - hf_model.model.layers[i].self_attn.k_proj.weight.data = kv_weight[:kv_dim, :] - hf_model.model.layers[i].self_attn.v_proj.weight.data = kv_weight[kv_dim:, :] - - # Output projection - hf_model.model.layers[i].self_attn.o_proj.weight.data = custom_model.transformer.h[ - i - ].attn.c_proj.weight.data - - # RMSNorm 2 - hf_model.model.layers[i].post_attention_layernorm.weight.data = custom_model.transformer.h[ - i - ].rms_2.weight.data - - # MLP layers - hf_model.model.layers[i].mlp.gate_proj.weight.data = custom_model.transformer.h[ - i - ].mlp.gate_proj.weight.data - hf_model.model.layers[i].mlp.up_proj.weight.data = custom_model.transformer.h[ - i - ].mlp.up_proj.weight.data - hf_model.model.layers[i].mlp.down_proj.weight.data = custom_model.transformer.h[ - i - ].mlp.down_proj.weight.data - - # 3. Final layer norm - hf_model.model.norm.weight.data = custom_model.transformer.rms_f.weight.data - - # 4. LM head - hf_model.lm_head.weight.data = custom_model.lm_head.weight.data - - # Set model to eval mode - hf_model.eval() - - return hf_model - - -if __name__ == "__main__": - # Example usage: Load a custom model and convert it - model_id = "llama-1.25M" # Change this to convert different model sizes - model_size = model_id.split("-")[1] - model_config = MODEL_CONFIGS[model_id] - assert isinstance(model_config, LlamaConfig) - custom_model = Llama.from_pretrained(f"SimpleStories/SimpleStories-{model_size}", model_config) - custom_model.eval() - - # Convert the model - hf_model = convert_llama_to_llama_for_causal_lm(custom_model) - - # Uncomment to save the converted model - # hf_model.save_pretrained(f"converted_hf_model_{model_size}") diff --git a/simple_stories_train/models/gpt2.py b/simple_stories_train/models/gpt2.py index e4b5e3f..d042902 100644 --- a/simple_stories_train/models/gpt2.py +++ b/simple_stories_train/models/gpt2.py @@ -2,6 +2,7 @@ import math from typing import Any from typing import cast as _cast +from typing import cast as t_cast import torch import torch.nn as nn @@ -10,9 +11,13 @@ from torch import Tensor from torch.distributed.optim import ZeroRedundancyOptimizer from torch.nn import functional as F +from transformers import GPT2Config as HFGPT2Config +from transformers import GPT2LMHeadModel from simple_stories_train.utils import print0 +# pyright: reportAttributeAccessIssue=false, reportIndexIssue=false + class GPT2Config(BaseModel): model_config = ConfigDict(extra="forbid", frozen=True) @@ -358,3 +363,138 @@ def generate( idx = idx.squeeze(0) return idx + + +def convert_hf_gpt2_to_gpt2(hf_model: GPT2LMHeadModel) -> "GPT2": + """Convert a HuggingFace GPT2LMHeadModel to our custom GPT2. + + Args: + hf_model: HuggingFace GPT2LMHeadModel instance + + Returns: + Our custom GPT2 model with weights copied from the HF model + """ + hf_config = hf_model.config + config = GPT2Config( + block_size=hf_config.n_ctx, + vocab_size=hf_config.vocab_size, + n_layer=hf_config.n_layer, + n_head=hf_config.n_head, + n_embd=hf_config.n_embd, + flash_attention=True, + ) + model = GPT2(config) + + # Embeddings + with torch.no_grad(): + model.wte.weight.copy_(hf_model.transformer.wte.weight) + model.wpe.weight.copy_(hf_model.transformer.wpe.weight) + + # Blocks + for i in range(config.n_layer): + custom_block = model.h[i] + hf_block = hf_model.transformer.h[i] + + # Layer norms + with torch.no_grad(): + custom_block.ln_1.weight.copy_(t_cast(Tensor, hf_block.ln_1.weight)) + custom_block.ln_1.bias.copy_(t_cast(Tensor, hf_block.ln_1.bias)) + custom_block.ln_2.weight.copy_(t_cast(Tensor, hf_block.ln_2.weight)) + custom_block.ln_2.bias.copy_(t_cast(Tensor, hf_block.ln_2.bias)) + + # Attention (transpose HF Conv1D weights to Linear) + with torch.no_grad(): + custom_block.attn.c_attn.weight.copy_(t_cast(Tensor, hf_block.attn.c_attn.weight).T) + custom_block.attn.c_attn.bias.copy_(t_cast(Tensor, hf_block.attn.c_attn.bias)) + + with torch.no_grad(): + custom_block.attn.c_proj.weight.copy_(t_cast(Tensor, hf_block.attn.c_proj.weight).T) + custom_block.attn.c_proj.bias.copy_(t_cast(Tensor, hf_block.attn.c_proj.bias)) + + # MLP (transpose HF Conv1D weights to Linear) + with torch.no_grad(): + custom_block.mlp.c_fc.weight.copy_(t_cast(Tensor, hf_block.mlp.c_fc.weight).T) + custom_block.mlp.c_fc.bias.copy_(t_cast(Tensor, hf_block.mlp.c_fc.bias)) + + with torch.no_grad(): + custom_block.mlp.c_proj.weight.copy_(t_cast(Tensor, hf_block.mlp.c_proj.weight).T) + custom_block.mlp.c_proj.bias.copy_(t_cast(Tensor, hf_block.mlp.c_proj.bias)) + + # Final ln_f + with torch.no_grad(): + model.ln_f.weight.copy_(hf_model.transformer.ln_f.weight) + model.ln_f.bias.copy_(hf_model.transformer.ln_f.bias) + + # LM head + with torch.no_grad(): + model.lm_head.weight.copy_(hf_model.lm_head.weight) + + return model + + +def convert_gpt2_to_hf_gpt2(custom_model: GPT2) -> GPT2LMHeadModel: + """Convert custom GPT-2 model to HuggingFace GPT2LMHeadModel. + + Args: + custom_model: The custom GPT-2 model to convert + + Returns: + The converted HuggingFace GPT2LMHeadModel + """ + model_config: GPT2Config = custom_model.config + + hf_config = HFGPT2Config( + vocab_size=model_config.vocab_size, + n_positions=model_config.block_size, + n_ctx=model_config.block_size, + n_layer=model_config.n_layer, + n_head=model_config.n_head, + n_embd=model_config.n_embd, + activation_function="gelu_new", + n_inner=None, + layer_norm_epsilon=1e-5, + # Tie embeddings and lm_head as our implementation does + tie_word_embeddings=True, + ) + + hf_model = GPT2LMHeadModel(hf_config) + + # Embeddings + hf_model.transformer.wte.weight.data = custom_model.wte.weight.data + hf_model.transformer.wpe.weight.data = custom_model.wpe.weight.data + + # Transformer blocks + for i in range(model_config.n_layer): + custom_block = custom_model.h[i] + hf_block = hf_model.transformer.h[i] + + # LayerNorms + hf_block.ln_1.weight.data = custom_block.ln_1.weight.data + hf_block.ln_1.bias.data = custom_block.ln_1.bias.data + hf_block.ln_2.weight.data = custom_block.ln_2.weight.data + hf_block.ln_2.bias.data = custom_block.ln_2.bias.data + + # Attention projections: HF uses Conv1D (weight shape [in, out]) + # Our Linear weights are [out, in], so transpose when copying to HF + hf_block.attn.c_attn.weight.data = custom_block.attn.c_attn.weight.data.t().contiguous() + hf_block.attn.c_attn.bias.data = custom_block.attn.c_attn.bias.data + + hf_block.attn.c_proj.weight.data = custom_block.attn.c_proj.weight.data.t().contiguous() + hf_block.attn.c_proj.bias.data = custom_block.attn.c_proj.bias.data + + # MLP projections + hf_block.mlp.c_fc.weight.data = custom_block.mlp.c_fc.weight.data.t().contiguous() + hf_block.mlp.c_fc.bias.data = custom_block.mlp.c_fc.bias.data + + hf_block.mlp.c_proj.weight.data = custom_block.mlp.c_proj.weight.data.t().contiguous() + hf_block.mlp.c_proj.bias.data = custom_block.mlp.c_proj.bias.data + + # Final LayerNorm + hf_model.transformer.ln_f.weight.data = custom_model.ln_f.weight.data + hf_model.transformer.ln_f.bias.data = custom_model.ln_f.bias.data + + # LM head + hf_model.lm_head.weight.data = custom_model.lm_head.weight.data + + hf_model.eval() + return hf_model diff --git a/simple_stories_train/models/llama.py b/simple_stories_train/models/llama.py index 6d686e5..0804537 100644 --- a/simple_stories_train/models/llama.py +++ b/simple_stories_train/models/llama.py @@ -10,6 +10,7 @@ from torch import Tensor from torch.distributed.optim import ZeroRedundancyOptimizer from torch.nn import functional as F +from transformers import LlamaConfig as HFLlamaConfig from transformers import LlamaForCausalLM from simple_stories_train.utils import print0 @@ -17,73 +18,6 @@ # pyright: reportAttributeAccessIssue=false, reportIndexIssue=false -def convert_llama_for_causal_lm_to_llama(hf_model: LlamaForCausalLM): - # Create a matching custom Llama configuration - hf_config = hf_model.config - - model_config = LlamaConfig( - vocab_size=hf_config.vocab_size, - n_layer=hf_config.num_hidden_layers, - n_head=hf_config.num_attention_heads, - n_embd=hf_config.hidden_size, - n_intermediate=hf_config.intermediate_size, - rotary_dim=hf_config.hidden_size // hf_config.num_attention_heads, # Assuming head_dim - n_key_value_heads=hf_config.num_key_value_heads, - ) - - model = Llama(model_config) - - # Convert embeddings - model.transformer.wte.weight.data = hf_model.model.embed_tokens.weight.data - - for i in range(hf_config.num_hidden_layers): - # RMSNorm 1 - model.transformer.h[i].rms_1.weight.data = hf_model.model.layers[ - i - ].input_layernorm.weight.data - - # Attention weights - model.transformer.h[i].attn.q_attn.weight.data = hf_model.model.layers[ - i - ].self_attn.q_proj.weight.data - - # Key and Value projections - combine separate HF weights into single KV weight - k_weight = cast(Tensor, hf_model.model.layers[i].self_attn.k_proj.weight.data) - v_weight = cast(Tensor, hf_model.model.layers[i].self_attn.v_proj.weight.data) - kv_combined = torch.cat([k_weight, v_weight], dim=0) - - model.transformer.h[i].attn.kv_attn.weight.data = kv_combined - - # Output projection - model.transformer.h[i].attn.c_proj.weight.data = hf_model.model.layers[ - i - ].self_attn.o_proj.weight.data - - # RMSNorm 2 - model.transformer.h[i].rms_2.weight.data = hf_model.model.layers[ - i - ].post_attention_layernorm.weight.data - - # MLP layers - model.transformer.h[i].mlp.gate_proj.weight.data = hf_model.model.layers[ - i - ].mlp.gate_proj.weight.data - model.transformer.h[i].mlp.up_proj.weight.data = hf_model.model.layers[ - i - ].mlp.up_proj.weight.data - model.transformer.h[i].mlp.down_proj.weight.data = hf_model.model.layers[ - i - ].mlp.down_proj.weight.data - - # Final layer norm - model.transformer.rms_f.weight.data = hf_model.model.norm.weight.data - - # LM head - model.lm_head.weight.data = hf_model.lm_head.weight.data - - return model - - class LlamaConfig(BaseModel): model_config = ConfigDict(extra="forbid", frozen=True) block_size: int = 1024 @@ -592,3 +526,152 @@ def generate( idx = idx.squeeze(0) return idx + + +def convert_llama_for_causal_lm_to_llama(hf_model: LlamaForCausalLM) -> Llama: + # Create a matching custom Llama configuration + hf_config = hf_model.config + + model_config = LlamaConfig( + vocab_size=hf_config.vocab_size, + n_layer=hf_config.num_hidden_layers, + n_head=hf_config.num_attention_heads, + n_embd=hf_config.hidden_size, + n_intermediate=hf_config.intermediate_size, + rotary_dim=hf_config.hidden_size // hf_config.num_attention_heads, # Assuming head_dim + n_key_value_heads=hf_config.num_key_value_heads, + ) + + model = Llama(model_config) + + # Convert embeddings + model.transformer.wte.weight.data = hf_model.model.embed_tokens.weight.data + + for i in range(hf_config.num_hidden_layers): + # RMSNorm 1 + model.transformer.h[i].rms_1.weight.data = hf_model.model.layers[ + i + ].input_layernorm.weight.data + + # Attention weights + model.transformer.h[i].attn.q_attn.weight.data = hf_model.model.layers[ + i + ].self_attn.q_proj.weight.data + + # Key and Value projections - combine separate HF weights into single KV weight + k_weight = cast(Tensor, hf_model.model.layers[i].self_attn.k_proj.weight.data) + v_weight = cast(Tensor, hf_model.model.layers[i].self_attn.v_proj.weight.data) + kv_combined = torch.cat([k_weight, v_weight], dim=0) + + model.transformer.h[i].attn.kv_attn.weight.data = kv_combined + + # Output projection + model.transformer.h[i].attn.c_proj.weight.data = hf_model.model.layers[ + i + ].self_attn.o_proj.weight.data + + # RMSNorm 2 + model.transformer.h[i].rms_2.weight.data = hf_model.model.layers[ + i + ].post_attention_layernorm.weight.data + + # MLP layers + model.transformer.h[i].mlp.gate_proj.weight.data = hf_model.model.layers[ + i + ].mlp.gate_proj.weight.data + model.transformer.h[i].mlp.up_proj.weight.data = hf_model.model.layers[ + i + ].mlp.up_proj.weight.data + model.transformer.h[i].mlp.down_proj.weight.data = hf_model.model.layers[ + i + ].mlp.down_proj.weight.data + + # Final layer norm + model.transformer.rms_f.weight.data = hf_model.model.norm.weight.data + + # LM head + model.lm_head.weight.data = hf_model.lm_head.weight.data + + return model + + +def convert_llama_to_llama_for_causal_lm(custom_model: Llama) -> LlamaForCausalLM: + """Convert Llama model to HuggingFace format. + + Args: + custom_model: The custom Llama model to convert + + Returns: + The converted HuggingFace model + """ + model_config = custom_model.config + + # Create a matching HuggingFace configuration + hf_config = HFLlamaConfig( + vocab_size=model_config.vocab_size, + hidden_size=model_config.n_embd, + intermediate_size=model_config.n_intermediate, + num_hidden_layers=model_config.n_layer, + num_attention_heads=model_config.n_head, + num_key_value_heads=model_config.n_key_value_heads, + hidden_act="silu", + max_position_embeddings=2048, + rms_norm_eps=model_config.rms_norm_eps, + tie_word_embeddings=True, + ) + + hf_model = LlamaForCausalLM(hf_config) + + hf_model.model.embed_tokens.weight.data = custom_model.transformer.wte.weight.data + + for i in range(model_config.n_layer): + # RMSNorm 1 + hf_model.model.layers[i].input_layernorm.weight.data = custom_model.transformer.h[ + i + ].rms_1.weight.data + + # Attention weights + # Query projection + hf_model.model.layers[i].self_attn.q_proj.weight.data = custom_model.transformer.h[ + i + ].attn.q_attn.weight.data + + # Key and Value are combined in your model but separate in HF model + kv_weight = custom_model.transformer.h[i].attn.kv_attn.weight.data + kv_dim = kv_weight.shape[0] // 2 + + # Split KV weights for HF model + hf_model.model.layers[i].self_attn.k_proj.weight.data = kv_weight[:kv_dim, :] + hf_model.model.layers[i].self_attn.v_proj.weight.data = kv_weight[kv_dim:, :] + + # Output projection + hf_model.model.layers[i].self_attn.o_proj.weight.data = custom_model.transformer.h[ + i + ].attn.c_proj.weight.data + + # RMSNorm 2 + hf_model.model.layers[i].post_attention_layernorm.weight.data = custom_model.transformer.h[ + i + ].rms_2.weight.data + + # MLP layers + hf_model.model.layers[i].mlp.gate_proj.weight.data = custom_model.transformer.h[ + i + ].mlp.gate_proj.weight.data + hf_model.model.layers[i].mlp.up_proj.weight.data = custom_model.transformer.h[ + i + ].mlp.up_proj.weight.data + hf_model.model.layers[i].mlp.down_proj.weight.data = custom_model.transformer.h[ + i + ].mlp.down_proj.weight.data + + # 3. Final layer norm + hf_model.model.norm.weight.data = custom_model.transformer.rms_f.weight.data + + # 4. LM head + hf_model.lm_head.weight.data = custom_model.lm_head.weight.data + + # Set model to eval mode + hf_model.eval() + + return hf_model diff --git a/tests/test_gpt2_hf_converters.py b/tests/test_gpt2_hf_converters.py new file mode 100644 index 0000000..9d112f7 --- /dev/null +++ b/tests/test_gpt2_hf_converters.py @@ -0,0 +1,73 @@ +""" +Tests for GPT-2 conversion between custom implementation and HuggingFace. +""" + +import torch +from jaxtyping import Int +from torch import Tensor +from transformers import GPT2Config as HFGPT2Config +from transformers import GPT2LMHeadModel + +from simple_stories_train.models.gpt2 import ( + GPT2, + GPT2Config, + convert_gpt2_to_hf_gpt2, + convert_hf_gpt2_to_gpt2, +) + + +@torch.inference_mode() +def test_convert_gpt2_to_hf_gpt2() -> None: + """Validate custom -> HF conversion produces identical logits.""" + # Small config for speed + config = GPT2Config(block_size=64, vocab_size=50257, n_layer=2, n_head=2, n_embd=128) + custom_model = GPT2(config) + custom_model.eval() + + hf_model = convert_gpt2_to_hf_gpt2(custom_model) + hf_model.eval() + + # Random input ids within vocab range + batch_size = 2 + seq_len = 16 + inputs: Int[Tensor, "batch pos"] = torch.randint(0, config.vocab_size, (batch_size, seq_len)) + + hf_logits = hf_model(input_ids=inputs).logits # type: ignore[arg-type] + custom_logits, _ = custom_model(inputs) + + assert custom_logits is not None + torch.testing.assert_close(hf_logits, custom_logits, rtol=1e-5, atol=1e-5) + + +@torch.inference_mode() +def test_convert_hf_gpt2_to_gpt2() -> None: + """Validate HF -> custom conversion produces identical logits.""" + # Construct a tiny HF GPT-2 config/model to avoid network downloads + + hf_config = HFGPT2Config( + vocab_size=50257, + n_positions=64, + n_ctx=64, + n_layer=2, + n_head=2, + n_embd=128, + activation_function="gelu_new", + tie_word_embeddings=True, + ) + hf_model = GPT2LMHeadModel(hf_config) + hf_model.eval() + + custom_model = convert_hf_gpt2_to_gpt2(hf_model) + custom_model.eval() + + batch_size = 2 + seq_len = 16 + inputs: Int[Tensor, "batch pos"] = torch.randint( + 0, hf_model.config.vocab_size, (batch_size, seq_len) + ) + + hf_logits = hf_model(input_ids=inputs).logits # type: ignore[arg-type] + custom_logits, _ = custom_model(inputs) + + assert custom_logits is not None + torch.testing.assert_close(hf_logits, custom_logits, rtol=1e-5, atol=1e-5) diff --git a/tests/test_hf_compatibility.py b/tests/test_llama_hf_converters.py similarity index 98% rename from tests/test_hf_compatibility.py rename to tests/test_llama_hf_converters.py index 8c73b9e..7e19a85 100644 --- a/tests/test_hf_compatibility.py +++ b/tests/test_llama_hf_converters.py @@ -7,11 +7,11 @@ from tokenizers import Tokenizer from transformers import AutoTokenizer, LlamaForCausalLM -from simple_stories_train.convert_to_hf import convert_llama_to_llama_for_causal_lm from simple_stories_train.models.llama import ( Llama, LlamaConfig, convert_llama_for_causal_lm_to_llama, + convert_llama_to_llama_for_causal_lm, ) from simple_stories_train.models.model_configs import MODEL_CONFIGS diff --git a/tests/test_llama_implementation.py b/tests/test_llama_implementation.py index 1782cd7..b436f75 100644 --- a/tests/test_llama_implementation.py +++ b/tests/test_llama_implementation.py @@ -1,9 +1,7 @@ import tempfile from pathlib import Path -from typing import cast import torch -from torch import Tensor from transformers import LlamaConfig as HFLlamaConfig from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb @@ -69,8 +67,8 @@ def test_rotary_embedding_implementation() -> None: q_custom = q_hf.detach().clone() k_custom = k_hf.detach().clone() - custom_cos = cast(Tensor, custom_implementation.rotary_cos)[position_ids].to(q_custom.dtype) - custom_sin = cast(Tensor, custom_implementation.rotary_sin)[position_ids].to(q_custom.dtype) + custom_cos = custom_implementation.rotary_cos[position_ids].to(q_custom.dtype) + custom_sin = custom_implementation.rotary_sin[position_ids].to(q_custom.dtype) q_custom_rot, k_custom_rot = custom_implementation.apply_rotary_pos_emb( q_custom, k_custom, custom_cos, custom_sin From ed62db81de47f1b511b940b7232a01748ea2c377 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 13 Aug 2025 17:54:09 +0000 Subject: [PATCH 08/13] Create push_to_hf --- scripts/push_to_hf.py | 225 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 225 insertions(+) create mode 100644 scripts/push_to_hf.py diff --git a/scripts/push_to_hf.py b/scripts/push_to_hf.py new file mode 100644 index 0000000..79b1193 --- /dev/null +++ b/scripts/push_to_hf.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import argparse +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import torch +import yaml +from huggingface_hub import HfApi +from transformers import PreTrainedModel + +from simple_stories_train.models.gpt2 import ( + GPT2, + GPT2Config, + convert_gpt2_to_hf_gpt2, +) +from simple_stories_train.models.llama import ( + Llama, + LlamaConfig, + convert_llama_to_llama_for_causal_lm, +) +from simple_stories_train.models.model_configs import MODEL_CONFIGS + + +@dataclass +class PushArgs: + checkpoint_path: Path + repo_id: str + token: str | None + private: bool + revision: str | None + commit_message: str | None + model_card_readme: Path | None + + +def parse_args() -> PushArgs: + parser = argparse.ArgumentParser( + description=( + "Load a local custom Llama checkpoint, convert to Hugging Face format, and push to the Hub." + ) + ) + parser.add_argument( + "--checkpoint-path", + type=str, + required=True, + help="Path to the local .pt checkpoint saved via torch.save(model.state_dict(), ...)", + ) + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Destination repository in the form 'username/repo_name' or 'org/repo_name'", + ) + parser.add_argument( + "--token", + type=str, + default=None, + help="Hugging Face token. If omitted, will use HF_TOKEN env var if present.", + ) + parser.add_argument( + "--private", + action="store_true", + help="Create the Hub repo as private (default: public).", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="Optional branch name on the Hub (e.g., 'main').", + ) + parser.add_argument( + "--commit-message", + type=str, + default=None, + help="Commit message to use when pushing to the Hub.", + ) + parser.add_argument( + "--model-card-readme", + type=str, + default=None, + help="Optional path to a README.md to upload as the model card.", + ) + + ns = parser.parse_args() + token = ns.token or os.environ.get("HF_TOKEN") + + return PushArgs( + checkpoint_path=Path(ns.checkpoint_path).expanduser().resolve(), + repo_id=ns.repo_id, + token=token, + private=bool(ns.private), + revision=ns.revision, + commit_message=ns.commit_message, + model_card_readme=( + Path(ns.model_card_readme).expanduser().resolve() if ns.model_card_readme else None + ), + ) + + +def load_config_from_checkpoint_dir(checkpoint_path: Path) -> tuple[str, LlamaConfig | GPT2Config]: + """Load model config by reading model_id from final_config.yaml adjacent to the checkpoint. + + Returns (model_id, config) where config is one of LlamaConfig or GPT2Config. + """ + final_cfg_path = checkpoint_path.parent / "final_config.yaml" + if not final_cfg_path.exists(): + raise FileNotFoundError( + f"Could not find 'final_config.yaml' next to checkpoint at {final_cfg_path}" + ) + + with final_cfg_path.open("r") as f: + data: dict[str, Any] = yaml.safe_load(f) + + model_id = data.get("model_id") + if not isinstance(model_id, str): + raise ValueError("'model_id' missing or invalid in final_config.yaml") + + preset = MODEL_CONFIGS.get(model_id) + if preset is None: + raise ValueError( + f"Unknown model_id '{model_id}' in final_config.yaml. Available: {tuple(MODEL_CONFIGS.keys())}" + ) + + # Optionally override context length from training config if present + train_ds_cfg = data.get("train_dataset_config", {}) or {} + n_ctx_override = train_ds_cfg.get("n_ctx") + if isinstance(n_ctx_override, int) and n_ctx_override > 0: + if isinstance(preset, LlamaConfig): + return model_id, preset.model_copy( + update={"n_ctx": n_ctx_override, "block_size": n_ctx_override} + ) + if isinstance(preset, GPT2Config): + return model_id, preset.model_copy(update={"block_size": n_ctx_override}) + return model_id, preset + + +def load_custom_model( + checkpoint_path: Path, model_id: str, config: LlamaConfig | GPT2Config +) -> Llama | GPT2: + # Llama requires special loader to rebuild rotary buffers + if isinstance(config, LlamaConfig): + model = Llama.from_pretrained(str(checkpoint_path), config=config, strict=True) + else: + # GPT-2: regular state_dict load + state_dict = torch.load(str(checkpoint_path), weights_only=True, map_location="cpu") + # Strip DDP prefixes if present + state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} + model = GPT2(config) + model.load_state_dict(state_dict, strict=True) + model.eval() + return model + + +def convert_to_hf_model(custom_model: Llama | GPT2) -> PreTrainedModel: + if isinstance(custom_model, Llama): + hf_model = convert_llama_to_llama_for_causal_lm(custom_model) + else: + hf_model = convert_gpt2_to_hf_gpt2(custom_model) + hf_model.eval() + return hf_model + + +def push_model_to_hub( + hf_model: PreTrainedModel, + repo_id: str, + token: str | None, + private: bool, + revision: str | None, + commit_message: str | None, +) -> None: + # Call via the class to satisfy certain linters complaining about 'self' + hf_model.__class__.push_to_hub( + hf_model, + repo_id=repo_id, + private=private, + token=token, + commit_message=commit_message, + revision=revision, + ) + + +def optionally_upload_readme(repo_id: str, token: str | None, readme_path: Path | None) -> None: + if readme_path is None: + return + if not readme_path.exists(): + raise FileNotFoundError(f"README file not found: {readme_path}") + api = HfApi() + api.upload_file( + path_or_fileobj=str(readme_path), + path_in_repo="README.md", + repo_id=repo_id, + repo_type="model", + token=token, + ) + + +def main() -> None: + args = parse_args() + + if not args.checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint not found at {args.checkpoint_path}") + + model_id, config = load_config_from_checkpoint_dir(args.checkpoint_path) + custom_model = load_custom_model(args.checkpoint_path, model_id, config) + + # Convert and push + hf_model = convert_to_hf_model(custom_model) + push_model_to_hub( + hf_model=hf_model, + repo_id=args.repo_id, + token=args.token, + private=args.private, + revision=args.revision, + commit_message=args.commit_message, + ) + + # Optional README + optionally_upload_readme(args.repo_id, args.token, args.model_card_readme) + + +if __name__ == "__main__": + torch.set_grad_enabled(False) + main() From 77ecfe7b10020bb3bb01f8f944898bfa9cc655d8 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 13 Aug 2025 18:55:36 +0000 Subject: [PATCH 09/13] Upload tokenizer to hf too --- scripts/push_to_hf.py | 139 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 136 insertions(+), 3 deletions(-) diff --git a/scripts/push_to_hf.py b/scripts/push_to_hf.py index 79b1193..cc08e78 100644 --- a/scripts/push_to_hf.py +++ b/scripts/push_to_hf.py @@ -1,6 +1,16 @@ -from __future__ import annotations +""" +Usage: +```bash +python scripts/push_to_hf.py \ + --checkpoint-path /path/to/checkpoint.pt \ + --repo-id your-username/your-repo \ + --token $HF_TOKEN +``` +""" import argparse +import io +import json import os from dataclasses import dataclass from pathlib import Path @@ -9,6 +19,7 @@ import torch import yaml from huggingface_hub import HfApi +from tokenizers import Tokenizer from transformers import PreTrainedModel from simple_stories_train.models.gpt2 import ( @@ -38,7 +49,8 @@ class PushArgs: def parse_args() -> PushArgs: parser = argparse.ArgumentParser( description=( - "Load a local custom Llama checkpoint, convert to Hugging Face format, and push to the Hub." + "Load a local custom Llama checkpoint, convert to Hugging Face format, and " + "push to the Hub." ) ) parser.add_argument( @@ -120,7 +132,8 @@ def load_config_from_checkpoint_dir(checkpoint_path: Path) -> tuple[str, LlamaCo preset = MODEL_CONFIGS.get(model_id) if preset is None: raise ValueError( - f"Unknown model_id '{model_id}' in final_config.yaml. Available: {tuple(MODEL_CONFIGS.keys())}" + f"Unknown model_id '{model_id}' in final_config.yaml." + f" Available: {tuple(MODEL_CONFIGS.keys())}" ) # Optionally override context length from training config if present @@ -162,6 +175,113 @@ def convert_to_hf_model(custom_model: Llama | GPT2) -> PreTrainedModel: return hf_model +def _resolve_tokenizer_path(final_cfg_path: Path) -> Path | None: + """Try to resolve a tokenizer file path from the final_config.yaml next to the checkpoint. + + Returns absolute path to the tokenizer json if it can be found, otherwise None. + + TODO: Save the tokenizer when training the model. + """ + try: + with final_cfg_path.open("r") as f: + data: dict[str, Any] = yaml.safe_load(f) + except Exception: + return None + + train_ds_cfg = data.get("train_dataset_config", {}) or {} + tokenizer_rel: str | None = train_ds_cfg.get("tokenizer_file_path") + if not tokenizer_rel or not isinstance(tokenizer_rel, str): + return None + + # As a last resort, if the file name matches a known tokenizer in the repo, use it + known_default = Path("simple_stories_train/tokenizer/simplestories-tokenizer.json") + if known_default.is_file(): + return known_default.resolve() + + return None + + +def upload_tokenizer_to_hub( + repo_id: str, + token: str | None, + model_max_length: int | None, + checkpoint_path: Path, +) -> None: + """Upload tokenizer artifacts (minimal set) to the Hub model repo. + + Uploads: + - tokenizer.json (raw Tokenizers file) + - tokenizer_config.json (minimal, includes eos/unk tokens and max length if known) + """ + final_cfg_path = checkpoint_path.parent / "final_config.yaml" + tokenizer_path = _resolve_tokenizer_path(final_cfg_path) + if tokenizer_path is None or not tokenizer_path.exists(): + # Nothing to upload + return + + api = HfApi() + + # Upload tokenizer.json (rename if needed) + api.upload_file( + path_or_fileobj=str(tokenizer_path), + path_in_repo="tokenizer.json", + repo_id=repo_id, + repo_type="model", + token=token, + ) + + # Build tokenizer_config.json matching desired structure + # Discover IDs for special tokens from the tokenizer file + unk_token = "[UNK]" + eos_token = "[EOS]" + added_tokens_decoder: dict[str, dict[str, Any]] = {} + + try: + tk: Tokenizer = Tokenizer.from_file(str(tokenizer_path)) + unk_id = tk.token_to_id(unk_token) + eos_id = tk.token_to_id(eos_token) + except Exception: + unk_id = None + eos_id = None + + def _entry(content: str) -> dict[str, Any]: + return { + "content": content, + "lstrip": False, + "normalized": False, + "rstrip": False, + "single_word": False, + "special": True, + } + + if isinstance(unk_id, int): + added_tokens_decoder[str(unk_id)] = _entry(unk_token) + if isinstance(eos_id, int): + added_tokens_decoder[str(eos_id)] = _entry(eos_token) + + # Use HF's sentinel for unlimited length to mirror common configs + unlimited_len = int(1e30) + + cfg: dict[str, Any] = { + "added_tokens_decoder": added_tokens_decoder, + "clean_up_tokenization_spaces": False, + "eos_token": eos_token, + "extra_special_tokens": {}, + "model_max_length": unlimited_len, + "tokenizer_class": "PreTrainedTokenizerFast", + "unk_token": unk_token, + } + + cfg_bytes = json.dumps(cfg, indent=2).encode("utf-8") + api.upload_file( + path_or_fileobj=io.BytesIO(cfg_bytes), + path_in_repo="tokenizer_config.json", + repo_id=repo_id, + repo_type="model", + token=token, + ) + + def push_model_to_hub( hf_model: PreTrainedModel, repo_id: str, @@ -216,6 +336,19 @@ def main() -> None: commit_message=args.commit_message, ) + # Upload tokenizer artifacts (minimal set) + model_max_len: int | None = None + if isinstance(config, LlamaConfig): + model_max_len = config.n_ctx + elif isinstance(config, GPT2Config): + model_max_len = config.block_size + upload_tokenizer_to_hub( + repo_id=args.repo_id, + token=args.token, + model_max_length=model_max_len, + checkpoint_path=args.checkpoint_path, + ) + # Optional README optionally_upload_readme(args.repo_id, args.token, args.model_card_readme) From 334fa0757616619b033ca0eb52cf74039d1d0294 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 14 Aug 2025 13:36:17 +0000 Subject: [PATCH 10/13] Refactor gpt conversions --- simple_stories_train/models/gpt2.py | 192 ++++++++++++---------------- 1 file changed, 84 insertions(+), 108 deletions(-) diff --git a/simple_stories_train/models/gpt2.py b/simple_stories_train/models/gpt2.py index d042902..f8812fa 100644 --- a/simple_stories_train/models/gpt2.py +++ b/simple_stories_train/models/gpt2.py @@ -1,8 +1,6 @@ import inspect import math -from typing import Any -from typing import cast as _cast -from typing import cast as t_cast +from typing import Any, Literal, cast import torch import torch.nn as nn @@ -69,7 +67,8 @@ def forward( x: Float[Tensor, "batch pos d_model"], ) -> Float[Tensor, "batch pos d_model"]: B, T, C = x.size() - # calculate query, key, values for all heads in batch and move head forward to be the batch dim + # calculate q, k, v for all heads in batch + # move head dimension forward to be the batch dimension qkv = self.c_attn(x) q, k, v = qkv.split(self.n_embd, dim=2) k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) @@ -227,7 +226,7 @@ def from_pretrained(cls, model_type: str) -> "GPT2": }[model_type] config_args["vocab_size"] = 50257 config_args["block_size"] = 1024 - config = GPT2Config(**_cast(dict[str, Any], config_args)) + config = GPT2Config(**cast(dict[str, Any], config_args)) model = GPT2(config) sd = model.state_dict() @@ -244,7 +243,7 @@ def from_pretrained(cls, model_type: str) -> "GPT2": "mlp.c_fc.weight", "mlp.c_proj.weight", ] - # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear + # openai checkpoints use a "Conv1D" module; we use a vanilla Linear # this means that we have to transpose these weights when we import them assert len(sd_keys_hf) == len(sd_keys), ( f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" @@ -365,71 +364,88 @@ def generate( return idx -def convert_hf_gpt2_to_gpt2(hf_model: GPT2LMHeadModel) -> "GPT2": - """Convert a HuggingFace GPT2LMHeadModel to our custom GPT2. - - Args: - hf_model: HuggingFace GPT2LMHeadModel instance +def _build_mapping( + direction: Literal["custom_to_hf", "hf_to_custom"], n_layer: int +) -> list[tuple[str, str, bool]]: + base_pairs: list[tuple[str, str, bool]] = [ + ("wte.weight", "transformer.wte.weight", False), + ("wpe.weight", "transformer.wpe.weight", False), + ("ln_f.weight", "transformer.ln_f.weight", False), + ("ln_f.bias", "transformer.ln_f.bias", False), + ("lm_head.weight", "lm_head.weight", False), + ] + + layer_pairs: list[tuple[str, str, bool]] = [] + for i in range(n_layer): + c_prefix = f"h_torch.{i}." + h_prefix = f"transformer.h.{i}." + layer_pairs.extend( + [ + (f"{c_prefix}ln_1.weight", f"{h_prefix}ln_1.weight", False), + (f"{c_prefix}ln_1.bias", f"{h_prefix}ln_1.bias", False), + (f"{c_prefix}ln_2.weight", f"{h_prefix}ln_2.weight", False), + (f"{c_prefix}ln_2.bias", f"{h_prefix}ln_2.bias", False), + (f"{c_prefix}attn.c_attn.weight", f"{h_prefix}attn.c_attn.weight", True), + (f"{c_prefix}attn.c_attn.bias", f"{h_prefix}attn.c_attn.bias", False), + (f"{c_prefix}attn.c_proj.weight", f"{h_prefix}attn.c_proj.weight", True), + (f"{c_prefix}attn.c_proj.bias", f"{h_prefix}attn.c_proj.bias", False), + (f"{c_prefix}mlp.c_fc.weight", f"{h_prefix}mlp.c_fc.weight", True), + (f"{c_prefix}mlp.c_fc.bias", f"{h_prefix}mlp.c_fc.bias", False), + (f"{c_prefix}mlp.c_proj.weight", f"{h_prefix}mlp.c_proj.weight", True), + (f"{c_prefix}mlp.c_proj.bias", f"{h_prefix}mlp.c_proj.bias", False), + ] + ) - Returns: - Our custom GPT2 model with weights copied from the HF model - """ - hf_config = hf_model.config - config = GPT2Config( - block_size=hf_config.n_ctx, - vocab_size=hf_config.vocab_size, - n_layer=hf_config.n_layer, - n_head=hf_config.n_head, - n_embd=hf_config.n_embd, - flash_attention=True, - ) - model = GPT2(config) + mapping = base_pairs + layer_pairs + if direction == "custom_to_hf": + return mapping + return [(dst, src, transpose) for (src, dst, transpose) in mapping] - # Embeddings - with torch.no_grad(): - model.wte.weight.copy_(hf_model.transformer.wte.weight) - model.wpe.weight.copy_(hf_model.transformer.wpe.weight) - # Blocks - for i in range(config.n_layer): - custom_block = model.h[i] - hf_block = hf_model.transformer.h[i] +def _resolve_tensor(module: nn.Module, path: str) -> Tensor: + """Get tensor from module by path. - # Layer norms - with torch.no_grad(): - custom_block.ln_1.weight.copy_(t_cast(Tensor, hf_block.ln_1.weight)) - custom_block.ln_1.bias.copy_(t_cast(Tensor, hf_block.ln_1.bias)) - custom_block.ln_2.weight.copy_(t_cast(Tensor, hf_block.ln_2.weight)) - custom_block.ln_2.bias.copy_(t_cast(Tensor, hf_block.ln_2.bias)) + E.g. _resolve_tensor(module, "transformer.h.0.attn.c_attn.weight") + will return the weight tensor for the first attention layer. + """ - # Attention (transpose HF Conv1D weights to Linear) - with torch.no_grad(): - custom_block.attn.c_attn.weight.copy_(t_cast(Tensor, hf_block.attn.c_attn.weight).T) - custom_block.attn.c_attn.bias.copy_(t_cast(Tensor, hf_block.attn.c_attn.bias)) + obj: Any = module + for part in path.split("."): + obj = obj[int(part)] if part.isdigit() else getattr(obj, part) + assert isinstance(obj, Tensor) + return obj - with torch.no_grad(): - custom_block.attn.c_proj.weight.copy_(t_cast(Tensor, hf_block.attn.c_proj.weight).T) - custom_block.attn.c_proj.bias.copy_(t_cast(Tensor, hf_block.attn.c_proj.bias)) - # MLP (transpose HF Conv1D weights to Linear) - with torch.no_grad(): - custom_block.mlp.c_fc.weight.copy_(t_cast(Tensor, hf_block.mlp.c_fc.weight).T) - custom_block.mlp.c_fc.bias.copy_(t_cast(Tensor, hf_block.mlp.c_fc.bias)) +@torch.inference_mode() +def _copy_by_mapping(src: nn.Module, dst: nn.Module, mapping: list[tuple[str, str, bool]]) -> None: + for src_path, dst_path, transpose in mapping: + src_tensor = _resolve_tensor(src, src_path) + dst_tensor = _resolve_tensor(dst, dst_path) + tensor_to_copy = src_tensor.t().contiguous() if transpose else src_tensor + dst_tensor.copy_(tensor_to_copy) - with torch.no_grad(): - custom_block.mlp.c_proj.weight.copy_(t_cast(Tensor, hf_block.mlp.c_proj.weight).T) - custom_block.mlp.c_proj.bias.copy_(t_cast(Tensor, hf_block.mlp.c_proj.bias)) - # Final ln_f - with torch.no_grad(): - model.ln_f.weight.copy_(hf_model.transformer.ln_f.weight) - model.ln_f.bias.copy_(hf_model.transformer.ln_f.bias) +def convert_hf_gpt2_to_gpt2(hf_model: GPT2LMHeadModel) -> GPT2: + """Convert a HuggingFace GPT2LMHeadModel to our custom GPT2. - # LM head - with torch.no_grad(): - model.lm_head.weight.copy_(hf_model.lm_head.weight) + Args: + hf_model: HuggingFace GPT2LMHeadModel instance - return model + Returns: + Our custom GPT2 model with weights copied from the HF model + """ + custom_config = GPT2Config( + block_size=hf_model.config.n_ctx, + vocab_size=hf_model.config.vocab_size, + n_layer=hf_model.config.n_layer, + n_head=hf_model.config.n_head, + n_embd=hf_model.config.n_embd, + flash_attention=True, + ) + custom_model = GPT2(custom_config) + mapping = _build_mapping("hf_to_custom", custom_model.config.n_layer) + _copy_by_mapping(src=hf_model, dst=custom_model, mapping=mapping) + return custom_model def convert_gpt2_to_hf_gpt2(custom_model: GPT2) -> GPT2LMHeadModel: @@ -441,60 +457,20 @@ def convert_gpt2_to_hf_gpt2(custom_model: GPT2) -> GPT2LMHeadModel: Returns: The converted HuggingFace GPT2LMHeadModel """ - model_config: GPT2Config = custom_model.config - hf_config = HFGPT2Config( - vocab_size=model_config.vocab_size, - n_positions=model_config.block_size, - n_ctx=model_config.block_size, - n_layer=model_config.n_layer, - n_head=model_config.n_head, - n_embd=model_config.n_embd, + vocab_size=custom_model.config.vocab_size, + n_positions=custom_model.config.block_size, + n_ctx=custom_model.config.block_size, + n_layer=custom_model.config.n_layer, + n_head=custom_model.config.n_head, + n_embd=custom_model.config.n_embd, activation_function="gelu_new", n_inner=None, layer_norm_epsilon=1e-5, - # Tie embeddings and lm_head as our implementation does tie_word_embeddings=True, ) - hf_model = GPT2LMHeadModel(hf_config) - - # Embeddings - hf_model.transformer.wte.weight.data = custom_model.wte.weight.data - hf_model.transformer.wpe.weight.data = custom_model.wpe.weight.data - - # Transformer blocks - for i in range(model_config.n_layer): - custom_block = custom_model.h[i] - hf_block = hf_model.transformer.h[i] - - # LayerNorms - hf_block.ln_1.weight.data = custom_block.ln_1.weight.data - hf_block.ln_1.bias.data = custom_block.ln_1.bias.data - hf_block.ln_2.weight.data = custom_block.ln_2.weight.data - hf_block.ln_2.bias.data = custom_block.ln_2.bias.data - - # Attention projections: HF uses Conv1D (weight shape [in, out]) - # Our Linear weights are [out, in], so transpose when copying to HF - hf_block.attn.c_attn.weight.data = custom_block.attn.c_attn.weight.data.t().contiguous() - hf_block.attn.c_attn.bias.data = custom_block.attn.c_attn.bias.data - - hf_block.attn.c_proj.weight.data = custom_block.attn.c_proj.weight.data.t().contiguous() - hf_block.attn.c_proj.bias.data = custom_block.attn.c_proj.bias.data - - # MLP projections - hf_block.mlp.c_fc.weight.data = custom_block.mlp.c_fc.weight.data.t().contiguous() - hf_block.mlp.c_fc.bias.data = custom_block.mlp.c_fc.bias.data - - hf_block.mlp.c_proj.weight.data = custom_block.mlp.c_proj.weight.data.t().contiguous() - hf_block.mlp.c_proj.bias.data = custom_block.mlp.c_proj.bias.data - - # Final LayerNorm - hf_model.transformer.ln_f.weight.data = custom_model.ln_f.weight.data - hf_model.transformer.ln_f.bias.data = custom_model.ln_f.bias.data - - # LM head - hf_model.lm_head.weight.data = custom_model.lm_head.weight.data - + mapping = _build_mapping("custom_to_hf", custom_model.config.n_layer) + _copy_by_mapping(src=custom_model, dst=hf_model, mapping=mapping) hf_model.eval() return hf_model From 49cd09afceb31dafc672701504d00d449bb5b01a Mon Sep 17 00:00:00 2001 From: chandanms Date: Thu, 14 Aug 2025 23:30:46 +0200 Subject: [PATCH 11/13] Fixed linter issues --- simple_stories_train/tokenizer.py | 11 +-- tests/test_tokenizer.py | 114 +++++++++++++++++++----------- 2 files changed, 75 insertions(+), 50 deletions(-) diff --git a/simple_stories_train/tokenizer.py b/simple_stories_train/tokenizer.py index e7f899c..860e105 100644 --- a/simple_stories_train/tokenizer.py +++ b/simple_stories_train/tokenizer.py @@ -18,10 +18,6 @@ OUT_DIR = Path("tokenizer") -# Define common affixes for special handling based on morphological analysis of the dataset -COMMON_PREFIXES = ["un", "re"] -COMMON_SUFFIXES = ["ed", "ing", "ly", "er", "ness"] - def clean_dataset(dataset_name: str, column_name: str) -> Generator[str, None, None]: """ @@ -121,12 +117,7 @@ def train_tokenizer(data: Generator[str, None, None], vocab_size: int) -> Tokeni tokenizer = create_tokenizer(vocab_size) special_tokens = ["[UNK]", "[EOS]"] - affixes = COMMON_PREFIXES + COMMON_SUFFIXES - - # Train the tokenizer - trainer = WordPieceTrainer( - vocab_size=vocab_size, special_tokens=special_tokens, initial_alphabet=affixes - ) + trainer = WordPieceTrainer(vocab_size=vocab_size, special_tokens=special_tokens) tokenizer.train_from_iterator(data, trainer=trainer) print("Tokenizer training completed") diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index 0a9e47e..dc4a6c6 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -2,29 +2,7 @@ from collections.abc import Generator -from tokenizers import Tokenizer -from tokenizers.models import WordPiece -from tokenizers.normalizers import Lowercase -from tokenizers.pre_tokenizers import Whitespace -from tokenizers.trainers import WordPieceTrainer - -from simple_stories_train.tokenizer import prune_tokenizer - - -# Create tokenizer once for all tests -def setup_tokenizer(): - tokenizer = Tokenizer(WordPiece(unk_token="[UNK]")) # type: ignore - tokenizer.normalizer = Lowercase() # type: ignore - tokenizer.pre_tokenizer = Whitespace() # type: ignore - - # High vocab size with minimal data ensures many unused tokens - trainer = WordPieceTrainer(vocab_size=200, special_tokens=["[UNK]", "[EOS]"]) - tokenizer.train_from_iterator(["hello world", "hello there", "world peace"], trainer=trainer) - return tokenizer - - -# Global tokenizer for reuse -TEST_TOKENIZER = setup_tokenizer() +from simple_stories_train.tokenizer import prune_tokenizer, train_tokenizer def create_test_data_iterator(test_data: list[str]) -> Generator[str, None, None]: @@ -32,43 +10,99 @@ def create_test_data_iterator(test_data: list[str]) -> Generator[str, None, None yield from test_data +def create_test_tokenizer(): + """Create a fresh tokenizer for testing.""" + train_data = ["hello world", "hello there", "world peace", "simple stories"] + train_iter = create_test_data_iterator(train_data) + return train_tokenizer(train_iter, vocab_size=200) + + def test_special_tokens_preserved(): - test_data = ["hello world"] - data_iter = create_test_data_iterator(test_data) + """Verify special tokens exist in both original and pruned tokenizers.""" + tokenizer = create_test_tokenizer() - pruned = prune_tokenizer(data_iter, TEST_TOKENIZER) - vocab = pruned.get_vocab() + vocab_orig = tokenizer.get_vocab() + assert "[UNK]" in vocab_orig and "[EOS]" in vocab_orig - assert "[UNK]" in vocab and "[EOS]" in vocab - assert vocab["[UNK]"] in [0, 1] and vocab["[EOS]"] in [0, 1] + test_data = ["hello world"] + data_iter = create_test_data_iterator(test_data) + pruned = prune_tokenizer(data_iter, tokenizer) + vocab_pruned = pruned.get_vocab() + assert "[UNK]" in vocab_pruned and "[EOS]" in vocab_pruned + assert vocab_pruned["[UNK]"] in [0, 1] and vocab_pruned["[EOS]"] in [0, 1] def test_unused_tokens_removed(): - original_size = len(TEST_TOKENIZER.get_vocab()) - test_data = ["hello"] # Very limited data + """Verify pruning removes unused tokens from vocabulary.""" + tokenizer = create_test_tokenizer() + + original_size = len(tokenizer.get_vocab()) + test_data = ["hello"] data_iter = create_test_data_iterator(test_data) - pruned = prune_tokenizer(data_iter, TEST_TOKENIZER) + pruned = prune_tokenizer(data_iter, tokenizer) assert len(pruned.get_vocab()) < original_size def test_functionality_preserved(): + """Verify encode/decode functionality works before and after pruning.""" + tokenizer = create_test_tokenizer() + + encoded_orig = tokenizer.encode("hello world") + decoded_orig = tokenizer.decode(encoded_orig.ids) + assert decoded_orig == "hello world" + test_data = ["hello world"] data_iter = create_test_data_iterator(test_data) + pruned = prune_tokenizer(data_iter, tokenizer) + encoded_pruned = pruned.encode("hello world") + decoded_pruned = pruned.decode(encoded_pruned.ids) + assert decoded_pruned == "hello world" - pruned = prune_tokenizer(data_iter, TEST_TOKENIZER) - encoded = pruned.encode("hello world") - decoded = pruned.decode(encoded.ids) - assert decoded == "hello world" +def test_sequential_ids(): + """Verify pruned tokenizer has sequential token IDs starting from 0.""" + tokenizer = create_test_tokenizer() + token_ids_orig = sorted(tokenizer.get_vocab().values()) + assert token_ids_orig[0] >= 0 + assert len(token_ids_orig) == len(set(token_ids_orig)) -def test_sequential_ids(): test_data = ["hello"] data_iter = create_test_data_iterator(test_data) + pruned = prune_tokenizer(data_iter, tokenizer) + token_ids_pruned = sorted(pruned.get_vocab().values()) + assert token_ids_pruned == list(range(len(token_ids_pruned))) + + +def test_eos_appended(): + """Verify EOS token is appended as last token before and after pruning.""" + tokenizer = create_test_tokenizer() + + eos_id_orig = tokenizer.token_to_id("[EOS]") + encoded_orig = tokenizer.encode("hello world") + assert encoded_orig.ids[-1] == eos_id_orig + + test_data = ["hello world"] + data_iter = create_test_data_iterator(test_data) + pruned = prune_tokenizer(data_iter, tokenizer) + eos_id_pruned = pruned.token_to_id("[EOS]") + encoded_pruned = pruned.encode("hello world") + assert encoded_pruned.ids[-1] == eos_id_pruned + - pruned = prune_tokenizer(data_iter, TEST_TOKENIZER) - token_ids = sorted(pruned.get_vocab().values()) +def test_unk_for_unknown_words(): + """Verify UNK token is used for unknown words before and after pruning.""" + tokenizer = create_test_tokenizer() - assert token_ids == list(range(len(token_ids))) + unk_id_orig = tokenizer.token_to_id("[UNK]") + encoded_orig = tokenizer.encode("antidisestablishmentarianism") + assert unk_id_orig in encoded_orig.ids + + test_data = ["hello world"] + data_iter = create_test_data_iterator(test_data) + pruned = prune_tokenizer(data_iter, tokenizer) + unk_id_pruned = pruned.token_to_id("[UNK]") + encoded_pruned = pruned.encode("antidisestablishmentarianism") + assert unk_id_pruned in encoded_pruned.ids From 04b655870be35935bb21194033e3b89e9ee92f16 Mon Sep 17 00:00:00 2001 From: chandanms Date: Fri, 15 Aug 2025 19:43:33 +0200 Subject: [PATCH 12/13] Made the tests more strict for verifying existance of EOS and UNK tokens; Made the data to tokenizer training iterable. --- simple_stories_train/tokenizer.py | 6 +++--- tests/test_tokenizer.py | 34 ++++++++++--------------------- 2 files changed, 14 insertions(+), 26 deletions(-) diff --git a/simple_stories_train/tokenizer.py b/simple_stories_train/tokenizer.py index 860e105..1126a15 100644 --- a/simple_stories_train/tokenizer.py +++ b/simple_stories_train/tokenizer.py @@ -2,7 +2,7 @@ This file is inspired from Nix Goldowsky-Dill's adaption of the tokenizer in https://github.com/juand-r/tiny_tokenizer. """ -from collections.abc import Generator +from collections.abc import Generator, Iterable from pathlib import Path from datasets import DatasetDict, IterableDatasetDict, load_dataset @@ -97,7 +97,7 @@ def create_tokenizer(vocab_size: int) -> Tokenizer: return tokenizer -def train_tokenizer(data: Generator[str, None, None], vocab_size: int) -> Tokenizer: +def train_tokenizer(data: Iterable[str], vocab_size: int) -> Tokenizer: """ Train the tokenizer with the specified vocabulary size and cleaned data. @@ -143,7 +143,7 @@ def save_tokenizer(tokenizer: Tokenizer, tokenizer_name: str) -> str: return tokenizer_path -def prune_tokenizer(data: Generator[str, None, None], tokenizer: Tokenizer) -> Tokenizer: +def prune_tokenizer(data: Iterable[str], tokenizer: Tokenizer) -> Tokenizer: """ Prune tokenizer by removing unused tokens and reordering IDs sequentially. Note: [UNK] token is handled automatically by WordPiece constructor, diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index dc4a6c6..62d4dce 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -1,35 +1,29 @@ """Simple test for tokenizer pruning functionality.""" -from collections.abc import Generator - from simple_stories_train.tokenizer import prune_tokenizer, train_tokenizer -def create_test_data_iterator(test_data: list[str]) -> Generator[str, None, None]: - """Create iterator from test data.""" - yield from test_data - - def create_test_tokenizer(): """Create a fresh tokenizer for testing.""" train_data = ["hello world", "hello there", "world peace", "simple stories"] - train_iter = create_test_data_iterator(train_data) - return train_tokenizer(train_iter, vocab_size=200) + return train_tokenizer(iter(train_data), vocab_size=200) def test_special_tokens_preserved(): - """Verify special tokens exist in both original and pruned tokenizers.""" + """Verify special tokens exist and are unique in both original and pruned tokenizers.""" tokenizer = create_test_tokenizer() vocab_orig = tokenizer.get_vocab() assert "[UNK]" in vocab_orig and "[EOS]" in vocab_orig + assert vocab_orig["[UNK]"] in [0, 1] and vocab_orig["[EOS]"] in [0, 1] + assert vocab_orig["[UNK]"] != vocab_orig["[EOS]"] test_data = ["hello world"] - data_iter = create_test_data_iterator(test_data) - pruned = prune_tokenizer(data_iter, tokenizer) + pruned = prune_tokenizer(iter(test_data), tokenizer) vocab_pruned = pruned.get_vocab() assert "[UNK]" in vocab_pruned and "[EOS]" in vocab_pruned assert vocab_pruned["[UNK]"] in [0, 1] and vocab_pruned["[EOS]"] in [0, 1] + assert vocab_pruned["[UNK]"] != vocab_pruned["[EOS]"] def test_unused_tokens_removed(): @@ -38,9 +32,8 @@ def test_unused_tokens_removed(): original_size = len(tokenizer.get_vocab()) test_data = ["hello"] - data_iter = create_test_data_iterator(test_data) - pruned = prune_tokenizer(data_iter, tokenizer) + pruned = prune_tokenizer(iter(test_data), tokenizer) assert len(pruned.get_vocab()) < original_size @@ -54,8 +47,7 @@ def test_functionality_preserved(): assert decoded_orig == "hello world" test_data = ["hello world"] - data_iter = create_test_data_iterator(test_data) - pruned = prune_tokenizer(data_iter, tokenizer) + pruned = prune_tokenizer(iter(test_data), tokenizer) encoded_pruned = pruned.encode("hello world") decoded_pruned = pruned.decode(encoded_pruned.ids) assert decoded_pruned == "hello world" @@ -70,8 +62,7 @@ def test_sequential_ids(): assert len(token_ids_orig) == len(set(token_ids_orig)) test_data = ["hello"] - data_iter = create_test_data_iterator(test_data) - pruned = prune_tokenizer(data_iter, tokenizer) + pruned = prune_tokenizer(iter(test_data), tokenizer) token_ids_pruned = sorted(pruned.get_vocab().values()) assert token_ids_pruned == list(range(len(token_ids_pruned))) @@ -85,8 +76,7 @@ def test_eos_appended(): assert encoded_orig.ids[-1] == eos_id_orig test_data = ["hello world"] - data_iter = create_test_data_iterator(test_data) - pruned = prune_tokenizer(data_iter, tokenizer) + pruned = prune_tokenizer(iter(test_data), tokenizer) eos_id_pruned = pruned.token_to_id("[EOS]") encoded_pruned = pruned.encode("hello world") assert encoded_pruned.ids[-1] == eos_id_pruned @@ -95,14 +85,12 @@ def test_eos_appended(): def test_unk_for_unknown_words(): """Verify UNK token is used for unknown words before and after pruning.""" tokenizer = create_test_tokenizer() - unk_id_orig = tokenizer.token_to_id("[UNK]") encoded_orig = tokenizer.encode("antidisestablishmentarianism") assert unk_id_orig in encoded_orig.ids test_data = ["hello world"] - data_iter = create_test_data_iterator(test_data) - pruned = prune_tokenizer(data_iter, tokenizer) + pruned = prune_tokenizer(iter(test_data), tokenizer) unk_id_pruned = pruned.token_to_id("[UNK]") encoded_pruned = pruned.encode("antidisestablishmentarianism") assert unk_id_pruned in encoded_pruned.ids From 6cb28de322aa6ea5dd19af0b0e17fa96a56ad770 Mon Sep 17 00:00:00 2001 From: chandanms Date: Fri, 15 Aug 2025 22:20:29 +0200 Subject: [PATCH 13/13] Updated the readme file --- simple_stories_train/README.md | 58 ++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 simple_stories_train/README.md diff --git a/simple_stories_train/README.md b/simple_stories_train/README.md new file mode 100644 index 0000000..24a4919 --- /dev/null +++ b/simple_stories_train/README.md @@ -0,0 +1,58 @@ +# simple_stories_train + +Training framework for small language models using SimpleStories, a large-scale synthetic dataset of over 2 million short stories in simple language. + +**Paper:** [Parameterized Synthetic Text Generation with SimpleStories](https://arxiv.org/abs/2504.09184) +**Models & Dataset:** [🤗 SimpleStories on Hugging Face](https://huggingface.co/SimpleStories) + +## Installation + +From the root of the repository, run one of + +```bash +make install-dev # To install the package, dev requirements and pre-commit hooks +make install # To just install the package (runs `pip install -e .`) +``` + +## Development + +Suggested extensions and settings for VSCode are provided in `.vscode/`. To use the suggested +settings, copy `.vscode/settings-example.json` to `.vscode/settings.json`. + +There are various `make` commands that may be helpful + +```bash +make check # Run pre-commit on all files (i.e. pyright, ruff linter, and ruff formatter) +make type # Run pyright on all files +make format # Run ruff linter and formatter on all files +make test # Run tests that aren't marked `slow` +make test-all # Run all tests +``` + +## Usage + +### Training a model +```bash +python -m simple_stories_train.train [PATH/TO/CONFIG.yaml] [--key1 value1 --key2 value2 ...] +``` +where +- `PATH/TO/CONFIG.yaml` contains the training config. If no path is provided, a default config will be used. +- `--key1 value1 --key2 value2 ...` override values in the config. Note that if you wish to update a + nested value, you must use dotted notation (e.g. `--train_dataset_config.name my_dataset`). + +If running on CPU, you may need to set `--compile=False`. + +To run on multiple GPUs, use +``` +torchrun --standalone --nproc_per_node=N -m simple_stories_train.train ... +``` +where `N` is the number of GPUs to use. + +### Logging with Weights & Biases +To track training with Weights & Biases, you can set the WANDB_PROJECT and WANDB_API_KEY variables in +`.env`. API keys can be obtained from your [Weights & Biases account settings](https://wandb.ai/settings). + +## Acknowledgments + +- Training script is based on the efficient [train_gpt2.py](https://github.com/karpathy/llm.c/blob/master/train_gpt2.py) in [llm.c](https://github.com/karpathy/llm.c) (licensed under MIT ((c) 2024 Andrej Karpathy)) +- Some model architecture implementations are based on [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) (licensed under MIT ((c) 2022 TransformerLensOrg))