diff --git a/src/sequence_packer.py b/src/sequence_packer.py index e937865a..b02f35d9 100644 --- a/src/sequence_packer.py +++ b/src/sequence_packer.py @@ -35,30 +35,19 @@ def __init__( else: self.warmup_tokens = warmup_tokens self.warmup_tokens = math.ceil(self.warmup_tokens / world_size) - self._step_thresholds = self._calculate_step_thresholds() - - def _calculate_step_thresholds(self): - total_batch_sizes = sum(range(self.min_batch_size, self.max_batch_size)) - steps_per_unit = self.warmup_tokens / total_batch_sizes - - thresholds = [] - cumsum = 0 - for batch_size in range(self.min_batch_size, self.max_batch_size): - cumsum += batch_size - steps = math.ceil(steps_per_unit * cumsum) - thresholds.append(steps) - return thresholds - - def __call__(self, current_step: int) -> int: - if current_step >= self.warmup_tokens: + self.tokens_per_batch_size = self._calculate_tokens_per_batch_size() + + def _calculate_tokens_per_batch_size(self): + total_batch_sizes = (self.max_batch_size-1)*(self.max_batch_size)/2 - (self.min_batch_size-1)*(self.min_batch_size)/2 + tokens_per_batch_size = self.warmup_tokens / total_batch_sizes + return tokens_per_batch_size + + def __call__(self, current_token_count: int) -> int: + if current_token_count >= self.warmup_tokens: return self.max_batch_size - for i, threshold in enumerate(self._step_thresholds): - if current_step < threshold: - return self.min_batch_size + i - - # should never hit this, but just in case - return self.max_batch_size + how_many_batch_sizes = (current_token_count) // self.tokens_per_batch_size + return self.min_batch_size + how_many_batch_sizes class SequencePackerBatchOutputTuple(NamedTuple):