diff --git a/.gitignore b/.gitignore index 7158e06..8c990cc 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,7 @@ *.pkl output/* -train_log/* +# train_log/* *.mp4 test/ @@ -12,3 +12,5 @@ test/ *.npz *.zip + +.DS_Store \ No newline at end of file diff --git a/inference_img.py b/inference_img.py index cee947e..70d4dfd 100644 --- a/inference_img.py +++ b/inference_img.py @@ -6,7 +6,7 @@ import warnings warnings.filterwarnings("ignore") -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") torch.set_grad_enabled(False) if torch.cuda.is_available(): torch.backends.cudnn.enabled = True diff --git a/inference_img_SR.py b/inference_img_SR.py index 4ecf2ac..cc6ccd5 100644 --- a/inference_img_SR.py +++ b/inference_img_SR.py @@ -6,7 +6,7 @@ import warnings warnings.filterwarnings("ignore") -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") torch.set_grad_enabled(False) if torch.cuda.is_available(): torch.backends.cudnn.enabled = True diff --git a/inference_video.py b/inference_video.py index 1069fc1..f9a23ed 100644 --- a/inference_video.py +++ b/inference_video.py @@ -81,7 +81,7 @@ def transferAudio(sourceVideo, targetVideo): if not args.img is None: args.png = True -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") torch.set_grad_enabled(False) if torch.cuda.is_available(): torch.backends.cudnn.enabled = True diff --git a/inference_video_enhance.py b/inference_video_enhance.py index d3076cd..1203899 100644 --- a/inference_video_enhance.py +++ b/inference_video_enhance.py @@ -67,7 +67,7 @@ def transferAudio(sourceVideo, targetVideo): if not args.img is None: args.png = True -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") torch.set_grad_enabled(False) if torch.cuda.is_available(): torch.backends.cudnn.enabled = True diff --git a/model/loss.py b/model/loss.py index 72e5de6..ee84ab5 100644 --- a/model/loss.py +++ b/model/loss.py @@ -4,7 +4,7 @@ import torch.nn.functional as F import torchvision.models as models -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") class EPE(nn.Module): diff --git a/model/pytorch_msssim/__init__.py b/model/pytorch_msssim/__init__.py index a4d3032..df91a66 100644 --- a/model/pytorch_msssim/__init__.py +++ b/model/pytorch_msssim/__init__.py @@ -3,7 +3,7 @@ from math import exp import numpy as np -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") def gaussian(window_size, sigma): gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) diff --git a/model/warplayer.py b/model/warplayer.py index 21b0b90..fe13aa3 100644 --- a/model/warplayer.py +++ b/model/warplayer.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") backwarp_tenGrid = {} @@ -19,4 +19,12 @@ def warp(tenInput, tenFlow): tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) - return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) + + pd = 'border' + + # mps does not support 'border' padding mode, use 'zero' instead + if tenInput.device.type == "mps": + pd = 'zeros' + g = g.clamp(-1, 1) + + return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode=pd, align_corners=True) diff --git a/train_log/IFNet_HDv3.py b/train_log/IFNet_HDv3.py new file mode 100644 index 0000000..6601c9c --- /dev/null +++ b/train_log/IFNet_HDv3.py @@ -0,0 +1,169 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from model.warplayer import warp +# from train_log.refine import * + +device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.LeakyReLU(0.2, True) + ) + +def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=False), + nn.BatchNorm2d(out_planes), + nn.LeakyReLU(0.2, True) + ) + +class Head(nn.Module): + def __init__(self): + super(Head, self).__init__() + self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1) + self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1) + self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1) + self.cnn3 = nn.ConvTranspose2d(16, 16, 4, 2, 1) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x, feat=False): + x0 = self.cnn0(x) + x = self.relu(x0) + x1 = self.cnn1(x) + x = self.relu(x1) + x2 = self.cnn2(x) + x = self.relu(x2) + x3 = self.cnn3(x) + if feat: + return [x0, x1, x2, x3] + return x3 + +class ResConv(nn.Module): + def __init__(self, c, dilation=1): + super(ResConv, self).__init__() + self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1\ +) + self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + return self.relu(self.conv(x) * self.beta + x) + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64): + super(IFBlock, self).__init__() + self.conv0 = nn.Sequential( + conv(in_planes, c//2, 3, 2, 1), + conv(c//2, c, 3, 2, 1), + ) + self.convblock = nn.Sequential( + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ) + self.lastconv = nn.Sequential( + nn.ConvTranspose2d(c, 4*13, 4, 2, 1), + nn.PixelShuffle(2) + ) + + def forward(self, x, flow=None, scale=1): + x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False) + if flow is not None: + flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale + x = torch.cat((x, flow), 1) + feat = self.conv0(x) + feat = self.convblock(feat) + tmp = self.lastconv(feat) + tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False) + flow = tmp[:, :4] * scale + mask = tmp[:, 4:5] + feat = tmp[:, 5:] + return flow, mask, feat + +class IFNet(nn.Module): + def __init__(self): + super(IFNet, self).__init__() + self.block0 = IFBlock(7+32, c=192) + self.block1 = IFBlock(8+4+8+32, c=128) + self.block2 = IFBlock(8+4+8+32, c=96) + self.block3 = IFBlock(8+4+8+32, c=64) + self.block4 = IFBlock(8+4+8+32, c=32) + self.encode = Head() + + # not used during inference + ''' + self.teacher = IFBlock(8+4+8+3+32, c=64) + self.caltime = nn.Sequential( + nn.Conv2d(32+9, 8, 3, 2, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(32, 64, 3, 2, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(64, 64, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(64, 64, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(64, 1, 3, 1, 1), + nn.Sigmoid() + ) + ''' + + def forward(self, x, timestep=0.5, scale_list=[8, 4, 2, 1], training=False, fastmode=True, ensemble=False): + if training == False: + channel = x.shape[1] // 2 + img0 = x[:, :channel] + img1 = x[:, channel:] + if not torch.is_tensor(timestep): + timestep = (x[:, :1].clone() * 0 + 1) * timestep + else: + timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) + f0 = self.encode(img0[:, :3]) + f1 = self.encode(img1[:, :3]) + flow_list = [] + merged = [] + mask_list = [] + warped_img0 = img0 + warped_img1 = img1 + flow = None + mask = None + loss_cons = 0 + block = [self.block0, self.block1, self.block2, self.block3, self.block4] + for i in range(5): + if flow is None: + flow, mask, feat = block[i](torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1), None, scale=scale_list[i]) + if ensemble: + print("warning: ensemble is not supported since RIFEv4.21") + else: + wf0 = warp(f0, flow[:, :2]) + wf1 = warp(f1, flow[:, 2:4]) + fd, m0, feat = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], wf0, wf1, timestep, mask, feat), 1), flow, scale=scale_list[i]) + if ensemble: + print("warning: ensemble is not supported since RIFEv4.21") + else: + mask = m0 + flow = flow + fd + mask_list.append(mask) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + merged.append((warped_img0, warped_img1)) + mask = torch.sigmoid(mask) + merged[4] = (warped_img0 * mask + warped_img1 * (1 - mask)) + if not fastmode: + print('contextnet is removed') + ''' + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, :3] * 2 - 1 + merged[4] = torch.clamp(merged[4] + res, 0, 1) + ''' + return flow_list, mask_list[4], merged diff --git a/train_log/RIFE_HDv3.py b/train_log/RIFE_HDv3.py new file mode 100644 index 0000000..ee49158 --- /dev/null +++ b/train_log/RIFE_HDv3.py @@ -0,0 +1,89 @@ +import torch +import torch.nn as nn +import numpy as np +from torch.optim import AdamW +import torch.optim as optim +import itertools +from model.warplayer import warp +from torch.nn.parallel import DistributedDataParallel as DDP +from train_log.IFNet_HDv3 import * +import torch.nn.functional as F +from model.loss import * + +device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + +class Model: + def __init__(self, local_rank=-1): + self.flownet = IFNet() + self.device() + self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4) + self.epe = EPE() + self.version = 4.25 + # self.vgg = VGGPerceptualLoss().to(device) + self.sobel = SOBEL() + if local_rank != -1: + self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank) + + def train(self): + self.flownet.train() + + def eval(self): + self.flownet.eval() + + def device(self): + self.flownet.to(device) + + def load_model(self, path, rank=0): + def convert(param): + if rank == -1: + return { + k.replace("module.", ""): v + for k, v in param.items() + if "module." in k + } + else: + return param + if rank <= 0: + if torch.cuda.is_available(): + self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path))), False) + else: + self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path), map_location ='cpu')), False) + + def save_model(self, path, rank=0): + if rank == 0: + torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path)) + + def inference(self, img0, img1, timestep=0.5, scale=1.0): + imgs = torch.cat((img0, img1), 1) + scale_list = [16/scale, 8/scale, 4/scale, 2/scale, 1/scale] + flow, mask, merged = self.flownet(imgs, timestep, scale_list) + return merged[-1] + + def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): + for param_group in self.optimG.param_groups: + param_group['lr'] = learning_rate + img0 = imgs[:, :3] + img1 = imgs[:, 3:] + if training: + self.train() + else: + self.eval() + scale = [16, 8, 4, 2, 1] + flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training) + loss_l1 = (merged[-1] - gt).abs().mean() + loss_smooth = self.sobel(flow[-1], flow[-1]*0).mean() + # loss_vgg = self.vgg(merged[-1], gt) + if training: + self.optimG.zero_grad() + loss_G = loss_l1 + loss_cons + loss_smooth * 0.1 + loss_G.backward() + self.optimG.step() + else: + flow_teacher = flow[2] + return merged[-1], { + 'mask': mask, + 'flow': flow[-1][:, :2], + 'loss_l1': loss_l1, + 'loss_cons': loss_cons, + 'loss_smooth': loss_smooth, + } diff --git a/train_log/refine.py b/train_log/refine.py new file mode 100644 index 0000000..41b648e --- /dev/null +++ b/train_log/refine.py @@ -0,0 +1,90 @@ +import torch +import torch.nn as nn +import numpy as np +from torch.optim import AdamW +import torch.optim as optim +import itertools +from model.warplayer import warp +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.nn.functional as F + +device = torch.device("cuda") + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.LeakyReLU(0.2, True) + ) + +def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + ) + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True), + nn.LeakyReLU(0.2, True) + ) + +class Conv2(nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + +c = 16 +class Contextnet(nn.Module): + def __init__(self): + super(Contextnet, self).__init__() + self.conv1 = Conv2(3, c) + self.conv2 = Conv2(c, 2*c) + self.conv3 = Conv2(2*c, 4*c) + self.conv4 = Conv2(4*c, 8*c) + + def forward(self, x, flow): + x = self.conv1(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f1 = warp(x, flow) + x = self.conv2(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f2 = warp(x, flow) + x = self.conv3(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f3 = warp(x, flow) + x = self.conv4(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f4 = warp(x, flow) + return [f1, f2, f3, f4] + +class Unet(nn.Module): + def __init__(self): + super(Unet, self).__init__() + self.down0 = Conv2(17, 2*c) + self.down1 = Conv2(4*c, 4*c) + self.down2 = Conv2(8*c, 8*c) + self.down3 = Conv2(16*c, 16*c) + self.up0 = deconv(32*c, 8*c) + self.up1 = deconv(16*c, 4*c) + self.up2 = deconv(8*c, 2*c) + self.up3 = deconv(4*c, c) + self.conv = nn.Conv2d(c, 3, 3, 1, 1) + + def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1): + s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1)) + s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) + s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) + s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) + x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) + x = self.up1(torch.cat((x, s2), 1)) + x = self.up2(torch.cat((x, s1), 1)) + x = self.up3(torch.cat((x, s0), 1)) + x = self.conv(x) + return torch.sigmoid(x) diff --git a/train_log_SAFA/flownet.py b/train_log_SAFA/flownet.py new file mode 100644 index 0000000..ec7f270 --- /dev/null +++ b/train_log_SAFA/flownet.py @@ -0,0 +1,171 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import models +from train_log.warplayer import warp +from train_log.head import Head + +device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True, groups=groups), + nn.LeakyReLU(0.2, True) + ) + +def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=False), + nn.BatchNorm2d(out_planes), + nn.LeakyReLU(0.2, True) + ) + +class Resblock(nn.Module): + def __init__(self, c, dilation=1): + super(Resblock, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(c, 2*c, 3, 2, dilation, dilation=dilation, groups=1), + nn.LeakyReLU(0.2, True), + nn.ConvTranspose2d(2*c, c, 4, 2, 1) + ) + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.prelu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + y = self.conv(x) + return self.prelu(y * self.beta + x) + +class RoundSTE(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + y = torch.bernoulli(x) + return y + + @staticmethod + def backward(ctx, grad): + return grad, None + +class RecurrentBlock(nn.Module): + def __init__(self, c, dilation=1, depth=6): + super(RecurrentBlock, self).__init__() + self.conv_stem = conv(3*c+6+1, c, 3, 1, 1, groups=1) + self.conv_backbone = torch.nn.ModuleList([]) + self.depth = depth + for i in range(depth): + self.conv_backbone.append(Resblock(c, dilation)) + + def forward(self, x, i0, i1, flow, timestep, convflow, getscale): + flow_down = F.interpolate(flow, scale_factor=0.5, mode="bilinear") + i0 = warp(i0, flow_down[:, :2] * 0.5) + i1 = warp(i1, flow_down[:, 2:4] * 0.5) + x = torch.cat((x, flow_down, i0, i1, timestep), 1) + scale = RoundSTE.apply(getscale(x)).unsqueeze(2).unsqueeze(3) + feat = 0 + if scale.shape[0] != 1 or (scale[:, 0:1].mean() > 0.5 and scale[:, 1:2].mean() > 0.5): + x0 = self.conv_stem(x) + for i in range(self.depth): + x0 = self.conv_backbone[i](x0) + feat = feat + x0 * scale[:, 0:1] * scale[:, 1:2] + + if scale.shape[0] != 1 or (scale[:, 0:1].mean() < 0.5 and scale[:, 1:2].mean() > 0.5): + x1 = self.conv_stem(F.interpolate(x, scale_factor=0.5, mode="bilinear")) + for i in range(self.depth): + x1 = self.conv_backbone[i](x1) + feat = feat + F.interpolate(x1, scale_factor=2.0, mode="bilinear") * (1 - scale[:, 0:1]) * scale[:, 1:2] + + if scale.shape[0] != 1 or scale[:, 1:2].mean() < 0.5: + x2 = self.conv_stem(F.interpolate(x, scale_factor=0.25, mode="bilinear")) + for i in range(self.depth): + x2 = self.conv_backbone[i](x2) + feat = feat + F.interpolate(x2, scale_factor=4.0, mode="bilinear") * (1 - scale[:, 1:2]) + return feat, convflow(feat) + flow, i0, i1, scale + +class Flownet(nn.Module): + def __init__(self, block_num, c=64): + super(Flownet, self).__init__() + self.convimg = nn.Sequential( + nn.Conv2d(3, 32, 3, 2, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(32, 32, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(32, 32, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(32, c, 3, 1, 1), + ) + self.convblock = torch.nn.ModuleList([]) + self.block_num = block_num + self.convflow = nn.Sequential( + nn.Conv2d(c, 4*6, 3, 1, 1), + nn.PixelShuffle(2) + ) + self.getscale = nn.Sequential( + conv(3*c+6+1, c, 1, 1, 0), + conv(c, c, 1, 2, 0), + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + nn.Linear(c, 2), + nn.Sigmoid() + ) + for i in range(self.block_num): + self.convblock.append(RecurrentBlock(c, 1, 2)) + + def extract_feat(self, x): + i0 = self.convimg(x[:, :3]) + i1 = self.convimg(x[:, 3:6]) + return i0, i1 + + def forward(self, i0, i1, feat, timestep, flow): + flow_list = [] + feat_list = [] + scale_list = [] + for i in range(self.block_num): + feat, flow, w0, w1, scale = self.convblock[i](feat, i0, i1, flow, timestep, self.convflow, self.getscale) + flow_list.append(flow) + feat_list.append(feat) + scale_list.append(scale) + return flow_list, feat_list, torch.cat(scale_list, 1) + +class SAFA(nn.Module): + def __init__(self): + super(SAFA, self).__init__() + c=96 + self.block = Flownet(4, c=c) + self.shuffle = conv(2*c, c, 3, 1, 1, groups=1) + self.lastconv0 = nn.Sequential( + conv(4*c, c, 3, 1, 1), + Resblock(c), + Resblock(c), + Resblock(c), + Resblock(c), + Resblock(c), + Resblock(c), + Resblock(c), + Resblock(c), + ) + self.lastconv1 = nn.Sequential( + conv(5*c, 2*c, 3, 1, 1), + nn.Conv2d(2*c, 12, 3, 1, 1), + nn.PixelShuffle(2), + ) + + def inference(self, lowres, timestep=None): + merged = [] + i0, i1 = self.block.extract_feat(lowres) + timestep = (lowres[:, :1] * 0).detach() + timestep = F.interpolate(timestep, scale_factor=0.5, mode="bilinear") + for i in range(2): + if i == 1: + tmp = i0 + i0 = i1 + i1 = tmp + feat = self.shuffle(torch.cat((i0, i1), 1)) + flow_list, feat_list, soft_scale = self.block(i0, i1, feat, timestep, (lowres[:, :6] * 0).detach()) + flow_sum = flow_list[-1] + flow_down = F.interpolate(flow_sum, scale_factor=0.5, mode="bilinear") + w1 = warp(i1, flow_down[:, 2:4] * 0.5) + lastfeat = torch.cat((i0, w1, feat_list[-1], feat_list[-3]), 1) + res = self.lastconv1(torch.cat((self.lastconv0(lastfeat), lastfeat), 1)) + merged.append(torch.clamp(res, 0, 1)) + return merged diff --git a/train_log_SAFA/head.py b/train_log_SAFA/head.py new file mode 100644 index 0000000..6b782ce --- /dev/null +++ b/train_log_SAFA/head.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import models +from model.warplayer import * + +device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.PReLU(out_planes) + ) + +def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=False), + nn.BatchNorm2d(out_planes), + nn.PReLU(out_planes) + ) + +class MeanShift(nn.Conv2d): + def __init__(self, data_mean, data_std, data_range=1, norm=True): + """norm (bool): normalize/denormalize the stats""" + c = len(data_mean) + super(MeanShift, self).__init__(c, c, kernel_size=1) + std = torch.Tensor(data_std).to(device) + self.weight.data = torch.eye(c).view(c, c, 1, 1).to(device) + if norm: + self.weight.data.div_(std.view(c, 1, 1, 1)) + self.bias.data = -1 * data_range * torch.Tensor(data_mean).to(device) + self.bias.data.div_(std) + else: + self.weight.data.mul_(std.view(c, 1, 1, 1)) + self.bias.data = data_range * torch.Tensor(data_mean).to(device) + self.requires_grad = False + +class Head(nn.Module): + def __init__(self, c): + super(Head, self).__init__() + model = models.resnet18(pretrained=False) + self.cnn0 = nn.Sequential(*nn.ModuleList(model.children())[:3]) + self.cnn1 = nn.Sequential( + *list(model.children())[3:5], + ) + self.cnn2 = nn.Sequential( + *list(model.children())[5:6], + ) + self.out0 = nn.Conv2d(64, c, 1, 1, 0) + self.out1 = nn.Conv2d(64, c, 1, 1, 0) + self.out2 = nn.Conv2d(128, c, 1, 1, 0) + self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).to(device) + def forward(self, x): + x = self.normalize(x) + f0 = self.cnn0(x) + f1 = self.cnn1(f0) + f2 = self.cnn2(f1) + f0 = self.out0(f0) + f1 = F.interpolate(self.out1(f1), scale_factor=2.0, mode="bilinear") + f2 = F.interpolate(self.out2(f2), scale_factor=4.0, mode="bilinear") + return f0 + f1 + f2 diff --git a/train_log_SAFA/model.py b/train_log_SAFA/model.py new file mode 100644 index 0000000..26fb21e --- /dev/null +++ b/train_log_SAFA/model.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn +import numpy as np +from torch.optim import AdamW +import torch.optim as optim +import itertools +from train_log.warplayer import warp +from torch.nn.parallel import DistributedDataParallel as DDP +from train_log.flownet import * +import torch.nn.functional as F + +device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + +class Model: + def __init__(self, local_rank=-1): + self.flownet = SAFA() + self.optimG = AdamW(self.flownet.parameters()) + self.device() + if local_rank != -1: + self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) + + def train(self): + self.flownet.train() + + def eval(self): + self.flownet.eval() + + def device(self): + self.flownet.to(device) + + def inference(self, i0, i1, timestep): + return self.flownet.inference(torch.cat((i0, i1), 1), timestep) + + def load_model(self, path, rank=0): + def convert(param): + return { + k.replace("module.", ""): v + for k, v in param.items() + if "module." in k + } + + if device == torch.device('cpu'): + self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path), map_location=torch.device('cpu')))) + elif device == torch.device('mps'): + self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path), map_location=torch.device('mps')))) + elif rank <= 0: + self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path)))) + + def save_model(self, path, rank=0): + if rank == 0: + torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path)) + + def update(self, imgs, lowres, learning_rate=0, timestep=0.5, mul=1, training=True): + for param_group in self.optimG.param_groups: + if param_group['name'] == 'flow': + param_group['lr'] = learning_rate + else: + param_group['lr'] = learning_rate + img0 = imgs[:, :3] + img1 = imgs[:, -3:] + if training: + self.train() + for m in self.flownet.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + else: + self.eval() + flow, scale, merged = self.flownet(lowres, timestep=timestep, training=training) + loss_l1 = 0 + for i in range(3): + loss_l1 += (imgs[:, i*3:i*3+3] - merged[i]).abs().mean() + if training: + self.optimG.zero_grad() + loss_G = loss_l1 + loss_G.backward() + torch.nn.utils.clip_grad_norm_(self.flownet.parameters(), 1.0) + self.optimG.step() + return merged, { + 'scale': scale, + 'flow': flow[:, :2], + 'loss_l1': loss_l1, + } diff --git a/train_log_SAFA/warplayer.py b/train_log_SAFA/warplayer.py new file mode 100644 index 0000000..e8d353f --- /dev/null +++ b/train_log_SAFA/warplayer.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable +import os +import numpy as np +device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + +backwarp_tenGrid = {} + +def warp(tenInput, tenFlow, mode='bilinear'): + k = (str(tenFlow.device), str(tenFlow.size())) + if k not in backwarp_tenGrid: + tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + backwarp_tenGrid[k] = torch.cat([ tenHorizontal, tenVertical ], 1).to(device) + + tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + + pd = 'border' + + # mps does not support 'border' padding mode, use 'zero' instead + if tenInput.device.type == "mps": + pd = 'zeros' + g = g.clamp(-1, 1) + + return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode=mode, padding_mode=pd, align_corners=True)