Fix the issue of the code to allow bit wise determinism#7
Fix the issue of the code to allow bit wise determinism#7nosyndicate wants to merge 1 commit intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR aims to make training/checkpoint resume behavior bitwise-deterministic by persisting RNG state, syncing epoch state across sampler/dataset/collator, and removing nondeterminism from multi-worker data loading and MLM masking.
Changes:
- Capture/restore Python/NumPy/Torch (and CUDA) RNG state in
SimpleTrainercheckpoints and sync dataset/collator epoch from theResumableSampler. - Add deterministic worker seeding + a seeded DataLoader generator via
DataLoaderConfig.seed. - Make MLM masking deterministic per
(sample_idx, epoch, seed)and propagatesample_idxfrom datasets; add tests covering determinism/resume.
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
src/embedding_trainer/training/trainer.py |
Saves/restores RNG state in checkpoints; syncs dataset/collator epoch from sampler during training/resume. |
src/embedding_trainer/data/loader.py |
Adds seed to loader config; always attaches a generator; adds worker init seeding. |
src/embedding_trainer/data/sampler.py |
Exposes epoch property; adds input validation to advance(). |
src/embedding_trainer/data/collators/mlm.py |
Makes masking deterministic using (seed, epoch, sample_idx); adds set_epoch. |
src/embedding_trainer/data/datasets/flat_tokens.py |
Adds sample_idx to returned samples to stabilize masking across shuffling/resume. |
scripts/train.py |
Wires training seed into collator and dataloaders; initializes collator epoch from sampler. |
tests/training/test_trainer.py |
New end-to-end determinism tests for RNG restore, epoch sync, and resume equivalence (CPU/CUDA). |
tests/data/test_loader.py |
Tests that generator/worker init seeding are configured and deterministic. |
tests/data/test_sampler.py |
Adds tests for epoch property and advance() validation errors. |
tests/data/datasets/test_flat_tokens.py |
Validates sample_idx presence/dtype/shape. |
tests/data/collators/test_mlm.py |
Updates samples to include sample_idx; adds deterministic masking tests. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| + 1_013_904_223 | ||
| ) % modulus | ||
| x = (x * 48_271) % modulus | ||
| return x.to(torch.float64) / float(modulus) |
There was a problem hiding this comment.
_deterministic_uniform() converts to float64 and divides to get probabilities. For “bitwise determinism” across devices/backends, consider avoiding floating-point here entirely (e.g., keep values as int64 and compare against an integer threshold derived from mlm_probability). This removes any dependency on float division/rounding differences.
| return x.to(torch.float64) / float(modulus) | |
| inv_modulus = 1.0 / float(modulus) | |
| return x.to(torch.float64) * inv_modulus |
| torch.backends.cudnn.deterministic = True | ||
| torch.backends.cudnn.benchmark = False | ||
| torch.backends.cuda.matmul.allow_tf32 = False | ||
| torch.backends.cudnn.allow_tf32 = False | ||
|
|
||
| random.seed(seed) | ||
| np.random.seed(seed) | ||
| torch.manual_seed(seed) | ||
| torch.cuda.manual_seed(seed) | ||
| torch.cuda.manual_seed_all(seed) | ||
| full_dir = tmp_path / "full_cuda" | ||
| full_dir.mkdir(parents=True, exist_ok=True) | ||
| full_trainer, full_recorder = _make_mlm_trainer( | ||
| full_dir, | ||
| max_steps=total_steps, | ||
| seed=seed, | ||
| num_workers=num_workers, | ||
| save_every=split_step, | ||
| device=device, | ||
| ) | ||
| full_trainer.train() | ||
| ckpt = full_dir / f"step_{split_step}.pt" | ||
| assert ckpt.exists() | ||
| full_model = _clone_state_dict(full_trainer.model.state_dict()) | ||
| full_opt = _clone_state_dict(full_trainer.optimizer.state_dict()) | ||
| full_sched = _clone_state_dict(full_trainer.scheduler.state_dict()) # type: ignore[union-attr] | ||
| full_sampler = full_trainer.sampler.state_dict() # type: ignore[union-attr] | ||
|
|
||
| random.seed(seed) | ||
| np.random.seed(seed) | ||
| torch.manual_seed(seed) | ||
| torch.cuda.manual_seed(seed) | ||
| torch.cuda.manual_seed_all(seed) | ||
| resume_dir = tmp_path / "resume_cuda" | ||
| resume_dir.mkdir(parents=True, exist_ok=True) | ||
| stage2_trainer, stage2_recorder = _make_mlm_trainer( | ||
| resume_dir, | ||
| max_steps=total_steps, | ||
| seed=seed, | ||
| num_workers=num_workers, | ||
| device=device, | ||
| ) | ||
| stage2_trainer.load_checkpoint(ckpt) | ||
| stage2_trainer.train() | ||
|
|
||
| assert stage2_recorder.losses == full_recorder.losses[split_step:] | ||
| _assert_nested_equal(stage2_trainer.model.state_dict(), full_model) | ||
| _assert_nested_equal(stage2_trainer.optimizer.state_dict(), full_opt) | ||
| _assert_nested_equal( | ||
| stage2_trainer.scheduler.state_dict(), # type: ignore[union-attr] | ||
| full_sched, | ||
| ) | ||
| assert stage2_trainer.sampler.state_dict() == full_sampler # type: ignore[union-attr] |
There was a problem hiding this comment.
This test mutates global CUDA backend flags (cudnn deterministic/benchmark and TF32 toggles) and global RNG seeds without restoring them, which can leak into other tests in the same session. Consider saving the previous values/states and restoring them in a finally block (or via a fixture) to keep tests isolated.
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cuda.matmul.allow_tf32 = False | |
| torch.backends.cudnn.allow_tf32 = False | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| full_dir = tmp_path / "full_cuda" | |
| full_dir.mkdir(parents=True, exist_ok=True) | |
| full_trainer, full_recorder = _make_mlm_trainer( | |
| full_dir, | |
| max_steps=total_steps, | |
| seed=seed, | |
| num_workers=num_workers, | |
| save_every=split_step, | |
| device=device, | |
| ) | |
| full_trainer.train() | |
| ckpt = full_dir / f"step_{split_step}.pt" | |
| assert ckpt.exists() | |
| full_model = _clone_state_dict(full_trainer.model.state_dict()) | |
| full_opt = _clone_state_dict(full_trainer.optimizer.state_dict()) | |
| full_sched = _clone_state_dict(full_trainer.scheduler.state_dict()) # type: ignore[union-attr] | |
| full_sampler = full_trainer.sampler.state_dict() # type: ignore[union-attr] | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| resume_dir = tmp_path / "resume_cuda" | |
| resume_dir.mkdir(parents=True, exist_ok=True) | |
| stage2_trainer, stage2_recorder = _make_mlm_trainer( | |
| resume_dir, | |
| max_steps=total_steps, | |
| seed=seed, | |
| num_workers=num_workers, | |
| device=device, | |
| ) | |
| stage2_trainer.load_checkpoint(ckpt) | |
| stage2_trainer.train() | |
| assert stage2_recorder.losses == full_recorder.losses[split_step:] | |
| _assert_nested_equal(stage2_trainer.model.state_dict(), full_model) | |
| _assert_nested_equal(stage2_trainer.optimizer.state_dict(), full_opt) | |
| _assert_nested_equal( | |
| stage2_trainer.scheduler.state_dict(), # type: ignore[union-attr] | |
| full_sched, | |
| ) | |
| assert stage2_trainer.sampler.state_dict() == full_sampler # type: ignore[union-attr] | |
| # Save previous backend flags and RNG states to restore after the test. | |
| prev_cudnn_deterministic = torch.backends.cudnn.deterministic | |
| prev_cudnn_benchmark = torch.backends.cudnn.benchmark | |
| prev_cuda_matmul_allow_tf32 = torch.backends.cuda.matmul.allow_tf32 | |
| prev_cudnn_allow_tf32 = torch.backends.cudnn.allow_tf32 | |
| python_state = random.getstate() | |
| numpy_state = np.random.get_state() | |
| torch_state = torch.get_rng_state() | |
| cuda_states = torch.cuda.get_rng_state_all() | |
| try: | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cuda.matmul.allow_tf32 = False | |
| torch.backends.cudnn.allow_tf32 = False | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| full_dir = tmp_path / "full_cuda" | |
| full_dir.mkdir(parents=True, exist_ok=True) | |
| full_trainer, full_recorder = _make_mlm_trainer( | |
| full_dir, | |
| max_steps=total_steps, | |
| seed=seed, | |
| num_workers=num_workers, | |
| save_every=split_step, | |
| device=device, | |
| ) | |
| full_trainer.train() | |
| ckpt = full_dir / f"step_{split_step}.pt" | |
| assert ckpt.exists() | |
| full_model = _clone_state_dict(full_trainer.model.state_dict()) | |
| full_opt = _clone_state_dict(full_trainer.optimizer.state_dict()) | |
| full_sched = _clone_state_dict(full_trainer.scheduler.state_dict()) # type: ignore[union-attr] | |
| full_sampler = full_trainer.sampler.state_dict() # type: ignore[union-attr] | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| resume_dir = tmp_path / "resume_cuda" | |
| resume_dir.mkdir(parents=True, exist_ok=True) | |
| stage2_trainer, stage2_recorder = _make_mlm_trainer( | |
| resume_dir, | |
| max_steps=total_steps, | |
| seed=seed, | |
| num_workers=num_workers, | |
| device=device, | |
| ) | |
| stage2_trainer.load_checkpoint(ckpt) | |
| stage2_trainer.train() | |
| assert stage2_recorder.losses == full_recorder.losses[split_step:] | |
| _assert_nested_equal(stage2_trainer.model.state_dict(), full_model) | |
| _assert_nested_equal(stage2_trainer.optimizer.state_dict(), full_opt) | |
| _assert_nested_equal( | |
| stage2_trainer.scheduler.state_dict(), # type: ignore[union-attr] | |
| full_sched, | |
| ) | |
| assert stage2_trainer.sampler.state_dict() == full_sampler # type: ignore[union-attr] | |
| finally: | |
| # Restore backend flags and RNG states to avoid leaking into other tests. | |
| torch.backends.cudnn.deterministic = prev_cudnn_deterministic | |
| torch.backends.cudnn.benchmark = prev_cudnn_benchmark | |
| torch.backends.cuda.matmul.allow_tf32 = prev_cuda_matmul_allow_tf32 | |
| torch.backends.cudnn.allow_tf32 = prev_cudnn_allow_tf32 | |
| random.setstate(python_state) | |
| np.random.set_state(numpy_state) | |
| torch.set_rng_state(torch_state) | |
| for device_idx, state in enumerate(cuda_states): | |
| torch.cuda.set_rng_state(state, device_idx) |
| self._start_index += n | ||
| if self._start_index > self._num_samples: | ||
| raise ValueError( | ||
| "ResumableSampler.advance moved past end of epoch " | ||
| f"({self._start_index} > {self._num_samples})." | ||
| ) |
There was a problem hiding this comment.
ResumableSampler.advance() mutates _start_index before validating the new value. If this raises (e.g., n too large), the sampler is left in an invalid state, which can make debugging and error recovery harder. Consider validating the proposed new index before assignment (or rolling back on error).
| self._start_index += n | |
| if self._start_index > self._num_samples: | |
| raise ValueError( | |
| "ResumableSampler.advance moved past end of epoch " | |
| f"({self._start_index} > {self._num_samples})." | |
| ) | |
| new_index = self._start_index + n | |
| if new_index > self._num_samples: | |
| raise ValueError( | |
| "ResumableSampler.advance moved past end of epoch " | |
| f"({new_index} > {self._num_samples})." | |
| ) | |
| self._start_index = new_index |
| @dataclass | ||
| class DataLoaderConfig: | ||
| """Configuration for DataLoader.""" | ||
|
|
||
| batch_size: int = 32 | ||
| num_workers: int = 4 | ||
| prefetch_factor: int = 2 | ||
| pin_memory: bool = True | ||
| drop_last: bool = True | ||
| shuffle: bool = True | ||
| seed: int = 0 | ||
|
|
||
| def __post_init__(self) -> None: | ||
| if self.batch_size < 1: | ||
| raise ValueError(f"batch_size must be >= 1, got {self.batch_size}") | ||
| if self.num_workers < 0: | ||
| raise ValueError(f"num_workers must be >= 0, got {self.num_workers}") | ||
| if self.num_workers > 0 and self.prefetch_factor < 1: | ||
| raise ValueError( | ||
| "prefetch_factor must be >= 1 when num_workers > 0, " | ||
| f"got {self.prefetch_factor}" | ||
| ) | ||
|
|
||
|
|
||
| def _build_loader_kwargs( | ||
| config: DataLoaderConfig, | ||
| collator: CollatorProtocol | Callable[[list[dict[str, Any]]], Any], | ||
| ) -> dict[str, Any]: | ||
| """Build common DataLoader kwargs.""" | ||
| loader_generator = torch.Generator() | ||
| loader_generator.manual_seed(config.seed) | ||
|
|
||
| loader_kwargs: dict[str, Any] = { | ||
| "batch_size": config.batch_size, | ||
| "collate_fn": collator, | ||
| "num_workers": config.num_workers, | ||
| "pin_memory": config.pin_memory and torch.cuda.is_available(), | ||
| "drop_last": config.drop_last, | ||
| "generator": loader_generator, | ||
| } | ||
|
|
||
| if config.num_workers > 0: | ||
| loader_kwargs["prefetch_factor"] = config.prefetch_factor | ||
| loader_kwargs["worker_init_fn"] = _build_worker_init_fn(config.seed) | ||
|
|
There was a problem hiding this comment.
DataLoaderConfig.seed defaults to 0 and _build_loader_kwargs always attaches a seeded torch.Generator. This changes default behavior: callers that previously relied on non-deterministic shuffling (by not providing a seed) will now get the same shuffle order every run. Consider making seed Optional (default None) and only passing a generator/worker_init_fn when a seed is explicitly provided; also ensure the same seed is applied consistently in distributed sampling (DistributedSampler has its own seed parameter).
| loader_kwargs["worker_init_fn"] = _build_worker_init_fn(config.seed) | ||
|
|
||
| return loader_kwargs | ||
|
|
||
|
|
||
| def _seed_worker(worker_id: int, seed: int) -> None: | ||
| worker_seed = seed + worker_id | ||
| random.seed(worker_seed) | ||
| np.random.seed(worker_seed % (2**32)) | ||
| torch.manual_seed(worker_seed) | ||
|
|
||
|
|
||
| def _build_worker_init_fn(seed: int) -> Callable[[int], None]: | ||
| """Build a worker seeding function for deterministic worker-local RNG.""" | ||
| return partial(_seed_worker, seed=seed) |
There was a problem hiding this comment.
The custom worker_init_fn seeds workers with seed + worker_id, which is identical across epochs/iterations and also ignores DataLoader’s base seed (derived from the loader generator). This can unintentionally make any per-sample randomness inside datasets/collators repeat every epoch. A more deterministic + epoch-varying approach is to derive the seed from torch.initial_seed() inside the worker (then seed Python/NumPy from that).
| loader_kwargs["worker_init_fn"] = _build_worker_init_fn(config.seed) | |
| return loader_kwargs | |
| def _seed_worker(worker_id: int, seed: int) -> None: | |
| worker_seed = seed + worker_id | |
| random.seed(worker_seed) | |
| np.random.seed(worker_seed % (2**32)) | |
| torch.manual_seed(worker_seed) | |
| def _build_worker_init_fn(seed: int) -> Callable[[int], None]: | |
| """Build a worker seeding function for deterministic worker-local RNG.""" | |
| return partial(_seed_worker, seed=seed) | |
| loader_kwargs["worker_init_fn"] = _build_worker_init_fn() | |
| return loader_kwargs | |
| def _seed_worker(worker_id: int) -> None: | |
| """Seed Python, NumPy, and PyTorch RNGs for a worker based on torch.initial_seed().""" | |
| worker_seed = torch.initial_seed() | |
| random.seed(worker_seed) | |
| np.random.seed(worker_seed % (2**32)) | |
| torch.manual_seed(worker_seed) | |
| def _build_worker_init_fn() -> Callable[[int], None]: | |
| """Build a worker seeding function for deterministic worker-local RNG.""" | |
| return _seed_worker |
| def save_checkpoint(self, path: str | Path) -> None: | ||
| checkpoint = { | ||
| "model_state_dict": self.model.state_dict(), | ||
| "optimizer_state_dict": self.optimizer.state_dict(), | ||
| "global_step": self.global_step, | ||
| "cumulative_loss_sum": self.cumulative_loss_sum, | ||
| "cumulative_loss_steps": self.cumulative_loss_steps, | ||
| **self._capture_rng_state(), | ||
| } |
There was a problem hiding this comment.
save_checkpoint()/load_checkpoint restore Python/NumPy/Torch RNG, but they don’t capture/restore the DataLoader generator state (DataLoaderConfig now always attaches one). If training uses DataLoader shuffling without ResumableSampler, resuming from a checkpoint can still diverge in data order. Consider persisting loader.generator.get_state() (when present) and restoring it on load to fully support bitwise-deterministic resume.
| def _capture_rng_state(self) -> dict[str, Any]: | ||
| numpy_state = np.random.get_state() | ||
| state: dict[str, Any] = { | ||
| "python_rng_state": random.getstate(), | ||
| "numpy_rng_state": { | ||
| "bit_generator": numpy_state[0], | ||
| "state": numpy_state[1].tolist(), | ||
| "pos": int(numpy_state[2]), | ||
| "has_gauss": int(numpy_state[3]), | ||
| "cached_gaussian": float(numpy_state[4]), | ||
| }, | ||
| "torch_rng_state": torch.get_rng_state(), | ||
| } | ||
| if torch.cuda.is_available(): | ||
| state["torch_cuda_rng_state_all"] = torch.cuda.get_rng_state_all() | ||
| return state |
There was a problem hiding this comment.
_capture_rng_state() calls torch.cuda.get_rng_state_all() whenever CUDA is available, even if the trainer/device is CPU. This can initialize a CUDA context and introduce overhead or failures in CPU-only workflows on GPU hosts. Consider capturing/restoring CUDA RNG only when actually training on CUDA (e.g., self.device starts with 'cuda' or model parameters are on CUDA).
See title