Skip to content

Fix the issue of the code to allow bit wise determinism#7

Open
nosyndicate wants to merge 1 commit intomainfrom
force_determinism
Open

Fix the issue of the code to allow bit wise determinism#7
nosyndicate wants to merge 1 commit intomainfrom
force_determinism

Conversation

@nosyndicate
Copy link
Owner

See title

Copilot AI review requested due to automatic review settings February 24, 2026 06:20
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to make training/checkpoint resume behavior bitwise-deterministic by persisting RNG state, syncing epoch state across sampler/dataset/collator, and removing nondeterminism from multi-worker data loading and MLM masking.

Changes:

  • Capture/restore Python/NumPy/Torch (and CUDA) RNG state in SimpleTrainer checkpoints and sync dataset/collator epoch from the ResumableSampler.
  • Add deterministic worker seeding + a seeded DataLoader generator via DataLoaderConfig.seed.
  • Make MLM masking deterministic per (sample_idx, epoch, seed) and propagate sample_idx from datasets; add tests covering determinism/resume.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
src/embedding_trainer/training/trainer.py Saves/restores RNG state in checkpoints; syncs dataset/collator epoch from sampler during training/resume.
src/embedding_trainer/data/loader.py Adds seed to loader config; always attaches a generator; adds worker init seeding.
src/embedding_trainer/data/sampler.py Exposes epoch property; adds input validation to advance().
src/embedding_trainer/data/collators/mlm.py Makes masking deterministic using (seed, epoch, sample_idx); adds set_epoch.
src/embedding_trainer/data/datasets/flat_tokens.py Adds sample_idx to returned samples to stabilize masking across shuffling/resume.
scripts/train.py Wires training seed into collator and dataloaders; initializes collator epoch from sampler.
tests/training/test_trainer.py New end-to-end determinism tests for RNG restore, epoch sync, and resume equivalence (CPU/CUDA).
tests/data/test_loader.py Tests that generator/worker init seeding are configured and deterministic.
tests/data/test_sampler.py Adds tests for epoch property and advance() validation errors.
tests/data/datasets/test_flat_tokens.py Validates sample_idx presence/dtype/shape.
tests/data/collators/test_mlm.py Updates samples to include sample_idx; adds deterministic masking tests.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

+ 1_013_904_223
) % modulus
x = (x * 48_271) % modulus
return x.to(torch.float64) / float(modulus)
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_deterministic_uniform() converts to float64 and divides to get probabilities. For “bitwise determinism” across devices/backends, consider avoiding floating-point here entirely (e.g., keep values as int64 and compare against an integer threshold derived from mlm_probability). This removes any dependency on float division/rounding differences.

Suggested change
return x.to(torch.float64) / float(modulus)
inv_modulus = 1.0 / float(modulus)
return x.to(torch.float64) * inv_modulus

Copilot uses AI. Check for mistakes.
Comment on lines +389 to +441
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
full_dir = tmp_path / "full_cuda"
full_dir.mkdir(parents=True, exist_ok=True)
full_trainer, full_recorder = _make_mlm_trainer(
full_dir,
max_steps=total_steps,
seed=seed,
num_workers=num_workers,
save_every=split_step,
device=device,
)
full_trainer.train()
ckpt = full_dir / f"step_{split_step}.pt"
assert ckpt.exists()
full_model = _clone_state_dict(full_trainer.model.state_dict())
full_opt = _clone_state_dict(full_trainer.optimizer.state_dict())
full_sched = _clone_state_dict(full_trainer.scheduler.state_dict()) # type: ignore[union-attr]
full_sampler = full_trainer.sampler.state_dict() # type: ignore[union-attr]

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
resume_dir = tmp_path / "resume_cuda"
resume_dir.mkdir(parents=True, exist_ok=True)
stage2_trainer, stage2_recorder = _make_mlm_trainer(
resume_dir,
max_steps=total_steps,
seed=seed,
num_workers=num_workers,
device=device,
)
stage2_trainer.load_checkpoint(ckpt)
stage2_trainer.train()

