diff --git a/experiments/run_finetuning.py b/experiments/run_finetuning.py index e39dbd91..59a29b7a 100644 --- a/experiments/run_finetuning.py +++ b/experiments/run_finetuning.py @@ -53,7 +53,7 @@ def get_optimizer_and_scheduler(model, train_dataset, config): class CustomTrainer(Trainer): def __init__(self, *args, train_loader=None, test_loader=None, **kwargs): super().__init__(*args, **kwargs) - self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.model.config.pad_token_id) + self.loss_fn = torch.nn.CrossEntropyLoss() self.train_loader = train_loader self.test_loader = test_loader @@ -62,6 +62,21 @@ def get_train_dataloader(self) -> DataLoader: def get_eval_dataloader(self, _) -> DataLoader: return self.test_loader + + def compute_loss(self, model, inputs, return_outputs=False): + labels = inputs.pop('labels') + attention_mask = inputs["attention_mask"] + outputs = model(**inputs) + labels = labels[..., 1:].contiguous() + logits = outputs.logits[..., :-1, :].contiguous() + attention_mask = attention_mask[..., :-1].contiguous() + + # ignore padding tokens when computing the loss + logits = logits * attention_mask.unsqueeze(-1) + + loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1)) + + return (loss, outputs) if return_outputs else loss def argparser():