diff --git a/llmfoundry/models/hf/model_wrapper.py b/llmfoundry/models/hf/model_wrapper.py index 58bd11a55f..192cc690d9 100644 --- a/llmfoundry/models/hf/model_wrapper.py +++ b/llmfoundry/models/hf/model_wrapper.py @@ -14,6 +14,7 @@ from torchmetrics import Metric from transformers import PreTrainedTokenizerBase from transformers.utils.generic import ModelOutput +from torch.utils.flop_counter import FlopCounterMode from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp @@ -62,6 +63,7 @@ def __init__(self, self.z_loss = float(z_loss) if self.z_loss < 0.0: raise ValueError(f'z_loss(={z_loss}) cannot be negative.') + self._fwd_flops = None # Note: We need to add the FSDP related attributes to the model AFTER the super init, # so that the (possible) embedding resizing doesn't destroy them @@ -77,7 +79,15 @@ def forward(self, batch: Mapping): batch = { k: v for k, v in batch.items() if k in self.model_forward_args } - output = self.model(**batch) # type: ignore (thirdparty) + if self._fwd_flops: + output = self.model(**batch) # type: ignore (thirdparty) + else: # count flops and save + bs = batch['input_ids'].shape[0] + flop_counter = FlopCounterMode(display=False) + with flop_counter: + output = self.model(**batch) + self._fwd_flops = flop_counter.get_total_flops() / bs + else: raise ValueError( 'Unexpected batch type. Expected a dictionary with keys corresponding to the inputs to the forward function of the Huggingface model' @@ -106,3 +116,7 @@ def loss(self, outputs: ModelOutput, batch: Mapping): else: outputs[0] += z_loss return outputs[0] + + def flops_per_batch(self, batch: Mapping) -> int: + bs = batch['input_ids'].shape[0] + return self._fwd_flops * 3 * bs # approximately 1x for fwd + 2x for bwd