assert stage2_recorder.losses == full_recorder.losses[split_step:]
_assert_nested_equal(stage2_trainer.model.state_dict(), full_model)
_assert_nested_equal(stage2_trainer.optimizer.state_dict(), full_opt)
_assert_nested_equal(
stage2_trainer.scheduler.state_dict(), # type: ignore[union-attr]
full_sched,
)
assert stage2_trainer.sampler.state_dict() == full_sampler # type: ignore[union-attr]
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test mutates global CUDA backend flags (cudnn deterministic/benchmark and TF32 toggles) and global RNG seeds without restoring them, which can leak into other tests in the same session. Consider saving the previous values/states and restoring them in a finally block (or via a fixture) to keep tests isolated.

Suggested change
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
full_dir = tmp_path / "full_cuda"
full_dir.mkdir(parents=True, exist_ok=True)
full_trainer, full_recorder = _make_mlm_trainer(
full_dir,
max_steps=total_steps,
seed=seed,
num_workers=num_workers,
save_every=split_step,
device=device,
)
full_trainer.train()
ckpt = full_dir / f"step_{split_step}.pt"
assert ckpt.exists()
full_model = _clone_state_dict(full_trainer.model.state_dict())
full_opt = _clone_state_dict(full_trainer.optimizer.state_dict())
full_sched = _clone_state_dict(full_trainer.scheduler.state_dict()) # type: ignore[union-attr]
full_sampler = full_trainer.sampler.state_dict() # type: ignore[union-attr]
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
resume_dir = tmp_path / "resume_cuda"
resume_dir.mkdir(parents=True, exist_ok=True)
stage2_trainer, stage2_recorder = _make_mlm_trainer(
resume_dir,
max_steps=total_steps,
seed=seed,
num_workers=num_workers,
device=device,
)
stage2_trainer.load_checkpoint(ckpt)
stage2_trainer.train()
assert stage2_recorder.losses == full_recorder.losses[split_step:]
_assert_nested_equal(stage2_trainer.model.state_dict(), full_model)
_assert_nested_equal(stage2_trainer.optimizer.state_dict(), full_opt)
_assert_nested_equal(
stage2_trainer.scheduler.state_dict(), # type: ignore[union-attr]
full_sched,
)
assert stage2_trainer.sampler.state_dict() == full_sampler # type: ignore[union-attr]
# Save previous backend flags and RNG states to restore after the test.
prev_cudnn_deterministic = torch.backends.cudnn.deterministic
prev_cudnn_benchmark = torch.backends.cudnn.benchmark
prev_cuda_matmul_allow_tf32 = torch.backends.cuda.matmul.allow_tf32
prev_cudnn_allow_tf32 = torch.backends.cudnn.allow_tf32
python_state = random.getstate()
numpy_state = np.random.get_state()
torch_state = torch.get_rng_state()
cuda_states = torch.cuda.get_rng_state_all()
try:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
full_dir = tmp_path / "full_cuda"
full_dir.mkdir(parents=True, exist_ok=True)
full_trainer, full_recorder = _make_mlm_trainer(
full_dir,
max_steps=total_steps,
seed=seed,
num_workers=num_workers,
save_every=split_step,
device=device,
)
full_trainer.train()
ckpt = full_dir / f"step_{split_step}.pt"
assert ckpt.exists()
full_model = _clone_state_dict(full_trainer.model.state_dict())
full_opt = _clone_state_dict(full_trainer.optimizer.state_dict())
full_sched = _clone_state_dict(full_trainer.scheduler.state_dict()) # type: ignore[union-attr]
full_sampler = full_trainer.sampler.state_dict() # type: ignore[union-attr]
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
resume_dir = tmp_path / "resume_cuda"
resume_dir.mkdir(parents=True, exist_ok=True)
stage2_trainer, stage2_recorder = _make_mlm_trainer(
resume_dir,
max_steps=total_steps,
seed=seed,
num_workers=num_workers,
device=device,
)
stage2_trainer.load_checkpoint(ckpt)
stage2_trainer.train()
assert stage2_recorder.losses == full_recorder.losses[split_step:]
_assert_nested_equal(stage2_trainer.model.state_dict(), full_model)
_assert_nested_equal(stage2_trainer.optimizer.state_dict(), full_opt)
_assert_nested_equal(
stage2_trainer.scheduler.state_dict(), # type: ignore[union-attr]
full_sched,
)
assert stage2_trainer.sampler.state_dict() == full_sampler # type: ignore[union-attr]
finally:
# Restore backend flags and RNG states to avoid leaking into other tests.
torch.backends.cudnn.deterministic = prev_cudnn_deterministic
torch.backends.cudnn.benchmark = prev_cudnn_benchmark
torch.backends.cuda.matmul.allow_tf32 = prev_cuda_matmul_allow_tf32
torch.backends.cudnn.allow_tf32 = prev_cudnn_allow_tf32
random.setstate(python_state)
np.random.set_state(numpy_state)
torch.set_rng_state(torch_state)
for device_idx, state in enumerate(cuda_states):
torch.cuda.set_rng_state(state, device_idx)

