Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@ def main(cfg: DictConfig) -> None:
)

collator_class = COLLATOR_REGISTRY.get(cfg.collator.name)
collator = collator_class(MLMCollatorConfig.from_tokenizer(tokenizer))
collator = collator_class(
MLMCollatorConfig.from_tokenizer(
tokenizer,
seed=cfg.training.seed,
mlm_probability=cfg.collator.mlm_probability,
epoch=0,
)
)

task = TASK_REGISTRY.get(cfg.task.name)()
model = MODEL_REGISTRY.get(cfg.model.name)(
Expand All @@ -86,6 +93,8 @@ def main(cfg: DictConfig) -> None:
callbacks.append(cb_cls(**cb_kwargs))

sampler = ResumableSampler(dataset, seed=cfg.training.seed)
if hasattr(collator, "set_epoch"):
collator.set_epoch(sampler.epoch)

loader = create_dataloader(
dataset=dataset,
Expand All @@ -96,6 +105,7 @@ def main(cfg: DictConfig) -> None:
prefetch_factor=cfg.training.prefetch_factor,
pin_memory=torch.cuda.is_available(),
drop_last=True,
seed=cfg.training.seed,
),
sampler=sampler,
)
Expand All @@ -118,6 +128,7 @@ def main(cfg: DictConfig) -> None:
pin_memory=torch.cuda.is_available(),
drop_last=False,
shuffle=False,
seed=cfg.training.seed,
),
)

Expand Down
105 changes: 93 additions & 12 deletions src/embedding_trainer/data/collators/mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,27 @@ class MLMCollatorConfig(BaseCollatorConfig):
"""Configuration for MLM collator."""

mlm_probability: float = 0.30 # Probability of replacing with [MASK]
seed: int = 0
epoch: int = 0

@staticmethod
def from_tokenizer(tokenizer: PreTrainedTokenizerBase) -> MLMCollatorConfig:
def from_tokenizer(
tokenizer: PreTrainedTokenizerBase,
*,
seed: int = 0,
mlm_probability: float = 0.30,
epoch: int = 0,
) -> MLMCollatorConfig:
"""Create config from a Hugging Face Tokenizer."""
return MLMCollatorConfig(
pad_token_id=tokenizer.pad_token_id,
cls_token_id=tokenizer.cls_token_id,
sep_token_id=tokenizer.sep_token_id,
mask_token_id=tokenizer.mask_token_id,
vocab_size=tokenizer.vocab_size,
seed=seed,
mlm_probability=mlm_probability,
epoch=epoch,
)


Expand All @@ -44,6 +55,12 @@ class MLMCollator(BaseCollator):
def __init__(self, config: MLMCollatorConfig) -> None:
super().__init__(config)
self.mlm_probability = config.mlm_probability
self.seed = config.seed
self.epoch = config.epoch

def set_epoch(self, epoch: int) -> None:
"""Set epoch used by deterministic masking."""
self.epoch = epoch

def __call__(self, samples: list[dict[str, Tensor]]) -> PreTokenizedBatch:
"""
Expand All @@ -57,6 +74,7 @@ def __call__(self, samples: list[dict[str, Tensor]]) -> PreTokenizedBatch:
"""
# Extract input_ids
input_ids_list = [sample["input_ids"] for sample in samples]
sample_indices = self._get_sample_indices(samples)

# Pad sequences
input_ids = self._pad_sequences(input_ids_list, self.pad_token_id)
Expand All @@ -65,7 +83,7 @@ def __call__(self, samples: list[dict[str, Tensor]]) -> PreTokenizedBatch:
attention_mask = self._create_attention_mask(input_ids)

# Apply MLM masking
input_ids, labels = self._mask_tokens(input_ids)
input_ids, labels = self._mask_tokens(input_ids, sample_indices)

