From c0a253680b6b469023e701aa82ea1978e09899d5 Mon Sep 17 00:00:00 2001 From: Onur Gungor Date: Mon, 24 Feb 2025 22:19:54 +0100 Subject: [PATCH] Bugfix: batch_size_warmup_scheduler was taking too long or was impossible for real world max_batch_size values --- src/sequence_packer.py | 33 +++++++++++---------------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/src/sequence_packer.py b/src/sequence_packer.py index e937865a..b02f35d9 100644 --- a/src/sequence_packer.py +++ b/src/sequence_packer.py @@ -35,30 +35,19 @@ def __init__( else: self.warmup_tokens = warmup_tokens self.warmup_tokens = math.ceil(self.warmup_tokens / world_size) - self._step_thresholds = self._calculate_step_thresholds() - - def _calculate_step_thresholds(self): - total_batch_sizes = sum(range(self.min_batch_size, self.max_batch_size)) - steps_per_unit = self.warmup_tokens / total_batch_sizes - - thresholds = [] - cumsum = 0 - for batch_size in range(self.min_batch_size, self.max_batch_size): - cumsum += batch_size - steps = math.ceil(steps_per_unit * cumsum) - thresholds.append(steps) - return thresholds - - def __call__(self, current_step: int) -> int: - if current_step >= self.warmup_tokens: + self.tokens_per_batch_size = self._calculate_tokens_per_batch_size() + + def _calculate_tokens_per_batch_size(self): + total_batch_sizes = (self.max_batch_size-1)*(self.max_batch_size)/2 - (self.min_batch_size-1)*(self.min_batch_size)/2 + tokens_per_batch_size = self.warmup_tokens / total_batch_sizes + return tokens_per_batch_size + + def __call__(self, current_token_count: int) -> int: + if current_token_count >= self.warmup_tokens: return self.max_batch_size - for i, threshold in enumerate(self._step_thresholds): - if current_step < threshold: - return self.min_batch_size + i - - # should never hit this, but just in case - return self.max_batch_size + how_many_batch_sizes = (current_token_count) // self.tokens_per_batch_size + return self.min_batch_size + how_many_batch_sizes class SequencePackerBatchOutputTuple(NamedTuple):