diff --git a/bert/bert_trainer.py b/bert/bert_trainer.py index 7479151..c10536a 100644 --- a/bert/bert_trainer.py +++ b/bert/bert_trainer.py @@ -105,7 +105,7 @@ def train_epoch(self, model, data_loader, loss_fn, optimizer, device, scheduler, scheduler.step() optimizer.zero_grad() - avg_train_avg = float(correct_predictions) / float(len(data_loader))*10 + avg_train_avg = float(correct_predictions) / float(len(data_loader.dataset))*100 # Calculate the average loss over all of the batches. avg_train_loss = total_train_loss / len(data_loader) @@ -173,7 +173,7 @@ def eval_model(self, model, data_loader, loss_fn, device, algorithm='transformer total_eval_loss += loss.item() # Report the final accuracy for this validation run. - avg_val_accuracy = float(total_eval_accuracy) / float(len(data_loader))*10 + avg_val_accuracy = float(total_eval_accuracy) / float(len(data_loader.dataset))*100 # Calculate the average loss over all of the batches. avg_val_loss = total_eval_loss / len(data_loader)