Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,45 +45,47 @@
#val_list = '/wrk/yuzitong/DONOTREMOVE/CVPRW2020/4@3_dev_res.txt'


def contrast_depth_conv(input):
def contrast_depth_conv(input, kernel_filter):
''' compute contrast depth in both of (out, label) '''
'''
input 32x32
output 8x32x32
'''


kernel_filter_list =[
[[1,0,0],[0,-1,0],[0,0,0]], [[0,1,0],[0,-1,0],[0,0,0]], [[0,0,1],[0,-1,0],[0,0,0]],
[[0,0,0],[1,-1,0],[0,0,0]], [[0,0,0],[0,-1,1],[0,0,0]],
[[0,0,0],[0,-1,0],[1,0,0]], [[0,0,0],[0,-1,0],[0,1,0]], [[0,0,0],[0,-1,0],[0,0,1]]
]

kernel_filter = np.array(kernel_filter_list, np.float32)

kernel_filter = torch.from_numpy(kernel_filter.astype(np.float)).float().cuda()
# weights (in_channel, out_channel, kernel, kernel)
kernel_filter = kernel_filter.unsqueeze(dim=1)

input = input.unsqueeze(dim=1).expand(input.shape[0], 8, input.shape[1],input.shape[2])

contrast_depth = F.conv2d(input, weight=kernel_filter, groups=8) # depthwise conv

return contrast_depth


class Contrast_depth_loss(nn.Module): # Pearson range [-1, 1] so if < 0, abs|loss| ; if >0, 1- loss
def __init__(self):
super(Contrast_depth_loss,self).__init__()

kernel_filter_list =[
[[1,0,0],[0,-1,0],[0,0,0]], [[0,1,0],[0,-1,0],[0,0,0]], [[0,0,1],[0,-1,0],[0,0,0]],
[[0,0,0],[1,-1,0],[0,0,0]], [[0,0,0],[0,-1,1],[0,0,0]],
[[0,0,0],[0,-1,0],[1,0,0]], [[0,0,0],[0,-1,0],[0,1,0]], [[0,0,0],[0,-1,0],[0,0,1]]
]

kernel_filter = np.array(kernel_filter_list, np.float32)
kernel_filter = torch.from_numpy(kernel_filter.astype(np.float)).float().cuda()
kernel_filter = kernel_filter.unsqueeze(dim=1)

self.kernel_filter = kernel_filter

return


def forward(self, out, label):
'''
compute contrast depth in both of (out, label),
then get the loss of them
tf.atrous_convd match tf-versions: 1.4
'''
contrast_out = contrast_depth_conv(out)
contrast_label = contrast_depth_conv(label)
contrast_out = contrast_depth_conv(out, self.kernel_filter)
contrast_label = contrast_depth_conv(label, self.kernel_filter)


criterion_MSE = nn.MSELoss().cuda()
Expand Down