From 1036d17e6275a836e40755f8587d41479d1556d1 Mon Sep 17 00:00:00 2001 From: Erick7451 <40013722+Erick7451@users.noreply.github.com> Date: Sun, 14 Jul 2019 13:20:17 -0500 Subject: [PATCH 1/2] Update train.py --- train.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 5ea5459..3c0a735 100644 --- a/train.py +++ b/train.py @@ -50,8 +50,13 @@ def main(): + parallel = False model = FaceNetModel(embedding_size = args.embedding_size, num_classes = args.num_classes).to(device) + if torch.cuda.device() > 1: + model = nn.DataParallel(model) + parallel = True + optimizer = optim.Adam(model.parameters(), lr = args.learning_rate) scheduler = lr_scheduler.StepLR(optimizer, step_size = 50, gamma = 0.1) @@ -124,9 +129,14 @@ def train_valid(model, optimizer, scheduler, epoch, dataloaders, data_size): pos_hard_cls = pos_cls[hard_triplets].to(device) neg_hard_cls = neg_cls[hard_triplets].to(device) - anc_img_pred = model.forward_classifier(anc_hard_img).to(device) - pos_img_pred = model.forward_classifier(pos_hard_img).to(device) - neg_img_pred = model.forward_classifier(neg_hard_img).to(device) + if parallel: + anc_img_pred = model.module.forward_classifier(anc_hard_img).to(device) + pos_img_pred = model.module.forward_classifier(pos_hard_img).to(device) + neg_img_pred = model.module.forward_classifier(neg_hard_img).to(device) + else: + anc_img_pred = model.forward_classifier(anc_hard_img).to(device) + pos_img_pred = model.forward_classifier(pos_hard_img).to(device) + neg_img_pred = model.forward_classifier(neg_hard_img).to(device) triplet_loss = TripletLoss(args.margin).forward(anc_hard_embed, pos_hard_embed, neg_hard_embed).to(device) From 73c0667de4e3d8916e671d9602f211b66440c8b7 Mon Sep 17 00:00:00 2001 From: Erick Platero <40013722+Erick7451@users.noreply.github.com> Date: Tue, 6 Aug 2019 19:33:21 -0500 Subject: [PATCH 2/2] Update README.md --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index ab64e46..802c36a 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,9 @@ - Pytorch implementation of the paper: "FaceNet: A Unified Embedding for Face Recognition and Clustering". - Training of network is done using triplet loss. +# Difference from Main Repository +- This repository takes advantage of Pytorch's DataParallel capacities to experience a much faster training time. + # How to train/validate model - Download vggface2 (for training) and lfw (for validation) datasets.