diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 0332de70d..f6f6a335a 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -101,7 +101,7 @@ def get_trainer_kwargs( ) -> Dict[str, Any]: """Construct default trainer kwargs given a model size.""" # tokens_per_batch = 4 * (1024**2) # 4M tokens. - tokens_per_batch = MAX_SEQUENCE_LENGTH[version] # one sequence per batch + tokens_per_batch = MAX_SEQUENCE_LENGTH[version] * int((jax.device_count()/TRN_MODEL_AXIS_SIZE)*GRADIENT_ACCUMULATION_MICROBATCHES) # one sequence per batch max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch max_sequence_length = MAX_SEQUENCE_LENGTH[version] train_batch_size = tokens_per_batch // max_sequence_length