-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
287 lines (234 loc) · 9.97 KB
/
utils.py
File metadata and controls
287 lines (234 loc) · 9.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import os
import shutil
import math
import csv
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
from torchvision import transforms
from torchvision.utils import save_image
from pytorch_msssim import ssim as evaluator_SSIM
from pytorch_msssim import ssim_matlab as calc_ssim
##########################
# Training Helper Functions for making main.py clean
##########################
def save_checkpoint(args, state, is_best, epoch):
"""
state: checkpoint we want to save
is_best: is this the best checkpoint; min validation loss
"""
# Create args.save_path if not exist
makedirs(args.save_path)
filename = 'epoch' + str(epoch) + '.pth'
previousfilename = 'epoch' + str(epoch-1) + '.pth'
# Delete previous checkpoint
remove_file(os.path.join(args.save_path, '_'.join((args.modelName, previousfilename))))
# save checkpoint data to the path given, filename_path
filename_path = os.path.join(args.save_path, '_'.join((args.modelName, filename)))
torch.save(state, filename_path)
# if it is a best model, min validation loss
if is_best:
best_path = os.path.join(args.save_path, '_'.join(
(args.modelName, 'best.pth')))
# copy that checkpoint file to best path given, best_path
shutil.copyfile(filename_path, best_path)
def load_checkpoint(args, checpoint_path, model, optimizer, scheduler):
"""
checkpoint_path: path to save checkpoint
model: model that we want to load checkpoint parameters into
optimizer: optimizer we defined in previous training
scheduler: scheduler we defined in previous training
"""
print('Loading checkpoint from %s' %checpoint_path)
# Load check point
checkpoint = torch.load(checpoint_path)
args.start_epoch = checkpoint['epoch']
args.best_psnr = checkpoint['best_psnr']
args.lr = checkpoint['lr']
model.load_state_dict(checkpoint['state_dict'])
scheduler.load_state_dict(checkpoint['scheduler'])
optimizer.load_state_dict(checkpoint['optimizer'])
return model, optimizer, scheduler
def load_dataset(datasetName, datasetPath, batch_size, val_batch_size, num_workers):
if datasetName == 'UCF101':
from datasets.ucf101 import ucf101
train_set = ucf101.UCF101(root=datasetPath + 'train1', is_training=True)
val_set = ucf101.UCF101(root=datasetPath + 'test1', is_training=False)
elif datasetName == 'Vimeo_90K':
from datasets.vimeo_90K.vimeo_90K import Vimeo_90K
train_set = Vimeo_90K(root=datasetPath, is_training=True)
val_set = Vimeo_90K(root=datasetPath, is_training=False)
elif datasetName == 'VimeoSepTuplet':
from datasets.vimeo_90K.vimeo_MVFI import VimeoSepTuplet_MVFI
train_set = VimeoSepTuplet_MVFI(root=datasetPath, is_training=True)
val_set = VimeoSepTuplet_MVFI(root=datasetPath, is_training=False)
elif datasetName == 'MultiVFI':
from datasets.multi.multi import Multi
train_set = Multi(root=datasetPath, is_training=True)
val_set = Multi(root=datasetPath, is_training=False)
else:
raise NotImplementedError('Training / Testing for this dataset is not implemented')
train_loader = torch.utils.data.DataLoader(train_set, pin_memory=True,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_set, pin_memory=True,
batch_size=val_batch_size,
num_workers=num_workers,
shuffle=False, drop_last=True)
return train_set, val_set, train_loader, val_loader
##########################
# PSNR and SSIM Calculation
##########################
class EvalPSNR(object):
"""Peak Signal to Noise Ratio
img1 and img2 have range [0, 255]"""
def __init__(self):
self.name = "PSNR"
@staticmethod
def __call__(img1, img2):
mse = torch.mean((img1 - img2) ** 2) + 1e-8
return 20 * torch.log10(255.0 / torch.sqrt(mse))
def calc_psnr(pred, gt):
diff = (pred - gt).pow(2).mean() + 1e-8
return -10 * math.log10(diff)
##########################
# Evaluations
##########################
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def init_losses(loss_str):
loss_specifics = {}
loss_list = loss_str.split('+')
for l in loss_list:
_, loss_name = l.split('*')
loss_specifics[loss_name] = AverageMeter()
loss_specifics['total'] = AverageMeter()
return loss_specifics
def init_meters(loss_str, reset_loss=True):
if reset_loss:
losses = init_losses(loss_str)
else:
losses = loss_str
psnrs = AverageMeter()
ssims = AverageMeter()
return losses, psnrs, ssims
def quantize(img, rgb_range=255.):
return img.mul(255. / rgb_range).round()
def eval_metrics(im_pred, im_gt, psnrs, ssims):
# PSNR should be calculated for each image, since sum(log) =/= log(sum).
for i in range(im_gt.size()[0]):
psnr = calc_psnr(im_pred[i], im_gt[i])
psnrs.update(psnr)
ssim = calc_ssim(im_pred[i].unsqueeze(0).clamp(0, 1), im_gt[i].unsqueeze(0).clamp(0, 1),
val_range=1.)
ssims.update(ssim)
def save_metrics(args, save_path, epoch, loss, psnr, ssim, lr, time, mode='train'):
# Create list for train and val loss,
# cannot write train and val loss at a same row
if mode == 'train':
args.loss_list_data.append(lr)
args.loss_list_data.append(epoch)
args.loss_list_data.append(loss)
args.loss_list_data.append(time)
if mode == 'val':
with open(os.path.join(save_path, 'val_PSNR_SSIM.csv'), 'a', newline='') as f:
writer = csv.writer(f)
listdata = []
listdata.append(epoch)
listdata.append(round(psnr, 5))
listdata.append(round(ssim, 5))
writer.writerow(listdata)
with open(os.path.join(save_path, 'loss.csv'), 'a', newline='') as f:
writer = csv.writer(f)
writer.writerow(args.loss_list_data)
args.loss_list_data = [] # reset list
##########################
# ETC
##########################
def makedirs(path):
if not os.path.exists(path):
# print("[*] Make directories: {}".format(path))
os.makedirs(path) # os.makedirs: creates all the intermediate directories if they don't exist
def remove_file(path):
if os.path.exists(path):
# print("[*] Removed: {}".format(path))
os.remove(path)
def count_network_parameters(model):
# Calculate model parameters
parameters = filter(lambda p: p.requires_grad, model.parameters())
N = sum([np.prod(p.size()) for p in parameters])
return N
# Tensorboard
def log_tensorboard(writer, losses, psnr, ssim, lr, epoch, mode='train'):
for k, v in losses.items():
writer.add_scalar('Loss/%s/%s' % (k, mode), v.avg, epoch)
writer.add_scalar('PSNR/%s' % mode, psnr, epoch)
writer.add_scalar('SSIM/%s' % mode, ssim, epoch)
if mode == 'train':
writer.add_scalar('lr', lr, epoch)
###########################
###### VISUALIZATIONS
###########################
def distribution_pixels(args, im_pred, im_gt, mode="train", title_name="Distribution of normalized pixels"):
# Convert tensor to numpy array
pred_np = im_pred.cpu().detach().numpy()
gt_np = im_gt.cpu().detach().numpy()
# Plot prediction image
plt.subplot(1, 2, 1)
plt.hist(pred_np.ravel(), bins=50, density=True)
plt.xlabel("pixel values")
plt.ylabel("relative frequency")
plt.title("prediction image")
# Plot ground trouth image
plt.subplot(1, 2, 2)
plt.hist(gt_np.ravel(), bins=50, density=True)
plt.xlabel("pixel values")
plt.title("ground truth image")
# Save plot
plt.savefig(args.save_path + "/" + mode + ".png")
plt.suptitle(title_name)
plt.show()
def save_image(img, path):
# img : torch Tensor of size (C, H, W)
q_im = quantize(img.data.mul(255))
if len(img.size()) == 2: # grayscale image
im = Image.fromarray(q_im.cpu().numpy().astype(np.uint8), 'L')
elif len(img.size()) == 3:
im = Image.fromarray(q_im.permute(1, 2, 0).cpu().numpy().astype(np.uint8), 'RGB')
else:
pass
im.save(path)
def save_batch_images(args, ims_pred, ims_gt, epoch_path, frame_idx):
# Check if epoch_path exists
makedirs(epoch_path) # ['./Results/20220425_1052_8827', 'Result_Images', 'Epoch_0']
# Save every image in batch to indicated location
for j in range(ims_pred.size(0)):
pred_name = str(args.out_counter) + '_im' + str(frame_idx) + '_out.png'
gt_name = str(args.out_counter) + '_im' + str(frame_idx) + '_gt.png'
save_image(ims_pred[j, :, :, :], os.path.join(epoch_path, pred_name))
save_image(ims_gt[j, :, :, :], os.path.join(epoch_path, gt_name))
# Save reference images at desired size if args.create_reference_images is True and epoch == 0
"""
if not args.create_reference_images and epoch_path.split('\\')[2].split('_')[1] == 0:
reference_path = os.path.join(args.save_path, args.reference_folder)
makedirs(reference_path)
input1_name = str(args.out_counter) + '_1.png',
input2_name = str(args.out_counter) + '_2.png'
save_image(reverse_normalize(input1[j, :, :, :]), os.path.join(reference_path, input1_name))
save_image(reverse_normalize(input2[j, :, :, :]), os.path.join(reference_path, input2_name))
"""
args.out_counter += 1