From b9a12fee0ca83d0b95e535cef0f2ed0c643694f1 Mon Sep 17 00:00:00 2001 From: hajix Date: Sun, 21 May 2023 12:03:19 +0330 Subject: [PATCH 1/6] Increase ONNX compatibility --- model.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/model.py b/model.py index 040f41f..0515454 100644 --- a/model.py +++ b/model.py @@ -75,7 +75,8 @@ def __init__(self, in_chan, out_chan, *args, **kwargs): def forward(self, x): feat = self.conv(x) - atten = F.avg_pool2d(feat, feat.size()[2:]) + # atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = F.adaptive_avg_pool2d(feat, output_size=(1, 1)) atten = self.conv_atten(atten) atten = self.bn_atten(atten) atten = self.sigmoid_atten(atten) @@ -108,7 +109,8 @@ def forward(self, x): H16, W16 = feat16.size()[2:] H32, W32 = feat32.size()[2:] - avg = F.avg_pool2d(feat32, feat32.size()[2:]) + # avg = F.avg_pool2d(feat32, feat32.size()[2:]) + avg = F.adaptive_avg_pool2d(feat32, output_size=(1, 1)) avg = self.conv_avg(avg) avg_up = F.interpolate(avg, (H32, W32), mode='nearest') @@ -200,7 +202,8 @@ def __init__(self, in_chan, out_chan, *args, **kwargs): def forward(self, fsp, fcp): fcat = torch.cat([fsp, fcp], dim=1) feat = self.convblk(fcat) - atten = F.avg_pool2d(feat, feat.size()[2:]) + # atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = F.adaptive_avg_pool2d(feat, output_size=(1, 1)) atten = self.conv1(atten) atten = self.relu(atten) atten = self.conv2(atten) From 4ce05ed3c1661a847d965c1b73c30d95a562b0cc Mon Sep 17 00:00:00 2001 From: hajix Date: Sun, 21 May 2023 12:05:31 +0330 Subject: [PATCH 2/6] Update device setting stype --- test.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test.py b/test.py index 76c4f56..d5ea294 100644 --- a/test.py +++ b/test.py @@ -13,6 +13,9 @@ import torchvision.transforms as transforms import cv2 +device = torch.device('cuda') + + def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'): # Colors for all 20 parts part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], @@ -55,7 +58,7 @@ def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth') n_classes = 19 net = BiSeNet(n_classes=n_classes) - net.cuda() + net.to(device) save_pth = osp.join('res/cp', cp) net.load_state_dict(torch.load(save_pth)) net.eval() @@ -70,7 +73,7 @@ def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth') image = img.resize((512, 512), Image.BILINEAR) img = to_tensor(image) img = torch.unsqueeze(img, 0) - img = img.cuda() + img = img.to(device) out = net(img)[0] parsing = out.squeeze(0).cpu().numpy().argmax(0) # print(parsing) From 739b60d2edf7699c3084c3c906c07e485360547a Mon Sep 17 00:00:00 2001 From: hajix Date: Sun, 21 May 2023 12:05:55 +0330 Subject: [PATCH 3/6] Improve lint --- test.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/test.py b/test.py index d5ea294..bdda00c 100644 --- a/test.py +++ b/test.py @@ -51,6 +51,7 @@ def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_res # return vis_im + def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'): if not os.path.exists(respth): @@ -82,11 +83,6 @@ def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth') vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path)) - - - - - if __name__ == "__main__": evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='79999_iter.pth') From afdd535c702f481c64dff486e5299818da011baf Mon Sep 17 00:00:00 2001 From: hajix Date: Sun, 21 May 2023 12:06:24 +0330 Subject: [PATCH 4/6] Eliminate absolute addressing --- test.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test.py b/test.py index bdda00c..2971f1c 100644 --- a/test.py +++ b/test.py @@ -84,6 +84,4 @@ def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth') if __name__ == "__main__": - evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='79999_iter.pth') - - + evaluate(dspth='test_image_folder', cp='79999_iter.pth') From 193e4afdab504217b6acdaeee77ee1005702d9bc Mon Sep 17 00:00:00 2001 From: hajix Date: Sun, 21 May 2023 12:42:59 +0330 Subject: [PATCH 5/6] Save model in torchscript and ONNX formats --- test.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/test.py b/test.py index 2971f1c..36f8136 100644 --- a/test.py +++ b/test.py @@ -68,6 +68,29 @@ def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth') transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) + # save script module + jit_model = torch.jit.script(net) + jit_model.save('/tmp/face_parsing.pt') + # save onnx format + batch_size = 32 + x = torch.randn(batch_size, 3, 512, 512, requires_grad=True).to(device) + torch.onnx.export( + net, + x, + '/tmp/face_parsing.onnx', + export_params=True, + opset_version=13, + do_constant_folding=True, + input_names=['input'], + output_names=['output', 'output16', 'output32'], + dynamic_axes={ + 'input' : {0: 'batch_size'}, + 'output' : {0: 'batch_size'}, + 'output16': {0: 'batch_size'}, + 'output32': {0: 'batch_size'}, + }, + ) + with torch.no_grad(): for image_path in os.listdir(dspth): img = Image.open(osp.join(dspth, image_path)) From f8f88a54d5be2e22d45f91677c836706b93ae855 Mon Sep 17 00:00:00 2001 From: hajix Date: Sun, 21 May 2023 13:36:14 +0330 Subject: [PATCH 6/6] Add requirements --- requirements.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 requirements.py diff --git a/requirements.py b/requirements.py new file mode 100644 index 0000000..f9d991b --- /dev/null +++ b/requirements.py @@ -0,0 +1,29 @@ +asttokens==2.2.1 +backcall==0.2.0 +certifi==2023.5.7 +charset-normalizer==3.1.0 +decorator==5.1.1 +executing==1.2.0 +idna==3.4 +ipython==8.13.2 +jedi==0.18.2 +matplotlib-inline==0.1.6 +numpy==1.24.3 +opencv-python==4.7.0.72 +parso==0.8.3 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==9.5.0 +prompt-toolkit==3.0.38 +ptyprocess==0.7.0 +pure-eval==0.2.2 +Pygments==2.15.1 +requests==2.30.0 +six==1.16.0 +stack-data==0.6.2 +torch==1.12.1+cu116 +torchvision==0.13.1+cu116 +traitlets==5.9.0 +typing_extensions==4.5.0 +urllib3==2.0.2 +wcwidth==0.2.6