Copilot uses AI. Check for mistakes.
Comment on lines 93 to +98
self._start_index += n
if self._start_index > self._num_samples:
raise ValueError(
"ResumableSampler.advance moved past end of epoch "
f"({self._start_index} > {self._num_samples})."
)
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ResumableSampler.advance() mutates _start_index before validating the new value. If this raises (e.g., n too large), the sampler is left in an invalid state, which can make debugging and error recovery harder. Consider validating the proposed new index before assignment (or rolling back on error).

Suggested change
self._start_index += n
if self._start_index > self._num_samples:
raise ValueError(
"ResumableSampler.advance moved past end of epoch "
f"({self._start_index} > {self._num_samples})."
)
new_index = self._start_index + n
if new_index > self._num_samples:
raise ValueError(
"ResumableSampler.advance moved past end of epoch "
f"({new_index} > {self._num_samples})."
)
self._start_index = new_index

Copilot uses AI. Check for mistakes.
Comment on lines 19 to 63
@dataclass
class DataLoaderConfig:
"""Configuration for DataLoader."""

batch_size: int = 32
num_workers: int = 4
prefetch_factor: int = 2
pin_memory: bool = True
drop_last: bool = True
shuffle: bool = True
seed: int = 0

def __post_init__(self) -> None:
if self.batch_size < 1:
raise ValueError(f"batch_size must be >= 1, got {self.batch_size}")
if self.num_workers < 0:
raise ValueError(f"num_workers must be >= 0, got {self.num_workers}")
if self.num_workers > 0 and self.prefetch_factor < 1:
raise ValueError(
"prefetch_factor must be >= 1 when num_workers > 0, "
f"got {self.prefetch_factor}"
)


def _build_loader_kwargs(
config: DataLoaderConfig,
collator: CollatorProtocol | Callable[[list[dict[str, Any]]], Any],
) -> dict[str, Any]:
"""Build common DataLoader kwargs."""
loader_generator = torch.Generator()
loader_generator.manual_seed(config.seed)

loader_kwargs: dict[str, Any] = {
"batch_size": config.batch_size,
"collate_fn": collator,
"num_workers": config.num_workers,
"pin_memory": config.pin_memory and torch.cuda.is_available(),
"drop_last": config.drop_last,
"generator": loader_generator,
}

if config.num_workers > 0:
loader_kwargs["prefetch_factor"] = config.prefetch_factor
loader_kwargs["worker_init_fn"] = _build_worker_init_fn(config.seed)

Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataLoaderConfig.seed defaults to 0 and _build_loader_kwargs always attaches a seeded torch.Generator. This changes default behavior: callers that previously relied on non-deterministic shuffling (by not providing a seed) will now get the same shuffle order every run. Consider making seed Optional (default None) and only passing a generator/worker_init_fn when a seed is explicitly provided; also ensure the same seed is applied consistently in distributed sampling (DistributedSampler has its own seed parameter).

