Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion experiments/run_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_optimizer_and_scheduler(model, train_dataset, config):
class CustomTrainer(Trainer):
def __init__(self, *args, train_loader=None, test_loader=None, **kwargs):
super().__init__(*args, **kwargs)
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.model.config.pad_token_id)
self.loss_fn = torch.nn.CrossEntropyLoss()
self.train_loader = train_loader
self.test_loader = test_loader

Expand All @@ -62,6 +62,21 @@ def get_train_dataloader(self) -> DataLoader:

def get_eval_dataloader(self, _) -> DataLoader:
return self.test_loader

def compute_loss(self, model, inputs, return_outputs=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this is most of the calculation for ppl can we make this one and our gpu_utils implementation consistent, and add a test? Better still would be to call a common subroutine from here and from the the ppl calc implementation. NB Pashmin'as ppl refactoring PR should go in first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the perplexity calculation should have the same changes, good idea to combine them

labels = inputs.pop('labels')
attention_mask = inputs["attention_mask"]
outputs = model(**inputs)
labels = labels[..., 1:].contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we should double check whether contiguous is required here, as we don't use it in our ppl implementation.

logits = outputs.logits[..., :-1, :].contiguous()
attention_mask = attention_mask[..., :-1].contiguous()

# ignore padding tokens when computing the loss
logits = logits * attention_mask.unsqueeze(-1)

loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1))

return (loss, outputs) if return_outputs else loss


def argparser():
Expand Down