diff --git a/scripts/train.py b/scripts/train.py index f35e204..4b1433a 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -66,7 +66,14 @@ def main(cfg: DictConfig) -> None: ) collator_class = COLLATOR_REGISTRY.get(cfg.collator.name) - collator = collator_class(MLMCollatorConfig.from_tokenizer(tokenizer)) + collator = collator_class( + MLMCollatorConfig.from_tokenizer( + tokenizer, + seed=cfg.training.seed, + mlm_probability=cfg.collator.mlm_probability, + epoch=0, + ) + ) task = TASK_REGISTRY.get(cfg.task.name)() model = MODEL_REGISTRY.get(cfg.model.name)( @@ -86,6 +93,8 @@ def main(cfg: DictConfig) -> None: callbacks.append(cb_cls(**cb_kwargs)) sampler = ResumableSampler(dataset, seed=cfg.training.seed) + if hasattr(collator, "set_epoch"): + collator.set_epoch(sampler.epoch) loader = create_dataloader( dataset=dataset, @@ -96,6 +105,7 @@ def main(cfg: DictConfig) -> None: prefetch_factor=cfg.training.prefetch_factor, pin_memory=torch.cuda.is_available(), drop_last=True, + seed=cfg.training.seed, ), sampler=sampler, ) @@ -118,6 +128,7 @@ def main(cfg: DictConfig) -> None: pin_memory=torch.cuda.is_available(), drop_last=False, shuffle=False, + seed=cfg.training.seed, ), ) diff --git a/src/embedding_trainer/data/collators/mlm.py b/src/embedding_trainer/data/collators/mlm.py index 468af60..70ac527 100644 --- a/src/embedding_trainer/data/collators/mlm.py +++ b/src/embedding_trainer/data/collators/mlm.py @@ -18,9 +18,17 @@ class MLMCollatorConfig(BaseCollatorConfig): """Configuration for MLM collator.""" mlm_probability: float = 0.30 # Probability of replacing with [MASK] + seed: int = 0 + epoch: int = 0 @staticmethod - def from_tokenizer(tokenizer: PreTrainedTokenizerBase) -> MLMCollatorConfig: + def from_tokenizer( + tokenizer: PreTrainedTokenizerBase, + *, + seed: int = 0, + mlm_probability: float = 0.30, + epoch: int = 0, + ) -> MLMCollatorConfig: """Create config from a Hugging Face Tokenizer.""" return MLMCollatorConfig( pad_token_id=tokenizer.pad_token_id, @@ -28,6 +36,9 @@ def from_tokenizer(tokenizer: PreTrainedTokenizerBase) -> MLMCollatorConfig: sep_token_id=tokenizer.sep_token_id, mask_token_id=tokenizer.mask_token_id, vocab_size=tokenizer.vocab_size, + seed=seed, + mlm_probability=mlm_probability, + epoch=epoch, ) @@ -44,6 +55,12 @@ class MLMCollator(BaseCollator): def __init__(self, config: MLMCollatorConfig) -> None: super().__init__(config) self.mlm_probability = config.mlm_probability + self.seed = config.seed + self.epoch = config.epoch + + def set_epoch(self, epoch: int) -> None: + """Set epoch used by deterministic masking.""" + self.epoch = epoch def __call__(self, samples: list[dict[str, Tensor]]) -> PreTokenizedBatch: """ @@ -57,6 +74,7 @@ def __call__(self, samples: list[dict[str, Tensor]]) -> PreTokenizedBatch: """ # Extract input_ids input_ids_list = [sample["input_ids"] for sample in samples] + sample_indices = self._get_sample_indices(samples) # Pad sequences input_ids = self._pad_sequences(input_ids_list, self.pad_token_id) @@ -65,7 +83,7 @@ def __call__(self, samples: list[dict[str, Tensor]]) -> PreTokenizedBatch: attention_mask = self._create_attention_mask(input_ids) # Apply MLM masking - input_ids, labels = self._mask_tokens(input_ids) + input_ids, labels = self._mask_tokens(input_ids, sample_indices) return { "input_ids": input_ids, @@ -76,6 +94,7 @@ def __call__(self, samples: list[dict[str, Tensor]]) -> PreTokenizedBatch: def _mask_tokens( self, input_ids: torch.Tensor, + sample_indices: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Apply MLM masking to input_ids. @@ -90,22 +109,84 @@ def _mask_tokens( labels = input_ids.clone() input_ids = input_ids.clone() - # Create probability matrix for masking - probability_matrix = torch.full(input_ids.shape, self.mlm_probability) - - # Get special tokens mask and set their probability to 0 + if sample_indices is None: + sample_indices = torch.arange( + input_ids.shape[0], + device=input_ids.device, + dtype=torch.int64, + ) + else: + sample_indices = sample_indices.to( + device=input_ids.device, dtype=torch.int64 + ) + if ( + sample_indices.ndim != 1 + or sample_indices.shape[0] != input_ids.shape[0] + ): + raise ValueError( + "sample_indices must be a 1D tensor with one entry per batch row." + ) + + # Get special tokens mask special_tokens_mask = self._get_special_tokens_mask(input_ids) - probability_matrix.masked_fill_(special_tokens_mask, 0.0) - # Also don't mask padding + # Also don't mask padding. padding_mask = input_ids == self.pad_token_id - probability_matrix.masked_fill_(padding_mask, 0.0) - - # Sample which tokens to mask - masked_indices = torch.bernoulli(probability_matrix).bool() + eligible_mask = ~(special_tokens_mask | padding_mask) + + if self.mlm_probability <= 0.0: + masked_indices = torch.zeros_like(input_ids, dtype=torch.bool) + elif self.mlm_probability >= 1.0: + masked_indices = eligible_mask + else: + p = self._deterministic_uniform( + sample_indices, input_ids.shape[1], input_ids.device + ) + masked_indices = (p < self.mlm_probability) & eligible_mask # Set labels to -100 for non-masked tokens (will be ignored in loss) labels[~masked_indices] = LABEL_IGNORE_ID input_ids[masked_indices] = self.mask_token_id # Default to [MASK] return input_ids, labels + + def _get_sample_indices(self, samples: list[dict[str, Tensor]]) -> torch.Tensor: + sample_indices: list[int] = [] + for row_idx, sample in enumerate(samples): + raw_idx = sample.get("sample_idx") + if raw_idx is None: + sample_indices.append(row_idx) + continue + if isinstance(raw_idx, torch.Tensor): + if raw_idx.numel() != 1: + raise ValueError( + "sample_idx tensor must contain a single scalar value." + ) + sample_indices.append(int(raw_idx.item())) + continue + sample_indices.append(int(raw_idx)) + return torch.tensor(sample_indices, dtype=torch.int64) + + def _deterministic_uniform( + self, + sample_indices: torch.Tensor, + seq_len: int, + device: torch.device, + ) -> torch.Tensor: + """Generate deterministic pseudo-random values in [0, 1).""" + modulus = 2_147_483_647 # 2^31 - 1, prime. + + sample_grid = sample_indices.unsqueeze(1) + position_grid = torch.arange( + seq_len, device=device, dtype=torch.int64 + ).unsqueeze(0) + + x = ( + sample_grid * 1_048_583 + + position_grid * 1_308_049 + + (self.seed % modulus) * 4_447_961 + + (self.epoch % modulus) * 22_695_477 + + 1_013_904_223 + ) % modulus + x = (x * 48_271) % modulus + return x.to(torch.float64) / float(modulus) diff --git a/src/embedding_trainer/data/datasets/flat_tokens.py b/src/embedding_trainer/data/datasets/flat_tokens.py index 396f5f6..0df2dc9 100644 --- a/src/embedding_trainer/data/datasets/flat_tokens.py +++ b/src/embedding_trainer/data/datasets/flat_tokens.py @@ -243,6 +243,7 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: return { # Convert to int32 tensor for PyTorch (torch.uint16 has limited support) "input_ids": torch.from_numpy(seq_tokens.astype(np.int32)), + "sample_idx": torch.tensor(idx, dtype=torch.int64), } def set_epoch(self, epoch: int) -> None: diff --git a/src/embedding_trainer/data/loader.py b/src/embedding_trainer/data/loader.py index f5a2b66..48982c0 100644 --- a/src/embedding_trainer/data/loader.py +++ b/src/embedding_trainer/data/loader.py @@ -2,10 +2,13 @@ from __future__ import annotations +import random from collections.abc import Callable from dataclasses import dataclass +from functools import partial from typing import Any +import numpy as np import torch from torch.utils.data import DataLoader, Dataset, Sampler from torch.utils.data.distributed import DistributedSampler @@ -23,6 +26,7 @@ class DataLoaderConfig: pin_memory: bool = True drop_last: bool = True shuffle: bool = True + seed: int = 0 def __post_init__(self) -> None: if self.batch_size < 1: @@ -41,20 +45,37 @@ def _build_loader_kwargs( collator: CollatorProtocol | Callable[[list[dict[str, Any]]], Any], ) -> dict[str, Any]: """Build common DataLoader kwargs.""" + loader_generator = torch.Generator() + loader_generator.manual_seed(config.seed) + loader_kwargs: dict[str, Any] = { "batch_size": config.batch_size, "collate_fn": collator, "num_workers": config.num_workers, "pin_memory": config.pin_memory and torch.cuda.is_available(), "drop_last": config.drop_last, + "generator": loader_generator, } if config.num_workers > 0: loader_kwargs["prefetch_factor"] = config.prefetch_factor + loader_kwargs["worker_init_fn"] = _build_worker_init_fn(config.seed) return loader_kwargs +def _seed_worker(worker_id: int, seed: int) -> None: + worker_seed = seed + worker_id + random.seed(worker_seed) + np.random.seed(worker_seed % (2**32)) + torch.manual_seed(worker_seed) + + +def _build_worker_init_fn(seed: int) -> Callable[[int], None]: + """Build a worker seeding function for deterministic worker-local RNG.""" + return partial(_seed_worker, seed=seed) + + def create_dataloader( dataset: DatasetProtocol | Dataset[Any], collator: CollatorProtocol | Callable[[list[dict[str, Any]]], Any], diff --git a/src/embedding_trainer/data/sampler.py b/src/embedding_trainer/data/sampler.py index 3a41c91..b650e38 100644 --- a/src/embedding_trainer/data/sampler.py +++ b/src/embedding_trainer/data/sampler.py @@ -41,6 +41,11 @@ def _generate_permutation(self) -> None: def __len__(self) -> int: return self._num_samples + @property + def epoch(self) -> int: + """Current sampler epoch.""" + return self._epoch + def __iter__(self) -> Iterator[int]: # Yield remaining indices from the current epoch. # No state mutation here — advance() and start_new_epoch() handle that. @@ -83,4 +88,11 @@ def start_new_epoch(self) -> None: def advance(self, n: int) -> None: """Advance the position by *n* samples (call once per batch).""" + if n < 0: + raise ValueError(f"advance expects non-negative n, got {n}") self._start_index += n + if self._start_index > self._num_samples: + raise ValueError( + "ResumableSampler.advance moved past end of epoch " + f"({self._start_index} > {self._num_samples})." + ) diff --git a/src/embedding_trainer/training/trainer.py b/src/embedding_trainer/training/trainer.py index 2c6f14c..a7fbd07 100644 --- a/src/embedding_trainer/training/trainer.py +++ b/src/embedding_trainer/training/trainer.py @@ -1,9 +1,12 @@ from __future__ import annotations import math +import random import re from pathlib import Path +from typing import Any +import numpy as np import torch from embedding_trainer.core import BaseCallback, BaseTask, EvalOutput, TrainOutput @@ -55,6 +58,7 @@ def train(self) -> TrainOutput: callback.on_train_begin(self) while self.global_step < self.max_steps: + self._sync_data_epoch() for batch in self.loader: if self.global_step >= self.max_steps: break @@ -159,6 +163,7 @@ def save_checkpoint(self, path: str | Path) -> None: "global_step": self.global_step, "cumulative_loss_sum": self.cumulative_loss_sum, "cumulative_loss_steps": self.cumulative_loss_steps, + **self._capture_rng_state(), } if self.scheduler is not None: checkpoint["scheduler_state_dict"] = self.scheduler.state_dict() @@ -182,6 +187,8 @@ def load_checkpoint(self, path: str | Path) -> None: self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) if self.sampler is not None and "sampler_state_dict" in checkpoint: self.sampler.load_state_dict(checkpoint["sampler_state_dict"]) + self._restore_rng_state(checkpoint) + self._sync_data_epoch() @staticmethod def latest_checkpoint(checkpoint_dir: str | Path) -> Path | None: @@ -197,3 +204,76 @@ def latest_checkpoint(checkpoint_dir: str | Path) -> Path | None: if best is None or step > best[0]: best = (step, p) return best[1] if best is not None else None + + def _capture_rng_state(self) -> dict[str, Any]: + numpy_state = np.random.get_state() + state: dict[str, Any] = { + "python_rng_state": random.getstate(), + "numpy_rng_state": { + "bit_generator": numpy_state[0], + "state": numpy_state[1].tolist(), + "pos": int(numpy_state[2]), + "has_gauss": int(numpy_state[3]), + "cached_gaussian": float(numpy_state[4]), + }, + "torch_rng_state": torch.get_rng_state(), + } + if torch.cuda.is_available(): + state["torch_cuda_rng_state_all"] = torch.cuda.get_rng_state_all() + return state + + def _restore_rng_state(self, checkpoint: dict[str, Any]) -> None: + if "python_rng_state" in checkpoint: + random.setstate(checkpoint["python_rng_state"]) + + numpy_state = checkpoint.get("numpy_rng_state") + if numpy_state is not None: + if isinstance(numpy_state, dict): + np_state = ( + numpy_state["bit_generator"], + np.asarray(numpy_state["state"], dtype=np.uint32), + int(numpy_state["pos"]), + int(numpy_state["has_gauss"]), + float(numpy_state["cached_gaussian"]), + ) + else: + # Backward-compatible fallback for tuple/list payloads. + np_state = ( + numpy_state[0], + np.asarray(numpy_state[1], dtype=np.uint32), + int(numpy_state[2]), + int(numpy_state[3]), + float(numpy_state[4]), + ) + np.random.set_state(np_state) + + if "torch_rng_state" in checkpoint: + torch_rng_state = checkpoint["torch_rng_state"] + if isinstance(torch_rng_state, torch.Tensor): + torch_rng_state = torch_rng_state.detach().to( + device="cpu", dtype=torch.uint8 + ) + torch.set_rng_state(torch_rng_state) + + if torch.cuda.is_available() and "torch_cuda_rng_state_all" in checkpoint: + cuda_states = checkpoint["torch_cuda_rng_state_all"] + normalized_states = [] + for state in cuda_states: + if isinstance(state, torch.Tensor): + normalized_states.append( + state.detach().to(device="cpu", dtype=torch.uint8) + ) + else: + normalized_states.append(state) + torch.cuda.set_rng_state_all(normalized_states) + + def _sync_data_epoch(self) -> None: + if self.sampler is None: + return + epoch = self.sampler.epoch + dataset = getattr(self.loader, "dataset", None) + if dataset is not None and hasattr(dataset, "set_epoch"): + dataset.set_epoch(epoch) + collate_fn = getattr(self.loader, "collate_fn", None) + if collate_fn is not None and hasattr(collate_fn, "set_epoch"): + collate_fn.set_epoch(epoch) diff --git a/tests/data/collators/test_mlm.py b/tests/data/collators/test_mlm.py index bd18b0c..27d16be 100644 --- a/tests/data/collators/test_mlm.py +++ b/tests/data/collators/test_mlm.py @@ -38,8 +38,13 @@ def collator(mlm_config: MLMCollatorConfig) -> MLMCollator: return MLMCollator(mlm_config) -def make_sample(token_ids: list[int]) -> dict[str, torch.Tensor]: - return {"input_ids": torch.tensor(token_ids, dtype=torch.long)} +def make_sample( + token_ids: list[int], *, sample_idx: int = 0 +) -> dict[str, torch.Tensor]: + return { + "input_ids": torch.tensor(token_ids, dtype=torch.long), + "sample_idx": torch.tensor(sample_idx, dtype=torch.int64), + } # --------------------------------------------------------------------------- @@ -127,7 +132,6 @@ def test_special_tokens_never_masked(self, collator: MLMCollator) -> None: assert (labels == LABEL_IGNORE_ID).all().item() def test_labels_minus_100_for_unmasked(self, collator: MLMCollator) -> None: - torch.manual_seed(0) input_ids = self._make_input([CLS_ID, 10, 20, 30, 40, SEP_ID, PAD_ID, PAD_ID]) _, labels = collator._mask_tokens(input_ids) # Positions not replaced by [MASK] must have label LABEL_IGNORE_ID @@ -136,6 +140,31 @@ def test_labels_minus_100_for_unmasked(self, collator: MLMCollator) -> None: if not masked_positions[0, pos]: assert labels[0, pos].item() == LABEL_IGNORE_ID + def test_deterministic_for_same_sample_and_epoch( + self, collator: MLMCollator + ) -> None: + input_ids = self._make_input([CLS_ID] + list(range(10, 74)) + [SEP_ID]) + sample_idx = torch.tensor([123], dtype=torch.int64) + + masked_1, labels_1 = collator._mask_tokens(input_ids, sample_idx) + masked_2, labels_2 = collator._mask_tokens(input_ids, sample_idx) + + assert torch.equal(masked_1, masked_2) + assert torch.equal(labels_1, labels_2) + + def test_epoch_changes_mask_pattern(self, collator: MLMCollator) -> None: + input_ids = self._make_input([CLS_ID] + list(range(10, 266)) + [SEP_ID]) + sample_idx = torch.tensor([55], dtype=torch.int64) + + collator.set_epoch(0) + _, labels_epoch0 = collator._mask_tokens(input_ids, sample_idx) + collator.set_epoch(1) + _, labels_epoch1 = collator._mask_tokens(input_ids, sample_idx) + + masked0 = labels_epoch0 != LABEL_IGNORE_ID + masked1 = labels_epoch1 != LABEL_IGNORE_ID + assert not torch.equal(masked0, masked1) + def test_masked_positions_get_mask_token(self, collator: MLMCollator) -> None: # Force all eligible tokens to be masked high_prob_collator = MLMCollator( @@ -174,7 +203,6 @@ def test_labels_preserve_original_for_masked(self, collator: MLMCollator) -> Non assert labels[0, 2].item() == 20 def test_masking_probability(self, collator: MLMCollator) -> None: - torch.manual_seed(42) # Large sequence of eligible tokens to get stable statistics eligible_tokens = list(range(10, 90)) # 80 tokens, none are special input_ids = self._make_input([CLS_ID] + eligible_tokens + [SEP_ID]) diff --git a/tests/data/datasets/test_flat_tokens.py b/tests/data/datasets/test_flat_tokens.py index c7d83a0..6f50383 100644 --- a/tests/data/datasets/test_flat_tokens.py +++ b/tests/data/datasets/test_flat_tokens.py @@ -182,6 +182,7 @@ def test_getitem_returns_correct_tokens(self, tmp_path: Path) -> None: s, e = i * 512, (i + 1) * 512 expected = tokens[s:e].astype(np.int64) assert torch.equal(batch["input_ids"], torch.from_numpy(expected)) + assert batch["sample_idx"].item() == i def test_getitem_returns_correct_shape_and_dtype(self, tmp_path: Path) -> None: create_shard(tmp_path / "fineweb_train_000000.bin", num_tokens=1024) @@ -192,6 +193,7 @@ def test_getitem_returns_correct_shape_and_dtype(self, tmp_path: Path) -> None: batch = ds[0] assert batch["input_ids"].shape == (128,) assert batch["input_ids"].dtype == torch.int32 + assert batch["sample_idx"].dtype == torch.int64 def test_getitem_cross_shard_boundary(self, tmp_path: Path) -> None: """Sequences near shard boundaries should map to the correct shard.""" @@ -265,3 +267,5 @@ def test_works_with_dataloader(self, tmp_path: Path) -> None: assert batch["input_ids"].shape == (4, 128) assert batch["input_ids"].dtype == torch.int32 + assert batch["sample_idx"].shape == (4,) + assert batch["sample_idx"].dtype == torch.int64 diff --git a/tests/data/test_loader.py b/tests/data/test_loader.py index 97f9bf5..ffd5201 100644 --- a/tests/data/test_loader.py +++ b/tests/data/test_loader.py @@ -2,6 +2,9 @@ from __future__ import annotations +import random + +import numpy as np import pytest import torch from torch import Tensor @@ -72,6 +75,33 @@ def test_num_workers_zero_ignores_prefetch_factor(self) -> None: ) batch = next(iter(loader)) assert batch["input_ids"].shape == (4,) + assert loader.generator is not None + + def test_worker_init_fn_is_set_when_num_workers_positive(self) -> None: + ds = DummyDataset() + loader = create_dataloader( + ds, + collator, + DataLoaderConfig(num_workers=2, batch_size=4, seed=123), + ) + assert loader.worker_init_fn is not None + + def test_worker_init_fn_seeds_rng_deterministically(self) -> None: + ds = DummyDataset() + loader = create_dataloader( + ds, + collator, + DataLoaderConfig(num_workers=1, batch_size=4, seed=9876), + ) + assert loader.worker_init_fn is not None + + loader.worker_init_fn(3) + first = (random.random(), float(np.random.rand()), float(torch.rand(()))) + + loader.worker_init_fn(3) + second = (random.random(), float(np.random.rand()), float(torch.rand(()))) + + assert first == second class TestCreateDistributedDataLoader: diff --git a/tests/data/test_sampler.py b/tests/data/test_sampler.py index 977d53b..20271cc 100644 --- a/tests/data/test_sampler.py +++ b/tests/data/test_sampler.py @@ -2,6 +2,8 @@ from __future__ import annotations +import pytest + from embedding_trainer.data.sampler import ResumableSampler @@ -106,12 +108,14 @@ def test_epoch_advances_after_start_new_epoch(self) -> None: s = ResumableSampler(ds, seed=0) assert s.state_dict()["epoch"] == 0 + assert s.epoch == 0 # Exhaust epoch 0 and explicitly advance _ = list(s) s.start_new_epoch() assert s.state_dict()["epoch"] == 1 + assert s.epoch == 1 assert s.state_dict()["index"] == 0 def test_cross_epoch_permutations_differ(self) -> None: @@ -126,3 +130,15 @@ def test_cross_epoch_permutations_differ(self) -> None: assert epoch0 != epoch1 assert sorted(epoch0) == sorted(epoch1) + + def test_advance_rejects_negative_values(self) -> None: + ds = _FakeDataset(10) + s = ResumableSampler(ds, seed=0) + with pytest.raises(ValueError, match="non-negative"): + s.advance(-1) + + def test_advance_rejects_past_end_of_epoch(self) -> None: + ds = _FakeDataset(10) + s = ResumableSampler(ds, seed=0) + with pytest.raises(ValueError, match="past end of epoch"): + s.advance(11) diff --git a/tests/training/test_trainer.py b/tests/training/test_trainer.py new file mode 100644 index 0000000..86754bc --- /dev/null +++ b/tests/training/test_trainer.py @@ -0,0 +1,441 @@ +"""Tests for SimpleTrainer checkpoint determinism behavior.""" + +from __future__ import annotations + +import random +from pathlib import Path + +import numpy as np +import pytest +import torch +from torch import nn +from torch.utils.data import Dataset + +from embedding_trainer.core import BaseCallback, ModelOutput, TaskOutput +from embedding_trainer.core.base_task import BaseTask +from embedding_trainer.data.collators.mlm import MLMCollator, MLMCollatorConfig +from embedding_trainer.data.loader import DataLoaderConfig, create_dataloader +from embedding_trainer.data.sampler import ResumableSampler +from embedding_trainer.tasks.mlm import MaskedLanguageModelingTask +from embedding_trainer.training.trainer import SimpleTrainer + + +class _EpochDataset(Dataset[dict[str, torch.Tensor]]): + def __init__(self, size: int = 16) -> None: + self._size = size + self.epoch_calls: list[int] = [] + + def __len__(self) -> int: + return self._size + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + return { + "input_ids": torch.tensor([float(idx)], dtype=torch.float32), + "attention_mask": torch.tensor([1], dtype=torch.long), + "labels": torch.tensor([0], dtype=torch.long), + } + + def set_epoch(self, epoch: int) -> None: + self.epoch_calls.append(epoch) + + +class _EpochCollator: + def __init__(self) -> None: + self.epoch_calls: list[int] = [] + + def __call__( + self, samples: list[dict[str, torch.Tensor]] + ) -> dict[str, torch.Tensor]: + return { + "input_ids": torch.stack([sample["input_ids"] for sample in samples]), + "attention_mask": torch.stack( + [sample["attention_mask"] for sample in samples] + ), + "labels": torch.stack([sample["labels"] for sample in samples]), + } + + def set_epoch(self, epoch: int) -> None: + self.epoch_calls.append(epoch) + + +class _DummyTask(BaseTask): + def compute_loss( + self, + model: torch.nn.Module, + batch: dict[str, torch.Tensor], + device: str = "cpu", + ) -> TaskOutput: + inputs = batch["input_ids"].to(device) + predictions = model(inputs) + loss = predictions.mean() + return TaskOutput(loss=loss) + + def get_metrics(self) -> dict[str, float]: + return {} + + +class _LossRecorder(BaseCallback): + def __init__(self) -> None: + self.losses: list[float] = [] + + def on_step_end(self, trainer: SimpleTrainer, step: int, loss: float) -> None: + self.losses.append(loss) + + +class _SyntheticMLMDataset(Dataset[dict[str, torch.Tensor]]): + def __init__(self, size: int = 128, seq_len: int = 16) -> None: + self._size = size + self._seq_len = seq_len + self.epoch = 0 + + def __len__(self) -> int: + return self._size + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + values = [((idx * 37 + i * 13 + self.epoch * 17) % 50) + 4 for i in range(14)] + tokens = [1] + values + [2] + return { + "input_ids": torch.tensor(tokens, dtype=torch.long), + "sample_idx": torch.tensor(idx, dtype=torch.int64), + } + + def set_epoch(self, epoch: int) -> None: + self.epoch = epoch + + +class _TinyMLMModel(nn.Module): + def __init__(self, vocab_size: int = 64, hidden_size: int = 16) -> None: + super().__init__() + self.embedding = nn.Embedding(vocab_size, hidden_size) + self.dropout = nn.Dropout(0.2) + self.proj = nn.Linear(hidden_size, vocab_size) + + def forward( + self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None + ) -> ModelOutput: + del attention_mask + x = self.embedding(input_ids) + x = self.dropout(x) + logits = self.proj(x) + return ModelOutput(logits=logits) + + +def _make_trainer( + tmp_path: Path, + *, + dataset: _EpochDataset, + collator: _EpochCollator, + sampler: ResumableSampler, +) -> SimpleTrainer: + loader = create_dataloader( + dataset=dataset, + collator=collator, + config=DataLoaderConfig(batch_size=4, num_workers=0, shuffle=False), + sampler=sampler, + ) + model = torch.nn.Linear(1, 1, bias=False) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + return SimpleTrainer( + model=model, + task=_DummyTask(), + loader=loader, + optimizer=optimizer, + device="cpu", + max_steps=1, + checkpoint_dir=tmp_path, + save_every=None, + sampler=sampler, + ) + + +def _make_mlm_trainer( + checkpoint_dir: Path, + *, + max_steps: int, + seed: int, + num_workers: int, + save_every: int | None = None, + device: str = "cpu", +) -> tuple[SimpleTrainer, _LossRecorder]: + dataset = _SyntheticMLMDataset() + sampler = ResumableSampler(dataset, seed=seed) + collator = MLMCollator( + MLMCollatorConfig( + pad_token_id=0, + cls_token_id=1, + sep_token_id=2, + mask_token_id=3, + vocab_size=64, + mlm_probability=0.30, + seed=seed, + epoch=0, + ) + ) + collator.set_epoch(sampler.epoch) + loader = create_dataloader( + dataset=dataset, + collator=collator, + config=DataLoaderConfig( + batch_size=8, + num_workers=num_workers, + prefetch_factor=2, + drop_last=True, + shuffle=False, + seed=seed, + ), + sampler=sampler, + ) + model = _TinyMLMModel(vocab_size=64, hidden_size=16) + model.to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: 1.0) + recorder = _LossRecorder() + trainer = SimpleTrainer( + model=model, + task=MaskedLanguageModelingTask(), + loader=loader, + optimizer=optimizer, + device=device, + max_steps=max_steps, + callbacks=[recorder], + scheduler=scheduler, + checkpoint_dir=checkpoint_dir, + save_every=save_every, + sampler=sampler, + ) + return trainer, recorder + + +def _clone_state_dict(state: dict[str, object]) -> dict[str, object]: + cloned: dict[str, object] = {} + for key, value in state.items(): + if isinstance(value, torch.Tensor): + cloned[key] = value.detach().cpu().clone() + elif isinstance(value, dict): + cloned[key] = _clone_state_dict(value) + elif isinstance(value, list): + out: list[object] = [] + for item in value: + if isinstance(item, torch.Tensor): + out.append(item.detach().cpu().clone()) + elif isinstance(item, dict): + out.append(_clone_state_dict(item)) + else: + out.append(item) + cloned[key] = out + else: + cloned[key] = value + return cloned + + +def _assert_nested_equal(left: object, right: object) -> None: + if isinstance(left, torch.Tensor) and isinstance(right, torch.Tensor): + assert left.dtype == right.dtype + assert left.shape == right.shape + assert torch.equal(left.detach().cpu(), right.detach().cpu()) + return + if isinstance(left, dict) and isinstance(right, dict): + assert left.keys() == right.keys() + for key in left: + _assert_nested_equal(left[key], right[key]) + return + if isinstance(left, list) and isinstance(right, list): + assert len(left) == len(right) + for li, ri in zip(left, right, strict=True): + _assert_nested_equal(li, ri) + return + assert left == right + + +class TestSimpleTrainerCheckpoint: + def test_load_checkpoint_restores_rng_states(self, tmp_path: Path) -> None: + random.seed(111) + np.random.seed(222) + torch.manual_seed(333) + + dataset = _EpochDataset() + collator = _EpochCollator() + sampler = ResumableSampler(dataset, seed=7) + trainer = _make_trainer( + tmp_path, + dataset=dataset, + collator=collator, + sampler=sampler, + ) + + python_state = random.getstate() + numpy_state = np.random.get_state() + torch_state = torch.get_rng_state() + + py_rng = random.Random() + py_rng.setstate(python_state) + expected_python = py_rng.random() + + np_rng = np.random.RandomState() + np_rng.set_state(numpy_state) + expected_numpy = float(np_rng.rand()) + + torch_rng = torch.Generator() + torch_rng.set_state(torch_state.clone()) + expected_torch = float(torch.rand((), generator=torch_rng)) + + ckpt = tmp_path / "rng.pt" + trainer.save_checkpoint(ckpt) + + random.seed(999) + np.random.seed(999) + torch.manual_seed(999) + + trainer.load_checkpoint(ckpt) + assert random.random() == expected_python + assert float(np.random.rand()) == expected_numpy + assert float(torch.rand(())) == expected_torch + + def test_load_checkpoint_syncs_dataset_and_collator_epoch( + self, tmp_path: Path + ) -> None: + dataset = _EpochDataset() + collator = _EpochCollator() + sampler = ResumableSampler(dataset, seed=21) + sampler.start_new_epoch() + sampler.start_new_epoch() + + trainer = _make_trainer( + tmp_path, + dataset=dataset, + collator=collator, + sampler=sampler, + ) + ckpt = tmp_path / "epoch.pt" + trainer.save_checkpoint(ckpt) + + resumed_dataset = _EpochDataset() + resumed_collator = _EpochCollator() + resumed_sampler = ResumableSampler(resumed_dataset, seed=21) + resumed_trainer = _make_trainer( + tmp_path, + dataset=resumed_dataset, + collator=resumed_collator, + sampler=resumed_sampler, + ) + + resumed_trainer.load_checkpoint(ckpt) + + assert resumed_sampler.epoch == 2 + assert resumed_dataset.epoch_calls[-1] == 2 + assert resumed_collator.epoch_calls[-1] == 2 + + def test_resume_matches_uninterrupted_with_multiworker( + self, tmp_path: Path + ) -> None: + total_steps = 12 + split_step = 6 + seed = 2026 + num_workers = 2 + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + full_dir = tmp_path / "full" + full_dir.mkdir(parents=True, exist_ok=True) + full_trainer, full_recorder = _make_mlm_trainer( + full_dir, + max_steps=total_steps, + seed=seed, + num_workers=num_workers, + save_every=split_step, + ) + full_trainer.train() + ckpt = full_dir / f"step_{split_step}.pt" + assert ckpt.exists() + full_model = _clone_state_dict(full_trainer.model.state_dict()) + full_opt = _clone_state_dict(full_trainer.optimizer.state_dict()) + full_sched = _clone_state_dict(full_trainer.scheduler.state_dict()) # type: ignore[union-attr] + full_sampler = full_trainer.sampler.state_dict() # type: ignore[union-attr] + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + resume_dir = tmp_path / "resume" + resume_dir.mkdir(parents=True, exist_ok=True) + stage2_trainer, stage2_recorder = _make_mlm_trainer( + resume_dir, + max_steps=total_steps, + seed=seed, + num_workers=num_workers, + ) + stage2_trainer.load_checkpoint(ckpt) + stage2_trainer.train() + + assert stage2_recorder.losses == full_recorder.losses[split_step:] + _assert_nested_equal(stage2_trainer.model.state_dict(), full_model) + _assert_nested_equal(stage2_trainer.optimizer.state_dict(), full_opt) + _assert_nested_equal( + stage2_trainer.scheduler.state_dict(), # type: ignore[union-attr] + full_sched, + ) + assert stage2_trainer.sampler.state_dict() == full_sampler # type: ignore[union-attr] + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") + def test_resume_matches_uninterrupted_with_multiworker_cuda( + self, tmp_path: Path + ) -> None: + total_steps = 12 + split_step = 6 + seed = 2026 + num_workers = 2 + device = "cuda" + + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + full_dir = tmp_path / "full_cuda" + full_dir.mkdir(parents=True, exist_ok=True) + full_trainer, full_recorder = _make_mlm_trainer( + full_dir, + max_steps=total_steps, + seed=seed, + num_workers=num_workers, + save_every=split_step, + device=device, + ) + full_trainer.train() + ckpt = full_dir / f"step_{split_step}.pt" + assert ckpt.exists() + full_model = _clone_state_dict(full_trainer.model.state_dict()) + full_opt = _clone_state_dict(full_trainer.optimizer.state_dict()) + full_sched = _clone_state_dict(full_trainer.scheduler.state_dict()) # type: ignore[union-attr] + full_sampler = full_trainer.sampler.state_dict() # type: ignore[union-attr] + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + resume_dir = tmp_path / "resume_cuda" + resume_dir.mkdir(parents=True, exist_ok=True) + stage2_trainer, stage2_recorder = _make_mlm_trainer( + resume_dir, + max_steps=total_steps, + seed=seed, + num_workers=num_workers, + device=device, + ) + stage2_trainer.load_checkpoint(ckpt) + stage2_trainer.train() + + assert stage2_recorder.losses == full_recorder.losses[split_step:] + _assert_nested_equal(stage2_trainer.model.state_dict(), full_model) + _assert_nested_equal(stage2_trainer.optimizer.state_dict(), full_opt) + _assert_nested_equal( + stage2_trainer.scheduler.state_dict(), # type: ignore[union-attr] + full_sched, + ) + assert stage2_trainer.sampler.state_dict() == full_sampler # type: ignore[union-attr]