From 400c25ecc51067895dead428c3319afd9107186d Mon Sep 17 00:00:00 2001 From: Ali Malek Date: Tue, 12 May 2020 20:00:51 +0100 Subject: [PATCH] Quick solution for division by Zero `torch.svd` cannot handle infinties. closes #13, closes #97 --- photo_wct.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/photo_wct.py b/photo_wct.py index e274b18..2068de6 100644 --- a/photo_wct.py +++ b/photo_wct.py @@ -128,8 +128,13 @@ def __wct_core(self, cont_feat, styl_feat): iden = torch.eye(cFSize[0]) # .double() if self.is_cuda: iden = iden.cuda() - - contentConv = torch.mm(cont_feat, cont_feat.t()).div(cFSize[1] - 1) + iden + + #To avoid division by zero + small_num = 0.00001 + cFdivisor = cFSize[1] - 1. + if cFdivisor == 0.: + cFdivisor = small_num + contentConv = torch.mm(cont_feat, cont_feat.t()).div(cFdivisor) + iden # del iden c_u, c_e, c_v = torch.svd(contentConv, some=False) # c_e2, c_v = torch.eig(contentConv, True) @@ -144,7 +149,12 @@ def __wct_core(self, cont_feat, styl_feat): sFSize = styl_feat.size() s_mean = torch.mean(styl_feat, 1) styl_feat = styl_feat - s_mean.unsqueeze(1).expand_as(styl_feat) - styleConv = torch.mm(styl_feat, styl_feat.t()).div(sFSize[1] - 1) + #To avoid division by zero + sFdivisor = sFSize[1] - 1. + if sFdivisor == 0.: + sFdivisor = small_num + + styleConv = torch.mm(styl_feat, styl_feat.t()).div(sFdivisor) s_u, s_e, s_v = torch.svd(styleConv, some=False) k_s = sFSize[0] @@ -168,4 +178,4 @@ def is_cuda(self): return next(self.parameters()).is_cuda def forward(self, *input): - pass \ No newline at end of file + pass