From 69afd7a9b25a917f42c3a3d0b35f9723d6af20fb Mon Sep 17 00:00:00 2001 From: Amithrajith Mamidala Date: Fri, 25 Oct 2024 17:40:24 +0000 Subject: [PATCH] tokens_per_batch fixed to take into account DP and micro-batch accumulation steps --- axlearn/experiments/text/gpt/fuji.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 0332de70d..f6f6a335a 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -101,7 +101,7 @@ def get_trainer_kwargs( ) -> Dict[str, Any]: """Construct default trainer kwargs given a model size.""" # tokens_per_batch = 4 * (1024**2) # 4M tokens. - tokens_per_batch = MAX_SEQUENCE_LENGTH[version] # one sequence per batch + tokens_per_batch = MAX_SEQUENCE_LENGTH[version] * int((jax.device_count()/TRN_MODEL_AXIS_SIZE)*GRADIENT_ACCUMULATION_MICROBATCHES) # one sequence per batch max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch max_sequence_length = MAX_SEQUENCE_LENGTH[version] train_batch_size = tokens_per_batch // max_sequence_length