-
Notifications
You must be signed in to change notification settings - Fork 0
Fix the issue of the code to allow bit wise determinism #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,10 +2,13 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import random | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from collections.abc import Callable | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from dataclasses import dataclass | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from functools import partial | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torch.utils.data import DataLoader, Dataset, Sampler | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torch.utils.data.distributed import DistributedSampler | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -23,6 +26,7 @@ class DataLoaderConfig: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pin_memory: bool = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| drop_last: bool = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shuffle: bool = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| seed: int = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __post_init__(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.batch_size < 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -41,20 +45,37 @@ def _build_loader_kwargs( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| collator: CollatorProtocol | Callable[[list[dict[str, Any]]], Any], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> dict[str, Any]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Build common DataLoader kwargs.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loader_generator = torch.Generator() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loader_generator.manual_seed(config.seed) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loader_kwargs: dict[str, Any] = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "batch_size": config.batch_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "collate_fn": collator, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "num_workers": config.num_workers, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "pin_memory": config.pin_memory and torch.cuda.is_available(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "drop_last": config.drop_last, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "generator": loader_generator, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if config.num_workers > 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loader_kwargs["prefetch_factor"] = config.prefetch_factor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loader_kwargs["worker_init_fn"] = _build_worker_init_fn(config.seed) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return loader_kwargs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _seed_worker(worker_id: int, seed: int) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| worker_seed = seed + worker_id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| random.seed(worker_seed) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| np.random.seed(worker_seed % (2**32)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.manual_seed(worker_seed) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _build_worker_init_fn(seed: int) -> Callable[[int], None]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Build a worker seeding function for deterministic worker-local RNG.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return partial(_seed_worker, seed=seed) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+62
to
+76
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loader_kwargs["worker_init_fn"] = _build_worker_init_fn(config.seed) | |
| return loader_kwargs | |
| def _seed_worker(worker_id: int, seed: int) -> None: | |
| worker_seed = seed + worker_id | |
| random.seed(worker_seed) | |
| np.random.seed(worker_seed % (2**32)) | |
| torch.manual_seed(worker_seed) | |
| def _build_worker_init_fn(seed: int) -> Callable[[int], None]: | |
| """Build a worker seeding function for deterministic worker-local RNG.""" | |
| return partial(_seed_worker, seed=seed) | |
| loader_kwargs["worker_init_fn"] = _build_worker_init_fn() | |
| return loader_kwargs | |
| def _seed_worker(worker_id: int) -> None: | |
| """Seed Python, NumPy, and PyTorch RNGs for a worker based on torch.initial_seed().""" | |
| worker_seed = torch.initial_seed() | |
| random.seed(worker_seed) | |
| np.random.seed(worker_seed % (2**32)) | |
| torch.manual_seed(worker_seed) | |
| def _build_worker_init_fn() -> Callable[[int], None]: | |
| """Build a worker seeding function for deterministic worker-local RNG.""" | |
| return _seed_worker |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -41,6 +41,11 @@ def _generate_permutation(self) -> None: | |||||||||||||||||||||||||||
| def __len__(self) -> int: | ||||||||||||||||||||||||||||
| return self._num_samples | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||||
| def epoch(self) -> int: | ||||||||||||||||||||||||||||
| """Current sampler epoch.""" | ||||||||||||||||||||||||||||
| return self._epoch | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def __iter__(self) -> Iterator[int]: | ||||||||||||||||||||||||||||
| # Yield remaining indices from the current epoch. | ||||||||||||||||||||||||||||
| # No state mutation here — advance() and start_new_epoch() handle that. | ||||||||||||||||||||||||||||
|
|
@@ -83,4 +88,11 @@ def start_new_epoch(self) -> None: | |||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def advance(self, n: int) -> None: | ||||||||||||||||||||||||||||
| """Advance the position by *n* samples (call once per batch).""" | ||||||||||||||||||||||||||||
| if n < 0: | ||||||||||||||||||||||||||||
| raise ValueError(f"advance expects non-negative n, got {n}") | ||||||||||||||||||||||||||||
| self._start_index += n | ||||||||||||||||||||||||||||
| if self._start_index > self._num_samples: | ||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||
| "ResumableSampler.advance moved past end of epoch " | ||||||||||||||||||||||||||||
| f"({self._start_index} > {self._num_samples})." | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
Comment on lines
93
to
+98
|
||||||||||||||||||||||||||||
| self._start_index += n | |
| if self._start_index > self._num_samples: | |
| raise ValueError( | |
| "ResumableSampler.advance moved past end of epoch " | |
| f"({self._start_index} > {self._num_samples})." | |
| ) | |
| new_index = self._start_index + n | |
| if new_index > self._num_samples: | |
| raise ValueError( | |
| "ResumableSampler.advance moved past end of epoch " | |
| f"({new_index} > {self._num_samples})." | |
| ) | |
| self._start_index = new_index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_deterministic_uniform() converts to float64 and divides to get probabilities. For “bitwise determinism” across devices/backends, consider avoiding floating-point here entirely (e.g., keep values as int64 and compare against an integer threshold derived from mlm_probability). This removes any dependency on float division/rounding differences.