Copilot uses AI. Check for mistakes.
Comment on lines +62 to +76
loader_kwargs["worker_init_fn"] = _build_worker_init_fn(config.seed)

return loader_kwargs


def _seed_worker(worker_id: int, seed: int) -> None:
worker_seed = seed + worker_id
random.seed(worker_seed)
np.random.seed(worker_seed % (2**32))
torch.manual_seed(worker_seed)


def _build_worker_init_fn(seed: int) -> Callable[[int], None]:
"""Build a worker seeding function for deterministic worker-local RNG."""
return partial(_seed_worker, seed=seed)
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The custom worker_init_fn seeds workers with seed + worker_id, which is identical across epochs/iterations and also ignores DataLoader’s base seed (derived from the loader generator). This can unintentionally make any per-sample randomness inside datasets/collators repeat every epoch. A more deterministic + epoch-varying approach is to derive the seed from torch.initial_seed() inside the worker (then seed Python/NumPy from that).

Suggested change
loader_kwargs["worker_init_fn"] = _build_worker_init_fn(config.seed)
return loader_kwargs
def _seed_worker(worker_id: int, seed: int) -> None:
worker_seed = seed + worker_id
random.seed(worker_seed)
np.random.seed(worker_seed % (2**32))
torch.manual_seed(worker_seed)
def _build_worker_init_fn(seed: int) -> Callable[[int], None]:
"""Build a worker seeding function for deterministic worker-local RNG."""
return partial(_seed_worker, seed=seed)
loader_kwargs["worker_init_fn"] = _build_worker_init_fn()
return loader_kwargs
def _seed_worker(worker_id: int) -> None:
"""Seed Python, NumPy, and PyTorch RNGs for a worker based on torch.initial_seed()."""
worker_seed = torch.initial_seed()
random.seed(worker_seed)
np.random.seed(worker_seed % (2**32))
torch.manual_seed(worker_seed)
def _build_worker_init_fn() -> Callable[[int], None]:
"""Build a worker seeding function for deterministic worker-local RNG."""
return _seed_worker

Copilot uses AI. Check for mistakes.
Comment on lines 159 to 167
def save_checkpoint(self, path: str | Path) -> None:
checkpoint = {
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"global_step": self.global_step,
"cumulative_loss_sum": self.cumulative_loss_sum,
"cumulative_loss_steps": self.cumulative_loss_steps,
**self._capture_rng_state(),
}
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_checkpoint()/load_checkpoint restore Python/NumPy/Torch RNG, but they don’t capture/restore the DataLoader generator state (DataLoaderConfig now always attaches one). If training uses DataLoader shuffling without ResumableSampler, resuming from a checkpoint can still diverge in data order. Consider persisting loader.generator.get_state() (when present) and restoring it on load to fully support bitwise-deterministic resume.

Copilot uses AI. Check for mistakes.
Comment on lines +208 to +223
def _capture_rng_state(self) -> dict[str, Any]:
numpy_state = np.random.get_state()
state: dict[str, Any] = {
"python_rng_state": random.getstate(),
"numpy_rng_state": {
"bit_generator": numpy_state[0],
"state": numpy_state[1].tolist(),
"pos": int(numpy_state[2]),
"has_gauss": int(numpy_state[3]),
"cached_gaussian": float(numpy_state[4]),
},
"torch_rng_state": torch.get_rng_state(),
}
if torch.cuda.is_available():
state["torch_cuda_rng_state_all"] = torch.cuda.get_rng_state_all()
return state
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_capture_rng_state() calls torch.cuda.get_rng_state_all() whenever CUDA is available, even if the trainer/device is CPU. This can initialize a CUDA context and introduce overhead or failures in CPU-only workflows on GPU hosts. Consider capturing/restoring CUDA RNG only when actually training on CUDA (e.g., self.device starts with 'cuda' or model parameters are on CUDA).

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants