From 204194180b19380e84301e0a0e033a19341520c5 Mon Sep 17 00:00:00 2001 From: Yogesh Chandrasekharuni Date: Mon, 13 Jun 2022 13:01:27 +0530 Subject: [PATCH] end2end CPU support --- main_end2end.py | 2 +- src/approaches/train_audio2landmark.py | 11 +++++++++-- src/approaches/train_image_translation.py | 5 ++++- src/models/model_audio2landmark.py | 5 ++++- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/main_end2end.py b/main_end2end.py index 64d09ef..89b2a5e 100644 --- a/main_end2end.py +++ b/main_end2end.py @@ -68,7 +68,7 @@ ''' STEP 1: preprocess input single image ''' img =cv2.imread('examples/' + opt_parser.jpg) -predictor = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, device='cuda', flip_input=True) +predictor = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, device="cuda" if torch.cuda.is_available() else "cpu", flip_input=True) shapes = predictor.get_landmarks(img) if (not shapes or len(shapes) != 1): print('Cannot detect face landmarks. Exit.') diff --git a/src/approaches/train_audio2landmark.py b/src/approaches/train_audio2landmark.py index 57fc602..e5305e1 100644 --- a/src/approaches/train_audio2landmark.py +++ b/src/approaches/train_audio2landmark.py @@ -55,7 +55,10 @@ def __init__(self, opt_parser, jpg_shape=None): print('G: Running on {}, total num params = {:.2f}M'.format(device, get_n_params(self.G)/1.0e6)) model_dict = self.G.state_dict() - ckpt = torch.load(opt_parser.load_a2l_G_name) + if device.type == "cpu": + ckpt = torch.load(opt_parser.load_a2l_G_name, map_location=torch.device("cpu")) + else: + ckpt = torch.load(opt_parser.load_a2l_G_name) pretrained_dict = {k: v for k, v in ckpt['G'].items() if k.split('.')[0] not in ['comb_mlp']} model_dict.update(pretrained_dict) self.G.load_state_dict(model_dict) @@ -68,7 +71,11 @@ def __init__(self, opt_parser, jpg_shape=None): in_size=80, use_prior_net=True, bidirectional=False, drop_out=0.5) - ckpt = torch.load(opt_parser.load_a2l_C_name) + if device.type == "cpu": + ckpt = torch.load(opt_parser.load_a2l_C_name, map_location=torch.device("cpu")) + else: + ckpt = torch.load(opt_parser.load_a2l_C_name) + self.C.load_state_dict(ckpt['model_g_face_id']) # self.C.load_state_dict(ckpt['C']) print('======== LOAD PRETRAINED FACE ID MODEL {} ========='.format(opt_parser.load_a2l_C_name)) diff --git a/src/approaches/train_image_translation.py b/src/approaches/train_image_translation.py index dc79d19..25027ba 100644 --- a/src/approaches/train_image_translation.py +++ b/src/approaches/train_image_translation.py @@ -42,7 +42,10 @@ def __init__(self, opt_parser, single_test=False): self.G = ResUnetGenerator(input_nc=6, output_nc=3, num_downs=6, use_dropout=False) if (opt_parser.load_G_name != ''): - ckpt = torch.load(opt_parser.load_G_name) + if torch.cuda.is_available(): + ckpt = torch.load(opt_parser.load_G_name) + else: + ckpt = torch.load(opt_parser.load_G_name, map_location=torch.device("cpu")) try: self.G.load_state_dict(ckpt['G']) except: diff --git a/src/models/model_audio2landmark.py b/src/models/model_audio2landmark.py index 2f2ec80..566a3d8 100644 --- a/src/models/model_audio2landmark.py +++ b/src/models/model_audio2landmark.py @@ -242,7 +242,10 @@ def __init__(self, d_model, heads, dropout=0.1): self.attn_1 = MultiHeadAttention(heads, d_model) self.attn_2 = MultiHeadAttention(heads, d_model) - self.ff = FeedForward(d_model).cuda() + if device.type == "cpu": + self.ff = FeedForward(d_model) + else: + self.ff = FeedForward(d_model).cuda() def forward(self, x, e_outputs, src_mask, trg_mask): x2 = self.norm_1(x)