return {
"input_ids": input_ids,
Expand All @@ -76,6 +94,7 @@ def __call__(self, samples: list[dict[str, Tensor]]) -> PreTokenizedBatch:
def _mask_tokens(
self,
input_ids: torch.Tensor,
sample_indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply MLM masking to input_ids.
Expand All @@ -90,22 +109,84 @@ def _mask_tokens(
labels = input_ids.clone()
input_ids = input_ids.clone()

# Create probability matrix for masking
probability_matrix = torch.full(input_ids.shape, self.mlm_probability)

# Get special tokens mask and set their probability to 0
if sample_indices is None:
sample_indices = torch.arange(
input_ids.shape[0],
device=input_ids.device,
dtype=torch.int64,
)
else:
sample_indices = sample_indices.to(
device=input_ids.device, dtype=torch.int64
)
if (
sample_indices.ndim != 1
or sample_indices.shape[0] != input_ids.shape[0]
):
raise ValueError(
"sample_indices must be a 1D tensor with one entry per batch row."
)

# Get special tokens mask
special_tokens_mask = self._get_special_tokens_mask(input_ids)
probability_matrix.masked_fill_(special_tokens_mask, 0.0)

# Also don't mask padding
# Also don't mask padding.
padding_mask = input_ids == self.pad_token_id
probability_matrix.masked_fill_(padding_mask, 0.0)

# Sample which tokens to mask
masked_indices = torch.bernoulli(probability_matrix).bool()
eligible_mask = ~(special_tokens_mask | padding_mask)

if self.mlm_probability <= 0.0:
masked_indices = torch.zeros_like(input_ids, dtype=torch.bool)
elif self.mlm_probability >= 1.0:
masked_indices = eligible_mask
else:
p = self._deterministic_uniform(
sample_indices, input_ids.shape[1], input_ids.device
)
masked_indices = (p < self.mlm_probability) & eligible_mask

# Set labels to -100 for non-masked tokens (will be ignored in loss)
labels[~masked_indices] = LABEL_IGNORE_ID
input_ids[masked_indices] = self.mask_token_id # Default to [MASK]

return input_ids, labels

def _get_sample_indices(self, samples: list[dict[str, Tensor]]) -> torch.Tensor:
sample_indices: list[int] = []
for row_idx, sample in enumerate(samples):
raw_idx = sample.get("sample_idx")
if raw_idx is None:
sample_indices.append(row_idx)
continue
if isinstance(raw_idx, torch.Tensor):
if raw_idx.numel() != 1:
raise ValueError(
"sample_idx tensor must contain a single scalar value."
)
sample_indices.append(int(raw_idx.item()))
continue
sample_indices.append(int(raw_idx))
return torch.tensor(sample_indices, dtype=torch.int64)

def _deterministic_uniform(
self,
sample_indices: torch.Tensor,
seq_len: int,
device: torch.device,
) -> torch.Tensor:
"""Generate deterministic pseudo-random values in [0, 1)."""
modulus = 2_147_483_647 # 2^31 - 1, prime.

sample_grid = sample_indices.unsqueeze(1)
position_grid = torch.arange(
seq_len, device=device, dtype=torch.int64
).unsqueeze(0)

x = (
sample_grid * 1_048_583
+ position_grid * 1_308_049
+ (self.seed % modulus) * 4_447_961
+ (self.epoch % modulus) * 22_695_477
+ 1_013_904_223
) % modulus
x = (x * 48_271) % modulus
return x.to(torch.float64) / float(modulus)
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_deterministic_uniform() converts to float64 and divides to get probabilities. For “bitwise determinism” across devices/backends, consider avoiding floating-point here entirely (e.g., keep values as int64 and compare against an integer threshold derived from mlm_probability). This removes any dependency on float division/rounding differences.

Suggested change
return x.to(torch.float64) / float(modulus)
inv_modulus = 1.0 / float(modulus)
return x.to(torch.float64) * inv_modulus

Copilot uses AI. Check for mistakes.
1 change: 1 addition & 0 deletions src/embedding_trainer/data/datasets/flat_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]:
return {
# Convert to int32 tensor for PyTorch (torch.uint16 has limited support)
"input_ids": torch.from_numpy(seq_tokens.astype(np.int32)),
"sample_idx": torch.tensor(idx, dtype=torch.int64),
}

def set_epoch(self, epoch: int) -> None:
Expand Down
21 changes: 21 additions & 0 deletions src/embedding_trainer/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

from __future__ import annotations

import random
from collections.abc import Callable
from dataclasses import dataclass
from functools import partial
from typing import Any

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, Sampler
from torch.utils.data.distributed import DistributedSampler
Expand All @@ -23,6 +26,7 @@ class DataLoaderConfig:
pin_memory: bool = True
drop_last: bool = True
shuffle: bool = True
seed: int = 0

def __post_init__(self) -> None:
if self.batch_size < 1:
Expand All @@ -41,20 +45,37 @@ def _build_loader_kwargs(
collator: CollatorProtocol | Callable[[list[dict[str, Any]]], Any],
) -> dict[str, Any]:
"""Build common DataLoader kwargs."""
loader_generator = torch.Generator()
loader_generator.manual_seed(config.seed)

loader_kwargs: dict[str, Any] = {
"batch_size": config.batch_size,
"collate_fn": collator,
"num_workers": config.num_workers,
"pin_memory": config.pin_memory and torch.cuda.is_available(),
"drop_last": config.drop_last,
"generator": loader_generator,
}

if config.num_workers > 0:
loader_kwargs["prefetch_factor"] = config.prefetch_factor
loader_kwargs["worker_init_fn"] = _build_worker_init_fn(config.seed)

return loader_kwargs


def _seed_worker(worker_id: int, seed: int) -> None:
worker_seed = seed + worker_id
random.seed(worker_seed)
np.random.seed(worker_seed % (2**32))
torch.manual_seed(worker_seed)


def _build_worker_init_fn(seed: int) -> Callable[[int], None]:
"""Build a worker seeding function for deterministic worker-local RNG."""
return partial(_seed_worker, seed=seed)
Comment on lines +62 to +76
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The custom worker_init_fn seeds workers with seed + worker_id, which is identical across epochs/iterations and also ignores DataLoader’s base seed (derived from the loader generator). This can unintentionally make any per-sample randomness inside datasets/collators repeat every epoch. A more deterministic + epoch-varying approach is to derive the seed from torch.initial_seed() inside the worker (then seed Python/NumPy from that).

Suggested change
loader_kwargs["worker_init_fn"] = _build_worker_init_fn(config.seed)
return loader_kwargs
def _seed_worker(worker_id: int, seed: int) -> None:
worker_seed = seed + worker_id
random.seed(worker_seed)
np.random.seed(worker_seed % (2**32))
torch.manual_seed(worker_seed)
def _build_worker_init_fn(seed: int) -> Callable[[int], None]:
"""Build a worker seeding function for deterministic worker-local RNG."""
return partial(_seed_worker, seed=seed)
loader_kwargs["worker_init_fn"] = _build_worker_init_fn()
return loader_kwargs
def _seed_worker(worker_id: int) -> None:
"""Seed Python, NumPy, and PyTorch RNGs for a worker based on torch.initial_seed()."""
worker_seed = torch.initial_seed()
random.seed(worker_seed)
np.random.seed(worker_seed % (2**32))
torch.manual_seed(worker_seed)
def _build_worker_init_fn() -> Callable[[int], None]:
"""Build a worker seeding function for deterministic worker-local RNG."""
return _seed_worker

Copilot uses AI. Check for mistakes.


def create_dataloader(
dataset: DatasetProtocol | Dataset[Any],
collator: CollatorProtocol | Callable[[list[dict[str, Any]]], Any],
Expand Down
12 changes: 12 additions & 0 deletions src/embedding_trainer/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def _generate_permutation(self) -> None:
def __len__(self) -> int:
return self._num_samples

@property
def epoch(self) -> int:
"""Current sampler epoch."""
return self._epoch

def __iter__(self) -> Iterator[int]:
# Yield remaining indices from the current epoch.
# No state mutation here — advance() and start_new_epoch() handle that.
Expand Down Expand Up @@ -83,4 +88,11 @@ def start_new_epoch(self) -> None:

def advance(self, n: int) -> None:
"""Advance the position by *n* samples (call once per batch)."""
if n < 0:
raise ValueError(f"advance expects non-negative n, got {n}")
self._start_index += n
if self._start_index > self._num_samples:
raise ValueError(
"ResumableSampler.advance moved past end of epoch "
f"({self._start_index} > {self._num_samples})."
)
Comment on lines 93 to +98
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ResumableSampler.advance() mutates _start_index before validating the new value. If this raises (e.g., n too large), the sampler is left in an invalid state, which can make debugging and error recovery harder. Consider validating the proposed new index before assignment (or rolling back on error).

Suggested change
self._start_index += n
if self._start_index > self._num_samples:
raise ValueError(
"ResumableSampler.advance moved past end of epoch "
f"({self._start_index} > {self._num_samples})."
)
new_index = self._start_index + n
if new_index > self._num_samples:
raise ValueError(
"ResumableSampler.advance moved past end of epoch "
f"({new_index} > {self._num_samples})."
)
self._start_index = new_index

Copilot uses AI. Check for mistakes.
Loading