diff --git a/.gitignore b/.gitignore index 2596db4..230277f 100644 --- a/.gitignore +++ b/.gitignore @@ -428,3 +428,9 @@ flycheck_*.el # network security /network-security.data + +# Re-include the committed experimental subpackage - the global +# experimental / experimental/ patterns above are for ignored +# research scratch dirs, not for the packaged module we ship. +!open_mythos/experimental/ +!open_mythos/experimental/** diff --git a/open_mythos/experimental/__init__.py b/open_mythos/experimental/__init__.py new file mode 100644 index 0000000..05c51bd --- /dev/null +++ b/open_mythos/experimental/__init__.py @@ -0,0 +1,42 @@ +""" +Experimental / research-line modules. + +These components are **not** part of the canonical OpenMythos architecture +(Prelude / Recurrent / Coda with MLA + DeepSeek-MoE) exposed at the package +root. They live here to be importable for research without polluting the +public API surface, and their contracts (names, signatures, behavior) are +explicitly unstable. + +Included: + - MoDA (Mixture-of-Depths Attention): a depth-aware attention variant + that attends across layer depth in addition to sequence position, + fused with DeepSeek-style MoE FFNs. + +Stability: **no guarantees**. Import at your own risk; APIs here may change +or disappear in any commit. Do not build production training configs that +depend on this subpackage. +""" + +from open_mythos.experimental.moda import ( + DeepSeekExpert, + DeepSeekGate, + DeepSeekMoE, + MoDAAttention, + MoDABlock, + MoDAConfig, + MoDAModel, + RMSNorm, + RotaryEmbedding, +) + +__all__ = [ + "MoDAConfig", + "MoDAModel", + "MoDABlock", + "MoDAAttention", + "DeepSeekExpert", + "DeepSeekGate", + "DeepSeekMoE", + "RMSNorm", + "RotaryEmbedding", +] diff --git a/open_mythos/moda.py b/open_mythos/experimental/moda.py similarity index 94% rename from open_mythos/moda.py rename to open_mythos/experimental/moda.py index e662d61..c549218 100644 --- a/open_mythos/moda.py +++ b/open_mythos/experimental/moda.py @@ -1063,72 +1063,3 @@ def extra_repr(self) -> str: ) -# # --------------------------------------------------------------------------- -# # Smoke test -# # --------------------------------------------------------------------------- - -# if __name__ == "__main__": -# torch.manual_seed(42) -# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# print(f"Device: {device}") - -# # Tiny config: 4 layers, 8 routed experts, top-2 -# cfg = MoDAConfig( -# vocab_size=512, -# d_model=128, -# n_layers=4, -# n_heads_q=4, -# n_heads_kv=2, -# head_dim=32, -# max_seq_len=64, -# # MoE: 2 shared + 8 routed, activate top-2 -# # (2+2)*64 = 256 ≈ equivalent to dense SwiGLU hidden~256 -# n_shared_experts=2, -# n_routed_experts=8, -# n_activated_experts=2, -# expert_hidden_dim=64, -# moe_balance_alpha=0.01, -# moe_score_func="softmax", -# ) - -# model = MoDAModel(cfg).to(device) -# print(f"Parameters: {model.num_parameters():,}") -# print(model) - -# B, T = 2, 32 -# input_ids = torch.randint(0, cfg.vocab_size, (B, T), device=device) -# labels = torch.randint(0, cfg.vocab_size, (B, T), device=device) - -# logits, loss = model(input_ids, labels) -# assert logits.shape == (B, T, cfg.vocab_size) -# print(f"Logits shape : {logits.shape}") -# print(f"Loss (LM + balance): {loss.item():.4f}") - -# loss.backward() - -# # Verify gradients -# last_writes = { -# f"blocks.{cfg.n_layers - 1}.k_write.weight", -# f"blocks.{cfg.n_layers - 1}.v_write.weight", -# } -# missing = [ -# name -# for name, p in model.named_parameters() -# if p.grad is None and name not in last_writes -# ] -# if missing: -# print(f"WARNING — unexpected missing gradients: {missing}") -# else: -# print("All parameters received gradients (excluding last-block writes).") - -# # Spot-check: MoE gate weights must receive gradients (through balance loss P_i) -# gate0_grad = model.blocks[0].moe.gate.weight.grad -# assert gate0_grad is not None, "blocks[0].moe.gate.weight has no gradient!" -# print(f"blocks[0].moe.gate.weight grad norm : {gate0_grad.norm().item():.6f}") - -# # Spot-check: depth write projections gradient flows from layer ≥ 1 depth reads -# k0_grad = model.blocks[0].k_write.weight.grad -# assert k0_grad is not None, "blocks[0].k_write.weight has no gradient!" -# print(f"blocks[0].k_write.weight grad norm : {k0_grad.norm().item():.6f}") - -# print("Smoke test passed.") diff --git a/open_mythos/main.py b/open_mythos/main.py index 10de093..4f87b70 100644 --- a/open_mythos/main.py +++ b/open_mythos/main.py @@ -72,6 +72,125 @@ class MythosConfig: max_output_tokens: int = 4096 # Dropout (set 0.0 to disable; 0.1 is standard for pretraining) dropout: float = 0.0 + # Aux-loss-free load-balancing (DeepSeek-V3). The router_bias is nudged + # toward under-used experts after each training step by `bias_update_speed`. + # Set bias_update_speed=0 to freeze the bias (pure unbiased routing). + bias_update_speed: float = 1e-3 + # Loop-index embedding base frequency (separate from sequence-RoPE theta + # because the recurrence depth is on a different scale than token position). + loop_rope_theta: float = 10000.0 + # Initial gain for the LTIInjection input gate B. Smaller = gentler input + # injection at the start of training; 0.1 is the published default. + lti_b_init: float = 0.1 + # Weight init stddev for Linear and Embedding layers. + init_std: float = 0.02 + + def __post_init__(self) -> None: + """ + Validate hyperparameters at construction time so misconfigurations + surface with a clear error message instead of a cryptic shape failure + deep inside the forward pass. + """ + # --- attention type --- + if self.attn_type not in ("gqa", "mla"): + raise ValueError( + f"attn_type must be 'gqa' or 'mla', got {self.attn_type!r}" + ) + + # --- core dims --- + if self.dim <= 0 or self.n_heads <= 0: + raise ValueError( + f"dim ({self.dim}) and n_heads ({self.n_heads}) must be positive" + ) + if self.dim % self.n_heads != 0: + raise ValueError( + f"dim ({self.dim}) must be divisible by n_heads ({self.n_heads})" + ) + + # --- GQA head grouping --- + if self.attn_type == "gqa": + if self.n_kv_heads <= 0: + raise ValueError(f"n_kv_heads must be positive, got {self.n_kv_heads}") + if self.n_heads % self.n_kv_heads != 0: + raise ValueError( + f"n_heads ({self.n_heads}) must be divisible by " + f"n_kv_heads ({self.n_kv_heads}) for GQA" + ) + + # --- MLA RoPE head dim must be even (complex-pair reshape) --- + if self.attn_type == "mla": + if self.qk_rope_head_dim <= 0 or self.qk_rope_head_dim % 2 != 0: + raise ValueError( + f"qk_rope_head_dim ({self.qk_rope_head_dim}) must be a " + "positive even integer for MLA" + ) + if self.kv_lora_rank <= 0 or self.q_lora_rank <= 0: + raise ValueError( + "kv_lora_rank and q_lora_rank must be positive for MLA" + ) + + # --- GQA head_dim must be even too (complex-pair RoPE) --- + head_dim = self.dim // self.n_heads + if self.attn_type == "gqa" and head_dim % 2 != 0: + raise ValueError( + f"head_dim=dim/n_heads ({head_dim}) must be even for GQA " + "so RoPE's complex reshape succeeds" + ) + + # --- loop layers --- + if self.max_loop_iters <= 0: + raise ValueError( + f"max_loop_iters must be positive, got {self.max_loop_iters}" + ) + if self.prelude_layers < 0 or self.coda_layers < 0: + raise ValueError("prelude_layers and coda_layers must be >= 0") + + # --- loop_dim (sinusoidal index embedding) must be even --- + loop_dim = self.dim // 8 + if loop_dim % 2 != 0: + raise ValueError( + f"cfg.dim // 8 ({loop_dim}) must be even for the loop-index " + f"embedding's sin/cos split (got dim={self.dim})" + ) + + # --- MoE --- + if self.n_experts <= 0 or self.expert_dim <= 0: + raise ValueError("n_experts and expert_dim must be positive") + if self.n_experts_per_tok <= 0: + raise ValueError("n_experts_per_tok must be positive") + if self.n_experts_per_tok > self.n_experts: + raise ValueError( + f"n_experts_per_tok ({self.n_experts_per_tok}) cannot exceed " + f"n_experts ({self.n_experts})" + ) + if self.n_shared_experts < 0: + raise ValueError("n_shared_experts must be >= 0") + + # --- ACT threshold must be in (0, 1] --- + if not 0.0 < self.act_threshold <= 1.0: + raise ValueError( + f"act_threshold must be in (0, 1], got {self.act_threshold}" + ) + + # --- dropout / init / bias speed sanity --- + if not 0.0 <= self.dropout < 1.0: + raise ValueError(f"dropout must be in [0, 1), got {self.dropout}") + if self.bias_update_speed < 0.0: + raise ValueError( + f"bias_update_speed must be >= 0, got {self.bias_update_speed}" + ) + if self.init_std <= 0.0: + raise ValueError(f"init_std must be positive, got {self.init_std}") + + # --- LoRA --- + if self.lora_rank <= 0: + raise ValueError(f"lora_rank must be positive, got {self.lora_rank}") + + # --- RoPE / sequence --- + if self.max_seq_len <= 0: + raise ValueError(f"max_seq_len must be positive, got {self.max_seq_len}") + if self.vocab_size <= 0: + raise ValueError(f"vocab_size must be positive, got {self.vocab_size}") # --------------------------------------------------------------------------- @@ -221,6 +340,12 @@ def forward( k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim) v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim) + # Defensively slice freqs_cis to the current T. OpenMythos.forward + # already pre-slices with a start_pos offset; callers that skip + # that step (unit tests, ad-hoc scripts) would otherwise crash in + # apply_rope's broadcast shape check. + if freqs_cis.shape[0] != T: + freqs_cis = freqs_cis[:T] q = apply_rope(q, freqs_cis) k = apply_rope(k, freqs_cis) @@ -242,7 +367,11 @@ def forward( attn = torch.matmul(q, k.transpose(-2, -1)) * scale if mask is not None: attn = attn + mask - attn = self.attn_drop(F.softmax(attn, dim=-1)) + # Upcast softmax to fp32: bf16 softmax loses precision on the tail, + # collapsing attention toward uniform or one-hot at long sequence + # lengths. Cost is tiny; numerical stability win is real. + attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype) + attn = self.attn_drop(attn) out = torch.matmul(attn, v) out = out.transpose(1, 2).contiguous().view(B, T, -1) return self.wo(out) @@ -340,6 +469,10 @@ def forward( """ B, T, _ = x.shape + # Defensive slice — see the matching comment in GQAttention.forward. + if freqs_cis.shape[0] != T: + freqs_cis = freqs_cis[:T] + # Q c_q = self.q_norm(self.q_down(x)) q_nope = self.q_up_nope(c_q).view(B, T, self.n_heads, self.qk_nope_dim) @@ -384,7 +517,10 @@ def forward( attn = torch.matmul(q, k.transpose(-2, -1)) * scale if mask is not None: attn = attn + mask - attn = self.attn_drop(F.softmax(attn, dim=-1)) + # Upcast softmax to fp32 for numerical stability under bf16/fp16 + # autocast (see GQAttention for rationale). + attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype) + attn = self.attn_drop(attn) out = torch.matmul(attn, v) # (B, H, T, v_dim) out = out.transpose(1, 2).contiguous().view(B, T, -1) return self.wo(out) @@ -453,8 +589,20 @@ def __init__(self, cfg: MythosConfig): self.topk = cfg.n_experts_per_tok self.router = nn.Linear(cfg.dim, cfg.n_experts, bias=False) - # load-balancing bias adjusted externally during training; not a gradient param + # Aux-loss-free bias (DeepSeek-V3): shifts expert SELECTION so that + # under-used experts get picked more, but never enters the gradient + # path (topk_scores come from the unbiased softmax). Updated in-place + # by OpenMythos.update_router_biases() after each optimizer step. self.register_buffer("router_bias", torch.zeros(cfg.n_experts)) + # Per-step load counter. Accumulates across microbatches during + # gradient accumulation; all-reduced across ranks inside update_bias() + # under FSDP/DDP, then zeroed. Non-persistent: rebuilt from scratch + # on every resume, not saved to checkpoints (router_bias is persisted). + self.register_buffer( + "expert_load", + torch.zeros(cfg.n_experts, dtype=torch.float32), + persistent=False, + ) self.routed_experts = nn.ModuleList( [Expert(cfg.dim, cfg.expert_dim) for _ in range(cfg.n_experts)] @@ -476,34 +624,108 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ B, T, D = x.shape flat = x.view(B * T, D) + N = B * T + # --- Router ----------------------------------------------------- # Aux-loss-free load balancing (DeepSeek-V3): the bias shifts only the - # selection of which experts fire so underused experts are picked more, - # but the gating weights come from unbiased softmax scores so the bias - # never shows up in the gradient. - logits = self.router(flat) # (B*T, n_experts), unbiased - scores = F.softmax(logits, dim=-1) - _, topk_idx = (logits + self.router_bias).topk(self.topk, dim=-1) - topk_scores = scores.gather(-1, topk_idx) - topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True) # renorm - - # routed expert dispatch (token-level scatter) + # SELECTION of which experts fire so underused experts are picked more, + # but the gating weights come from the unbiased softmax so the bias + # never enters the gradient. Softmax upcast to fp32 for numerical + # stability under bf16/fp16 autocast. + logits = self.router(flat) # (N, E) + scores = F.softmax(logits, dim=-1, dtype=torch.float32).to(flat.dtype) + _, topk_idx = (logits.float() + self.router_bias).topk(self.topk, dim=-1) + topk_scores = scores.gather(-1, topk_idx) # (N, K) + # clamp_min keeps the renorm denominator out of zero even when every + # selected expert's unbiased score underflows to ~0 in bf16/fp16. + denom = topk_scores.sum(dim=-1, keepdim=True).clamp_min(1e-9) + topk_scores = topk_scores / denom + + # Track expert load for the aux-loss-free bias update. Accumulates + # across microbatches; flushed by update_bias() post-optimizer-step. + if self.training: + with torch.no_grad(): + load = torch.bincount( + topk_idx.reshape(-1), minlength=self.n_experts + ).to(self.expert_load.dtype) + self.expert_load.add_(load) + + # --- Vectorized dispatch --------------------------------------- + # One iteration per expert (not per (topk, expert) pair): gather + # the token slice routed to expert e, run a single dense matmul, + # scatter-add back weighted by the gate. Replaces the original + # O(topk * n_experts) Python loop with O(n_experts) dense ops. + flat_idx = topk_idx.reshape(-1) # (N*K,) + flat_scores = topk_scores.reshape(-1, 1) # (N*K, 1) + token_ids = ( + torch.arange(N, device=flat.device) + .unsqueeze(-1) + .expand(N, self.topk) + .reshape(-1) + ) # (N*K,) + + # Stable-sort by expert so each expert's tokens are contiguous. + sort_order = flat_idx.argsort(stable=True) + sorted_experts = flat_idx[sort_order] + sorted_tokens = token_ids[sort_order] + sorted_scores = flat_scores[sort_order] + + # Count tokens per expert; cumulative sum gives slice boundaries. + counts = torch.bincount(sorted_experts, minlength=self.n_experts) + offsets = torch.cat( + [ + torch.zeros(1, dtype=counts.dtype, device=counts.device), + counts.cumsum(0), + ] + ) + out = torch.zeros_like(flat) - for i in range(self.topk): - expert_ids = topk_idx[:, i] - token_scores = topk_scores[:, i].unsqueeze(-1) - for eid in range(self.n_experts): - mask = expert_ids == eid - if not mask.any(): - continue - out[mask] += token_scores[mask] * self.routed_experts[eid](flat[mask]) - - # shared experts always fire for every token + for eid in range(self.n_experts): + start = int(offsets[eid]) + end = int(offsets[eid + 1]) + if start == end: + continue + tok_slice = sorted_tokens[start:end] + weight_slice = sorted_scores[start:end] + expert_out = self.routed_experts[eid](flat[tok_slice]) + out.index_add_(0, tok_slice, weight_slice * expert_out) + + # Shared experts always fire for every token. for shared in self.shared_experts: out = out + shared(flat) return out.view(B, T, D) + @torch.no_grad() + def update_bias(self, speed: float) -> None: + """ + Aux-loss-free load-balancing update (DeepSeek-V3, Eq. 16). + + After each optimizer step, nudge `router_bias` toward under-used experts + so that subsequent forward passes route more tokens to them. Uses the + sign of (mean_load - expert_load) so the update magnitude is bounded + per step — the speed constant alone controls how aggressive balancing + is, independent of the imbalance magnitude. + + Under FSDP/DDP, callers MUST all-reduce `expert_load` across ranks + before calling this method (see OpenMythos.update_router_biases). + After applying the update `expert_load` is zeroed. + + Args: + speed -- per-step update magnitude. DeepSeek-V3 uses ~1e-3. + """ + if speed == 0.0: + self.expert_load.zero_() + return + total = self.expert_load.sum() + if total <= 0: + # No training forward pass happened since last update; nothing to do. + return + mean = total / self.n_experts + direction = torch.sign(mean - self.expert_load).to(self.router_bias.dtype) + self.router_bias.add_(direction, alpha=speed) + self.expert_load.zero_() + # --------------------------------------------------------------------------- # Loop-index RoPE (differentiates recurrent block across iterations) @@ -673,15 +895,18 @@ class LTIInjection(nn.Module): even at high learning rates. """ - def __init__(self, dim: int): + def __init__(self, dim: int, b_init: float = 0.1): """ Args: - dim -- hidden state dimension; one scalar per channel for A and B + dim -- hidden state dimension; one scalar per channel for A and B + b_init -- initial magnitude of the input-gate B. Smaller keeps the + residual signal from the encoded input gentle at startup + so training can find a stable regime before B grows. """ super().__init__() self.log_A = nn.Parameter(torch.zeros(dim)) # log of A_continuous magnitude self.log_dt = nn.Parameter(torch.zeros(1)) # log of discretization step Δt - self.B = nn.Parameter(torch.ones(dim) * 0.1) + self.B = nn.Parameter(torch.ones(dim) * b_init) def get_A(self) -> torch.Tensor: """ @@ -693,8 +918,13 @@ def get_A(self) -> torch.Tensor: """ # Compute in log space to avoid 0 * inf = NaN when log_dt → -∞, log_A → +∞. # dt * A_c = -exp(log_dt) * exp(log_A) = -exp(log_dt + log_A) - # Clamp keeps the product finite in float32 for any gradient step size. - return torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20))) + # + # Lower clamp at -10 (not -20): when the sum saturates at -20, + # exp(-exp(-20)) ≈ 1 - 2e-9, which rounds to exactly 1.0 in float32 + # and breaks the strict ρ(A) < 1 invariant LTI depends on. At -10, + # exp(-exp(-10)) ≈ 1 - 4.5e-5 stays strictly below 1 with room to + # spare. Upper bound unchanged — it's only for finiteness. + return torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-10, 20))) def forward( self, h: torch.Tensor, e: torch.Tensor, transformer_out: torch.Tensor @@ -786,7 +1016,7 @@ def __init__(self, cfg: MythosConfig): super().__init__() self.cfg = cfg self.block = TransformerBlock(cfg, use_moe=True) - self.injection = LTIInjection(cfg.dim) + self.injection = LTIInjection(cfg.dim, b_init=cfg.lti_b_init) self.act = ACTHalting(cfg.dim) self.lora = LoRAAdapter(cfg.dim, cfg.lora_rank, cfg.max_loop_iters) self.norm = RMSNorm(cfg.dim) @@ -827,7 +1057,9 @@ def forward( h_out = torch.zeros_like(h) for t in range(n_loops): - h_loop = loop_index_embedding(h, t, self.loop_dim) + h_loop = loop_index_embedding( + h, t, self.loop_dim, theta=self.cfg.loop_rope_theta + ) combined = self.norm(h_loop + e) cache_key = f"recurrent_loop_{t}" trans_out = self.block(combined, freqs_cis, mask, kv_cache, cache_key) @@ -930,12 +1162,25 @@ def __init__(self, cfg: MythosConfig): self._init_weights() def _init_weights(self) -> None: - """Initialize all linear and embedding weights with N(0, 0.02).""" + """ + Initialize all linear and embedding weights with N(0, cfg.init_std). + Router linear layers are initialized with a smaller std so expert + selection starts near-uniform and depends mostly on the aux-loss-free + router_bias during early training. + """ + std = self.cfg.init_std for m in self.modules(): if isinstance(m, nn.Linear): - nn.init.normal_(m.weight, std=0.02) + nn.init.normal_(m.weight, std=std) elif isinstance(m, nn.Embedding): - nn.init.normal_(m.weight, std=0.02) + nn.init.normal_(m.weight, std=std) + # Nudge router logits closer to uniform at init: with std=0.02 a + # 12288-dim router still has enough signal to spike single experts + # before training has shaped them. Small std keeps routing diffuse + # long enough for the bias update to find a balance. + for m in self.modules(): + if isinstance(m, MoEFFN): + nn.init.normal_(m.router.weight, std=std * 0.1) @staticmethod def _causal_mask(seq_len: int, device: torch.device) -> torch.Tensor: @@ -996,6 +1241,42 @@ def forward( return self.head(self.norm(x)) + @torch.no_grad() + def update_router_biases( + self, + speed: Optional[float] = None, + ddp: bool = False, + ) -> None: + """ + Advance the aux-loss-free router_bias for every MoE layer. + + Call this AFTER `optimizer.step()` each training step so the bias + update is not entangled with the weight gradients (the bias deliberately + sits outside the autograd graph). Under FSDP/DDP this all-reduces each + layer's `expert_load` across ranks before applying the update so every + rank converges to the same bias values without an extra checkpoint sync. + + Args: + speed -- per-step bias step size; falls back to cfg.bias_update_speed. + Pass 0.0 to freeze the bias (pure unbiased routing). + ddp -- True to enable the cross-rank all-reduce on expert_load. + """ + step = self.cfg.bias_update_speed if speed is None else speed + if step == 0.0: + # Even when frozen, flush counters so memory doesn't grow unbounded. + for m in self.modules(): + if isinstance(m, MoEFFN): + m.expert_load.zero_() + return + for m in self.modules(): + if not isinstance(m, MoEFFN): + continue + if ddp and torch.distributed.is_initialized(): + torch.distributed.all_reduce( + m.expert_load, op=torch.distributed.ReduceOp.SUM + ) + m.update_bias(step) + @torch.no_grad() def generate( self, diff --git a/open_mythos/tokenizer.py b/open_mythos/tokenizer.py index fadb3a5..52188e2 100644 --- a/open_mythos/tokenizer.py +++ b/open_mythos/tokenizer.py @@ -1,64 +1,150 @@ +"""HuggingFace tokenizer wrapper with vocab sizing + EOS-aware encoding.""" + +from typing import Optional + from transformers import AutoTokenizer DEFAULT_MODEL_ID = "openai/gpt-oss-20b" +# Upper bound on a single document before we truncate. FineWeb-Edu has +# pathological outliers (single "documents" that are gigabytes of concatenated +# dumps); those stall DataLoader workers and occasionally OOM the tokenizer. +MAX_CHARS_PER_DOC = 4_000_000 # ~1M BPE tokens — well past any legitimate doc + class MythosTokenizer: """ HuggingFace tokenizer wrapper for OpenMythos. - Args: - model_id (str): The HuggingFace model ID or path to use with AutoTokenizer. - Defaults to "openai/gpt-oss-20b". + Key behavior: + - ``vocab_size`` returns ``len(self.tokenizer)`` (base vocab + added + specials), NOT the base model's nominal vocab size. This matches the ID + space reachable via ``encode``/``decode``, so ``nn.Embedding(vocab_size, ...)`` + sized from this property cannot index out of range. + - Optionally rounds up to ``vocab_multiple_of`` (default 128) so embedding + matrices align to tensor-core-friendly widths. + - ``encode`` defends against None / non-str / huge inputs so a single bad + sample does not crash the DataLoader worker. + - ``encode_with_eos`` appends the EOS token id so the document packer can + inject a boundary between concatenated docs, preventing the model from + learning spurious cross-document attention. - Attributes: - tokenizer: An instance of HuggingFace's AutoTokenizer. + Args: + model_id -- HF model id or local tokenizer path. + vocab_multiple_of -- round vocab_size up to this multiple (0 = off). Example: >>> tok = MythosTokenizer() >>> ids = tok.encode("Hello world") - >>> s = tok.decode(ids) + >>> text = tok.decode(ids) """ - def __init__(self, model_id: str = DEFAULT_MODEL_ID): + def __init__( + self, + model_id: str = DEFAULT_MODEL_ID, + vocab_multiple_of: int = 128, + ): """ Initialize the MythosTokenizer. Args: - model_id (str): HuggingFace model identifier or path to tokenizer files. + model_id -- HF model id or path to a tokenizer directory. + vocab_multiple_of -- if >0, round ``vocab_size`` up to this multiple + for tensor-core-friendly embedding widths. + 128 matches the Llama / DeepSeek convention. """ - self.tokenizer = AutoTokenizer.from_pretrained(model_id) + # Explicit trust_remote_code=False pins safe behavior across future + # transformers upgrades even if defaults shift. + self.tokenizer = AutoTokenizer.from_pretrained( + model_id, + trust_remote_code=False, + ) + self.vocab_multiple_of = vocab_multiple_of @property def vocab_size(self) -> int: """ - Return the size of the tokenizer vocabulary. + Number of token IDs the model embedding must cover. - Returns: - int: The number of unique tokens in the tokenizer vocabulary. + Returns ``len(self.tokenizer)`` (base vocab + added special tokens), + rounded up to ``vocab_multiple_of`` when enabled. ``tokenizer.vocab_size`` + on HF returns the base vocab and silently excludes added specials, + which is a common source of CUDA ``device-side assert`` on long runs. """ - return self.tokenizer.vocab_size + true_size = len(self.tokenizer) + m = self.vocab_multiple_of + if m and m > 1: + return ((true_size + m - 1) // m) * m + return true_size - def encode(self, text: str) -> list[int]: + @property + def eos_token_id(self) -> Optional[int]: + """ + Token id used as a document boundary. Prefers ``eos_token_id``, falls + back to ``bos_token_id``, then to the first defined special id, then + to None. The trainer's doc packer injects this between concatenated + samples so the model never sees a cross-document boundary without a + marker at train time. """ - Encode input text into a list of token IDs. + tid = self.tokenizer.eos_token_id + if tid is not None: + return int(tid) + tid = self.tokenizer.bos_token_id + if tid is not None: + return int(tid) + added = self.tokenizer.all_special_ids + if added: + return int(added[0]) + return None + + def encode( + self, + text: str, + *, + add_special_tokens: bool = False, + max_chars: int = MAX_CHARS_PER_DOC, + ) -> list[int]: + """ + Encode text to a list of token ids. + + Rejects None / non-str silently (returns ``[]``) so a single malformed + sample does not kill the DataLoader worker. Oversized documents are + truncated at ``max_chars`` characters before tokenization. Args: - text (str): The input text string to tokenize. + text -- source text; non-strings are treated as empty. + add_special_tokens -- defaults to False for pretraining packing + (the packer injects EOS explicitly). + max_chars -- hard character cap; pass a very large value + if you genuinely have legitimate huge docs. Returns: - list[int]: List of integer token IDs representing the input text. + List of integer token ids. """ - return self.tokenizer.encode(text, add_special_tokens=False) + if not isinstance(text, str) or not text: + return [] + if max_chars and len(text) > max_chars: + text = text[:max_chars] + return self.tokenizer.encode(text, add_special_tokens=add_special_tokens) - def decode(self, token_ids: list[int]) -> str: + def encode_with_eos(self, text: str) -> list[int]: """ - Decode a list of token IDs back into a text string. + Encode ``text`` and append the EOS token id (when defined). - Args: - token_ids (list[int]): A list of integer token IDs to decode. + Used by the document-packer path so concatenated documents are + delimited by a boundary token rather than flowing into each other. + Falls back to plain ``encode`` when the tokenizer has no EOS. + """ + ids = self.encode(text) + if not ids: + return ids + eos = self.eos_token_id + if eos is not None: + ids.append(eos) + return ids - Returns: - str: Decoded string representation of the token IDs. + def decode(self, token_ids: list[int]) -> str: + """ + Decode ``token_ids`` back to a string, stripping special tokens. """ return self.tokenizer.decode(token_ids, skip_special_tokens=True) diff --git a/pyproject.toml b/pyproject.toml index 1d9f720..0d820e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,9 +38,20 @@ classifiers = [ [tool.poetry.dependencies] python = ">=3.10,<4.0" -torch = "2.11.0" -transformers = ">=4.40.0" -datasets = ">=2.18.0" +# Library constraints. torch 2.3 is the floor for the FSDP APIs we use +# (ShardedGradScaler in torch.distributed.fsdp.sharded_grad_scaler, +# torch.amp.autocast with device_type=, full_state_dict rank0_only=). +# Pretraining runs should pin exact versions via training/requirements.txt. +torch = ">=2.3.0,<3.0.0" +transformers = ">=4.40.0,<5.0.0" +datasets = ">=2.18.0,<4.0.0" + + +[tool.poetry.group.training.dependencies] +# Only required for the pretraining scripts in training/, not for using +# the library for inference / architecture research. +numpy = ">=1.26,<3.0" +loguru = ">=0.7.3,<1.0.0" [tool.poetry.group.lint.dependencies] diff --git a/requirements.txt b/requirements.txt index 3b01619..109f65b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,9 @@ -torch>=2.1.0 -transformers>=4.40.0 -datasets>=2.18.0 -pytest>=7.0.0 +# Runtime requirements for using `open_mythos` as a library. +# These are broad ranges; a training run should use training/requirements.txt, +# which pins exact versions for reproducibility across nodes. +torch>=2.3.0,<3.0.0 +transformers>=4.40.0,<5.0.0 +datasets>=2.18.0,<4.0.0 + +# dev / test tooling +pytest>=8.1.1,<10.0.0 diff --git a/tests/test_config_validation.py b/tests/test_config_validation.py new file mode 100644 index 0000000..d64d4b4 --- /dev/null +++ b/tests/test_config_validation.py @@ -0,0 +1,100 @@ +""" +Validation of `MythosConfig.__post_init__`. + +The `__post_init__` guards exist so a bad hyperparameter combination fails +at *config construction* time instead of halfway through a multi-day +pretraining run. Every test here pokes one axis at a time so regressions +in the validator pinpoint exactly which rule broke. +""" + +import pytest + +from open_mythos.main import MythosConfig + + +def _base(**overrides) -> dict: + """ + Minimal kwargs that produce a valid config. Overrides mutate one axis; + other tests rely on this being a known-good baseline. + """ + cfg = dict( + vocab_size=128, + dim=64, + n_heads=4, + n_kv_heads=2, + max_seq_len=32, + max_loop_iters=2, + prelude_layers=1, + coda_layers=1, + attn_type="gqa", + n_experts=4, + n_shared_experts=1, + n_experts_per_tok=2, + expert_dim=32, + ) + cfg.update(overrides) + return cfg + + +def test_baseline_is_valid() -> None: + MythosConfig(**_base()) + + +def test_attn_type_invalid_raises() -> None: + with pytest.raises(ValueError, match="attn_type"): + MythosConfig(**_base(attn_type="linear")) + + +def test_dim_not_divisible_by_n_heads_raises() -> None: + with pytest.raises(ValueError, match="dim"): + MythosConfig(**_base(dim=65, n_heads=4)) + + +def test_gqa_divisibility_raises() -> None: + # 5 Q heads / 2 KV groups → 5 not divisible by 2 + with pytest.raises(ValueError, match="n_heads"): + MythosConfig(**_base(n_heads=5, n_kv_heads=2, dim=80)) + + +def test_mla_odd_rope_head_dim_raises() -> None: + with pytest.raises(ValueError, match=r"qk_rope_head_dim"): + MythosConfig( + **_base( + attn_type="mla", + kv_lora_rank=16, + q_lora_rank=16, + qk_rope_head_dim=15, # odd + qk_nope_head_dim=16, + v_head_dim=16, + ) + ) + + +def test_moe_experts_per_tok_exceeds_n_experts_raises() -> None: + with pytest.raises(ValueError, match="experts_per_tok"): + MythosConfig(**_base(n_experts=4, n_experts_per_tok=5)) + + +def test_dropout_out_of_range_raises() -> None: + with pytest.raises(ValueError, match="dropout"): + MythosConfig(**_base(dropout=1.5)) + + +def test_negative_bias_update_speed_raises() -> None: + with pytest.raises(ValueError, match="bias_update_speed"): + MythosConfig(**_base(bias_update_speed=-1e-3)) + + +def test_non_positive_init_std_raises() -> None: + with pytest.raises(ValueError, match="init_std"): + MythosConfig(**_base(init_std=0.0)) + + +def test_non_positive_vocab_raises() -> None: + with pytest.raises(ValueError, match="vocab_size"): + MythosConfig(**_base(vocab_size=0)) + + +def test_non_positive_max_seq_len_raises() -> None: + with pytest.raises(ValueError, match="max_seq_len"): + MythosConfig(**_base(max_seq_len=0)) diff --git a/tests/test_moe_router.py b/tests/test_moe_router.py new file mode 100644 index 0000000..d5fdfd3 --- /dev/null +++ b/tests/test_moe_router.py @@ -0,0 +1,226 @@ +""" +Tests for the MoE router: aux-loss-free bias update + vectorized dispatch. + +Two invariants we guard here: + +1. **Router bias actually moves.** ``MoEFFN.router_bias`` is a non-persistent + buffer updated by ``update_bias`` after each optimizer step. The + upstream version never wired this up — router_bias stayed at zeros + forever and balancing was silently inert. These tests would catch + that regression. + +2. **Vectorized dispatch matches the reference loop.** The per-expert + argsort + index_add_ path replaces an O(topk * n_experts) Python + loop. The two paths must produce byte-identical (modulo tiny fp + error) outputs for the same inputs, otherwise the optimization is a + silent semantics change. +""" + +import torch + +from open_mythos.main import MoEFFN, MythosConfig, OpenMythos + + +def _moe_cfg(**overrides) -> MythosConfig: + cfg = dict( + vocab_size=128, + dim=64, + n_heads=4, + n_kv_heads=2, + max_seq_len=32, + max_loop_iters=1, + prelude_layers=1, + coda_layers=1, + attn_type="gqa", + n_experts=8, + n_shared_experts=1, + n_experts_per_tok=2, + expert_dim=32, + bias_update_speed=1e-2, # larger than default so 1 step is visible + ) + cfg.update(overrides) + return MythosConfig(**cfg) + + +# --------------------------------------------------------------------------- +# Router-bias update +# --------------------------------------------------------------------------- + + +def test_router_bias_is_a_buffer() -> None: + """Must be a buffer (non-persistent) — not a trainable Parameter.""" + cfg = _moe_cfg() + moe = MoEFFN(cfg) + assert "router_bias" in dict(moe.named_buffers()) + # expert_load is also a non-persistent buffer, not a param. + assert "expert_load" in dict(moe.named_buffers()) + # And neither should be in the Parameter list (they must not be trained). + params = {name for name, _ in moe.named_parameters()} + assert "router_bias" not in params + assert "expert_load" not in params + + +def test_expert_load_accumulates_on_forward() -> None: + cfg = _moe_cfg() + moe = MoEFFN(cfg).train() + x = torch.randn(2, 4, cfg.dim) + + assert float(moe.expert_load.sum()) == 0.0 + _ = moe(x) + # N*K tokens were routed (batch*seq*topk). + expected_total = 2 * 4 * cfg.n_experts_per_tok + assert int(moe.expert_load.sum().item()) == expected_total + + +def test_expert_load_not_accumulated_in_eval() -> None: + cfg = _moe_cfg() + moe = MoEFFN(cfg).eval() + x = torch.randn(2, 4, cfg.dim) + _ = moe(x) + assert float(moe.expert_load.sum()) == 0.0 + + +def test_update_bias_shifts_toward_underused_experts() -> None: + """ + Construct a synthetic load vector where one expert is underused and + one is overused, then verify update_bias pushes the bias for the + underused expert up and the overused one down. + """ + cfg = _moe_cfg(n_experts=4) + moe = MoEFFN(cfg) + + # Load: expert 0 underused (0 tokens), expert 3 overused (100 tokens), + # rest neutral. + moe.expert_load.copy_(torch.tensor([0.0, 20.0, 20.0, 100.0])) + speed = 1e-2 + before = moe.router_bias.clone() + moe.update_bias(speed) + + # Mean = (0+20+20+100)/4 = 35. Direction = sign(mean - load). + # expert 0: sign(35 - 0) = +1 → bias[0] increases by speed + # expert 3: sign(35 - 100) = -1 → bias[3] decreases by speed + assert (moe.router_bias[0] - before[0]).item() > 0 + assert (moe.router_bias[3] - before[3]).item() < 0 + + +def test_update_bias_zeroes_load_after_step() -> None: + cfg = _moe_cfg() + moe = MoEFFN(cfg) + moe.expert_load.fill_(5.0) + moe.update_bias(1e-2) + assert float(moe.expert_load.sum()) == 0.0 + + +def test_update_bias_noop_when_speed_zero() -> None: + cfg = _moe_cfg() + moe = MoEFFN(cfg) + moe.expert_load.fill_(5.0) + before = moe.router_bias.clone() + moe.update_bias(0.0) + # Bias unchanged ... + assert torch.equal(moe.router_bias, before) + # ... but load still flushed so memory doesn't grow unbounded. + assert float(moe.expert_load.sum()) == 0.0 + + +def test_update_bias_magnitude_bounded_by_speed() -> None: + """ + Sign-based update means per-step delta is exactly `speed` per expert + regardless of how large the imbalance is — the whole point of the + design, as opposed to a direct delta that explodes on spiky loads. + """ + cfg = _moe_cfg(n_experts=4) + moe = MoEFFN(cfg) + + moe.expert_load.copy_(torch.tensor([0.0, 0.0, 0.0, 1_000_000.0])) + speed = 1e-2 + before = moe.router_bias.clone() + moe.update_bias(speed) + + diff = (moe.router_bias - before).abs().max().item() + assert abs(diff - speed) < 1e-6 + + +# --------------------------------------------------------------------------- +# End-to-end: OpenMythos.update_router_biases drives every MoE layer +# --------------------------------------------------------------------------- + + +def test_openmythos_update_router_biases_walks_all_moe_layers() -> None: + cfg = _moe_cfg() + model = OpenMythos(cfg) + + # Fabricate imbalance into every MoE layer. + moes = [m for m in model.modules() if isinstance(m, MoEFFN)] + assert len(moes) > 0 + + # Make expert 0 look underused across the whole model. + for m in moes: + m.expert_load.copy_(torch.tensor([0.0] + [10.0] * (cfg.n_experts - 1))) + + before = [m.router_bias.clone() for m in moes] + model.update_router_biases(ddp=False) + + for b_before, m in zip(before, moes): + # Expert 0's bias must have moved up. + assert (m.router_bias[0] - b_before[0]).item() > 0 + # Load flushed. + assert float(m.expert_load.sum()) == 0.0 + + +# --------------------------------------------------------------------------- +# Vectorized dispatch semantic equivalence to the naive loop. +# --------------------------------------------------------------------------- + + +def _naive_moe_forward(moe: MoEFFN, x: torch.Tensor) -> torch.Tensor: + """ + Reference implementation of the routed-expert dispatch using the + obvious (and slow) double loop: for each (token, k) pair, find the + selected expert and add its weighted contribution to the output. + + Deliberately verbose so it's easy to read as a spec. Only used here + to cross-check the fast implementation; never called by the model. + """ + import torch.nn.functional as F + + B, T, D = x.shape + flat = x.view(B * T, D) + + logits = moe.router(flat) + scores = F.softmax(logits, dim=-1, dtype=torch.float32).to(flat.dtype) + _, topk_idx = (logits.float() + moe.router_bias).topk(moe.topk, dim=-1) + topk_scores = scores.gather(-1, topk_idx) + denom = topk_scores.sum(dim=-1, keepdim=True).clamp_min(1e-9) + topk_scores = topk_scores / denom + + out = torch.zeros_like(flat) + N = flat.shape[0] + for n in range(N): + for k in range(moe.topk): + eid = int(topk_idx[n, k]) + w = topk_scores[n, k] + out[n] = out[n] + w * moe.routed_experts[eid](flat[n : n + 1]).squeeze(0) + + for shared in moe.shared_experts: + out = out + shared(flat) + + return out.view(B, T, D) + + +def test_vectorized_matches_naive_dispatch() -> None: + """ + The fast path uses argsort + index_add_; the reference path iterates + per token per k. Same inputs, same (eval-mode) state → same outputs + within fp tolerance. + """ + cfg = _moe_cfg(n_experts=6, n_experts_per_tok=3, dim=32, expert_dim=16) + moe = MoEFFN(cfg).eval() # eval so expert_load isn't mutated + + torch.manual_seed(0) + x = torch.randn(2, 5, cfg.dim) + + fast = moe(x) + slow = _naive_moe_forward(moe, x) + + torch.testing.assert_close(fast, slow, rtol=1e-5, atol=1e-5) diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index fab7533..d087825 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -1,74 +1,115 @@ +"""Tokenizer tests — cover the documented contract of MythosTokenizer.""" + import pytest + from open_mythos.tokenizer import MythosTokenizer @pytest.fixture(scope="module") -def tokenizer(): - tok = MythosTokenizer() - print(f"\nLoaded tokenizer: {tok.tokenizer.name_or_path}") - return tok +def tokenizer() -> MythosTokenizer: + return MythosTokenizer() -def test_loads(tokenizer): +def test_loads(tokenizer: MythosTokenizer) -> None: assert tokenizer is not None - print(f"Tokenizer: {tokenizer}") -def test_vocab_size(tokenizer): - size = tokenizer.vocab_size - print(f"Vocab size: {size:,}") - assert size > 0 +def test_vocab_size_positive(tokenizer: MythosTokenizer) -> None: + assert tokenizer.vocab_size > 0 + + +def test_vocab_size_covers_len(tokenizer: MythosTokenizer) -> None: + """ + ``vocab_size`` must be >= ``len(tokenizer.tokenizer)``. + + This is the core invariant: an ``nn.Embedding(vocab_size, dim)`` sized + from this property cannot index out of range for any token the + tokenizer can emit, including added special tokens that HF's + ``tokenizer.vocab_size`` silently excludes. + """ + assert tokenizer.vocab_size >= len(tokenizer.tokenizer) + + +def test_vocab_size_rounded_to_multiple(tokenizer: MythosTokenizer) -> None: + """Default ``vocab_multiple_of=128`` → vocab_size is a multiple of 128.""" + assert tokenizer.vocab_size % 128 == 0 + + +def test_vocab_multiple_of_zero_disables_rounding() -> None: + """Passing ``vocab_multiple_of=0`` returns the raw ``len(tokenizer)``.""" + tok = MythosTokenizer(vocab_multiple_of=0) + assert tok.vocab_size == len(tok.tokenizer) -def test_encode_returns_list_of_ints(tokenizer): +def test_encode_returns_list_of_ints(tokenizer: MythosTokenizer) -> None: ids = tokenizer.encode("Hello, world!") - print(f"encode('Hello, world!') → {ids}") assert isinstance(ids, list) assert all(isinstance(i, int) for i in ids) assert len(ids) > 0 -def test_encode_empty_string(tokenizer): - ids = tokenizer.encode("") - print(f"encode('') → {ids}") - assert isinstance(ids, list) +def test_encode_ids_within_vocab(tokenizer: MythosTokenizer) -> None: + """Every emitted id must be < vocab_size (otherwise embeddings crash).""" + ids = tokenizer.encode("Any reasonable text with punctuation, 123.") + vocab = tokenizer.vocab_size + assert all(0 <= i < vocab for i in ids) + + +def test_encode_empty_string(tokenizer: MythosTokenizer) -> None: + assert tokenizer.encode("") == [] + + +def test_encode_none_returns_empty(tokenizer: MythosTokenizer) -> None: + # Defense against a corrupt dataset sample killing the loader worker. + assert tokenizer.encode(None) == [] # type: ignore[arg-type] + + +def test_encode_non_str_returns_empty(tokenizer: MythosTokenizer) -> None: + assert tokenizer.encode(12345) == [] # type: ignore[arg-type] + assert tokenizer.encode(["list"]) == [] # type: ignore[arg-type] + + +def test_encode_truncates_oversized_input(tokenizer: MythosTokenizer) -> None: + """Hard character cap keeps the tokenizer from eating giant docs.""" + huge = "a" * 10_000_000 + ids = tokenizer.encode(huge, max_chars=1000) + # 1000 chars of 'a' tokenize to something bounded — certainly much less + # than if we had tokenized the full 10M-char input. + assert len(ids) < 5000 -def test_decode_returns_string(tokenizer): +def test_decode_returns_string(tokenizer: MythosTokenizer) -> None: ids = tokenizer.encode("Hello, world!") text = tokenizer.decode(ids) - print(f"decode({ids}) → '{text}'") assert isinstance(text, str) -def test_roundtrip(tokenizer): +def test_roundtrip(tokenizer: MythosTokenizer) -> None: original = "The quick brown fox jumps over the lazy dog." ids = tokenizer.encode(original) recovered = tokenizer.decode(ids) - print(f"original: '{original}'") - print(f"token ids: {ids}") - print(f"recovered: '{recovered}'") assert original in recovered or recovered in original -def test_encode_long_text(tokenizer): - text = "OpenMythos is a recurrent depth transformer. " * 100 - ids = tokenizer.encode(text) - print(f"Long text ({len(text)} chars) → {len(ids)} tokens") - assert len(ids) > 100 +def test_eos_token_id_is_int_or_none(tokenizer: MythosTokenizer) -> None: + eos = tokenizer.eos_token_id + assert eos is None or isinstance(eos, int) -def test_custom_model_id(): - tok = MythosTokenizer(model_id="openai/gpt-oss-20b") - print(f"Custom model_id vocab size: {tok.vocab_size:,}") - assert tok.vocab_size > 0 +def test_encode_with_eos_appends_eos_token(tokenizer: MythosTokenizer) -> None: + """``encode_with_eos`` appends EOS when one is defined, else plain encode.""" + ids = tokenizer.encode_with_eos("Hello, world!") + base = tokenizer.encode("Hello, world!") + eos = tokenizer.eos_token_id + if eos is not None: + assert ids == base + [eos] + else: + assert ids == base -def test_vocab_size_consistent(tokenizer): - outer = tokenizer.vocab_size - inner = tokenizer.tokenizer.vocab_size - print(f"vocab_size property: {outer:,} | inner tokenizer.vocab_size: {inner:,}") - assert outer == inner +def test_encode_with_eos_on_empty_is_empty(tokenizer: MythosTokenizer) -> None: + """No EOS appended on empty input — packer would otherwise emit a lone EOS.""" + assert tokenizer.encode_with_eos("") == [] if __name__ == "__main__": diff --git a/training/3b_fine_web_edu.py b/training/3b_fine_web_edu.py index e980302..467ef0d 100644 --- a/training/3b_fine_web_edu.py +++ b/training/3b_fine_web_edu.py @@ -7,32 +7,118 @@ Multi-GPU: torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") training/3b_fine_web_edu.py + +Hardening vs. the upstream reference trainer: + - All RNGs seeded (python / numpy / torch / cuda) and the seed is + checkpointed so a resume reproduces the same data stream given the + same shard position. + - Checkpoint carries RNG state, AMP scaler state, and torch + cuda + versions — enough to survive a node swap mid-run. + - Graceful SIGTERM / SIGINT handling: drains the current microbatch + loop, saves a final atomic checkpoint, then tears down NCCL. No + half-written `.tmp` files, no hang. + - NaN / Inf loss guard per microstep. A bad batch zeros its grad + contribution; if the whole accumulation window is bad we skip + optimizer.step() but still advance the step counter so LR schedule + and logging stay monotonic. + - ShardedGradScaler on the fp16 path (Volta / Pascal); bf16 path + (Ampere+) runs with FSDP MixedPrecision and no scaler, which is + the officially supported combination. + - File logging with loguru rotation (100 MB + 7-day retention) so + long runs don't fill the disk with a single multi-GB log. + - EOS injected between packed documents via + `encoding.encode_with_eos()` so the model never sees a + cross-document attention window without a boundary marker. + - `model.update_router_biases(ddp=ddp)` called after every + successful `optimizer.step()` — this is what actually drives the + aux-loss-free load balancing; without it, `router_bias` stays at + zeros forever. + - Micro-batch loss accumulated on-device; single `.item()` per step. """ -import os import math +import os +import random +import signal +import sys import time +from contextlib import nullcontext +from dataclasses import asdict +from typing import Optional + +import numpy as np import torch -import torch.nn as nn import torch.distributed as dist +import torch.nn as nn +from datasets import load_dataset from loguru import logger from torch.distributed.fsdp import ( + FullStateDictConfig, FullyShardedDataParallel as FSDP, - ShardingStrategy, MixedPrecision, - FullStateDictConfig, + ShardingStrategy, StateDictType, ) +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.distributed.fsdp.wrap import ModuleWrapPolicy -from torch.utils.data import IterableDataset, DataLoader, get_worker_info -from contextlib import nullcontext - -from datasets import load_dataset +from torch.utils.data import DataLoader, IterableDataset, get_worker_info from open_mythos import OpenMythos -from open_mythos.main import TransformerBlock, RecurrentBlock -from open_mythos.variants import mythos_3b +from open_mythos.main import RecurrentBlock, TransformerBlock from open_mythos.tokenizer import MythosTokenizer +from open_mythos.variants import mythos_3b + + +# --------------------------------------------------------------------------- +# Determinism / RNG +# --------------------------------------------------------------------------- + + +DEFAULT_SEED = 1337 + + +def seed_everything(seed: int) -> None: + """ + Seed every RNG we touch so two runs with the same seed draw the same + microbatches (given an unchanged dataset shard order). + + This is not a promise of bit-exact model outputs — cuBLAS / cuDNN + kernels pick different algorithms depending on workspace allocator + state, and FSDP all-reduce order is non-deterministic under NCCL. + What it does buy is a reproducible *data path* and a reproducible + initialization — the two things that actually cause "why did my + loss curve change" regressions. + """ + os.environ["PYTHONHASHSEED"] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def snapshot_rng() -> dict: + """Collect every RNG's state so a checkpoint can resume reproducibly.""" + return { + "python": random.getstate(), + "numpy": np.random.get_state(), + "torch": torch.get_rng_state(), + "cuda": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None, + } + + +def restore_rng(state: Optional[dict]) -> None: + """Inverse of `snapshot_rng`; tolerates missing keys from older ckpts.""" + if not state: + return + if "python" in state: + random.setstate(state["python"]) + if "numpy" in state: + np.random.set_state(state["numpy"]) + if "torch" in state: + torch.set_rng_state(state["torch"]) + if state.get("cuda") is not None and torch.cuda.is_available(): + torch.cuda.set_rng_state_all(state["cuda"]) # --------------------------------------------------------------------------- @@ -50,16 +136,28 @@ class FineWebEduDataset(IterableDataset): `(rank, worker_id)` deterministically owns one shard of the global stream. That gives disjoint coverage without any cross-process coordination. + Documents are encoded with `encode_with_eos` so concatenated docs have + an EOS token between them — the model sees explicit boundaries instead + of silently attending across unrelated documents (which inflates loss + and teaches spurious long-range dependencies). + Streaming datasets are not seekable, so a resumed run re-enters its shard from the beginning. Acceptable at pretraining scale: the chance of re-playing the same tokens before the run ends is negligible versus the cost of a true resumable loader. """ - def __init__(self, encoding, seq_len: int, subset: str, rank: int, world_size: int): + def __init__( + self, + encoding: MythosTokenizer, + seq_len: int, + subset: str, + rank: int, + world_size: int, + ): """ Args: - encoding -- tokenizer exposing `.encode(str) -> list[int]` + encoding -- tokenizer exposing ``encode_with_eos(str) -> list[int]`` seq_len -- context length; every yielded pair has this many tokens subset -- FineWeb-Edu config name (e.g. "sample-10BT", "default") rank -- global rank of this process within the distributed job @@ -73,10 +171,10 @@ def __init__(self, encoding, seq_len: int, subset: str, rank: int, world_size: i def __iter__(self): """ - Yield `(input_ids, target_ids)` tensors of length `seq_len` forever. + Yield ``(input_ids, target_ids)`` tensors of length ``seq_len`` forever. Inputs and targets are shifted by one for next-token prediction — - `target[i] == input[i + 1]`. Documents are concatenated into a rolling + ``target[i] == input[i + 1]``. Documents are concatenated into a rolling buffer and sliced into fixed-length chunks, packing short docs together and splitting long ones. This keeps every step at the same shape, which under FSDP avoids recompute from variable-length inputs and @@ -96,9 +194,13 @@ def __iter__(self): streaming=True, ).shard(num_shards=total_shards, index=shard_index) - buf = [] + buf: list[int] = [] for sample in ds: - buf.extend(self.encoding.encode(sample["text"])) + text = sample.get("text") + ids = self.encoding.encode_with_eos(text) if text else [] + if not ids: + continue + buf.extend(ids) while len(buf) >= self.seq_len + 1: chunk = buf[: self.seq_len + 1] buf = buf[self.seq_len + 1 :] @@ -115,13 +217,13 @@ def __iter__(self): def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float: """ - Linear warmup → half-cosine decay to `min_lr`. + Linear warmup → half-cosine decay to ``min_lr``. Standard language-model pretraining schedule. The warmup phase prevents Adam's second-moment estimate from collapsing to a huge LR in the first few steps when gradients are noisy. The cosine tail lets the model make small, increasingly conservative updates near the end of training rather - than crashing to `min_lr` at a fixed step. + than crashing to ``min_lr`` at a fixed step. Behavior by region: step < warmup → linear ramp 0 → max_lr @@ -129,22 +231,12 @@ def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> step ≥ total → clamped at min_lr (safety for off-by-one step counters at the end of training) - - Args: - step -- current global optimizer step (0-indexed) - warmup -- number of warmup steps before cosine decay begins - total -- step at which the cosine reaches `min_lr` - max_lr -- peak learning rate reached at the end of warmup - min_lr -- floor learning rate at and after `total` steps - - Returns: - Scalar learning rate for this step. """ if step < warmup: - return max_lr * step / warmup + return max_lr * step / max(1, warmup) if step >= total: return min_lr - decay = (step - warmup) / (total - warmup) + decay = (step - warmup) / max(1, total - warmup) return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay)) @@ -153,37 +245,37 @@ def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> # --------------------------------------------------------------------------- +CKPT_SUFFIX = ".pt" +CKPT_PREFIX = "step_" + + def _list_ckpts(ckpt_dir: str) -> list[str]: """ - Return checkpoint paths in `ckpt_dir` sorted oldest → newest. + Return checkpoint paths in ``ckpt_dir`` sorted oldest → newest. - Relies on the zero-padded `step_{0000000}.pt` filename convention so + Relies on the zero-padded ``step_{0000000}.pt`` filename convention so lexicographic sort matches chronological order. Changing the filename format elsewhere without updating the pad width would silently break - both `keep_last` pruning and resume-latest on startup, since both pick + both ``keep_last`` pruning and resume-latest on startup, since both pick the last element of this list. - - Args: - ckpt_dir -- directory to scan; missing directory returns [] - - Returns: - Sorted list of absolute paths to matching checkpoint files. """ if not os.path.isdir(ckpt_dir): return [] return sorted( os.path.join(ckpt_dir, f) for f in os.listdir(ckpt_dir) - if f.startswith("step_") and f.endswith(".pt") + if f.startswith(CKPT_PREFIX) and f.endswith(CKPT_SUFFIX) ) def save_checkpoint( - model, - optimizer, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scaler: Optional[ShardedGradScaler], step: int, cfg, vocab_size: int, + seed: int, ckpt_dir: str, ddp: bool, master: bool, @@ -200,22 +292,9 @@ def save_checkpoint( truncated .pt file. Non-master ranks participate in the FSDP gather (otherwise the collective would hang) but exit before touching disk. - Args: - model -- FSDP-wrapped (ddp=True) or raw (ddp=False) model - optimizer -- the optimizer whose state should round-trip with the model - step -- global step number; encoded zero-padded into the filename - cfg -- model config object; saved so downstream eval can - reconstruct the model without re-importing the variant - vocab_size -- tokenizer vocab size at train time; saved for sanity-check - on load against a (possibly updated) tokenizer - ckpt_dir -- directory to write into; created if missing - ddp -- True if FSDP path; False for single-GPU / CPU - master -- whether this rank writes to disk (rank 0 only) - keep_last -- number of most-recent checkpoints to retain; older ones - are unlinked after a successful write - - Returns: - None. Writes to disk as a side effect on master rank. + The checkpoint also carries the RNG state snapshot, the AMP scaler + state (when applicable), and the training seed — enough for a fresh + process on a different node to resume with deterministic data order. """ if ddp: with FSDP.state_dict_type( @@ -233,20 +312,35 @@ def save_checkpoint( return os.makedirs(ckpt_dir, exist_ok=True) - final_path = os.path.join(ckpt_dir, f"step_{step:07d}.pt") + final_path = os.path.join(ckpt_dir, f"{CKPT_PREFIX}{step:07d}{CKPT_SUFFIX}") tmp_path = final_path + ".tmp" - torch.save( - { - "step": step, - "model": model_state, - "optimizer": optim_state, - "cfg": cfg, - "vocab_size": vocab_size, - }, - tmp_path, - ) + + payload = { + "step": step, + "model": model_state, + "optimizer": optim_state, + "cfg": cfg, + "vocab_size": vocab_size, + "seed": seed, + "rng": snapshot_rng(), + "scaler": scaler.state_dict() if scaler is not None else None, + "torch_version": torch.__version__, + "cuda_version": torch.version.cuda, + } + torch.save(payload, tmp_path) os.replace(tmp_path, final_path) + # Best-effort fsync of the directory so the rename is durable across + # a crash/power-loss — torch.save already fsyncs the file itself. + try: + dir_fd = os.open(ckpt_dir, os.O_DIRECTORY) + try: + os.fsync(dir_fd) + finally: + os.close(dir_fd) + except OSError as exc: + logger.warning(f"Directory fsync failed (non-fatal) on {ckpt_dir}: {exc}") + for old in _list_ckpts(ckpt_dir)[:-keep_last]: try: os.remove(old) @@ -256,30 +350,27 @@ def save_checkpoint( logger.success(f"Checkpoint saved → {final_path}") -def load_checkpoint(model, optimizer, path: str, ddp: bool) -> int: +def load_checkpoint( + model: nn.Module, + optimizer: torch.optim.Optimizer, + scaler: Optional[ShardedGradScaler], + path: str, + ddp: bool, +) -> int: """ - Restore model + optimizer from disk, returning the step to resume at. - - Every rank reads the file (`rank0_only=False` on load) so FSDP has access - to the full state on each rank — the complement to the `rank0_only=True` - save path. Must mirror save's single-context pattern; splitting the model - and optimizer loads across two `state_dict_type` blocks has historically - produced optimizer state bound to the wrong shard shapes. - - `weights_only=False` is required because the checkpoint contains the - pickled `cfg` dataclass — flip to `weights_only=True` only if you + Restore model + optimizer + scaler + RNG from disk, returning the step + to resume at. + + Every rank reads the file (``rank0_only=False`` on load) so FSDP has + access to the full state on each rank — the complement to the + ``rank0_only=True`` save path. Must mirror save's single-context + pattern; splitting the model and optimizer loads across two + ``state_dict_type`` blocks has historically produced optimizer state + bound to the wrong shard shapes. + + ``weights_only=False`` is required because the checkpoint contains the + pickled ``cfg`` dataclass — flip to ``weights_only=True`` only if you separate config out. - - Args: - model -- same FSDP-wrapped or raw model used during save - optimizer -- freshly constructed optimizer to be filled in-place - path -- absolute path to a `step_{N:07d}.pt` file produced by - `save_checkpoint` - ddp -- whether the model is FSDP-wrapped; must match the save run - - Returns: - The step number the checkpoint was taken at; the caller advances the - training loop from this value. """ ckpt = torch.load(path, map_location="cpu", weights_only=False) @@ -300,15 +391,100 @@ def load_checkpoint(model, optimizer, path: str, ddp: bool) -> int: model.load_state_dict(ckpt["model"]) optimizer.load_state_dict(ckpt["optimizer"]) + if scaler is not None and ckpt.get("scaler") is not None: + scaler.load_state_dict(ckpt["scaler"]) + + restore_rng(ckpt.get("rng")) + return int(ckpt["step"]) +# --------------------------------------------------------------------------- +# Signal handling +# --------------------------------------------------------------------------- + + +class ShutdownFlag: + """ + Cooperative shutdown request set by SIGTERM / SIGINT handlers. + + Using a class attribute (rather than a bare global) because loguru's + multiprocessing fork tracker sometimes re-imports modules in child + processes, which would reset a module-level global. A class attribute + on a module-level singleton survives that. + """ + + requested = False + + @classmethod + def request(cls, signum: int, frame) -> None: + """ + Signal handler: mark the flag and log once. The training loop + polls this between microsteps and exits cleanly with a final + atomic checkpoint. A second signal falls through to default + handling (hard kill) — so a stuck rank can always be force-killed. + """ + if cls.requested: + logger.warning(f"Second signal {signum} received — falling through") + signal.signal(signum, signal.SIG_DFL) + os.kill(os.getpid(), signum) + return + cls.requested = True + logger.warning(f"Signal {signum} received — draining to final checkpoint") + + +def install_signal_handlers() -> None: + """ + Install SIGTERM + SIGINT handlers on the main rank / main thread only. + + Data-loader worker processes inherit these handlers too, which is + harmless: they don't have a model to checkpoint, and the flag they + would flip lives in a different address space. + """ + signal.signal(signal.SIGTERM, ShutdownFlag.request) + signal.signal(signal.SIGINT, ShutdownFlag.request) + + +# --------------------------------------------------------------------------- +# Logging +# --------------------------------------------------------------------------- + + +def configure_logging(log_dir: str, master: bool, rank: int) -> None: + """ + Attach a rotating file sink to loguru on every rank. + + Rank 0 keeps the default stderr sink; non-master ranks silence stderr + to avoid interleaved output chaos, but still write a per-rank file log + so post-mortem debugging sees every rank's view. 100 MB per file and + 7-day retention is enough to cover a multi-week pretraining run while + keeping total log volume bounded. + """ + os.makedirs(log_dir, exist_ok=True) + + if not master: + # Replace stderr with a null sink so non-master ranks don't spam the + # terminal; per-rank file sink still captures everything. + logger.remove() + + logger.add( + os.path.join(log_dir, f"train.rank{rank}.log"), + rotation="100 MB", + retention="7 days", + compression="gz", + enqueue=True, # thread-safe + signal-safe + level="INFO", + backtrace=True, + diagnose=False, # diagnose=True can leak tensor contents into logs + ) + + # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- -def main(): +def main() -> None: """ End-to-end pretraining entry point. @@ -318,23 +494,6 @@ def main(): (FSDP re-flattens parameters, so an optimizer built on the unwrapped model would track stale param objects). Resume then loads state into the already-constructed optimizer in-place. - - Lifecycle: - 1. Initialize torch.distributed (NCCL) if launched under torchrun. - 2. Build tokenizer → derive vocab_size. - 3. Construct OpenMythos with the 3B variant config. - 4. Wrap in FSDP with FULL_SHARD + bf16/fp16 mixed precision (multi-GPU) - or move to device + autocast (single-GPU). - 5. Build fused AdamW on (possibly sharded) parameters. - 6. Resume from the latest checkpoint in `ckpt_dir` if one exists. - 7. Stream FineWeb-Edu through grad-accumulation microbatches with - cosine LR schedule, per-step logging, and periodic checkpoints. - 8. Write a final checkpoint if the last save wasn't aligned to - `ckpt_every`, then barrier + tear down the process group. - - All hyperparameters are literal constants in this function by design — - pretraining runs are long-lived and each run pins exact settings; a - CLI/config layer is deliberately avoided to keep the file self-auditable. """ # ------------------------------------------------------------------ # Distributed init @@ -354,19 +513,42 @@ def main(): master = rank == 0 + log_dir = "logs" + configure_logging(log_dir, master, rank) + install_signal_handlers() + if master: logger.info( - f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}" + f"GPUs: {torch.cuda.device_count()} | " + f"World size: {world_size} | Device: {device}" ) + logger.info(f"Torch: {torch.__version__} | CUDA: {torch.version.cuda}") + + # ------------------------------------------------------------------ + # RNG seeding — uses per-rank offset so each rank's in-process RNG + # starts from a different state (torch.manual_seed is local), but the + # numeric seed stored in the checkpoint is the same base value. + # ------------------------------------------------------------------ + seed = int(os.environ.get("OPENMYTHOS_SEED", DEFAULT_SEED)) + seed_everything(seed + rank) # ------------------------------------------------------------------ # Tokenizer # ------------------------------------------------------------------ encoding = MythosTokenizer() vocab_size = encoding.vocab_size + eos_id = encoding.eos_token_id if master: - logger.info(f"Tokenizer: gpt-oss-20b | Vocab size: {vocab_size:,}") + logger.info( + f"Tokenizer: gpt-oss-20b | Vocab size: {vocab_size:,} | " + f"EOS id: {eos_id}" + ) + if eos_id is None: + logger.warning( + "Tokenizer has no EOS token — documents will be concatenated " + "without a boundary marker (not ideal for pretraining quality)" + ) # ------------------------------------------------------------------ # Hyperparameters @@ -380,6 +562,7 @@ def main(): warmup_steps = 2000 lr = 3e-4 wd = 0.1 + grad_clip = 1.0 log_every = 10 ckpt_every = 1000 ckpt_dir = "checkpoints" @@ -387,8 +570,10 @@ def main(): if master: logger.info( - f"seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum} | " - f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,}" + f"seq_len={seq_len} | micro_batch={micro_batch} | " + f"grad_accum={grad_accum} | " + f"global_batch_tokens={global_batch_tok:,} | " + f"total_steps={total_steps:,} | seed={seed}" ) # ------------------------------------------------------------------ @@ -397,9 +582,17 @@ def main(): cfg = mythos_3b() cfg.vocab_size = vocab_size cfg.max_seq_len = seq_len + # Re-validate after mutating vocab_size / max_seq_len — the default + # variant values pass, but an operator who edits them at the CLI + # gets a clean error here instead of a mid-step crash. + cfg.__post_init__() + + if master: + logger.info(f"Config: {asdict(cfg)}") bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported() amp_dtype = torch.bfloat16 if bf16_ok else torch.float16 + use_scaler = (amp_dtype == torch.float16) and torch.cuda.is_available() model = OpenMythos(cfg) @@ -417,6 +610,7 @@ def main(): auto_wrap_policy=wrap_policy, device_id=local_rank, ) + amp_ctx = nullcontext() else: model = model.to(device) amp_ctx = ( @@ -425,33 +619,37 @@ def main(): else nullcontext() ) - # FSDP handles its own mixed precision; only need autocast for single-GPU - amp_ctx = nullcontext() if ddp else amp_ctx # type: ignore[possibly-undefined] - if master: n_params = sum(p.numel() for p in model.parameters()) - logger.info(f"Parameters: {n_params:,} | AMP dtype: {amp_dtype}") + logger.info( + f"Parameters: {n_params:,} | AMP dtype: {amp_dtype} | " + f"Scaler: {'on' if use_scaler else 'off'}" + ) # ------------------------------------------------------------------ - # Optimizer + # Optimizer + GradScaler # ------------------------------------------------------------------ optimizer = torch.optim.AdamW( model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True ) + # ShardedGradScaler = FSDP-aware unscale / all-reduce inf detection. + # Only needed when training in fp16; bf16 has enough dynamic range to + # forgo loss scaling entirely. On CPU / single-GPU we still use the + # sharded variant — it degenerates to a normal GradScaler but keeps + # the call sites identical. + scaler = ShardedGradScaler(enabled=use_scaler) if use_scaler else None + # ------------------------------------------------------------------ # Resume from latest checkpoint (if any) # ------------------------------------------------------------------ - # Streaming datasets are not resumable by position, so re-iterating from - # the beginning is accepted — at pretraining scale the loss of dataset - # position is negligible vs. the cost of discarded training steps. start_step = 0 existing_ckpts = _list_ckpts(ckpt_dir) if existing_ckpts: latest = existing_ckpts[-1] if master: logger.info(f"Resuming from checkpoint: {latest}") - start_step = load_checkpoint(model, optimizer, latest, ddp) + start_step = load_checkpoint(model, optimizer, scaler, latest, ddp) if master: logger.success(f"Resumed at step {start_step}") @@ -459,7 +657,13 @@ def main(): # Dataset + DataLoader # ------------------------------------------------------------------ dataset = FineWebEduDataset(encoding, seq_len, dataset_subset, rank, world_size) - loader = DataLoader(dataset, batch_size=micro_batch, num_workers=4, pin_memory=True) + loader = DataLoader( + dataset, + batch_size=micro_batch, + num_workers=4, + pin_memory=True, + persistent_workers=True, + ) # ------------------------------------------------------------------ # Training loop @@ -471,80 +675,205 @@ def main(): data_iter = iter(loader) t0 = time.perf_counter() step = start_step - - while step < total_steps: - cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1) - for g in optimizer.param_groups: - g["lr"] = cur_lr - - optimizer.zero_grad() - loss_accum = 0.0 - - for micro_step in range(grad_accum): - try: - x, y = next(data_iter) - except StopIteration: - data_iter = iter(loader) - x, y = next(data_iter) - - x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True) - y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True) - - sync = ( - nullcontext() - if (not ddp or micro_step == grad_accum - 1) - else model.no_sync() - ) - with sync, amp_ctx: - logits = model(x) - loss = nn.functional.cross_entropy( - logits.view(-1, vocab_size), y.view(-1) + target_device = f"cuda:{local_rank}" if ddp else device + skipped_steps = 0 + + try: + while step < total_steps: + if ShutdownFlag.requested: + if master: + logger.warning( + f"Shutdown requested at step {step}, breaking loop" + ) + break + + cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1) + for g in optimizer.param_groups: + g["lr"] = cur_lr + + optimizer.zero_grad(set_to_none=True) + + # Accumulate micro-batch loss on-device; single .item() at the + # end of the accumulation window keeps the GPU fully async. + loss_accum = torch.zeros((), device=target_device, dtype=torch.float32) + bad_microsteps = 0 + + for micro_step in range(grad_accum): + try: + x, y = next(data_iter) + except StopIteration: + data_iter = iter(loader) + x, y = next(data_iter) + + x = x.to(target_device, non_blocking=True) + y = y.to(target_device, non_blocking=True) + + sync = ( + nullcontext() + if (not ddp or micro_step == grad_accum - 1) + else model.no_sync() ) - loss = loss / grad_accum - - loss.backward() - loss_accum += loss.item() + with sync, amp_ctx: + logits = model(x) + loss = nn.functional.cross_entropy( + logits.view(-1, vocab_size), y.view(-1) + ) + loss = loss / grad_accum + + # NaN / Inf guard: a single corrupt sample (or a spike in + # the MoE routing) can produce a non-finite loss. Propagating + # it into backward() poisons every param's grad and — worse + # — the Adam second-moment buffers, which is unrecoverable + # without rewinding to a prior checkpoint. Detect it here + # and skip the backward entirely. + if not torch.isfinite(loss): + bad_microsteps += 1 + logger.warning( + f"non-finite loss at step {step}.{micro_step} " + f"(value={loss.item()}) — skipping backward for this microstep" + ) + continue + + if scaler is not None: + scaler.scale(loss).backward() + else: + loss.backward() + + loss_accum = loss_accum + loss.detach() + + # If every microstep was non-finite, skip the optimizer step + # entirely — grads are either empty (set_to_none=True above) or + # untouched, and advancing Adam on them would taint the moment + # buffers. LR/step still tick so the schedule stays monotonic. + if bad_microsteps == grad_accum: + skipped_steps += 1 + step += 1 + if master: + logger.error( + f"step {step}: ALL {grad_accum} microsteps non-finite — " + f"skipping optimizer.step() " + f"(total skipped: {skipped_steps})" + ) + continue + + # FSDP shards parameters, so `nn.utils.clip_grad_norm_` would clip + # against each rank's local norm and miss the cross-shard gather. + # FSDP.clip_grad_norm_ computes the true global norm and returns it. + if scaler is not None: + # Unscale in-place so clip_grad_norm_ sees true-magnitude grads. + scaler.unscale_(optimizer) + + if ddp: + grad_norm = model.clip_grad_norm_(grad_clip) + else: + grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip) + + # One more non-finite check, this time on the global grad norm — + # an Inf/NaN here means clip didn't rescue us and stepping would + # taint Adam state. ShardedGradScaler.step() handles this for the + # fp16 path, but we enforce it uniformly. + if not torch.isfinite(grad_norm): + skipped_steps += 1 + step += 1 + if master: + logger.error( + f"step {step}: non-finite grad_norm={float(grad_norm)} — " + f"skipping optimizer.step() (total skipped: {skipped_steps})" + ) + if scaler is not None: + scaler.update() # keep the scaler bookkeeping consistent + continue + + if scaler is not None: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + + # Drive the aux-loss-free router bias update. Without this call, + # the router_bias buffer stays at zeros forever and the balancing + # mechanism is inert — a silent correctness bug. `ddp=ddp` asks + # the OpenMythos wrapper to all-reduce expert_load across ranks + # before each local update. + base_model = model.module if ddp and hasattr(model, "module") else model + if hasattr(base_model, "update_router_biases"): + base_model.update_router_biases(ddp=ddp) + + step += 1 + + if master and step % log_every == 0: + dt = time.perf_counter() - t0 + tok_per_sec = global_batch_tok * log_every / max(dt, 1e-9) + tokens_seen = step * global_batch_tok + loss_val = float(loss_accum.detach().item()) + logger.info( + f"step {step:6d}/{total_steps} | loss {loss_val:.4f} " + f"| gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} " + f"| {tok_per_sec / 1e6:.2f}M tok/s " + f"| {tokens_seen / 1e9:.1f}B tokens seen " + f"| skipped={skipped_steps}" + ) + t0 = time.perf_counter() + + if step % ckpt_every == 0: + save_checkpoint( + model, + optimizer, + scaler, + step, + cfg, + vocab_size, + seed, + ckpt_dir, + ddp, + master, + ) + except Exception: + # Log the traceback from every rank so a crash from one rank is + # recoverable post-mortem; then re-raise so the exit code is non-zero. + logger.exception(f"Training crashed at step {step}") + raise + finally: + # Final save covers two cases: normal completion when + # `step` isn't aligned to ckpt_every, *and* SIGTERM-driven early + # exit. Either way, we want the most recent state on disk. + try: + if step > start_step and ( + step % ckpt_every != 0 or ShutdownFlag.requested + ): + save_checkpoint( + model, + optimizer, + scaler, + step, + cfg, + vocab_size, + seed, + ckpt_dir, + ddp, + master, + ) + except Exception: + logger.exception("Final checkpoint save failed") - # FSDP shards parameters, so `nn.utils.clip_grad_norm_` would clip - # against each rank's local norm and miss the cross-shard gather. - # FSDP.clip_grad_norm_ computes the true global norm and returns it. if ddp: - grad_norm = model.clip_grad_norm_(1.0) - else: - grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) - optimizer.step() - step += 1 - - if master and step % log_every == 0: - dt = time.perf_counter() - t0 - tok_per_sec = global_batch_tok * log_every / dt - tokens_seen = step * global_batch_tok - logger.info( - f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} " - f"| gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} " - f"| {tok_per_sec / 1e6:.2f}M tok/s " - f"| {tokens_seen / 1e9:.1f}B tokens seen" - ) - t0 = time.perf_counter() - - if step % ckpt_every == 0: - save_checkpoint( - model, optimizer, step, cfg, vocab_size, ckpt_dir, ddp, master - ) - - # Final checkpoint — total_steps may not be divisible by ckpt_every, so - # without this the tail of the run is lost if the schedule doesn't align. - if step > start_step and step % ckpt_every != 0: - save_checkpoint(model, optimizer, step, cfg, vocab_size, ckpt_dir, ddp, master) - - if ddp: - # Barrier so no rank exits while another is still finishing its - # checkpoint gather — avoids NCCL "process group destroyed" noise. - dist.barrier() - dist.destroy_process_group() + # Barrier so no rank exits while another is still finishing its + # checkpoint gather — avoids NCCL "process group destroyed" noise. + try: + dist.barrier() + finally: + dist.destroy_process_group() if master: - logger.success("Training complete.") + if ShutdownFlag.requested: + logger.warning(f"Training stopped early at step {step} by signal") + else: + logger.success(f"Training complete at step {step}") + + # Non-zero exit on SIGTERM so a supervisor (k8s, slurm) sees the job + # as interrupted rather than successfully completed. + if ShutdownFlag.requested: + sys.exit(130) if __name__ == "__main__": diff --git a/training/requirements.txt b/training/requirements.txt index e3348c5..a86c519 100644 --- a/training/requirements.txt +++ b/training/requirements.txt @@ -1,4 +1,24 @@ -torch>=2.11.0 -datasets>=3.6.0 -loguru>=0.7.3 -open-mythos \ No newline at end of file +# Pretraining-run requirements. +# +# Pinned to exact patch versions for *training* reproducibility across +# nodes and weeks. Library consumers should use the top-level +# requirements.txt, which uses compatible ranges instead. +# +# When bumping torch here, also bump pyproject.toml's torch constraint +# and rerun `python -m open_mythos` smoke tests to confirm the new +# version still exposes: +# - torch.distributed.fsdp.ShardingStrategy.FULL_SHARD +# - torch.distributed.fsdp.sharded_grad_scaler.ShardedGradScaler +# - torch.amp.autocast(device_type=..., dtype=...) + +--extra-index-url https://download.pytorch.org/whl/cu124 + +torch==2.11.0 +transformers==4.46.3 +datasets==3.2.0 +numpy>=1.26,<3.0 +loguru==0.7.3 + +# install the local package in editable mode when running from a clone: +# pip install -r training/requirements.txt +# pip install -e .