From 5b221cc03b4cae6b40190e3067ebc7b215e7ed33 Mon Sep 17 00:00:00 2001 From: xuzhemin <757583912@qq.com> Date: Mon, 20 Apr 2026 17:13:37 +0800 Subject: [PATCH] align training recipe and add validation logging Add a float16 GradScaler path for single-GPU training, include periodic validation loss/perplexity estimation, and update README training notes to match the current optimizer and precision behavior. Made-with: Cursor --- README.md | 5 +- training/3b_fine_web_edu.py | 97 +++++++++++++++++++++++++++++++++++-- 2 files changed, 95 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index b1d9984..a8cadc8 100644 --- a/README.md +++ b/README.md @@ -151,12 +151,13 @@ Key design choices: | Feature | Detail | |---|---| -| Optimizer | Muon for 2D weight matrices, AdamW for embeddings/norms | +| Optimizer | AdamW | | Dataset | `HuggingFaceFW/fineweb-edu` (`sample-10BT` by default, swap to `sample-100BT` or `default` for full run) | | Tokenizer | `openai/gpt-oss-20b` via `MythosTokenizer` | | Parallelism | PyTorch DDP via `torchrun`, sharded streaming dataset | -| Precision | bfloat16 on H100/A100, float16 + GradScaler on older GPUs | +| Precision | bfloat16 when supported; float16 + GradScaler on single-GPU older cards | | Schedule | Linear warmup (2000 steps) → cosine decay | +| Validation | Periodic val loss + perplexity reporting during training | | Target | 30B tokens (~Chinchilla-adjusted for looped architecture) | --- diff --git a/training/3b_fine_web_edu.py b/training/3b_fine_web_edu.py index 215381d..794618b 100644 --- a/training/3b_fine_web_edu.py +++ b/training/3b_fine_web_edu.py @@ -40,12 +40,21 @@ class FineWebEduDataset(IterableDataset): - def __init__(self, encoding, seq_len: int, subset: str, rank: int, world_size: int): + def __init__( + self, + encoding, + seq_len: int, + subset: str, + rank: int, + world_size: int, + shard_offset: int = 0, + ): self.encoding = encoding self.seq_len = seq_len self.subset = subset self.rank = rank self.world_size = world_size + self.shard_offset = shard_offset def __iter__(self): worker = get_worker_info() @@ -53,7 +62,7 @@ def __iter__(self): worker_id = worker.id if worker else 0 total_shards = self.world_size * num_workers - shard_index = self.rank * num_workers + worker_id + shard_index = (self.rank * num_workers + worker_id + self.shard_offset) % total_shards ds = load_dataset( "HuggingFaceFW/fineweb-edu", @@ -88,6 +97,48 @@ def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay)) +@torch.no_grad() +def evaluate_loss( + model: nn.Module, + val_loader: DataLoader, + steps: int, + vocab_size: int, + amp_ctx, + ddp: bool, + device: str, + local_rank: int, +) -> float: + """Estimate validation loss over a small number of micro-batches.""" + model_was_training = model.training + model.eval() + + val_iter = iter(val_loader) + losses = [] + for _ in range(steps): + try: + x, y = next(val_iter) + except StopIteration: + val_iter = iter(val_loader) + x, y = next(val_iter) + + x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True) + y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True) + + with amp_ctx: + logits = model(x) + loss = nn.functional.cross_entropy(logits.view(-1, vocab_size), y.view(-1)) + losses.append(loss.detach()) + + loss_tensor = torch.stack(losses).mean() + if ddp: + dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM) + loss_tensor = loss_tensor / dist.get_world_size() + + if model_was_training: + model.train() + return float(loss_tensor.item()) + + # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- @@ -138,6 +189,8 @@ def main(): wd = 0.1 log_every = 10 ckpt_every = 1000 + val_every = 200 + val_steps = 20 ckpt_dir = "checkpoints" dataset_subset = "sample-10BT" # → sample-100BT or "default" for full run @@ -194,12 +247,20 @@ def main(): optimizer = torch.optim.AdamW( model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True ) + use_grad_scaler = (amp_dtype == torch.float16) and ("cuda" in device) and (not ddp) + scaler = torch.amp.GradScaler("cuda", enabled=use_grad_scaler) # ------------------------------------------------------------------ # Dataset + DataLoader # ------------------------------------------------------------------ dataset = FineWebEduDataset(encoding, seq_len, dataset_subset, rank, world_size) loader = DataLoader(dataset, batch_size=micro_batch, num_workers=4, pin_memory=True) + val_dataset = FineWebEduDataset( + encoding, seq_len, dataset_subset, rank, world_size, shard_offset=1 + ) + val_loader = DataLoader( + val_dataset, batch_size=micro_batch, num_workers=2, pin_memory=True + ) # ------------------------------------------------------------------ # Training loop @@ -242,11 +303,20 @@ def main(): ) loss = loss / grad_accum - loss.backward() + if scaler.is_enabled(): + scaler.scale(loss).backward() + else: + loss.backward() loss_accum += loss.item() - nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) - optimizer.step() + if scaler.is_enabled(): + scaler.unscale_(optimizer) + nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + scaler.step(optimizer) + scaler.update() + else: + nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() step += 1 if master and step % log_every == 0: @@ -260,6 +330,23 @@ def main(): ) t0 = time.perf_counter() + if step % val_every == 0: + val_loss = evaluate_loss( + model=model, + val_loader=val_loader, + steps=val_steps, + vocab_size=vocab_size, + amp_ctx=amp_ctx, + ddp=ddp, + device=device, + local_rank=local_rank, + ) + if master: + val_ppl = math.exp(min(val_loss, 20.0)) + print( + f"validation | step {step:6d}/{total_steps} | val_loss {val_loss:.4f} | val_ppl {val_ppl:.2f}" + ) + if master and step % ckpt_every == 0: path = os.path.join(ckpt_dir, f"step_{step:07d}.pt") if ddp: