Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,13 @@ Key design choices:

| Feature | Detail |
|---|---|
| Optimizer | Muon for 2D weight matrices, AdamW for embeddings/norms |
| Optimizer | AdamW |
| Dataset | `HuggingFaceFW/fineweb-edu` (`sample-10BT` by default, swap to `sample-100BT` or `default` for full run) |
| Tokenizer | `openai/gpt-oss-20b` via `MythosTokenizer` |
| Parallelism | PyTorch DDP via `torchrun`, sharded streaming dataset |
| Precision | bfloat16 on H100/A100, float16 + GradScaler on older GPUs |
| Precision | bfloat16 when supported; float16 + GradScaler on single-GPU older cards |
| Schedule | Linear warmup (2000 steps) → cosine decay |
| Validation | Periodic val loss + perplexity reporting during training |
| Target | 30B tokens (~Chinchilla-adjusted for looped architecture) |

---
Expand Down
97 changes: 92 additions & 5 deletions training/3b_fine_web_edu.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,29 @@


class FineWebEduDataset(IterableDataset):
def __init__(self, encoding, seq_len: int, subset: str, rank: int, world_size: int):
def __init__(
self,
encoding,
seq_len: int,
subset: str,
rank: int,
world_size: int,
shard_offset: int = 0,
):
self.encoding = encoding
self.seq_len = seq_len
self.subset = subset
self.rank = rank
self.world_size = world_size
self.shard_offset = shard_offset

def __iter__(self):
worker = get_worker_info()
num_workers = worker.num_workers if worker else 1
worker_id = worker.id if worker else 0

total_shards = self.world_size * num_workers
shard_index = self.rank * num_workers + worker_id
shard_index = (self.rank * num_workers + worker_id + self.shard_offset) % total_shards

ds = load_dataset(
"HuggingFaceFW/fineweb-edu",
Expand Down Expand Up @@ -88,6 +97,48 @@ def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) ->
return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay))


@torch.no_grad()
def evaluate_loss(
model: nn.Module,
val_loader: DataLoader,
steps: int,
vocab_size: int,
amp_ctx,
ddp: bool,
device: str,
local_rank: int,
) -> float:
"""Estimate validation loss over a small number of micro-batches."""
model_was_training = model.training
model.eval()

val_iter = iter(val_loader)
losses = []
for _ in range(steps):
try:
x, y = next(val_iter)
except StopIteration:
val_iter = iter(val_loader)
x, y = next(val_iter)

x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)

with amp_ctx:
logits = model(x)
loss = nn.functional.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
losses.append(loss.detach())

loss_tensor = torch.stack(losses).mean()
if ddp:
dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
loss_tensor = loss_tensor / dist.get_world_size()

if model_was_training:
model.train()
return float(loss_tensor.item())


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -138,6 +189,8 @@ def main():
wd = 0.1
log_every = 10
ckpt_every = 1000
val_every = 200
val_steps = 20
ckpt_dir = "checkpoints"
dataset_subset = "sample-10BT" # → sample-100BT or "default" for full run

Expand Down Expand Up @@ -194,12 +247,20 @@ def main():
optimizer = torch.optim.AdamW(
model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
)
use_grad_scaler = (amp_dtype == torch.float16) and ("cuda" in device) and (not ddp)
scaler = torch.amp.GradScaler("cuda", enabled=use_grad_scaler)

# ------------------------------------------------------------------
# Dataset + DataLoader
# ------------------------------------------------------------------
dataset = FineWebEduDataset(encoding, seq_len, dataset_subset, rank, world_size)
loader = DataLoader(dataset, batch_size=micro_batch, num_workers=4, pin_memory=True)
val_dataset = FineWebEduDataset(
encoding, seq_len, dataset_subset, rank, world_size, shard_offset=1
)
val_loader = DataLoader(
val_dataset, batch_size=micro_batch, num_workers=2, pin_memory=True
)

# ------------------------------------------------------------------
# Training loop
Expand Down Expand Up @@ -242,11 +303,20 @@ def main():
)
loss = loss / grad_accum

loss.backward()
if scaler.is_enabled():
scaler.scale(loss).backward()
else:
loss.backward()
loss_accum += loss.item()

nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
if scaler.is_enabled():
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
else:
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
step += 1

if master and step % log_every == 0:
Expand All @@ -260,6 +330,23 @@ def main():
)
t0 = time.perf_counter()

if step % val_every == 0:
val_loss = evaluate_loss(
model=model,
val_loader=val_loader,
steps=val_steps,
vocab_size=vocab_size,
amp_ctx=amp_ctx,
ddp=ddp,
device=device,
local_rank=local_rank,
)
if master:
val_ppl = math.exp(min(val_loss, 20.0))
print(
f"validation | step {step:6d}/{total_steps} | val_loss {val_loss:.4f} | val_ppl {val_ppl:.2f}"
)

if master and step % ckpt_every == 0:
path = os.path.join(ckpt_dir, f"step_{step:07d}.pt")
if ddp:
Expand Down