diff --git a/README.md b/README.md index 934d2e3..e2177ff 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,6 @@ +# Fred Lu's fork of "Stochastic Video Generation with a Learned Prior" +I modified Emily Denton and Rob Fergus' paper [Stochastic Video Generation with a Learned Prior](https://arxiv.org/abs/1802.07687) to build a predictive model for the [Machine Common Sense](https://github.com/NextCenturyCorporation/MCS) project that observes plausible physics events and uses the prediction it makes to detect if an unseen physics event was plausible and generate a heatmap of implausible regions if the event is deemed implausible. This repo only contains the training code. The evaluation process has additional algorithms that are not yet included here. + # Stochastic Video Generation with a Learned Prior This is code for the paper [Stochastic Video Generation with a Learned Prior](https://arxiv.org/abs/1802.07687) by Emily Denton and Rob Fergus. See the [project page](https://sites.google.com/view/svglp/) for details and generated video sequences. diff --git a/_do_mcs_implausblility_test_gravity.py b/_do_mcs_implausblility_test_gravity.py new file mode 100644 index 0000000..fce2107 --- /dev/null +++ b/_do_mcs_implausblility_test_gravity.py @@ -0,0 +1,486 @@ +import glob + +import cv2 +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import matplotlib.pyplot as plt +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.004, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=1, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=1000, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--use_edge_kernels', action='store_true') +parser.add_argument('--dataset', default='mcs_test', help='dataset to train with') +parser.add_argument('--mcs_task', default='GravitySupportEvaluation', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=55, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, + help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=8, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', + help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') +opt = parser.parse_args() +BATCH_SIZE = opt.batch_size +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + niter = opt.niter + dataset = opt.dataset + mcs_task = opt.mcs_task + n_future = opt.n_future + opt = saved_model['opt'] + opt.batch_size = BATCH_SIZE + opt.niter = niter # update number of epochs to train for + opt.dataset = dataset + opt.mcs_task = mcs_task + opt.n_future = n_future +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +font = cv2.FONT_HERSHEY_SIMPLEX +bottomLeftCornerOfText = (30, 30) +fontScale = 0.6 +fontColor = (0, 0, 0) +thickness = 1 + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + +import models.lstm as lstm_models + +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = BATCH_SIZE + posterior = saved_model['posterior'] + posterior.batch_size = BATCH_SIZE + prior = saved_model['prior'] + prior.batch_size = BATCH_SIZE +else: + raise ValueError('Please specify --model_dir') + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() + + +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +print(opt) + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt, sequential=True, implausible=False) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=False, + drop_last=True, + pin_memory=True, ) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=False, + drop_last=True, + pin_memory=True) + + +def get_training_batch(): + while True: + for sequence, labels in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch, labels + + +training_batch_generator = get_training_batch() + + +def get_testing_batch(): + while True: + for sequence, labels in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch, labels + + +testing_batch_generator = get_testing_batch() + + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + h = h.detach() + elif i < opt.n_past: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + x_in = x[i] + gen_seq[s].append(x_in) + else: + h = frame_predictor(h).detach() + x_in = decoder([h, skip]).detach() + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [[] for t in range(opt.n_eval)] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq = [] + gen_seq.append(x[0]) + x_in = x[0] + h_seq = [encoder(x[i]) for i in range(opt.n_past + opt.n_future)] + for i in range(1, opt.n_past + opt.n_future): + h_target = h_seq[i][0].detach() + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + else: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + gen_seq.append(x[i]) + else: + h_pred = frame_predictor(h).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + + to_plot = [] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + row = [] + for t in range(opt.n_past + opt.n_future): + row.append(gen_seq[t][i]) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +def get_center_maximal_contour(img, draw=False): + if len(img.shape) == 3: # bgr 2 gray + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + # Find contours + contours, hierarchy = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if len(contours) == 0: + return None + max_contour = max(contours, key=cv2.contourArea) + if cv2.contourArea(max_contour) == 0: + # print(max_contour) + # print(tuple(max_contour[0][0])) + if draw: + cv2.circle(img, tuple(max_contour[0][0]), 1, 150, 1) + return tuple(max_contour[0][0]) + M = cv2.moments(max_contour) + cx, cy = int(M['m10'] / M['m00']), int(M['m01'] / M['m00']) + if draw: + cv2.circle(img, (cx, cy), 1, 150, 1) + return (cx, cy) + + +def high_pass(frame, min=10): + fr = frame.copy() + fr[fr < min] = 0 # remove low brightness pixels + if np.sum(fr) != 0: + frame = fr + + frame = cv2.medianBlur(frame, 5) + # frame = get_maximal_contour(frame) + return frame + + +def do_implasubility_test(z_residual_mean, z_residual_cov, thresh, visualize=True): + MOTION_THRESH = 0.001 + epoch_size = 200 // opt.batch_size # we have a total of 1000 videos + cov_inv = [np.linalg.pinv(cov, hermitian=True) for cov in + z_residual_cov] # we assume the covariance matrix is diagonal and do the inverse for each time + # cov_inv = [np.linalg.inv(np.diag(np.diag(cov))) for cov in z_residual_cov] + # cov_inv = [np.eye(32) for cov in z_residual_cov] + frame_predictor.eval() + posterior.eval() + encoder.eval() + decoder.eval() + progress = progressbar.ProgressBar(max_value=epoch_size).start() + confusion_matrix = [[0, 0], [0, 0]] + # for i in range(50): + # if i % 2 == 0: + # x = next(training_batch_generator) + # else: + # x = next(training_batch_generator_implausible) + for i in range(epoch_size): + h_residual_var = torch.tensor(np.zeros(opt.n_future, dtype=np.float32), requires_grad=False, + device=torch.device('cpu')) + scores = np.array([0 for i in range(opt.n_future)], dtype=np.float32) + progress.update(i + 1) + is_implausible = False + try: + x, labels = next(training_batch_generator) + x_diff = [torch.abs(x[t] - x[t - 1]).detach() for t in range(1, len(x))] + x_diff = [torch.mean(frame, dim=(1, 2, 3)) for frame in x_diff] # mean along C, H, W + motion_start_time = [] + + # find when motion starts (object is about to be dropped) + for batch in range(len(x_diff[0])): + started = False + for t in range(30, len(x_diff)): + if x_diff[t][batch] > MOTION_THRESH: + motion_start_time.append( + t + 1 + 1) # add one because x_diff starts at t=1, then add another frame + started = True + break + if not started: + motion_start_time.append(33) # default to frame 33 if no motion is found + frames = [frame.cpu() for frame in x] + # print(frames[0][0].numpy().dtype) + # print(frames[0][0]) + # quit() + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + h_posterior = [encoder(x[j]) for j in range(0, opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) + last_pred = None + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + posterior.hidden = posterior.init_hidden() + + # start is the first frame that the prior model sees (so start + n_past is the first frame predicted) + if motion_start_time[0] is not None: + start = motion_start_time[0] - 4 + x_in = x[start] + x_out_seq = [x[start].cpu()] + for j in range(start + 1, opt.n_past + opt.n_future + 30): + h = encoder(x_in) + if opt.last_frame_skip or j < opt.n_past + start - 1: + h, skip = h + else: + h, _ = h + + if j < opt.n_past + start + 5: + z_t = posterior(h_posterior[j][0].detach()) + prior(h) + h_post = frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[j] + x_out_seq.append(decoder([h_post, skip]).detach().cpu()) + else: + z_t_hat = prior(h) + h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_in = decoder([h, skip]) + x_out_seq.append(x_in.detach().cpu()) + + if j >= opt.n_past + opt.n_future and torch.mean(torch.abs(x_out_seq[-1] - x_out_seq[-2]), + dim=(1, 2, 3)) <= MOTION_THRESH / 3: + # print('broke at j= ', j) + break + + # z_residual_scores_filtered = -0.25 * scores[:-2] + (0.5 + 0.5) * scores[1:-1] - 0.25 * scores[2:] + # print(motion_start_time[0], len(frames), len(x_out_seq)) + # print(h_residual_var) + background = frames[0][0][0].numpy().copy() + background = np.uint8(np.minimum(background / 1.2, 1.0) * 255.) + if visualize and motion_start_time[0] is not None: + k = 0 + j = start + source_center = None + pred_center = None + while not (j >= len(frames) and k >= len(x_out_seq)): + frame_cv2 = frames[min(j, len(frames) - 1)][0][0].numpy().copy() + frame_cv2 /= 1.2 + frame_cv2 = np.uint8(np.minimum(frame_cv2, 1.0) * 255.) + source_diff = high_pass(cv2.absdiff(frame_cv2, background)) + source_center = get_center_maximal_contour(source_diff, draw=True) + cv2.imshow('source', cv2.resize(frame_cv2, (384, 384), interpolation=cv2.INTER_NEAREST)) + cv2.imshow('source diff', cv2.resize(source_diff, (384, 384), interpolation=cv2.INTER_NEAREST)) + + out_cv2 = x_out_seq[min(k, len(x_out_seq) - 1)][0][0].numpy().copy() + out_cv2 /= 1.2 + # out_cv2 /= np.max(out_cv2) + out_cv2 = np.uint8(np.minimum(out_cv2, 1.0) * 255.) + pred_diff = high_pass(cv2.absdiff(out_cv2, background)) + pred_center = get_center_maximal_contour(pred_diff, draw=True) + cv2.imshow('prediction from first 5 frames', + cv2.resize(out_cv2, (384, 384), interpolation=cv2.INTER_NEAREST)) + cv2.imshow('pred diff', + cv2.resize(pred_diff, (384, 384), interpolation=cv2.INTER_NEAREST)) + cv2.waitKey(0) + j += 1 + k += 1 + if not visualize: + frame_cv2 = frames[-1][0][0].numpy().copy() + frame_cv2 = np.uint8(np.minimum(frame_cv2 / 1.2, 1.0) * 255.) + source_diff = high_pass(cv2.absdiff(frame_cv2, background)) + source_center = get_center_maximal_contour(source_diff, draw=True) + out_cv2 = x_out_seq[-1][0][0].numpy().copy() + out_cv2 = np.uint8(np.minimum(out_cv2 / 1.2, 1.0) * 255.) + pred_diff = high_pass(cv2.absdiff(out_cv2, background)) + pred_center = get_center_maximal_contour(pred_diff, draw=True) + + if (source_center is not None) and (pred_center is not None): + is_implausible = np.abs(source_center[-1] - pred_center[-1]) > thresh # a good thresh seems to be + + if is_implausible: + msg = 'implausible' + else: + msg = 'plausible' + # print('gt: ', labels, 'prediction: ', msg) + confusion_matrix[int(labels[0] == 'implausible')][int(is_implausible)] += 1 + if visualize: + cv2.waitKey(0) + + + # percentile = np.percentile(scores[76:152], 85.0) + # thresh = percentile * 1 + 0.5 + # spikes_idx = np.argwhere(z_residual_scores_filtered > thresh) + # spikes_idx = spikes_idx[ + # (spikes_idx >= 75) & (spikes_idx <= 150)] # ignore spikes near the start and end of video + # spikes = z_residual_scores_filtered[spikes_idx] + # msg = '' + # if len(spikes_idx) > 0: + # # we add n_past because the first n_past frames are not counted. Add 1 because of the filtering + # msg = 'thresh {:.1f} IMPLAUSIBLE spikes: '.format(thresh) + str( + # ['{:.2f}@{}'.format(z_residual_scores_filtered[k], k + opt.n_past + 1) for k in spikes_idx]) + # confusion_matrix[is_implausible][1] += 1 + # else: + # max_idx = np.argmax(z_residual_scores_filtered[75:151]) + 75 + # msg = 'thresh {:.1f} PLAUSIBLE max {:.2f}@{}'.format(thresh, z_residual_scores_filtered[max_idx], + # max_idx + opt.n_past + 1) + # confusion_matrix[is_implausible][0] += 1 + # + # print(msg) + # plt.savefig('implausibility_test.png') + # h_residual_var /= epoch_size # get the mean error vector per time + # h_residual_sd = torch.sqrt(h_residual_var) + print('Last i = {}'.format(i)) + + # H_err_cov = torch.tensor(np.zeros((opt.n_future, 128, 128), dtype=np.float32), requires_grad=False, + # device=torch.device('cuda:0')) + + # plot some stuff + return confusion_matrix + + +# f = open('new_mcs_stats_post.json', 'r') +# mcs_stats_dict = json.load(f) +mcs_stats_dict = {} +with open('new_mcs_stats_post.npy', 'rb') as f: + mcs_stats_dict['mean'] = np.load(f) + mcs_stats_dict['cov'] = np.load(f) + +ROC_curve = {} +for thr in range(2, 11): + train_data, test_data = utils.load_dataset(opt, sequential=True, implausible=False) + + train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=False, + drop_last=True, + pin_memory=True, ) + training_batch_generator = get_training_batch() + + conf_mat = do_implasubility_test(np.array(mcs_stats_dict['mean']), np.array(mcs_stats_dict['cov']), thr, visualize=True) + ROC_curve[thr] = conf_mat +print(ROC_curve) diff --git a/_do_mcs_implausblility_test_posterior.py b/_do_mcs_implausblility_test_posterior.py new file mode 100644 index 0000000..35ec6cb --- /dev/null +++ b/_do_mcs_implausblility_test_posterior.py @@ -0,0 +1,393 @@ +import glob + +import cv2 +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import matplotlib.pyplot as plt +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.004, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=1, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=1000, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=195, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, + help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=8, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', + help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') +opt = parser.parse_args() +BATCH_SIZE = opt.batch_size +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + optimizer = opt.optimizer + model_dir = opt.model_dir + niter = opt.niter + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.optimizer = optimizer + opt.model_dir = model_dir + opt.log_dir = '%s/continued' % opt.log_dir +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +import models.lstm as lstm_models + +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = BATCH_SIZE + posterior = saved_model['posterior'] + posterior.batch_size = BATCH_SIZE +else: + raise ValueError('Please specify --model_dir') + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() + + +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +opt.batch_size = BATCH_SIZE +opt.epoch_size = 1000 +opt.n_future = 195 +print(opt) + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt, sequential=True, implausible=True) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + drop_last=True, + pin_memory=True,) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +training_batch_generator = get_training_batch() +training_batch_generator_2 = get_training_batch() + + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +testing_batch_generator = get_testing_batch() + + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + h = h.detach() + elif i < opt.n_past: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + x_in = x[i] + gen_seq[s].append(x_in) + else: + h = frame_predictor(h).detach() + x_in = decoder([h, skip]).detach() + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [[] for t in range(opt.n_eval)] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq = [] + gen_seq.append(x[0]) + x_in = x[0] + h_seq = [encoder(x[i]) for i in range(opt.n_past + opt.n_future)] + for i in range(1, opt.n_past + opt.n_future): + h_target = h_seq[i][0].detach() + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + else: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + gen_seq.append(x[i]) + else: + h_pred = frame_predictor(h).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + + to_plot = [] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + row = [] + for t in range(opt.n_past + opt.n_future): + row.append(gen_seq[t][i]) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +def do_implasubility_test(h_residual_mean, h_residual_vars): + epoch_size = 1000 // opt.batch_size # we have a total of 1000 videos + cov_inv = [np.diag(1.0 / h_var)[np.newaxis, ...] for h_var in h_residual_vars] # we assume the covariance matrix is diagonal and do the inverse for each time + + frame_predictor.eval() + posterior.eval() + encoder.eval() + decoder.eval() + progress = progressbar.ProgressBar(max_value=epoch_size).start() + + + for i in range(epoch_size): + h_residual_var = torch.tensor(np.zeros(opt.n_future, dtype=np.float32), requires_grad=False, + device=torch.device('cpu')) + progress.update(i + 1) + try: + x = next(training_batch_generator) + frames = [frame.cpu() for frame in x] + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + h_posterior = [encoder(x[j]) for j in range(opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) + last_pred = None + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + start = 1 + for j in range(start, opt.n_past + opt.n_future): + h_target = h_posterior[j][0].detach() + if opt.last_frame_skip or j < opt.n_past + start - 1: + h, skip = h_posterior[j - 1] + else: + h = h_posterior[j - 1][0].detach() + # we predict h_t from h_{t-1} + h_prior_pred = frame_predictor(h.detach()) + h_posterior_pred = posterior(h_target.detach()) + + if j >= opt.n_past + start - 1: + # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h + # squared_diff = torch.square(h_res - h_residual_mean[j - opt.n_past]) + # squared_diff = torch.mean(squared_diff, dim=1) # average squared residuals at time j over the dimensions of h + # squared_diff = torch.mean(squared_diff, dim=0) # average squared residuals at time j over the batch + # h_residual_var[j - opt.n_past] += squared_diff + + # residual = (h_prior_pred - h_posterior_pred).cpu().detach() + residual = (last_post_pred - h_posterior_pred).cpu().detach() + # err = (residual - h_residual_mean[j - opt.n_past]) + err = residual + err = np.square(err) # [batch, dim_feature] + # err = err / h_residual_vars[j - opt.n_past][np.newaxis, ...] # [batch, dim_feature] / [1, dim_feature] + err = torch.mean(err, axis=1) # -> [batch,] + # if len(err.shape) == 2: # [batch, dim_feature] + # err = err[..., np.newaxis] # make err into a vector [batch, dim_feature, 1] + # # print(cov_inv[j - opt.n_past].shape) + # # print(cov_inv[j - opt.n_past]) + # # print(err.shape) + # print(err) + # print(np.diag(cov_inv[j - opt.n_past][0])) + # quit() + # mahanlanobis_dist = np.matmul(cov_inv[j - opt.n_past], err).transpose(2, 1) + # # print(mahanlanobis_dist.shape) + # mahanlanobis_dist = np.matmul(mahanlanobis_dist, err) # [batch, 1, 1] + # # print(mahanlanobis_dist.shape) + # # print(np.sqrt(mahanlanobis_dist)) + # print(mahanlanobis_dist) + # quit() + # mahanlanobis_dist = torch.mean(mahanlanobis_dist) # scalar + + + # err = torch.mean(err, dim=1) # average errs at time j over the dimensions of h + # err = torch.mean(err, dim=0) # average errs at time j over the batch + # h_residual_var[j - opt.n_past] += err.detach() + h_residual_var[j - opt.n_past] += torch.mean(err, axis=0) + last_post_pred = h_posterior_pred.detach() + + h_residual_sd = torch.sqrt(h_residual_var).cpu() + + h_residual_sd_filtered = - h_residual_sd[:-2] + 2 * h_residual_sd[1:-1] - h_residual_sd[2:] + + + # print(h_residual_var) + for j in range(len(frames)): + frame = np.uint8(np.minimum(frames[j][0][0], 1) * 255) + cv2.imshow('frame', frame) + cv2.waitKey(15) + fig = plt.figure() + plt.ylim(0, 2.0) + plt.xlabel('Time') + plt.title("sqrt of average squared dimensional error") + plt.bar(np.arange(len(h_residual_sd)), h_residual_sd) + fig.canvas.draw() + img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, + sep='') + img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + plt.xlabel('Time') + plt.title("filtered sqrt of average squared dimensional error") + plt.bar(np.arange(len(h_residual_sd_filtered)), h_residual_sd_filtered) + fig.canvas.draw() + img2 = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, + sep='') + img2 = img2.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2BGR) + + cv2.imshow("plot", img) + cv2.imshow("plot2", img2) + + k = cv2.waitKey(0) + if k == ord('q'): + quit() + # plt.savefig('implausibility_test.png') + # h_residual_var /= epoch_size # get the mean error vector per time + # h_residual_sd = torch.sqrt(h_residual_var) + print('Last i = {}'.format(i)) + + + # H_err_cov = torch.tensor(np.zeros((opt.n_future, 128, 128), dtype=np.float32), requires_grad=False, + # device=torch.device('cuda:0')) + + # plot some stuff + h_residual_mean_norm = torch.norm(h_residual_mean, dim=1).cpu() + plt.subplot(2, 1, 1) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Norm of the residual mean") + plt.bar(np.arange(len(h_residual_mean_norm)), h_residual_mean_norm) + + plt.subplot(2, 1, 2) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Average dimensional sqrt(variance) of the residual") + plt.bar(np.arange(len(h_residual_sd.cpu())), h_residual_sd.cpu()) + plt.savefig('h_residual_mean.png') + + +f = open('mcs_stats_post.json', 'r') +mcs_stats_dict = json.load(f) +do_implasubility_test(np.array(mcs_stats_dict['mean']), np.array(mcs_stats_dict['vars'])) diff --git a/_do_mcs_stats_posterior.py b/_do_mcs_stats_posterior.py new file mode 100644 index 0000000..5ccec1c --- /dev/null +++ b/_do_mcs_stats_posterior.py @@ -0,0 +1,377 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import matplotlib.pyplot as plt +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.004, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=1, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=1000, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=195, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, + help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=8, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', + help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') +opt = parser.parse_args() +BATCH_SIZE = opt.batch_size +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.log_dir = '%s/continued' % opt.log_dir +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +import models.lstm as lstm_models + +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = BATCH_SIZE + posterior = saved_model['posterior'] + posterior.batch_size = BATCH_SIZE +else: + frame_predictor = lstm_models.lstm(opt.g_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, opt.channels) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() + + +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +opt.batch_size = BATCH_SIZE +opt.epoch_size = 1000 +opt.n_future = 195 +print(opt) + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt, sequential=True) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + drop_last=True, + pin_memory=True) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +training_batch_generator = get_training_batch() +training_batch_generator_2 = get_training_batch() + + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +testing_batch_generator = get_testing_batch() + + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + h = h.detach() + elif i < opt.n_past: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + x_in = x[i] + gen_seq[s].append(x_in) + else: + h = frame_predictor(h).detach() + x_in = decoder([h, skip]).detach() + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [[] for t in range(opt.n_eval)] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq = [] + gen_seq.append(x[0]) + x_in = x[0] + h_seq = [encoder(x[i]) for i in range(opt.n_past + opt.n_future)] + for i in range(1, opt.n_past + opt.n_future): + h_target = h_seq[i][0].detach() + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + else: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + gen_seq.append(x[i]) + else: + h_pred = frame_predictor(h).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + + to_plot = [] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + row = [] + for t in range(opt.n_past + opt.n_future): + row.append(gen_seq[t][i]) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +def do_stats(): + epoch_size = 1000 // opt.batch_size // 1 # we have a total of 1000 videos + + frame_predictor.eval() + encoder.eval() + decoder.eval() + progress = progressbar.ProgressBar(max_value=epoch_size).start() + h_residual_mean = torch.tensor(np.zeros((opt.n_future, 128), dtype=np.float32), requires_grad=False, + device=torch.device('cuda:0')) + i = 0 + for i in range(epoch_size): + progress.update(i + 1) + try: + x = next(training_batch_generator) + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + h_posterior = [encoder(x[j]) for j in range(opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) + last_pred = None + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + for j in range(1, opt.n_past + opt.n_future): + h_target = h_posterior[j][0].detach() + if opt.last_frame_skip or j < opt.n_past: + h, skip = h_posterior[j - 1] + else: + h = h_posterior[j - 1][0].detach() + # we predict h_t from h_{t-1} + h_prior_pred = frame_predictor(h).detach() + h_posterior_pred = posterior(h_target).detach() + + if j >= opt.n_past: + # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h + # h_res = h_prior_pred + # h_res = torch.mean(h_res, dim=0) # average errors at the same time j over the batch + # h_residual_mean[j - opt.n_past] += h_res + residual = h_prior_pred - h_posterior_pred + residual = torch.mean(residual, dim=0) + h_residual_mean[j - opt.n_past] += residual + h_residual_mean /= epoch_size # get the mean error vector per time + + # restart training dataset + global train_loader + train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + drop_last=True, + pin_memory=True) + h_residual_var = torch.tensor(np.zeros(opt.n_future, dtype=np.float32), requires_grad=False, + device=torch.device('cuda:0')) + h_residual_vars = torch.tensor(np.zeros((opt.n_future, 128), dtype=np.float32), requires_grad=False, + device=torch.device('cuda:0')) + + for i in range(epoch_size): + progress.update(i + 1) + try: + x = next(training_batch_generator_2) + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + h_posterior = [encoder(x[j]) for j in range(opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + for j in range(1, opt.n_past + opt.n_future): + h_target = h_posterior[j][0].detach() + if opt.last_frame_skip or j < opt.n_past: + h, skip = h_posterior[j - 1] + else: + h = h_posterior[j - 1][0].detach() + # we predict h_t from h_{t-1} + h_prior_pred = frame_predictor(h).detach() + h_posterior_pred = posterior(h_target).detach() + + if j >= opt.n_past: + # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h + # squared_diff = torch.square(h_res - h_residual_mean[j - opt.n_past]) + # squared_diff = torch.mean(squared_diff, dim=1) # average squared residuals at time j over the dimensions of h + # squared_diff = torch.mean(squared_diff, dim=0) # average squared residuals at time j over the batch + # h_residual_var[j - opt.n_past] += squared_diff + + residual = h_prior_pred - h_posterior_pred + squared_err = torch.square(residual - h_residual_mean[j - opt.n_past]) + # squared_err = torch.square(residual) + squared_err = torch.mean(squared_err, dim=0) # average errs at time j over the batch + h_residual_vars[j - opt.n_past] += squared_err.detach() + squared_err = torch.mean(squared_err, dim=0) # average errs at time j over the dimensions of h + h_residual_var[j - opt.n_past] += squared_err.detach() + h_residual_var /= epoch_size + h_residual_vars /= epoch_size + h_residual_sd = torch.sqrt(h_residual_var) + print('Last i = {}'.format(i)) + print('sd of h residual: ', h_residual_sd) + print('var of h residual: ', h_residual_var) + print('norm(mean of h residual)', torch.norm(h_residual_mean, dim=1)) + print('vars of h residual: ', h_residual_vars) + + # H_err_cov = torch.tensor(np.zeros((opt.n_future, 128, 128), dtype=np.float32), requires_grad=False, + # device=torch.device('cuda:0')) + + # plot some stuff + h_residual_mean_norm = torch.norm(h_residual_mean, dim=1).cpu() + plt.subplot(2, 1, 1) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Norm of the residual mean") + plt.bar(np.arange(len(h_residual_mean_norm)), h_residual_mean_norm) + + plt.subplot(2, 1, 2) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Average dimensional sqrt(variance) of the residual") + plt.bar(np.arange(len(h_residual_sd.cpu())), h_residual_sd.cpu()) + plt.savefig('post_h_residual.png') + + stats_dict = {'mean': h_residual_mean.cpu().tolist(), 'var': h_residual_var.cpu().tolist(), + 'vars': h_residual_vars.cpu().tolist()} + f = open('mcs_stats_post.json', 'w') + json.dump(stats_dict, f) + + +do_stats() diff --git a/_train_baseline_object_permanence.py b/_train_baseline_object_permanence.py new file mode 100644 index 0000000..c6f4a3b --- /dev/null +++ b/_train_baseline_object_permanence.py @@ -0,0 +1,451 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.0002, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=10, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=40, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=500, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int, help='number of channels for input images. ') +parser.add_argument('--use_edge_kernels', default=True, type=bool, help='whether to use edge kernels to reduce to 1 channel') +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='ObjectPermanenceTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=45, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=50, help='number of frames to predict at eval time') +parser.add_argument('--start_min', type=int, default=65, help='min starting time for sampling sequence (0-indexed)') +parser.add_argument('--start_max', type=int, default=85, help='max starting time for sampling sequence (0-indexed)') +parser.add_argument('--sequence_stride', type=int, default=3, help='factor for sequence temporal subsampling (int)') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--prior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=32, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') +parser.add_argument('--model', default='vgg', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') + + +opt = parser.parse_args() +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + lr = opt.lr + batch_size = opt.batch_size + n_future = opt.n_future + n_eval = opt.n_eval + data_root = opt.data_root + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.n_future = n_future + opt.lr = lr + opt.batch_size = batch_size + opt.n_eval = n_eval + opt.data_root = data_root + opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) +else: + name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + elif opt.dataset == 'mcs': + opt.log_dir = '%s/%s/%s/%s' % (opt.log_dir, opt.dataset, opt.mcs_task, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=True) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=True) +with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: + opt2 = opt.__dict__.copy() + if isinstance(opt2['optimizer'], type): + opt2['optimizer'] = str(opt2['optimizer']) + json.dump(opt2, f, indent=2) + del opt2 + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +elif isinstance(opt.optimizer, type): + pass +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = opt.batch_size + prior = saved_model['prior'] + prior.batch_size = opt.batch_size + posterior = saved_model['posterior'] + posterior.batch_size = opt.batch_size +else: + frame_predictor = lstm_models.lstm(opt.g_dim + opt.z_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + posterior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.posterior_rnn_layers, opt.batch_size) + prior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.prior_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + posterior.apply(utils.init_weights) + prior.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, 1) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +prior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True,) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [utils.torch_rgb_img_to_gray(x[t]) for t in range(len(x))] + + # h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + gen_seq[s].append(utils.torch_rgb_img_to_gray(x[0])) + x_in = x[0] + for i in range(1, opt.n_eval): + with torch.no_grad(): + # if input is grayscale + if x_in.shape[1] == 1 and opt.channels == 3: + h = encoder(torch.cat(3*[x_in], dim=1)) # convert to RGB + else: + h = encoder(x_in) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + if i < opt.n_past: + h_target = encoder(x[i]) + z_t = posterior(h_target[0].detach()) + prior(h) + frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[i] + gen_seq[s].append(utils.torch_rgb_img_to_gray(x_in)) + else: + z_t_hat = prior(h) + h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_in = decoder([h, skip]) + gen_seq[s].append(x_in) + + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + gen_seq = [utils.torch_rgb_img_to_gray(x[0])] + gen_seq_post = [utils.torch_rgb_img_to_gray(x[0])] + + # prediction using posterior Z + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + z_t = posterior(h_target[0]) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t], 1)) + gen_seq_post.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq_post.append(x_pred) + h = h_target + + # prediction using prior Z + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + z_t_hat = prior(h) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t_hat], 1)) + gen_seq.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + h = h_target + + to_plot = [] + nrow = min(opt.batch_size, 25) + x_gray = [utils.torch_rgb_img_to_gray(x[t]) for t in range(min(opt.n_eval, opt.n_past+opt.n_future))] + for i in range(nrow): + row_gt = [] + row_post = [] + row = [] + for t in range(min(opt.n_eval, opt.n_past+opt.n_future)): + row_gt.append(x_gray[t][i]) + row_post.append(gen_seq_post[t][i]) + row.append(gen_seq[t][i]) + to_plot.append(row_gt) + to_plot.append(row_post) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + posterior.zero_grad() + prior.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + + mse = 0 + mse_residual = 0 + # x: T x B x C x H x W + x_diff = [1] + for i in range(1, len(x)): + diff = torch.abs(x[i] - x[i - 1]) + diff = torch.mean(diff, dim=(1, 2, 3)).detach() # mean over channels, width, and height + x_diff.append(diff) + + # print(i, x_diff > 1e-6) + h = encoder(x[0]) + for i in range(1, opt.n_past+opt.n_future): + h_target = encoder(x[i]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h = h[0] + + z_t = posterior(h_target[0]) + z_t_hat = prior(h) + h_pred = frame_predictor(torch.cat([h, z_t], 1)) + x_pred = decoder([h_pred, skip]) + gray_target_frame = utils.torch_rgb_img_to_gray(x[i]) + + still_frames = x_diff[i] <= 5e-6 + weights = torch.pow(0.05, still_frames).detach()[:, None, None, None] + mse += mse_criterion(weights * x_pred, weights * gray_target_frame) + # penalize prior for being far from posterior + mse_residual += opt.gamma * torch.mean(weights * torch.square(z_t.detach() - z_t_hat)) + h = h_target + + loss = mse + mse_residual + loss.backward() + + frame_predictor_optimizer.step() + posterior_optimizer.step() + prior_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + N = opt.n_past+opt.n_future + return mse.data.cpu().numpy()/N, mse_residual.data.cpu().numpy()/N + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + posterior.train() + prior.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_mse_residual = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + # opt.epoch_size = 10 + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse, mse_residual = train(x) + epoch_mse += mse + epoch_mse_residual += mse_residual + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)\n' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + posterior.eval() + prior.eval() + encoder.eval() + decoder.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'posterior': posterior, + 'prior': prior, + 'opt': opt}, + '%s/model_e%02d.pth' % (opt.log_dir, epoch)) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + diff --git a/_train_baseline_shapeconstancy.py b/_train_baseline_shapeconstancy.py new file mode 100644 index 0000000..7d20951 --- /dev/null +++ b/_train_baseline_shapeconstancy.py @@ -0,0 +1,440 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.0002, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=13, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=40, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=500, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int, help='number of channels for input images. ') +parser.add_argument('--use_edge_kernels', default=True, type=bool, help='whether to use edge kernels to reduce to 1 channel') +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='GravitySupportTraining', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=30, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=35, help='number of frames to predict at eval time') +parser.add_argument('--start_min', type=int, default=79, help='min starting time for sampling sequence (0-indexed)') +parser.add_argument('--start_max', type=int, default=89, help='max starting time for sampling sequence (0-indexed)') +parser.add_argument('--sequence_stride', type=int, default=2, help='factor for sequence temporal subsampling (int)') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--prior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=32, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') +parser.add_argument('--model', default='vgg', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') + + +opt = parser.parse_args() +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + lr = opt.lr + batch_size = opt.batch_size + n_future = opt.n_future + n_eval = opt.n_eval + data_root = opt.data_root + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.n_future = n_future + opt.lr = lr + opt.batch_size = batch_size + opt.n_eval = n_eval + opt.data_root = data_root + opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) +else: + name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + elif opt.dataset == 'mcs': + opt.log_dir = '%s/%s/%s/%s' % (opt.log_dir, opt.dataset, opt.mcs_task, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=False) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=False) +with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: + opt2 = opt.__dict__.copy() + if isinstance(opt2['optimizer'], type): + opt2['optimizer'] = str(opt2['optimizer']) + json.dump(opt2, f, indent=2) + del opt2 + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +elif isinstance(opt.optimizer, type): + pass +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = opt.batch_size + prior = saved_model['prior'] + prior.batch_size = opt.batch_size + posterior = saved_model['posterior'] + posterior.batch_size = opt.batch_size +else: + frame_predictor = lstm_models.lstm(opt.g_dim + opt.z_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + posterior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.posterior_rnn_layers, opt.batch_size) + prior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.prior_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + posterior.apply(utils.init_weights) + prior.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, 1) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +prior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True,) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [utils.torch_rgb_img_to_gray(x[t]) for t in range(len(x))] + + # h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + gen_seq[s].append(utils.torch_rgb_img_to_gray(x[0])) + x_in = x[0] + for i in range(1, opt.n_eval): + with torch.no_grad(): + # if input is grayscale + if x_in.shape[1] == 1 and opt.channels == 3: + h = encoder(torch.cat(3*[x_in], dim=1)) # convert to RGB + else: + h = encoder(x_in) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + if i < opt.n_past: + h_target = encoder(x[i]) + z_t = posterior(h_target[0].detach()) + prior(h) + frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[i] + gen_seq[s].append(utils.torch_rgb_img_to_gray(x_in)) + else: + z_t_hat = prior(h) + h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_in = decoder([h, skip]) + gen_seq[s].append(x_in) + + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + gen_seq = [utils.torch_rgb_img_to_gray(x[0])] + gen_seq_post = [utils.torch_rgb_img_to_gray(x[0])] + + # prediction using posterior Z + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + z_t = posterior(h_target[0]) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t], 1)) + gen_seq_post.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq_post.append(x_pred) + h = h_target + + # prediction using prior Z + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + z_t_hat = prior(h) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t_hat], 1)) + gen_seq.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + h = h_target + + to_plot = [] + nrow = min(opt.batch_size, 25) + x_gray = [utils.torch_rgb_img_to_gray(x[t]) for t in range(min(opt.n_eval, opt.n_past+opt.n_future))] + for i in range(nrow): + row_gt = [] + row_post = [] + row = [] + for t in range(min(opt.n_eval, opt.n_past+opt.n_future)): + row_gt.append(x_gray[t][i]) + row_post.append(gen_seq_post[t][i]) + row.append(gen_seq[t][i]) + to_plot.append(row_gt) + to_plot.append(row_post) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + posterior.zero_grad() + prior.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + + mse = 0 + mse_residual = 0 + h = encoder(x[0]) + for i in range(1, opt.n_past+opt.n_future): + h_target = encoder(x[i]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h = h[0] + + z_t = posterior(h_target[0]) + z_t_hat = prior(h) + h_pred = frame_predictor(torch.cat([h, z_t], 1)) + x_pred = decoder([h_pred, skip]) + gray_target_frame = utils.torch_rgb_img_to_gray(x[i]) + mse += mse_criterion(x_pred, gray_target_frame) + # penalize prior for being far from posterior + mse_residual += opt.gamma * torch.mean(torch.square(z_t.detach() - z_t_hat)) + h = h_target + + loss = mse + mse_residual + loss.backward() + + frame_predictor_optimizer.step() + posterior_optimizer.step() + prior_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + N = opt.n_past+opt.n_future + return mse.data.cpu().numpy()/N, mse_residual.data.cpu().numpy()/N + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + posterior.train() + prior.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_mse_residual = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + # opt.epoch_size = 10 + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse, mse_residual = train(x) + epoch_mse += mse + epoch_mse_residual += mse_residual + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)\n' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + posterior.eval() + prior.eval() + encoder.eval() + decoder.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'posterior': posterior, + 'prior': prior, + 'opt': opt}, + '%s/model_e%02d.pth' % (opt.log_dir, epoch)) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + diff --git a/_train_svg_nonstochastic_posterior.py b/_train_svg_nonstochastic_posterior.py new file mode 100644 index 0000000..fddd604 --- /dev/null +++ b/_train_svg_nonstochastic_posterior.py @@ -0,0 +1,371 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.008, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=128, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=400, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=15, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') + + +opt = parser.parse_args() +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + lr = opt.lr + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.lr = lr + opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) +else: + name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + elif opt.dataset == 'mcs': + opt.log_dir = '%s/%s/%s/%s' % (opt.log_dir, opt.dataset, opt.mcs_task, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=True) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=True) + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + posterior = saved_model['posterior'] +else: + frame_predictor = lstm_models.lstm(opt.g_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + posterior = lstm_models.lstm(opt.g_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + posterior.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, opt.channels) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + with torch.no_grad(): + h = encoder(x_in) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h_seq[i-1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + x_in = x[i] + gen_seq[s].append(x_in) + else: + h = frame_predictor(h).detach() + x_in = decoder([h, skip]) + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + + gen_seq = [] + gen_seq.append(x[0]) + gen_seq_post = [] + gen_seq_post.append(x[0]) + x_in = x[0] + h_seq = [encoder(x[i]) for i in range(opt.n_past+opt.n_future)] + for i in range(1, opt.n_past+opt.n_future): + h_target = h_seq[i][0].detach() + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i-1] + else: + h, _ = h_seq[i-1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + gen_seq.append(x[i]) + gen_seq_post.append(x[i]) + else: + h_pred = frame_predictor(h).detach() + x_pred = decoder([h_pred, skip]).detach() + h_posterior = posterior(h_target).detach() + x_posterior = decoder([h_posterior, skip]).detach() + gen_seq.append(x_pred) + gen_seq_post.append(x_posterior) + + to_plot = [] + nrow = min(opt.batch_size * 3, 25 * 3) + for i in range(nrow): + row_gt = [] + row_post = [] + row = [] + for t in range(opt.n_past+opt.n_future): + row_gt.append(x[t][i]) + row_post.append(gen_seq_post[t][i]) + row.append(gen_seq[t][i]) + to_plot.append(row_gt) + to_plot.append(row_post) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + posterior.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + + h_seq = [encoder(x[i]) for i in range(opt.n_past+opt.n_future)] + mse = 0 + mse_post = 0 + mse_diff_post = 0 + for i in range(1, opt.n_past+opt.n_future): + h_target = h_seq[i][0] + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i-1] + else: + h = h_seq[i-1][0] + h_pred = frame_predictor(h) + x_pred = decoder([h_pred, skip]) + h_posterior = posterior(h_target) + x_posterior = decoder([h_posterior, skip]) + mse += mse_criterion(x_pred, x[i]) + mse_post += mse_criterion(x_posterior, x[i]) + mse_diff_post += opt.gamma * torch.mean(torch.square(h_posterior.detach() - h_pred)) + + loss = mse + mse_post + mse_diff_post + loss.backward() + + frame_predictor_optimizer.step() + posterior_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + N = opt.n_past+opt.n_future + return mse.data.cpu().numpy()/N, mse_post.data.cpu().numpy()/N, mse_diff_post.data.cpu().numpy()/N + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + posterior.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_mse_posterior = 0 + epoch_posterior_diff = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + # opt.epoch_size = 10 + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse, mse_posterior, posterior_diff = train(x) + epoch_mse += mse + epoch_mse_posterior += mse_posterior + epoch_posterior_diff += posterior_diff + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f, %.5f posterior | posterior diff loss: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_posterior/opt.epoch_size, epoch_posterior_diff/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + posterior.eval() + encoder.eval() + decoder.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'posterior': posterior, + 'opt': opt}, + '%s/model_e%02d.pth' % (opt.log_dir, epoch)) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + diff --git a/data/convert_mcs.py b/data/convert_mcs.py new file mode 100644 index 0000000..e436a8e --- /dev/null +++ b/data/convert_mcs.py @@ -0,0 +1,54 @@ +import os +import glob +import argparse +from os import path +import subprocess +from tqdm import tqdm +import threading + +parser = argparse.ArgumentParser() +parser.add_argument("-d", "--datadir", type=str, help="path to the mcs video dataset folder", + default="data/mcs_videos_1000") +parser.add_argument("-s", "--imsize", type=int, help="width of converted (square) images. default 64.", default=64) +args = parser.parse_args() + +DATA_ROOT = args.datadir +IMSIZE = args.imsize + +if not path.exists(DATA_ROOT): + print(f'directory "{DATA_ROOT}" does not exist! Check -d dataset argument') +elif not path.exists(path.join(DATA_ROOT, 'raw')): + print("Training videos must be in [datadir]/raw/(task)/*.mp4, where task is the task to which " + "the video belongs") + +def mp4_to_png_worker(path_to_task): + task_name = path.basename(path_to_task[:-1]) # remove the '/' at the end of path to a folder + vids = glob.glob(path.join(path_to_task, '*.mp4')) + vids_tqdm = tqdm(sorted(vids)) + vids_tqdm.set_description(f'Task {task_name}') + for vid in vids_tqdm: + sample_name = path.basename(vid) + sample_name = sample_name[:sample_name.rfind('.')] # get rid of extension name + out_folder = path.join(DATA_ROOT, 'processed', task_name, sample_name) + os.makedirs(out_folder, exist_ok=True) + ffmpeg_exe = 'ffmpeg' + ffmpeg_args = f'-i "{vid}" -s {IMSIZE}x{IMSIZE} "{path.join(out_folder, sample_name + "_%04d.png")}"' + + try: + subprocess.check_call(ffmpeg_exe + ' ' + ffmpeg_args, shell=True, stdout=subprocess.DEVNULL + , stderr=subprocess.DEVNULL) + except subprocess.CalledProcessError: + print('convert_mcs.py: ffmpeg execution failed. Check that you have installed the ffmpeg executable' + 'by typing "ffmpeg" in your terminal.') + quit() + + +tasks = glob.glob(path.join(DATA_ROOT, 'raw', '*/')) +threads = [] +for task in sorted(tasks): + t = threading.Thread(target=mp4_to_png_worker, args=[task], daemon=True) + t.start() + threads.append(t) + +for thread in threads: + thread.join() diff --git a/data/mcs.py b/data/mcs.py new file mode 100644 index 0000000..abf952e --- /dev/null +++ b/data/mcs.py @@ -0,0 +1,227 @@ +import logging +import random +import os + +import cv2 +import numpy as np +from glob import glob +import torch +import scipy.misc +import imageio +from os import path + + +class MCS(object): + + def __init__(self, train, data_root, seq_len=20, image_size=64, task='ALL', sequential=None, implausible=False, + test_set=False, im_channels=1, use_edge_kernels=True, labels=False, start_min=None, start_max=None, sequence_stride=None, + reduce_static_frames=False, is_shape_constancy=False): + # if implausible is set to True, generates "fake" images by cutting out or repeating frames + self.implausible = implausible + if test_set: + self.data_root = '%s/mcs_videos_test/processed/' % data_root + else: + self.data_root = '%s/mcs_videos_1000/processed/' % data_root + if not os.path.exists(self.data_root): + raise os.error('data/mcs.py: Data directory not found!') + self.seq_len = seq_len + self.image_size = image_size + self.im_channels = im_channels + self.labels = labels + if use_edge_kernels: + if im_channels != 1: + raise AssertionError('Using edge kernels implies the output images are grayscale! Set im_channels to 1!') + self.use_edge_kernels = use_edge_kernels + self.start_min = start_min + self.start_max = start_max + self.sequence_stride = sequence_stride + self.reduce_static_frames = reduce_static_frames + self.motion_threshold = 0.1 + self.is_shape_constancy = is_shape_constancy + + # print('mcs.py: found tasks ', self.tasks) + self.video_folder = {} + self.len_video_folder = {} + if task == 'ALL': + self.tasks = [os.path.basename(folder) for folder in sorted(glob(path.join(self.data_root, '*')))] + else: + self.tasks = [task] + + for task in self.tasks: + self.video_folder[task] = [path.basename(folder) for folder in + sorted(glob(path.join(self.data_root, task, '*')))] + self.len_video_folder[task] = len(self.video_folder[task]) + self.seed_set = False + self.sequential = sequential # if set to true, return videos in sequence + + def get_sequence(self, idx=None): + stride = max(1, self.sequence_stride) + if not self.sequential: + task = random.choice(self.tasks) + vid = random.choice(self.video_folder[task]) + num_frames = len(next(os.walk(path.join(self.data_root, task, vid)))[2]) + frame_path = path.join(self.data_root, task, vid, vid + '_') + else: + assert len(self.tasks) == 1 and idx is not None + # index = idx % self.len_video_folder[task] # loop over the videos under a task + task = self.tasks[0] + if idx >= self.len_video_folder[task]: + return None # we've run out of videos + vid = self.video_folder[task][idx] + num_frames = len(next(os.walk(path.join(self.data_root, task, vid)))[2]) # how many frames we can use after considering the stride + frame_path = path.join(self.data_root, task, vid, vid + '_') + label = str(os.path.basename(vid)) + label = label[label.rfind('_') + 1:] + + start_min = 0 + start_max = num_frames - 1 - (self.seq_len - 1) * stride + if start_max < 0: + raise ValueError("Number of frames in the dataset less than the desired sequence length!") + if self.start_min: + start_min = self.start_min + if self.start_max: + assert self.start_max <= start_max # so the sequence doesn't start too late + start_max = self.start_max + assert start_min <= start_max + start = random.randint(start_min, start_max) + seq = [] + last_im = None + first_movement = None + first_static = None + second_movement = None + # choose a random subsequence of frames in the selected video + if not self.reduce_static_frames: + end = start + self.seq_len * stride + else: + end = num_frames + for i in range(start, end, stride): + # i is 0-indexed so we need to add 1 to i + fname = frame_path + f'{i + 1:04d}.png' + im = imageio.imread(fname) / np.float32(255.) + if self.im_channels == 1: # convert to grayscale if specified + if not self.use_edge_kernels: + # regular grayscale conversion + gray = lambda rgb: np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) + im = gray(im)[..., np.newaxis] + else: + edge_map = np.zeros((self.image_size, self.image_size), dtype=np.float32) + ddepth = cv2.CV_32F + scale = 1 + delta = 0 + for channel in range(3): + color = im[..., channel] + # grad_x = cv2.Sobel(color, ddepth, 1, 0, ksize=5, scale=scale, delta=delta, + # borderType=cv2.BORDER_DEFAULT) + # grad_y = cv2.Sobel(color, ddepth, 0, 1, ksize=5, scale=scale, delta=delta, + # borderType=cv2.BORDER_DEFAULT) + grad_x = cv2.Scharr(color, ddepth, 1, 0, scale=scale, delta=delta, + borderType=cv2.BORDER_DEFAULT) + grad_y = cv2.Scharr(color, ddepth, 0, 1, scale=scale, delta=delta, + borderType=cv2.BORDER_DEFAULT) + # abs_grad_x = np.abs(grad_x) + # abs_grad_y = np.abs(grad_y) + edge_map += np.sqrt(grad_x**2 + grad_y**2) + edge_map /= 3 # 3 channels + edge_map /= 12 # to reduce magnitude + im = edge_map[..., np.newaxis] + if self.reduce_static_frames and last_im is not None: + motion_magnitude = np.mean(cv2.absdiff(im, last_im)) * 256 + # print(motion_magnitude) + if first_movement is None and motion_magnitude > self.motion_threshold: + first_movement = i + elif first_movement is not None and first_static is None and motion_magnitude <= self.motion_threshold: + first_static = i + # elif first_static is not None and second_movement is None and motion_magnitude > self.motion_threshold: + # second_movement = i + last_im = im + seq.append(im) + if self.reduce_static_frames: + new_seq = [] + + # keep the first and last static frames, then add the frames before and after to fill the sequence + # len_before = (self.seq_len - 1) // 2 + # len_after = (self.seq_len - 1) - len_before + # new_seq += seq[first_movement - start: first_movement - start + 5] + if self.is_shape_constancy: + first_static = min(105, max(start + 10, first_static)) # so we don't run into negative indices + new_seq += seq[first_static - start - 10: first_static - start + 3] + assert len(new_seq) == 13 + else: + first_static = min(165, max(start + 7, first_static)) # so we don't run into negative indices + new_seq += seq[first_static - start - 7: first_static - start + 3] + assert len(new_seq) == 10 + + if self.is_shape_constancy: + new_seq += seq[144 - start: 144 - start + (self.seq_len - 13)] + else: + new_seq += seq[first_static - start + 3 + 12: first_static - start + 3 + 22] + assert len(new_seq) == 20 + new_seq += seq[200 - start: 200 - start + (self.seq_len - 20)] + assert len(new_seq) == self.seq_len + # new_seq = seq[first_movement - start:] + # new_seq += seq[first_static - start:] + # new_seq += seq[second_movement - start:] + # print(first_movement, first_static, len(new_seq), self.seq_len) + # print(first_static - start - len_before, first_static - start + 1) + # print(second_movement - 1 - start, second_movement - start + len_after) + # print(len(new_seq), len(seq)) + # for img in new_seq: + # cv2.imshow('img', img[:, :, 0]) + # cv2.waitKey(0) + seq = new_seq + if self.labels: + return np.array(seq), label + else: + return np.array(seq) + + def abnormalize_sequence(self, seq): + """ + Takes a sequence and makes it implausible by cutting out and then repeating frames or + repeating frames and then cutting out frames + """ + implausibility_type = random.randint(1, 3) + # start = random.randint(100, 140) + start = 115 + vid_len = len(seq) + duration = 7 + implausibility_type = 1 + if implausibility_type == 1: # object is invisible when/where it shouldn't be + no_object_frame = seq[30] + seq[start:start + duration] = no_object_frame + elif implausibility_type == 2: # object suddenly freezes then teleports where it would be + seq[start:start + duration] = seq[start] + elif implausibility_type == 3: # object jumps forward 10 frames and keeps moving + seq[start:vid_len - duration] = \ + seq[start + duration: vid_len] + elif implausibility_type == 0: + pass # do nothing if type if 0 + return seq + + def __getitem__(self, index): + if not self.seed_set: + self.seed_set = True + random.seed(index) + np.random.seed(index) + # torch.manual_seed(index) + if self.labels: + seq, labels = self.get_sequence(index) + else: + seq = self.get_sequence(index) + + if seq is not None: + if self.implausible: + seq = self.abnormalize_sequence(seq) + if self.labels: + return torch.from_numpy(seq), labels + else: + return torch.from_numpy(seq) + else: + return None + + def __len__(self): + return 5 * 1000 * 200 # approximate + +# if __name__ == '__main__': +# m = MCS(True, '/home/lol/Hub/svg/data/') +# s = m.__getitem__(0).cuda() +# print(s.device) diff --git a/do_mcs_implausblility_test.py b/do_mcs_implausblility_test.py new file mode 100644 index 0000000..91ea12b --- /dev/null +++ b/do_mcs_implausblility_test.py @@ -0,0 +1,384 @@ +import glob + +import cv2 +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import matplotlib.pyplot as plt +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.004, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=1, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=1000, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=195, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, + help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=8, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', + help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') +opt = parser.parse_args() +BATCH_SIZE = opt.batch_size +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + optimizer = opt.optimizer + model_dir = opt.model_dir + niter = opt.niter + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.optimizer = optimizer + opt.model_dir = model_dir + opt.log_dir = '%s/continued' % opt.log_dir +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + +import models.lstm as lstm_models + +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = BATCH_SIZE +else: + raise ValueError('Please specify --model_dir') + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, opt.channels) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() + + +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +opt.batch_size = BATCH_SIZE +opt.epoch_size = 1000 +opt.n_future = 195 +print(opt) + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt, sequential=True, implausible=True) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + drop_last=True, + pin_memory=True,) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +training_batch_generator = get_training_batch() +training_batch_generator_2 = get_training_batch() + + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +testing_batch_generator = get_testing_batch() + + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + h = h.detach() + elif i < opt.n_past: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + x_in = x[i] + gen_seq[s].append(x_in) + else: + h = frame_predictor(h).detach() + x_in = decoder([h, skip]).detach() + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [[] for t in range(opt.n_eval)] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + gen_seq = [] + gen_seq.append(x[0]) + x_in = x[0] + h_seq = [encoder(x[i]) for i in range(opt.n_past + opt.n_future)] + for i in range(1, opt.n_past + opt.n_future): + h_target = h_seq[i][0].detach() + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + else: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + gen_seq.append(x[i]) + else: + h_pred = frame_predictor(h).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + + to_plot = [] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + row = [] + for t in range(opt.n_past + opt.n_future): + row.append(gen_seq[t][i]) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +def do_implasubility_test(h_residual_mean, h_residual_vars): + epoch_size = 1000 // opt.batch_size # we have a total of 1000 videos + cov_inv = [np.diag(1.0 / h_var)[np.newaxis, ...] for h_var in h_residual_vars] # we assume the covariance matrix is diagonal and do the inverse for each time + + frame_predictor.eval() + encoder.eval() + decoder.eval() + progress = progressbar.ProgressBar(max_value=epoch_size).start() + + + for i in range(epoch_size): + h_residual_var = torch.tensor(np.zeros(opt.n_future, dtype=np.float32), requires_grad=False, + device=torch.device('cpu')) + progress.update(i + 1) + try: + x = next(training_batch_generator) + frames = [frame.cpu() for frame in x] + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + h_posterior = [encoder(x[j]) for j in range(opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) + last_pred = None + frame_predictor.hidden = frame_predictor.init_hidden() + for j in range(1, opt.n_past + opt.n_future): + if opt.last_frame_skip or j < opt.n_past: + h, skip = h_posterior[j - 1] + else: + h = h_posterior[j - 1][0].detach() + # we predict h_t from h_{t-1} + h_prior_pred = frame_predictor(torch.cat([h], 1)).detach() + + if j >= opt.n_past: + # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h + # squared_diff = torch.square(h_res - h_residual_mean[j - opt.n_past]) + # squared_diff = torch.mean(squared_diff, dim=1) # average squared residuals at time j over the dimensions of h + # squared_diff = torch.mean(squared_diff, dim=0) # average squared residuals at time j over the batch + # h_residual_var[j - opt.n_past] += squared_diff + + residual = (h_prior_pred - last_pred).cpu() + # err = (residual - h_residual_mean[j - opt.n_past]) + err = residual + err = np.square(err) # [batch, dim_feature] + err = err / h_residual_vars[j - opt.n_past][np.newaxis, ...] # [batch, dim_feature] / [1, dim_feature] + err = torch.sum(err, axis=1) # -> [batch,] + # if len(err.shape) == 2: # [batch, dim_feature] + # err = err[..., np.newaxis] # make err into a vector [batch, dim_feature, 1] + # # print(cov_inv[j - opt.n_past].shape) + # # print(cov_inv[j - opt.n_past]) + # # print(err.shape) + # print(err) + # print(np.diag(cov_inv[j - opt.n_past][0])) + # quit() + # mahanlanobis_dist = np.matmul(cov_inv[j - opt.n_past], err).transpose(2, 1) + # # print(mahanlanobis_dist.shape) + # mahanlanobis_dist = np.matmul(mahanlanobis_dist, err) # [batch, 1, 1] + # # print(mahanlanobis_dist.shape) + # # print(np.sqrt(mahanlanobis_dist)) + # print(mahanlanobis_dist) + # quit() + # mahanlanobis_dist = torch.mean(mahanlanobis_dist) # scalar + + + # err = torch.mean(err, dim=1) # average errs at time j over the dimensions of h + # err = torch.mean(err, dim=0) # average errs at time j over the batch + # h_residual_var[j - opt.n_past] += err.detach() + h_residual_var[j - opt.n_past] += torch.mean(err, axis=0) + last_pred = h_prior_pred.detach() + + h_residual_sd = torch.sqrt(h_residual_var).cpu() + + h_residual_sd_filtered = - h_residual_sd[:-2] + 2 * h_residual_sd[1:-1] - h_residual_sd[2:] + + + print(h_residual_var) + for j in range(len(frames)): + frame = np.uint8(np.minimum(frames[j][0][0], 1) * 255) + cv2.imshow('frame', frame) + cv2.waitKey(10) + fig = plt.figure() + # plt.ylim(0, 10.0) + plt.xlabel('Time') + plt.title("sqrt of average squared dimensional error") + plt.bar(np.arange(len(h_residual_sd)), h_residual_sd) + fig.canvas.draw() + img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, + sep='') + img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + plt.xlabel('Time') + plt.title("filtered sqrt of average squared dimensional error") + plt.bar(np.arange(len(h_residual_sd_filtered)), h_residual_sd_filtered) + fig.canvas.draw() + img2 = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, + sep='') + img2 = img2.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2BGR) + + cv2.imshow("plot", img) + cv2.imshow("plot2", img2) + + k = cv2.waitKey(0) + if k == ord('q'): + quit() + # plt.savefig('implausibility_test.png') + # h_residual_var /= epoch_size # get the mean error vector per time + # h_residual_sd = torch.sqrt(h_residual_var) + print('Last i = {}'.format(i)) + + + # H_err_cov = torch.tensor(np.zeros((opt.n_future, 128, 128), dtype=np.float32), requires_grad=False, + # device=torch.device('cuda:0')) + + # plot some stuff + h_residual_mean_norm = torch.norm(h_residual_mean, dim=1).cpu() + plt.subplot(2, 1, 1) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Norm of the residual mean") + plt.bar(np.arange(len(h_residual_mean_norm)), h_residual_mean_norm) + + plt.subplot(2, 1, 2) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Average dimensional sqrt(variance) of the residual") + plt.bar(np.arange(len(h_residual_sd.cpu())), h_residual_sd.cpu()) + plt.savefig('h_residual_mean.png') + + +f = open('mcs_stats.json', 'r') +mcs_stats_dict = json.load(f) +do_implasubility_test(np.array(mcs_stats_dict['mean']), np.array(mcs_stats_dict['vars'])) diff --git a/do_mcs_implausblility_test_gravity_v2.py b/do_mcs_implausblility_test_gravity_v2.py new file mode 100644 index 0000000..6c7c1c6 --- /dev/null +++ b/do_mcs_implausblility_test_gravity_v2.py @@ -0,0 +1,461 @@ +import glob +from typing import List + +import cv2 +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random + +from shapely.geometry import Polygon +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import matplotlib.pyplot as plt +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.004, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=1, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=1000, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--use_edge_kernels', action='store_true') +parser.add_argument('--dataset', default='mcs_test', help='dataset to train with') +parser.add_argument('--mcs_task', default='GravitySupportEvaluation', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=55, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, + help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=8, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', + help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') +opt = parser.parse_args() +BATCH_SIZE = opt.batch_size +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + niter = opt.niter + dataset = opt.dataset + mcs_task = opt.mcs_task + n_future = opt.n_future + opt = saved_model['opt'] + opt.batch_size = BATCH_SIZE + opt.niter = niter # update number of epochs to train for + opt.dataset = dataset + opt.mcs_task = mcs_task + opt.n_future = n_future +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +font = cv2.FONT_HERSHEY_SIMPLEX +bottomLeftCornerOfText = (30, 30) +fontScale = 0.6 +fontColor = (0, 0, 0) +thickness = 1 + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + +import models.lstm as lstm_models + +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = BATCH_SIZE + posterior = saved_model['posterior'] + posterior.batch_size = BATCH_SIZE + prior = saved_model['prior'] + prior.batch_size = BATCH_SIZE +else: + raise ValueError('Please specify --model_dir') + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() + + +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +print(opt) + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt, sequential=True, implausible=False) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=False, + drop_last=True, + pin_memory=True, ) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=False, + drop_last=True, + pin_memory=True) + + +def get_training_batch(): + while True: + for sequence, labels in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch, labels + + +training_batch_generator = get_training_batch() + + +def get_testing_batch(): + while True: + for sequence, labels in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch, labels + + +testing_batch_generator = get_testing_batch() + + +def get_center_maximal_contour(img, draw=False): + if len(img.shape) == 3: # bgr 2 gray + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + # Find contours + contours, hierarchy = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if len(contours) == 0: + return None + max_contour = max(contours, key=cv2.contourArea) + if cv2.contourArea(max_contour) == 0: + # print(max_contour) + # print(tuple(max_contour[0][0])) + if draw: + cv2.circle(img, tuple(max_contour[0][0]), 1, 150, 1) + return tuple(max_contour[0][0]) + M = cv2.moments(max_contour) + cx, cy = int(M['m10'] / M['m00']), int(M['m01'] / M['m00']) + if draw: + cv2.circle(img, (cx, cy), 1, 150, 1) + return (cx, cy) + + +def get_polygons_from_img(img, top_n_contours=2): + img = high_pass(img.copy(), min=9, rad=7) + contours, hierarchy = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + contours = sorted(contours, key=cv2.contourArea, reverse=True) + contours = contours[:top_n_contours] + img = cv2.drawContours(img, contours, -1, 180, 1) + # contours = filter(lambda x: cv2.contourArea(x) > 0, contours) + polys = [] + for cnt in contours: + if cv2.contourArea(cnt) == 0: + continue + epsilon = 0.01 * cv2.arcLength(cnt, True) + poly_approx = cv2.approxPolyDP(cnt, epsilon, True) # shape: NUM_POINTS, 1, 2 (x and y locations) + poly_approx = np.squeeze(poly_approx, axis=1) # compress the second dimension -> (NUM_POINTS, 2) + try: + poly_approx = Polygon(poly_approx) + except ValueError: + print('too few vertices for polygon. original list: ', poly_approx) + idx = np.random.choice(list(range(len(poly_approx))), 3, replace=True) + poly_approx = Polygon(poly_approx[idx]) + print('resampled list: ', poly_approx) + # poly_approx -= get_center_of_mass(poly_approx) # center the COM at origin + c = poly_approx.centroid.coords[0] + cv2.circle(img, (round(c[0]), round(c[1])), 1, 150, 1) + polys.append(poly_approx) + return polys, img + + +def get_implausibility_score(poly_list1: List[Polygon], poly_list2: List[Polygon], thresh=5): + # thresh: when max of min distance > thresh, implausibility score > 0.5 + min_dists = [] + for p1 in poly_list1: + p1_centeroid = p1.centroid + min_dist = 100000000000 + for p2 in poly_list2: + min_dist = min(min_dist, p1_centeroid.distance(p2.centroid)) + min_dists.append(min_dist) + + poly_list1, poly_list2 = poly_list2, poly_list1 + for p1 in poly_list1: + p1_centeroid = p1.centroid + min_dist = 100000000000 + for p2 in poly_list2: + min_dist = min(min_dist, p1_centeroid.distance(p2.centroid)) + min_dists.append(min_dist) + if len(min_dists) == 0: + return 0 # if both the source and prediction has nothing going on, we say it's plausible + max_min_dist = max(min_dists) + # plausibility = 1 / (max(3, max_min_dist) - 2) # R+ -> [0,1] + # return 1 - plausibility + x = (max_min_dist - thresh) / 2 # this means that when MMD >= 4, we have implausibility score >= 0.5 and <0.5 if not. + implausibility = 1 / (1 + np.exp(-x)) + return implausibility + + +def high_pass(frame, min=10, rad=5): + fr = frame.copy() + fr[fr < min] = 0 # remove low brightness pixels + if np.sum(fr) != 0: + frame = fr + + frame = cv2.medianBlur(frame, rad) + # frame = get_maximal_contour(frame) + return frame + + +def do_implasubility_test(z_residual_mean, z_residual_cov, thresh, visualize=True): + MOTION_THRESH = 0.001 + epoch_size = 200 // opt.batch_size # we have a total of 1000 videos + cov_inv = [np.linalg.pinv(cov, hermitian=True) for cov in + z_residual_cov] # we assume the covariance matrix is diagonal and do the inverse for each time + # cov_inv = [np.linalg.inv(np.diag(np.diag(cov))) for cov in z_residual_cov] + # cov_inv = [np.eye(32) for cov in z_residual_cov] + frame_predictor.eval() + posterior.eval() + encoder.eval() + decoder.eval() + progress = progressbar.ProgressBar(max_value=epoch_size).start() + confusion_matrix = [[0, 0], [0, 0]] + # for i in range(50): + # if i % 2 == 0: + # x = next(training_batch_generator) + # else: + # x = next(training_batch_generator_implausible) + for i in range(epoch_size): + h_residual_var = torch.tensor(np.zeros(opt.n_future, dtype=np.float32), requires_grad=False, + device=torch.device('cpu')) + scores = np.array([0 for i in range(opt.n_future)], dtype=np.float32) + progress.update(i + 1) + is_implausible = False + try: + x, labels = next(training_batch_generator) + x_diff = [torch.abs(x[t] - x[t - 1]).detach() for t in range(1, len(x))] + x_diff = [torch.mean(frame, dim=(1, 2, 3)) for frame in x_diff] # mean along C, H, W + motion_start_time = [] + + # find when motion starts (object is about to be dropped) + for batch in range(len(x_diff[0])): + started = False + for t in range(30, len(x_diff)): + if x_diff[t][batch] > MOTION_THRESH: + motion_start_time.append( + t + 1 + 1) # add one because x_diff starts at t=1, then add another frame + started = True + break + if not started: + motion_start_time.append(33) # default to frame 33 if no motion is found + frames = [frame.cpu() for frame in x] + # print(frames[0][0].numpy().dtype) + # print(frames[0][0]) + # quit() + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + h_posterior = [encoder(x[j]) for j in range(0, opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) + last_pred = None + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + posterior.hidden = posterior.init_hidden() + + # start is the first frame that the prior model sees (so start + n_past is the first frame predicted) + if motion_start_time[0] is not None: + start = motion_start_time[0] - 4 + x_in = x[start] + x_out_seq = [x[start].cpu()] + for j in range(start + 1, opt.n_past + opt.n_future + 30): + h = encoder(x_in) + if opt.last_frame_skip or j < opt.n_past + start - 1: + h, skip = h + else: + h, _ = h + + if j < opt.n_past + start + 5: + z_t = posterior(h_posterior[j][0].detach()) + prior(h) + h_post = frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[j] + x_out_seq.append(decoder([h_post, skip]).detach().cpu()) + else: + z_t_hat = prior(h) + h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_in = decoder([h, skip]) + x_out_seq.append(x_in.detach().cpu()) + + if j >= opt.n_past + opt.n_future and torch.mean(torch.abs(x_out_seq[-1] - x_out_seq[-2]), + dim=(1, 2, 3)) <= MOTION_THRESH / 3: + # print('broke at j= ', j) + break + + # z_residual_scores_filtered = -0.25 * scores[:-2] + (0.5 + 0.5) * scores[1:-1] - 0.25 * scores[2:] + # print(motion_start_time[0], len(frames), len(x_out_seq)) + # print(h_residual_var) + background = frames[0][0][0].numpy().copy() + background = np.uint8(np.minimum(background / 1.2, 1.0) * 255.) + if visualize and motion_start_time[0] is not None: + k = 0 + j = start + source_center = None + pred_center = None + while not (j >= len(frames) and k >= len(x_out_seq)): + frame_cv2 = frames[min(j, len(frames) - 1)][0][0].numpy().copy() + frame_cv2 /= 1.2 + frame_cv2 = np.uint8(np.minimum(frame_cv2, 1.0) * 255.) + source_diff = cv2.absdiff(frame_cv2, background) + polys_source, source_diff_marked = get_polygons_from_img(source_diff) + # source_center = get_center_maximal_contour(source_diff, draw=True) + cv2.imshow('source', cv2.resize(frame_cv2, (384, 384), interpolation=cv2.INTER_NEAREST)) + cv2.imshow('source diff', cv2.resize(source_diff_marked, (384, 384), interpolation=cv2.INTER_NEAREST)) + + out_cv2 = x_out_seq[min(k, len(x_out_seq) - 1)][0][0].numpy().copy() + out_cv2 /= 1.2 + # out_cv2 /= np.max(out_cv2) + out_cv2 = np.uint8(np.minimum(out_cv2, 1.0) * 255.) + pred_diff = cv2.absdiff(out_cv2, background) + polys_pred, pred_diff_marked = get_polygons_from_img(pred_diff) + # pred_center = get_center_maximal_contour(pred_diff, draw=True) + cv2.imshow('prediction from first 5 frames', + cv2.resize(out_cv2, (384, 384), interpolation=cv2.INTER_NEAREST)) + cv2.imshow('pred diff', + cv2.resize(pred_diff_marked, (384, 384), interpolation=cv2.INTER_NEAREST)) + src_pred_diff = cv2.absdiff(source_diff, pred_diff) + imp_score = get_implausibility_score(polys_source, polys_pred) + src_pred_diff = cv2.resize(src_pred_diff, (384, 384), interpolation=cv2.INTER_NEAREST) + cv2.putText(src_pred_diff, str(imp_score), (20, 20), font, fontScale, 128, thickness) + cv2.imshow('source-pred diff', src_pred_diff) + cv2.waitKey(0) + j += 1 + k += 1 + if not visualize: + frame_cv2 = frames[-1][0][0].numpy().copy() + frame_cv2 = np.uint8(np.minimum(frame_cv2 / 1.2, 1.0) * 255.) + source_diff = high_pass(cv2.absdiff(frame_cv2, background)) + source_center = get_center_maximal_contour(source_diff, draw=True) + out_cv2 = x_out_seq[-1][0][0].numpy().copy() + out_cv2 = np.uint8(np.minimum(out_cv2 / 1.2, 1.0) * 255.) + pred_diff = high_pass(cv2.absdiff(out_cv2, background)) + pred_center = get_center_maximal_contour(pred_diff, draw=True) + + if (source_center is not None) and (pred_center is not None): + is_implausible = np.abs(source_center[-1] - pred_center[-1]) > thresh # a good thresh seems to be + + if is_implausible: + msg = 'implausible' + else: + msg = 'plausible' + # print('gt: ', labels, 'prediction: ', msg) + confusion_matrix[int(labels[0] == 'implausible')][int(is_implausible)] += 1 + if visualize: + cv2.waitKey(0) + + + # percentile = np.percentile(scores[76:152], 85.0) + # thresh = percentile * 1 + 0.5 + # spikes_idx = np.argwhere(z_residual_scores_filtered > thresh) + # spikes_idx = spikes_idx[ + # (spikes_idx >= 75) & (spikes_idx <= 150)] # ignore spikes near the start and end of video + # spikes = z_residual_scores_filtered[spikes_idx] + # msg = '' + # if len(spikes_idx) > 0: + # # we add n_past because the first n_past frames are not counted. Add 1 because of the filtering + # msg = 'thresh {:.1f} IMPLAUSIBLE spikes: '.format(thresh) + str( + # ['{:.2f}@{}'.format(z_residual_scores_filtered[k], k + opt.n_past + 1) for k in spikes_idx]) + # confusion_matrix[is_implausible][1] += 1 + # else: + # max_idx = np.argmax(z_residual_scores_filtered[75:151]) + 75 + # msg = 'thresh {:.1f} PLAUSIBLE max {:.2f}@{}'.format(thresh, z_residual_scores_filtered[max_idx], + # max_idx + opt.n_past + 1) + # confusion_matrix[is_implausible][0] += 1 + # + # print(msg) + # plt.savefig('implausibility_test.png') + # h_residual_var /= epoch_size # get the mean error vector per time + # h_residual_sd = torch.sqrt(h_residual_var) + print('Last i = {}'.format(i)) + + # H_err_cov = torch.tensor(np.zeros((opt.n_future, 128, 128), dtype=np.float32), requires_grad=False, + # device=torch.device('cuda:0')) + + # plot some stuff + return confusion_matrix + + +# f = open('new_mcs_stats_post.json', 'r') +# mcs_stats_dict = json.load(f) +mcs_stats_dict = {} +with open('new_mcs_stats_post.npy', 'rb') as f: + mcs_stats_dict['mean'] = np.load(f) + mcs_stats_dict['cov'] = np.load(f) + +ROC_curve = {} +for thr in range(2, 11): + train_data, test_data = utils.load_dataset(opt, sequential=True, implausible=False) + + train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=False, + drop_last=True, + pin_memory=True, ) + training_batch_generator = get_training_batch() + + conf_mat = do_implasubility_test(np.array(mcs_stats_dict['mean']), np.array(mcs_stats_dict['cov']), thr, visualize=True) + ROC_curve[thr] = conf_mat +print(ROC_curve) diff --git a/do_mcs_implausblility_test_posterior.py b/do_mcs_implausblility_test_posterior.py new file mode 100644 index 0000000..27a8ce4 --- /dev/null +++ b/do_mcs_implausblility_test_posterior.py @@ -0,0 +1,447 @@ +import glob + +import cv2 +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import matplotlib.pyplot as plt +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.004, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=1, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=1000, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--use_edge_kernels', action='store_true') +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=195, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, + help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=8, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', + help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') +opt = parser.parse_args() +BATCH_SIZE = 1 +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + niter = opt.niter + dataset = opt.dataset + mcs_task = opt.mcs_task + n_future = opt.n_future + data_root = opt.data_root + opt = saved_model['opt'] + opt.batch_size = BATCH_SIZE + opt.niter = niter # update number of epochs to train for + opt.dataset = dataset + opt.mcs_task = mcs_task + opt.n_future = n_future + opt.data_root = data_root + opt.start_min = 0 + opt.start_max = None + frame_predictor = saved_model['frame_predictor'].cuda() + frame_predictor.batch_size = BATCH_SIZE + posterior = saved_model['posterior'].cuda() + posterior.batch_size = BATCH_SIZE + prior = saved_model['prior'].cuda() + prior.batch_size = BATCH_SIZE + decoder = saved_model['decoder'].cuda() + encoder = saved_model['encoder'].cuda() +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +font = cv2.FONT_HERSHEY_SIMPLEX +bottomLeftCornerOfText = (30, 30) +fontScale = 0.6 +fontColor = (0, 0, 0) +thickness = 1 + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + +import models.lstm as lstm_models + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() + + +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# # --------- transfer to gpu ------------------------------------ +# frame_predictor.cuda() +# posterior.cuda() +# encoder.cuda() +# decoder.cuda() +# mse_criterion.cuda() + +opt.batch_size = BATCH_SIZE +opt.epoch_size = 1000 +opt.n_future = 195 +print(opt) + +# --------- load a dataset ------------------------------------ +train_data_implausible, test_data_implausible = utils.load_dataset(opt, sequential=True, implausible=True) +train_data, test_data = utils.load_dataset(opt, sequential=True, implausible=False) + +train_loader_implausible = DataLoader(train_data_implausible, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + drop_last=True, + pin_memory=True, ) +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + drop_last=True, + pin_memory=True, ) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +def get_training_batch_implausible(): + while True: + for sequence in train_loader_implausible: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +training_batch_generator = get_training_batch() +training_batch_generator_implausible = get_training_batch_implausible() + + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +testing_batch_generator = get_testing_batch() + + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + h = h.detach() + elif i < opt.n_past: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + x_in = x[i] + gen_seq[s].append(x_in) + else: + h = frame_predictor(h).detach() + x_in = decoder([h, skip]).detach() + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [[] for t in range(opt.n_eval)] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq = [] + gen_seq.append(x[0]) + x_in = x[0] + h_seq = [encoder(x[i]) for i in range(opt.n_past + opt.n_future)] + for i in range(1, opt.n_past + opt.n_future): + h_target = h_seq[i][0].detach() + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + else: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + gen_seq.append(x[i]) + else: + h_pred = frame_predictor(h).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + + to_plot = [] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + row = [] + for t in range(opt.n_past + opt.n_future): + row.append(gen_seq[t][i]) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +def do_implasubility_test(z_residual_mean, z_residual_cov, visualize=True): + epoch_size = 199 // opt.batch_size # we have a total of 1000 videos + cov_inv = [np.linalg.pinv(cov, hermitian=True) for cov in + z_residual_cov] # we assume the covariance matrix is diagonal and do the inverse for each time + # cov_inv = [np.linalg.inv(np.diag(np.diag(cov))) for cov in z_residual_cov] + # cov_inv = [np.eye(32) for cov in z_residual_cov] + frame_predictor.eval() + posterior.eval() + encoder.eval() + decoder.eval() + progress = progressbar.ProgressBar(max_value=epoch_size).start() + confusion_matrix = [[0, 0], [0, 0]] + # for i in range(50): + # if i % 2 == 0: + # x = next(training_batch_generator) + # else: + # x = next(training_batch_generator_implausible) + for i in range(epoch_size): + h_residual_var = torch.tensor(np.zeros(opt.n_future, dtype=np.float32), requires_grad=False, + device=torch.device('cpu')) + scores = np.array([0 for i in range(opt.n_future)], dtype=np.float32) + progress.update(i + 1) + is_implausible = False + try: + if i % 2 == 0: + x = next(training_batch_generator) + else: + x = next(training_batch_generator_implausible) + is_implausible = True + + frames = [frame.cpu() for frame in x] + # print(frames[0][0].numpy().dtype) + # print(frames[0][0]) + # quit() + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + h_posterior = [encoder(x[j]) for j in range(opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) + last_pred = None + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + start = 1 + for j in range(start, opt.n_past + opt.n_future): + h_target = h_posterior[j][0].detach() + if opt.last_frame_skip or j < opt.n_past + start - 1: + h, skip = h_posterior[j - 1] + else: + h = h_posterior[j - 1][0].detach() + # we predict h_t from h_{t-1} + z_t = posterior(h_target[0].detach()).detach() + z_t_hat = prior(h).detach() + + if j >= opt.n_past + start - 1: + # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h + # squared_diff = torch.square(h_res - h_residual_mean[j - opt.n_past]) + # squared_diff = torch.mean(squared_diff, dim=1) # average squared residuals at time j over the dimensions of h + # squared_diff = torch.mean(squared_diff, dim=0) # average squared residuals at time j over the batch + # h_residual_var[j - opt.n_past] += squared_diff + + # residual = (h_prior_pred - h_posterior_pred).cpu().detach() + residual = (z_t - z_t_hat).cpu().detach().numpy() + # err = (residual - z_residual_mean[j - opt.n_past]) + err = residual + err = np.square(err) # [batch, dim_feature] + score = np.sum(err) + # score = np.matmul(err[:, np.newaxis, :], cov_inv[j - opt.n_past][np.newaxis, ...]) # B*1*D * 1*D*D + # score = np.matmul(score, err[:, :, np.newaxis]) # * B*1*D * B*D*1 -> B*1*1 + # score = numpy.nan_to_num(score) + # score /= opt.z_dim + + scores[j - opt.n_past] = np.sqrt(score) + # scores[j - opt.n_past] = score + # note: scores[t]^2 ~ Chi^2_df=z_dim so E[scores[t]^2] = z_dim + # if len(err.shape) == 2: # [batch, dim_feature] + # err = err[..., np.newaxis] # make err into a vector [batch, dim_feature, 1] + # # print(cov_inv[j - opt.n_past].shape) + # # print(cov_inv[j - opt.n_past]) + # # print(err.shape) + # print(err) + # print(np.diag(cov_inv[j - opt.n_past][0])) + # quit() + # mahanlanobis_dist = np.matmul(cov_inv[j - opt.n_past], err).transpose(2, 1) + # # print(mahanlanobis_dist.shape) + # mahanlanobis_dist = np.matmul(mahanlanobis_dist, err) # [batch, 1, 1] + # # print(mahanlanobis_dist.shape) + # # print(np.sqrt(mahanlanobis_dist)) + # print(mahanlanobis_dist) + # quit() + # mahanlanobis_dist = torch.mean(mahanlanobis_dist) # scalar + + # err = torch.mean(err, dim=1) # average errs at time j over the dimensions of h + # err = torch.mean(err, dim=0) # average errs at time j over the batch + # h_residual_var[j - opt.n_past] += err.detach() + # h_residual_var[j - opt.n_past] += torch.mean(err, axis=0) + + z_residual_scores_filtered = -0.25 * scores[:-2] + (0.5+0.5) * scores[1:-1] - 0.25 *scores[2:] + + # print(h_residual_var) + if visualize: + for j in range(len(frames)): + frame_cv2 = frames[j][0][0].numpy() + # frame_cv2 /= 3 + frame_cv2 = np.uint8(np.minimum(frame_cv2, 1.0) * 255.) + cv2.imshow('frame', frame_cv2) + cv2.waitKey(15) + + percentile = np.percentile(scores[76:152], 85.0) + thresh = percentile * 1 + 0.5 + spikes_idx = np.argwhere(z_residual_scores_filtered > thresh) + spikes_idx = spikes_idx[ + (spikes_idx >= 75) & (spikes_idx <= 150)] # ignore spikes near the start and end of video + spikes = z_residual_scores_filtered[spikes_idx] + msg = '' + if len(spikes_idx) > 0: + # we add n_past because the first n_past frames are not counted. Add 1 because of the filtering + msg = 'thresh {:.1f} IMPLAUSIBLE spikes: '.format(thresh) + str(['{:.2f}@{}'.format(z_residual_scores_filtered[k], k + opt.n_past + 1) for k in spikes_idx]) + confusion_matrix[is_implausible][1] += 1 + else: + max_idx = np.argmax(z_residual_scores_filtered[75:151]) + 75 + msg = 'thresh {:.1f} PLAUSIBLE max {:.2f}@{}'.format(thresh, z_residual_scores_filtered[max_idx], max_idx + opt.n_past + 1) + confusion_matrix[is_implausible][0] += 1 + + print(msg) + + if visualize: + fig = plt.figure() + # plt.ylim(0, 2.0) + plt.xlabel('Time') + plt.title("sqrt of average squared dimensional error") + plt.bar(np.arange(len(scores)), scores) + fig.canvas.draw() + img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, + sep='') + img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + plt.xlabel('Time') + plt.title("filtered sqrt of average squared dimensional error") + plt.bar(np.arange(len(z_residual_scores_filtered)), z_residual_scores_filtered) + fig.canvas.draw() + img2 = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, + sep='') + img2 = img2.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2BGR) + + cv2.imshow("plot", img) + cv2.putText(img2, msg, bottomLeftCornerOfText, font, fontScale, fontColor, thickness, cv2.LINE_AA) + cv2.imshow("plot2", img2) + + k = cv2.waitKey(0) + if k == ord('q'): + quit() + # plt.savefig('implausibility_test.png') + # h_residual_var /= epoch_size # get the mean error vector per time + # h_residual_sd = torch.sqrt(h_residual_var) + print('Last i = {}'.format(i)) + + # H_err_cov = torch.tensor(np.zeros((opt.n_future, 128, 128), dtype=np.float32), requires_grad=False, + # device=torch.device('cuda:0')) + + # plot some stuff + return confusion_matrix + + +# f = open('new_mcs_stats_post.json', 'r') +# mcs_stats_dict = json.load(f) +mcs_stats_dict = {} +with open('new_mcs_stats_post.npy', 'rb') as f: + mcs_stats_dict['mean'] = np.load(f) + mcs_stats_dict['cov'] = np.load(f) +conf_mat = do_implasubility_test(np.array(mcs_stats_dict['mean']), np.array(mcs_stats_dict['cov']), visualize=True) +print(conf_mat) diff --git a/do_mcs_stats.py b/do_mcs_stats.py new file mode 100644 index 0000000..3a74423 --- /dev/null +++ b/do_mcs_stats.py @@ -0,0 +1,380 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import matplotlib.pyplot as plt +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.004, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=1, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=1000, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=195, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, + help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=8, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', + help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') +opt = parser.parse_args() +BATCH_SIZE = opt.batch_size +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + optimizer = opt.optimizer + model_dir = opt.model_dir + niter = opt.niter + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.optimizer = optimizer + opt.model_dir = model_dir + opt.log_dir = '%s/continued' % opt.log_dir +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models + +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = BATCH_SIZE +else: + frame_predictor = lstm_models.lstm(opt.g_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, opt.channels) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() + + +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +opt.batch_size = BATCH_SIZE +opt.epoch_size = 1000 +opt.n_future = 195 +print(opt) + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt, sequential=True) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + drop_last=True, + pin_memory=True) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +training_batch_generator = get_training_batch() +training_batch_generator_2 = get_training_batch() + + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +testing_batch_generator = get_testing_batch() + + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + h = h.detach() + elif i < opt.n_past: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + x_in = x[i] + gen_seq[s].append(x_in) + else: + h = frame_predictor(h).detach() + x_in = decoder([h, skip]).detach() + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [[] for t in range(opt.n_eval)] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + gen_seq = [] + gen_seq.append(x[0]) + x_in = x[0] + h_seq = [encoder(x[i]) for i in range(opt.n_past + opt.n_future)] + for i in range(1, opt.n_past + opt.n_future): + h_target = h_seq[i][0].detach() + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + else: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + gen_seq.append(x[i]) + else: + h_pred = frame_predictor(h).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + + to_plot = [] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + row = [] + for t in range(opt.n_past + opt.n_future): + row.append(gen_seq[t][i]) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +def do_stats(): + epoch_size = 1000 // opt.batch_size // 4 # we have a total of 1000 videos + + frame_predictor.eval() + encoder.eval() + decoder.eval() + progress = progressbar.ProgressBar(max_value=epoch_size).start() + h_residual_mean = torch.tensor(np.zeros((opt.n_future, 128), dtype=np.float32), requires_grad=False, + device=torch.device('cuda:0')) + i = 0 + for i in range(epoch_size): + progress.update(i + 1) + try: + x = next(training_batch_generator) + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + h_posterior = [encoder(x[j]) for j in range(opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) + last_pred = None + frame_predictor.hidden = frame_predictor.init_hidden() + for j in range(1, opt.n_past + opt.n_future): + if opt.last_frame_skip or j < opt.n_past: + h, skip = h_posterior[j - 1] + else: + h = h_posterior[j - 1][0].detach() + # we predict h_t from h_{t-1} + h_prior_pred = frame_predictor(torch.cat([h], 1)).detach() + + if j >= opt.n_past: + # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h + # h_res = h_prior_pred + # h_res = torch.mean(h_res, dim=0) # average errors at the same time j over the batch + # h_residual_mean[j - opt.n_past] += h_res + this_minus_last_pred = h_prior_pred - last_pred + this_minus_last_pred = torch.mean(this_minus_last_pred, dim=0) + h_residual_mean[j - opt.n_past] += this_minus_last_pred + last_pred = h_prior_pred + h_residual_mean /= epoch_size # get the mean error vector per time + + # restart training dataset + global train_loader + train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + drop_last=True, + pin_memory=True) + h_residual_var = torch.tensor(np.zeros(opt.n_future, dtype=np.float32), requires_grad=False, + device=torch.device('cuda:0')) + h_residual_vars = torch.tensor(np.zeros((opt.n_future, 128), dtype=np.float32), requires_grad=False, + device=torch.device('cuda:0')) + + for i in range(epoch_size): + progress.update(i + 1) + try: + x = next(training_batch_generator_2) + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + h_posterior = [encoder(x[j]) for j in range(opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) + last_pred = None + frame_predictor.hidden = frame_predictor.init_hidden() + for j in range(1, opt.n_past + opt.n_future): + if opt.last_frame_skip or j < opt.n_past: + h, skip = h_posterior[j - 1] + else: + h = h_posterior[j - 1][0].detach() + # we predict h_t from h_{t-1} + h_prior_pred = frame_predictor(torch.cat([h], 1)).detach() + + if j >= opt.n_past: + # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h + # squared_diff = torch.square(h_res - h_residual_mean[j - opt.n_past]) + # squared_diff = torch.mean(squared_diff, dim=1) # average squared residuals at time j over the dimensions of h + # squared_diff = torch.mean(squared_diff, dim=0) # average squared residuals at time j over the batch + # h_residual_var[j - opt.n_past] += squared_diff + + this_minus_last_pred = h_prior_pred - last_pred + squared_err = torch.square(this_minus_last_pred - h_residual_mean[j - opt.n_past]) + squared_err = torch.mean(squared_err, dim=0) # average errs at time j over the batch + h_residual_vars[j - opt.n_past] += squared_err.detach() + squared_err = torch.mean(squared_err, dim=0) # average errs at time j over the dimensions of h + h_residual_var[j - opt.n_past] += squared_err.detach() + last_pred = h_prior_pred.detach() + h_residual_var /= epoch_size # get the mean error vector per time + h_residual_vars /= epoch_size + h_residual_sd = torch.sqrt(h_residual_var) + print('Last i = {}'.format(i)) + print('sd of h residual: ', h_residual_sd) + print('var of h residual: ', h_residual_var) + print('norm(mean of h residual)', torch.norm(h_residual_mean, dim=1)) + print('vars of h residual: ', h_residual_vars) + + # H_err_cov = torch.tensor(np.zeros((opt.n_future, 128, 128), dtype=np.float32), requires_grad=False, + # device=torch.device('cuda:0')) + + # plot some stuff + h_residual_mean_norm = torch.norm(h_residual_mean, dim=1).cpu() + plt.subplot(2, 1, 1) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Norm of the residual mean") + plt.bar(np.arange(len(h_residual_mean_norm)), h_residual_mean_norm) + + plt.subplot(2, 1, 2) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Average dimensional sqrt(variance) of the residual") + plt.bar(np.arange(len(h_residual_sd.cpu())), h_residual_sd.cpu()) + plt.savefig('h_residual_mean_1.png') + + stats_dict = {'mean': h_residual_mean.cpu().tolist(), 'var': h_residual_var.cpu().tolist(), + 'vars': h_residual_vars.cpu().tolist()} + f = open('mcs_stats.json', 'w') + json.dump(stats_dict, f) + + +do_stats() diff --git a/do_mcs_stats_posterior.py b/do_mcs_stats_posterior.py new file mode 100644 index 0000000..621696a --- /dev/null +++ b/do_mcs_stats_posterior.py @@ -0,0 +1,385 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import matplotlib.pyplot as plt +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.004, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=1, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=1000, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=195, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, + help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=8, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', + help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') +opt = parser.parse_args() +BATCH_SIZE = opt.batch_size +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.log_dir = '%s/continued' % opt.log_dir +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +import models.lstm as lstm_models + +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = BATCH_SIZE + posterior = saved_model['posterior'] + posterior.batch_size = BATCH_SIZE + prior = saved_model['prior'] + prior.batch_size = BATCH_SIZE +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() + + +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +opt.batch_size = BATCH_SIZE +opt.epoch_size = 1000 +opt.n_future = 195 +print(opt) + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt, sequential=True) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + drop_last=True, + pin_memory=True) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +training_batch_generator = get_training_batch() +training_batch_generator_2 = get_training_batch() + + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +testing_batch_generator = get_testing_batch() + + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + h = h.detach() + elif i < opt.n_past: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + x_in = x[i] + gen_seq[s].append(x_in) + else: + h = frame_predictor(h).detach() + x_in = decoder([h, skip]).detach() + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [[] for t in range(opt.n_eval)] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq = [] + gen_seq.append(x[0]) + x_in = x[0] + h_seq = [encoder(x[i]) for i in range(opt.n_past + opt.n_future)] + for i in range(1, opt.n_past + opt.n_future): + h_target = h_seq[i][0].detach() + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + else: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + gen_seq.append(x[i]) + else: + h_pred = frame_predictor(h).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + + to_plot = [] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + row = [] + for t in range(opt.n_past + opt.n_future): + row.append(gen_seq[t][i]) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +def do_stats(): + epoch_size = 1000 // opt.batch_size // 1 # we have a total of 1000 videos + + frame_predictor.eval() + prior.eval() + posterior.eval() + encoder.eval() + decoder.eval() + progress = progressbar.ProgressBar(max_value=epoch_size).start() + z_residual_mean = torch.tensor(np.zeros((opt.n_future, opt.z_dim), dtype=np.float64), requires_grad=False, + device=torch.device('cuda:0')) + i = 0 + for i in range(epoch_size): + progress.update(i + 1) + try: + x = next(training_batch_generator) + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + # print(h_posterior[0][0].size()) + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + posterior.hidden = posterior.init_hidden() + + last_h = encoder(x[0]) + for j in range(1, opt.n_past + opt.n_future): + h_target = encoder(x[j]) + if opt.last_frame_skip or j < opt.n_past: + h, skip = last_h + h = h.detach() + else: + h = last_h[0].detach() + # we predict h_t from h_{t-1} + z_t = posterior(h_target[0].detach()).detach() + z_t_hat = prior(h).detach() + + if j >= opt.n_past: + # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h + # h_res = h_prior_pred + # h_res = torch.mean(h_res, dim=0) # average errors at the same time j over the batch + # z_residual_mean[j - opt.n_past] += h_res + residual = z_t - z_t_hat + residual = torch.mean(residual, dim=0) + z_residual_mean[j - opt.n_past] += residual + last_h = h_target + z_residual_mean /= epoch_size # get the mean error vector per time + + # restart training dataset + global train_loader + train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + drop_last=True, + pin_memory=True) + + z_cov = torch.tensor(np.zeros((opt.n_future, opt.z_dim, opt.z_dim), dtype=np.float64), requires_grad=False, + device=torch.device('cuda:0')) + + + for i in range(epoch_size): + progress.update(i + 1) + try: + x = next(training_batch_generator_2) + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + + # print(h_posterior[0][0].size()) + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + posterior.hidden = posterior.init_hidden() + last_h = encoder(x[0]) + for j in range(1, opt.n_past + opt.n_future): + h_target = encoder(x[j]) + if opt.last_frame_skip or j < opt.n_past: + h, skip = last_h + h = h.detach() + else: + h = last_h[0].detach() + # we predict h_t from h_{t-1} + z_t = posterior(h_target[0].detach()).detach() + z_t_hat = prior(h).detach() + + if j >= opt.n_past: + # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h + # squared_diff = torch.square(h_res - z_residual_mean[j - opt.n_past]) + # squared_diff = torch.mean(squared_diff, dim=1) # average squared residuals at time j over the dimensions of h + # squared_diff = torch.mean(squared_diff, dim=0) # average squared residuals at time j over the batch + # h_residual_var[j - opt.n_past] += squared_diff + + residual = z_t - z_t_hat # B x D + # B x D x 1 * B x 1 x D -> B x D x D + sample_cov = torch.matmul(residual[:, :, np.newaxis], residual[:, np.newaxis, :]) + sample_cov = torch.mean(sample_cov, axis=0) # mean over batch dimension + z_cov[j - opt.n_past] += sample_cov + last_h = h_target + z_cov /= epoch_size + z_cov = z_cov.cpu().numpy() + z_sd = [np.sqrt(np.diag(cov)) for cov in z_cov] + z_sd = np.array(z_sd) + print('Last i = {}'.format(i)) + print('sd of z residual: ', z_sd) + print('norm(mean of z residual)', torch.norm(z_residual_mean, dim=1)) + + # H_err_cov = torch.tensor(np.zeros((opt.n_future, 128, 128), dtype=np.float32), requires_grad=False, + # device=torch.device('cuda:0')) + + # plot some stuff + z_residual_mean_norm = torch.norm(z_residual_mean, dim=1).cpu() + plt.subplot(3, 1, 1) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Norm of the residual mean") + plt.bar(np.arange(len(z_residual_mean_norm)), z_residual_mean_norm) + + plt.subplot(3, 1, 2) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Average dimensional sqrt(variance) of the residual") + plt.bar(np.arange(len(z_sd)), np.mean(z_sd, axis=1)) + plt.savefig('z_residual.png') + + stats_dict = {'mean': z_residual_mean.cpu().numpy(), 'cov': z_cov} + print(stats_dict['mean'].dtype) + print(stats_dict['cov'].dtype) + # f = open('new_mcs_stats_post.json', 'w') + # json.dump(stats_dict, f) + with open('new_mcs_stats_post.npy', 'wb') as f: + np.save(f, stats_dict['mean']) + np.save(f, stats_dict['cov']) + + +do_stats() diff --git a/h_residual_mean.png b/h_residual_mean.png new file mode 100644 index 0000000..c679777 Binary files /dev/null and b/h_residual_mean.png differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4d88cc3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +numpy +torch +shapely +sklearn +matplotlib +scikit-image +progressbar2 +opencv-python \ No newline at end of file diff --git a/train_baseline_collision.py b/train_baseline_collision.py new file mode 100644 index 0000000..1c900b2 --- /dev/null +++ b/train_baseline_collision.py @@ -0,0 +1,447 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.0002, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=13, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=40, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=500, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int, help='number of channels for input images. ') +parser.add_argument('--use_edge_kernels', default=True, type=bool, help='whether to use edge kernels to reduce to 1 channel') +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='CollisionTraining', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=30, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=35, help='number of frames to predict at eval time') +parser.add_argument('--start_min', type=int, default=75, help='min starting time for sampling sequence (0-indexed)') +parser.add_argument('--start_max', type=int, default=100, help='max starting time for sampling sequence (0-indexed)') +parser.add_argument('--sequence_stride', type=int, default=2, help='factor for sequence temporal subsampling (int)') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--prior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=32, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') +parser.add_argument('--model', default='vgg', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') + + +opt = parser.parse_args() +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + lr = opt.lr + batch_size = opt.batch_size + n_future = opt.n_future + n_eval = opt.n_eval + data_root = opt.data_root + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.n_future = n_future + opt.lr = lr + opt.batch_size = batch_size + opt.n_eval = n_eval + opt.data_root = data_root + opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) +else: + name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + elif opt.dataset == 'mcs': + opt.log_dir = '%s/%s/%s/%s' % (opt.log_dir, opt.dataset, opt.mcs_task, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=False) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=False) +with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: + opt2 = opt.__dict__.copy() + if isinstance(opt2['optimizer'], type): + opt2['optimizer'] = str(opt2['optimizer']) + json.dump(opt2, f, indent=2) + del opt2 + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +elif isinstance(opt.optimizer, type): + pass +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = opt.batch_size + prior = saved_model['prior'] + prior.batch_size = opt.batch_size + posterior = saved_model['posterior'] + posterior.batch_size = opt.batch_size +else: + frame_predictor = lstm_models.lstm(opt.g_dim + opt.z_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + posterior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.posterior_rnn_layers, opt.batch_size) + prior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.prior_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + posterior.apply(utils.init_weights) + prior.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, 1) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +if opt.optimizer == optim.SGD: + frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, momentum=opt.beta1) + posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, momentum=opt.beta1) + prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, momentum=opt.beta1) + encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, momentum=opt.beta1) + decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, momentum=opt.beta1) +else: + frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +prior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True,) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [utils.torch_rgb_img_to_gray(x[t]) for t in range(len(x))] + + # h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + gen_seq[s].append(utils.torch_rgb_img_to_gray(x[0])) + x_in = x[0] + for i in range(1, opt.n_eval): + with torch.no_grad(): + # if input is grayscale + if x_in.shape[1] == 1 and opt.channels == 3: + h = encoder(torch.cat(3*[x_in], dim=1)) # convert to RGB + else: + h = encoder(x_in) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + if i < opt.n_past: + h_target = encoder(x[i]) + z_t = posterior(h_target[0].detach()) + prior(h) + frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[i] + gen_seq[s].append(utils.torch_rgb_img_to_gray(x_in)) + else: + z_t_hat = prior(h) + h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_in = decoder([h, skip]) + gen_seq[s].append(x_in) + + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + gen_seq = [utils.torch_rgb_img_to_gray(x[0])] + gen_seq_post = [utils.torch_rgb_img_to_gray(x[0])] + + # prediction using posterior Z + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + z_t = posterior(h_target[0]) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t], 1)) + gen_seq_post.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq_post.append(x_pred) + h = h_target + + # prediction using prior Z + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + z_t_hat = prior(h) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t_hat], 1)) + gen_seq.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + h = h_target + + to_plot = [] + nrow = min(opt.batch_size, 25) + x_gray = [utils.torch_rgb_img_to_gray(x[t]) for t in range(min(opt.n_eval, opt.n_past+opt.n_future))] + for i in range(nrow): + row_gt = [] + row_post = [] + row = [] + for t in range(min(opt.n_eval, opt.n_past+opt.n_future)): + row_gt.append(x_gray[t][i]) + row_post.append(gen_seq_post[t][i]) + row.append(gen_seq[t][i]) + to_plot.append(row_gt) + to_plot.append(row_post) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + posterior.zero_grad() + prior.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + + mse = 0 + mse_residual = 0 + h = encoder(x[0]) + for i in range(1, opt.n_past+opt.n_future): + h_target = encoder(x[i]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h = h[0] + + z_t = posterior(h_target[0]) + z_t_hat = prior(h) + h_pred = frame_predictor(torch.cat([h, z_t], 1)) + x_pred = decoder([h_pred, skip]) + gray_target_frame = utils.torch_rgb_img_to_gray(x[i]) + mse += mse_criterion(x_pred, gray_target_frame) + # penalize prior for being far from posterior + mse_residual += opt.gamma * torch.mean(torch.square(z_t.detach() - z_t_hat)) + h = h_target + + loss = mse + mse_residual + loss.backward() + + frame_predictor_optimizer.step() + posterior_optimizer.step() + prior_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + N = opt.n_past+opt.n_future + return mse.data.cpu().numpy()/N, mse_residual.data.cpu().numpy()/N + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + posterior.train() + prior.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_mse_residual = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + # opt.epoch_size = 10 + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse, mse_residual = train(x) + epoch_mse += mse + epoch_mse_residual += mse_residual + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)\n' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + posterior.eval() + prior.eval() + encoder.eval() + decoder.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'posterior': posterior, + 'prior': prior, + 'opt': opt}, + '%s/model_e%02d.pth' % (opt.log_dir, epoch)) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + diff --git a/train_baseline_gravity.py b/train_baseline_gravity.py new file mode 100644 index 0000000..f237590 --- /dev/null +++ b/train_baseline_gravity.py @@ -0,0 +1,446 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.0002, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=13, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=40, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=500, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int, help='number of channels for input images. ') +parser.add_argument('--use_edge_kernels', default=True, type=bool, help='whether to use edge kernels to reduce to 1 channel') +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='GravitySupportTraining', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=25, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--start_min', type=int, default=7, help='min starting time for sampling sequence (0-indexed)') +parser.add_argument('--start_max', type=int, default=29, help='max starting time for sampling sequence (0-indexed)') +parser.add_argument('--sequence_stride', type=int, default=1, help='factor for sequence temporal subsampling (int)') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--prior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=32, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') +parser.add_argument('--model', default='vgg', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') + + +opt = parser.parse_args() +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + lr = opt.lr + batch_size = opt.batch_size + n_future = opt.n_future + n_eval = opt.n_eval + data_root = opt.data_root + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.n_future = n_future + opt.lr = lr + opt.batch_size = batch_size + opt.n_eval = n_eval + opt.data_root = data_root + opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) +else: + name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + elif opt.dataset == 'mcs': + opt.log_dir = '%s/%s/%s/%s' % (opt.log_dir, opt.dataset, opt.mcs_task, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=False) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=False) +with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: + opt2 = opt.__dict__.copy() + if isinstance(opt2['optimizer'], type): + opt2['optimizer'] = str(opt2['optimizer']) + json.dump(opt2, f, indent=2) + del opt2 +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +elif isinstance(opt.optimizer, type): + pass +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = opt.batch_size + prior = saved_model['prior'] + prior.batch_size = opt.batch_size + posterior = saved_model['posterior'] + posterior.batch_size = opt.batch_size +else: + frame_predictor = lstm_models.lstm(opt.g_dim + opt.z_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + posterior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.posterior_rnn_layers, opt.batch_size) + prior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.prior_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + posterior.apply(utils.init_weights) + prior.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, 1) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +if opt.optimizer == optim.SGD: + frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, momentum=opt.beta1) + posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, momentum=opt.beta1) + prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, momentum=opt.beta1) + encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, momentum=opt.beta1) + decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, momentum=opt.beta1) +else: + frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +prior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True,) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [utils.torch_rgb_img_to_gray(x[t]) for t in range(len(x))] + + # h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + gen_seq[s].append(utils.torch_rgb_img_to_gray(x[0])) + x_in = x[0] + for i in range(1, opt.n_eval): + with torch.no_grad(): + # if input is grayscale + if x_in.shape[1] == 1 and opt.channels == 3: + h = encoder(torch.cat(3*[x_in], dim=1)) # convert to RGB + else: + h = encoder(x_in) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + if i < opt.n_past: + h_target = encoder(x[i]) + z_t = posterior(h_target[0].detach()) + prior(h) + frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[i] + gen_seq[s].append(utils.torch_rgb_img_to_gray(x_in)) + else: + z_t_hat = prior(h) + h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_in = decoder([h, skip]) + gen_seq[s].append(x_in) + + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + gen_seq = [utils.torch_rgb_img_to_gray(x[0])] + gen_seq_post = [utils.torch_rgb_img_to_gray(x[0])] + + # prediction using posterior Z + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + z_t = posterior(h_target[0]) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t], 1)) + gen_seq_post.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq_post.append(x_pred) + h = h_target + + # prediction using prior Z + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + z_t_hat = prior(h) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t_hat], 1)) + gen_seq.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + h = h_target + + to_plot = [] + nrow = min(opt.batch_size, 25) + x_gray = [utils.torch_rgb_img_to_gray(x[t]) for t in range(min(opt.n_eval, opt.n_past+opt.n_future))] + for i in range(nrow): + row_gt = [] + row_post = [] + row = [] + for t in range(min(opt.n_eval, opt.n_past+opt.n_future)): + row_gt.append(x_gray[t][i]) + row_post.append(gen_seq_post[t][i]) + row.append(gen_seq[t][i]) + to_plot.append(row_gt) + to_plot.append(row_post) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + posterior.zero_grad() + prior.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + + mse = 0 + mse_residual = 0 + h = encoder(x[0]) + for i in range(1, opt.n_past+opt.n_future): + h_target = encoder(x[i]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h = h[0] + + z_t = posterior(h_target[0]) + z_t_hat = prior(h) + h_pred = frame_predictor(torch.cat([h, z_t], 1)) + x_pred = decoder([h_pred, skip]) + gray_target_frame = utils.torch_rgb_img_to_gray(x[i]) + mse += mse_criterion(x_pred, gray_target_frame) + # penalize prior for being far from posterior + mse_residual += opt.gamma * torch.mean(torch.square(z_t.detach() - z_t_hat)) + h = h_target + + loss = mse + mse_residual + loss.backward() + + frame_predictor_optimizer.step() + posterior_optimizer.step() + prior_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + N = opt.n_past+opt.n_future + return mse.data.cpu().numpy()/N, mse_residual.data.cpu().numpy()/N + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + posterior.train() + prior.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_mse_residual = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + # opt.epoch_size = 10 + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse, mse_residual = train(x) + epoch_mse += mse + epoch_mse_residual += mse_residual + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)\n' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + posterior.eval() + prior.eval() + encoder.eval() + decoder.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'posterior': posterior, + 'prior': prior, + 'opt': opt}, + '%s/model_e%02d.pth' % (opt.log_dir, epoch)) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + diff --git a/train_baseline_object_permanence.py b/train_baseline_object_permanence.py new file mode 100644 index 0000000..085838e --- /dev/null +++ b/train_baseline_object_permanence.py @@ -0,0 +1,448 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.0002, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=10, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=40, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=500, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int, help='number of channels for input images. ') +parser.add_argument('--use_edge_kernels', default=True, type=bool, help='whether to use edge kernels to reduce to 1 channel') +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='ObjectPermanenceTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=35, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=40, help='number of frames to predict at eval time') +parser.add_argument('--start_min', type=int, default=75, help='min starting time for sampling sequence (0-indexed)') +parser.add_argument('--start_max', type=int, default=77, help='max starting time for sampling sequence (0-indexed)') +parser.add_argument('--sequence_stride', type=int, default=1, help='factor for sequence temporal subsampling (int)') +parser.add_argument('--reduce_static_frames', type=bool, default=True, help='reduce number of static frames') +# parser.add_argument('--lifting_frame_index', type=int, default=200, help='index of frame when panels are lifted') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--prior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=32, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') +parser.add_argument('--model', default='vgg', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') + + +opt = parser.parse_args() +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + lr = opt.lr + batch_size = opt.batch_size + n_future = opt.n_future + n_eval = opt.n_eval + data_root = opt.data_root + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.n_future = n_future + opt.lr = lr + opt.batch_size = batch_size + opt.n_eval = n_eval + opt.data_root = data_root + opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) +else: + name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + elif opt.dataset == 'mcs': + opt.log_dir = '%s/%s/%s/%s' % (opt.log_dir, opt.dataset, opt.mcs_task, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=False) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=False) +with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: + opt2 = opt.__dict__.copy() + if isinstance(opt2['optimizer'], type): + opt2['optimizer'] = str(opt2['optimizer']) + json.dump(opt2, f, indent=2) + del opt2 + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +elif isinstance(opt.optimizer, type): + pass +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = opt.batch_size + prior = saved_model['prior'] + prior.batch_size = opt.batch_size + posterior = saved_model['posterior'] + posterior.batch_size = opt.batch_size +else: + frame_predictor = lstm_models.lstm(opt.g_dim + opt.z_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + posterior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.posterior_rnn_layers, opt.batch_size) + prior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.prior_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + posterior.apply(utils.init_weights) + prior.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, 1) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +if opt.optimizer == optim.SGD: + frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, momentum=opt.beta1) + posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, momentum=opt.beta1) + prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, momentum=opt.beta1) + encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, momentum=opt.beta1) + decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, momentum=opt.beta1) +else: + frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +prior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True,) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [utils.torch_rgb_img_to_gray(x[t]) for t in range(len(x))] + + # h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + gen_seq[s].append(utils.torch_rgb_img_to_gray(x[0])) + x_in = x[0] + for i in range(1, opt.n_eval): + with torch.no_grad(): + # if input is grayscale + if x_in.shape[1] == 1 and opt.channels == 3: + h = encoder(torch.cat(3*[x_in], dim=1)) # convert to RGB + else: + h = encoder(x_in) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + if i < opt.n_past: + h_target = encoder(x[i]) + z_t = posterior(h_target[0].detach()) + prior(h) + frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[i] + gen_seq[s].append(utils.torch_rgb_img_to_gray(x_in)) + else: + z_t_hat = prior(h) + h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_in = decoder([h, skip]) + gen_seq[s].append(x_in) + + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + gen_seq = [utils.torch_rgb_img_to_gray(x[0])] + gen_seq_post = [utils.torch_rgb_img_to_gray(x[0])] + + # prediction using posterior Z + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + z_t = posterior(h_target[0]) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t], 1)) + gen_seq_post.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq_post.append(x_pred) + h = h_target + + # prediction using prior Z + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + z_t_hat = prior(h) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t_hat], 1)) + gen_seq.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + h = h_target + + to_plot = [] + nrow = min(opt.batch_size, 25) + x_gray = [utils.torch_rgb_img_to_gray(x[t]) for t in range(min(opt.n_eval, opt.n_past+opt.n_future))] + for i in range(nrow): + row_gt = [] + row_post = [] + row = [] + for t in range(min(opt.n_eval, opt.n_past+opt.n_future)): + row_gt.append(x_gray[t][i]) + row_post.append(gen_seq_post[t][i]) + row.append(gen_seq[t][i]) + to_plot.append(row_gt) + to_plot.append(row_post) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + posterior.zero_grad() + prior.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + + mse = 0 + mse_residual = 0 + h = encoder(x[0]) + for i in range(1, opt.n_past+opt.n_future): + h_target = encoder(x[i]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h = h[0] + + z_t = posterior(h_target[0]) + z_t_hat = prior(h) + h_pred = frame_predictor(torch.cat([h, z_t], 1)) + x_pred = decoder([h_pred, skip]) + gray_target_frame = utils.torch_rgb_img_to_gray(x[i]) + mse += mse_criterion(x_pred, gray_target_frame) + # penalize prior for being far from posterior + mse_residual += opt.gamma * torch.mean(torch.square(z_t.detach() - z_t_hat)) + h = h_target + + loss = mse + mse_residual + loss.backward() + + frame_predictor_optimizer.step() + posterior_optimizer.step() + prior_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + N = opt.n_past+opt.n_future + return mse.data.cpu().numpy()/N, mse_residual.data.cpu().numpy()/N + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + posterior.train() + prior.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_mse_residual = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + # opt.epoch_size = 10 + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse, mse_residual = train(x) + epoch_mse += mse + epoch_mse_residual += mse_residual + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)\n' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + # plot some stuff + frame_predictor.eval() + posterior.eval() + prior.eval() + encoder.eval() + decoder.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'posterior': posterior, + 'prior': prior, + 'opt': opt}, + '%s/model_e%02d.pth' % (opt.log_dir, epoch)) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + diff --git a/train_baseline_shapeconstancy.py b/train_baseline_shapeconstancy.py new file mode 100644 index 0000000..4e7d44f --- /dev/null +++ b/train_baseline_shapeconstancy.py @@ -0,0 +1,449 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.0002, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=13, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=40, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=500, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int, help='number of channels for input images. ') +parser.add_argument('--use_edge_kernels', default=True, type=bool, help='whether to use edge kernels to reduce to 1 channel') +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='ShapeConstancyTraining', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=21, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=26, help='number of frames to predict at eval time') +parser.add_argument('--start_min', type=int, default=70, help='min starting time for sampling sequence (0-indexed)') +parser.add_argument('--start_max', type=int, default=75, help='max starting time for sampling sequence (0-indexed)') +parser.add_argument('--sequence_stride', type=int, default=1, help='factor for sequence temporal subsampling (int)') +parser.add_argument('--reduce_static_frames', type=bool, default=True, help='reduce number of static frames') +parser.add_argument('--is_shape_constancy', type=bool, default=True, help='flag indicating using shape constancy dataset') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--prior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=32, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') +parser.add_argument('--model', default='vgg', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') + + +opt = parser.parse_args() +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + lr = opt.lr + batch_size = opt.batch_size + n_future = opt.n_future + n_eval = opt.n_eval + data_root = opt.data_root + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.n_future = n_future + opt.lr = lr + opt.batch_size = batch_size + opt.n_eval = n_eval + opt.data_root = data_root + opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) +else: + name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + elif opt.dataset == 'mcs': + opt.log_dir = '%s/%s/%s/%s' % (opt.log_dir, opt.dataset, opt.mcs_task, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=False) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=False) +with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: + opt2 = opt.__dict__.copy() + if isinstance(opt2['optimizer'], type): + opt2['optimizer'] = str(opt2['optimizer']) + json.dump(opt2, f, indent=2) + del opt2 + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +elif isinstance(opt.optimizer, type): + pass +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = opt.batch_size + prior = saved_model['prior'] + prior.batch_size = opt.batch_size + posterior = saved_model['posterior'] + posterior.batch_size = opt.batch_size +else: + frame_predictor = lstm_models.lstm(opt.g_dim + opt.z_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + posterior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.posterior_rnn_layers, opt.batch_size) + prior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.prior_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + posterior.apply(utils.init_weights) + prior.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, 1) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +if opt.optimizer == optim.SGD: + frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, momentum=opt.beta1) + posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, momentum=opt.beta1) + prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, momentum=opt.beta1) + encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, momentum=opt.beta1) + decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, momentum=opt.beta1) +else: + frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +prior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True,) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [utils.torch_rgb_img_to_gray(x[t]) for t in range(len(x))] + + # h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + gen_seq[s].append(utils.torch_rgb_img_to_gray(x[0])) + x_in = x[0] + for i in range(1, opt.n_eval): + with torch.no_grad(): + # if input is grayscale + if x_in.shape[1] == 1 and opt.channels == 3: + h = encoder(torch.cat(3*[x_in], dim=1)) # convert to RGB + else: + h = encoder(x_in) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + if i < opt.n_past: + h_target = encoder(x[i]) + z_t = posterior(h_target[0].detach()) + prior(h) + frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[i] + gen_seq[s].append(utils.torch_rgb_img_to_gray(x_in)) + else: + z_t_hat = prior(h) + h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_in = decoder([h, skip]) + gen_seq[s].append(x_in) + + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + gen_seq = [utils.torch_rgb_img_to_gray(x[0])] + gen_seq_post = [utils.torch_rgb_img_to_gray(x[0])] + + # prediction using posterior Z + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + z_t = posterior(h_target[0]) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t], 1)) + gen_seq_post.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq_post.append(x_pred) + h = h_target + + # prediction using prior Z + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + z_t_hat = prior(h) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t_hat], 1)) + gen_seq.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + h = h_target + + to_plot = [] + nrow = min(opt.batch_size, 25) + x_gray = [utils.torch_rgb_img_to_gray(x[t]) for t in range(min(opt.n_eval, opt.n_past+opt.n_future))] + for i in range(nrow): + row_gt = [] + row_post = [] + row = [] + for t in range(min(opt.n_eval, opt.n_past+opt.n_future)): + row_gt.append(x_gray[t][i]) + row_post.append(gen_seq_post[t][i]) + row.append(gen_seq[t][i]) + to_plot.append(row_gt) + to_plot.append(row_post) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + posterior.zero_grad() + prior.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + + mse = 0 + mse_residual = 0 + h = encoder(x[0]) + for i in range(1, opt.n_past+opt.n_future): + h_target = encoder(x[i]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h = h[0] + + z_t = posterior(h_target[0]) + z_t_hat = prior(h) + h_pred = frame_predictor(torch.cat([h, z_t], 1)) + x_pred = decoder([h_pred, skip]) + gray_target_frame = utils.torch_rgb_img_to_gray(x[i]) + mse += mse_criterion(x_pred, gray_target_frame) + # penalize prior for being far from posterior + mse_residual += opt.gamma * torch.mean(torch.square(z_t.detach() - z_t_hat)) + h = h_target + + loss = mse + mse_residual + loss.backward() + + frame_predictor_optimizer.step() + posterior_optimizer.step() + prior_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + N = opt.n_past+opt.n_future + return mse.data.cpu().numpy()/N, mse_residual.data.cpu().numpy()/N + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + posterior.train() + prior.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_mse_residual = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + # opt.epoch_size = 10 + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse, mse_residual = train(x) + epoch_mse += mse + epoch_mse_residual += mse_residual + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)\n' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + posterior.eval() + prior.eval() + encoder.eval() + decoder.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'posterior': posterior, + 'prior': prior, + 'opt': opt}, + '%s/model_e%02d.pth' % (opt.log_dir, epoch)) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + diff --git a/train_baseline_spatialTemporalContinuity.py b/train_baseline_spatialTemporalContinuity.py new file mode 100644 index 0000000..26af3b9 --- /dev/null +++ b/train_baseline_spatialTemporalContinuity.py @@ -0,0 +1,447 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.0002, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=16, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=40, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=500, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int, help='number of channels for input images. ') +parser.add_argument('--use_edge_kernels', default=True, type=bool, help='whether to use edge kernels to reduce to 1 channel') +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=25, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--start_min', type=int, default=80, help='min starting time for sampling sequence (0-indexed)') +parser.add_argument('--start_max', type=int, default=135, help='max starting time for sampling sequence (0-indexed)') +parser.add_argument('--sequence_stride', type=int, default=1, help='factor for sequence temporal subsampling (int)') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--prior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=32, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') +parser.add_argument('--model', default='vgg', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') + + +opt = parser.parse_args() +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + lr = opt.lr + batch_size = opt.batch_size + n_future = opt.n_future + n_eval = opt.n_eval + data_root = opt.data_root + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.n_future = n_future + opt.lr = lr + opt.batch_size = batch_size + opt.n_eval = n_eval + opt.data_root = data_root + opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) +else: + name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + elif opt.dataset == 'mcs': + opt.log_dir = '%s/%s/%s/%s' % (opt.log_dir, opt.dataset, opt.mcs_task, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=False) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=False) +with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: + opt2 = opt.__dict__.copy() + if isinstance(opt2['optimizer'], type): + opt2['optimizer'] = str(opt2['optimizer']) + json.dump(opt2, f, indent=2) + del opt2 + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +elif isinstance(opt.optimizer, type): + pass +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = opt.batch_size + prior = saved_model['prior'] + prior.batch_size = opt.batch_size + posterior = saved_model['posterior'] + posterior.batch_size = opt.batch_size +else: + frame_predictor = lstm_models.lstm(opt.g_dim + opt.z_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + posterior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.posterior_rnn_layers, opt.batch_size) + prior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.prior_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + posterior.apply(utils.init_weights) + prior.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, 1) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +if opt.optimizer == optim.SGD: + frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, momentum=opt.beta1) + posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, momentum=opt.beta1) + prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, momentum=opt.beta1) + encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, momentum=opt.beta1) + decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, momentum=opt.beta1) +else: + frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +prior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True,) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [utils.torch_rgb_img_to_gray(x[t]) for t in range(len(x))] + + # h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + gen_seq[s].append(utils.torch_rgb_img_to_gray(x[0])) + x_in = x[0] + for i in range(1, opt.n_eval): + with torch.no_grad(): + # if input is grayscale + if x_in.shape[1] == 1 and opt.channels == 3: + h = encoder(torch.cat(3*[x_in], dim=1)) # convert to RGB + else: + h = encoder(x_in) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + if i < opt.n_past: + h_target = encoder(x[i]) + z_t = posterior(h_target[0].detach()) + prior(h) + frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[i] + gen_seq[s].append(utils.torch_rgb_img_to_gray(x_in)) + else: + z_t_hat = prior(h) + h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_in = decoder([h, skip]) + gen_seq[s].append(x_in) + + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + gen_seq = [utils.torch_rgb_img_to_gray(x[0])] + gen_seq_post = [utils.torch_rgb_img_to_gray(x[0])] + + # prediction using posterior Z + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + z_t = posterior(h_target[0]) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t], 1)) + gen_seq_post.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq_post.append(x_pred) + h = h_target + + # prediction using prior Z + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + z_t_hat = prior(h) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t_hat], 1)) + gen_seq.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + h = h_target + + to_plot = [] + nrow = min(opt.batch_size, 25) + x_gray = [utils.torch_rgb_img_to_gray(x[t]) for t in range(min(opt.n_eval, opt.n_past+opt.n_future))] + for i in range(nrow): + row_gt = [] + row_post = [] + row = [] + for t in range(min(opt.n_eval, opt.n_past+opt.n_future)): + row_gt.append(x_gray[t][i]) + row_post.append(gen_seq_post[t][i]) + row.append(gen_seq[t][i]) + to_plot.append(row_gt) + to_plot.append(row_post) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + posterior.zero_grad() + prior.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + + mse = 0 + mse_residual = 0 + h = encoder(x[0]) + for i in range(1, opt.n_past+opt.n_future): + h_target = encoder(x[i]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h = h[0] + + z_t = posterior(h_target[0]) + z_t_hat = prior(h) + h_pred = frame_predictor(torch.cat([h, z_t], 1)) + x_pred = decoder([h_pred, skip]) + gray_target_frame = utils.torch_rgb_img_to_gray(x[i]) + mse += mse_criterion(x_pred, gray_target_frame) + # penalize prior for being far from posterior + mse_residual += opt.gamma * torch.mean(torch.square(z_t.detach() - z_t_hat)) + h = h_target + + loss = mse + mse_residual + loss.backward() + + frame_predictor_optimizer.step() + posterior_optimizer.step() + prior_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + N = opt.n_past+opt.n_future + return mse.data.cpu().numpy()/N, mse_residual.data.cpu().numpy()/N + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + posterior.train() + prior.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_mse_residual = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + # opt.epoch_size = 10 + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse, mse_residual = train(x) + epoch_mse += mse + epoch_mse_residual += mse_residual + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)\n' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + posterior.eval() + prior.eval() + encoder.eval() + decoder.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'posterior': posterior, + 'prior': prior, + 'opt': opt}, + '%s/model_e%02d.pth' % (opt.log_dir, epoch)) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + diff --git a/train_svg_nonstochastic.py b/train_svg_nonstochastic.py new file mode 100644 index 0000000..2ad20ab --- /dev/null +++ b/train_svg_nonstochastic.py @@ -0,0 +1,331 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.004, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=168, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=400, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=10, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=8, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') + + +opt = parser.parse_args() +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.log_dir = '%s/continued' % opt.log_dir +else: + name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + elif opt.dataset == 'mcs': + opt.log_dir = '%s/%s/%s/%s' % (opt.log_dir, opt.dataset, opt.mcs_task, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=True) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=True) + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] +else: + frame_predictor = lstm_models.lstm(opt.g_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, opt.channels) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i-1] + h = h.detach() + elif i < opt.n_past: + h, _ = h_seq[i-1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + x_in = x[i] + gen_seq[s].append(x_in) + else: + h = frame_predictor(h).detach() + x_in = decoder([h, skip]).detach() + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + gen_seq = [] + gen_seq.append(x[0]) + x_in = x[0] + h_seq = [encoder(x[i]) for i in range(opt.n_past+opt.n_future)] + for i in range(1, opt.n_past+opt.n_future): + h_target = h_seq[i][0].detach() + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i-1] + else: + h, _ = h_seq[i-1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + gen_seq.append(x[i]) + else: + h_pred = frame_predictor(h).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + + to_plot = [] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + row = [] + for t in range(opt.n_past+opt.n_future): + row.append(gen_seq[t][i]) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + + h_seq = [encoder(x[i]) for i in range(opt.n_past+opt.n_future)] + mse = 0 + kld = 0 + for i in range(1, opt.n_past+opt.n_future): + h_target = h_seq[i][0] + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i-1] + else: + h = h_seq[i-1][0] + h_pred = frame_predictor(torch.cat([h], 1)) + x_pred = decoder([h_pred, skip]) + mse += mse_criterion(x_pred, x[i]) + + loss = mse + loss.backward() + + frame_predictor_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + return mse.data.cpu().numpy()/(opt.n_past+opt.n_future) + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_kld = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse = train(x) + epoch_mse += mse + epoch_kld += 0 + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f | kld loss: %.5f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_kld/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + encoder.eval() + decoder.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'opt': opt}, + '%s/model_e%02d.pth' % (opt.log_dir, epoch)) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + diff --git a/train_svg_nonstochastic_posterior.py b/train_svg_nonstochastic_posterior.py new file mode 100644 index 0000000..5008401 --- /dev/null +++ b/train_svg_nonstochastic_posterior.py @@ -0,0 +1,434 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.0002, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=24, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=500, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=3, type=int, help='number of channels for input images. ') +parser.add_argument('--use_edge_kernels', action='store_true') +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=15, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--prior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=32, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') +parser.add_argument('--model', default='vgg', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') + + +opt = parser.parse_args() +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + lr = opt.lr + batch_size = opt.batch_size + n_future = opt.n_future + n_eval = opt.n_eval + data_root = opt.data_root + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.n_future = n_future + opt.lr = lr + opt.batch_size = batch_size + opt.n_eval = n_eval + opt.data_root = data_root + opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) +else: + name = 'RGBmodel=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + elif opt.dataset == 'mcs': + opt.log_dir = '%s/%s/%s/%s' % (opt.log_dir, opt.dataset, opt.mcs_task, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=False) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=False) +with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: + opt2 = opt.__dict__.copy() + if isinstance(opt2['optimizer'], type): + opt2['optimizer'] = str(opt2['optimizer']) + json.dump(opt2, f, indent=2) + del opt2 + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +elif isinstance(opt.optimizer, type): + pass +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = opt.batch_size + prior = saved_model['prior'] + prior.batch_size = opt.batch_size + posterior = saved_model['posterior'] + posterior.batch_size = opt.batch_size +else: + frame_predictor = lstm_models.lstm(opt.g_dim + opt.z_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + posterior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.posterior_rnn_layers, opt.batch_size) + prior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.prior_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + posterior.apply(utils.init_weights) + prior.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, 1) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +prior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True,) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [utils.torch_rgb_img_to_gray(x[t]) for t in range(len(x))] + + # h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + gen_seq[s].append(utils.torch_rgb_img_to_gray(x[0])) + x_in = x[0] + for i in range(1, opt.n_eval): + with torch.no_grad(): + # if input is grayscale + if x_in.shape[1] == 1 and opt.channels == 3: + h = encoder(torch.cat(3*[x_in], dim=1)) # convert to RGB + else: + h = encoder(x_in) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + if i < opt.n_past: + h_target = encoder(x[i]) + z_t = posterior(h_target[0].detach()) + prior(h) + frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[i] + gen_seq[s].append(utils.torch_rgb_img_to_gray(x_in)) + else: + z_t_hat = prior(h) + h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_in = decoder([h, skip]) + gen_seq[s].append(x_in) + + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + gen_seq = [utils.torch_rgb_img_to_gray(x[0])] + gen_seq_post = [utils.torch_rgb_img_to_gray(x[0])] + + # prediction using posterior Z + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, opt.n_past+opt.n_future): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + z_t = posterior(h_target[0]) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t], 1)) + gen_seq_post.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq_post.append(x_pred) + h = h_target + + # prediction using prior Z + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, opt.n_past+opt.n_future): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + z_t_hat = prior(h) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t_hat], 1)) + gen_seq.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + h = h_target + + to_plot = [] + nrow = min(opt.batch_size, 25) + x_gray = [utils.torch_rgb_img_to_gray(x[t]) for t in range(opt.n_past+opt.n_future)] + for i in range(nrow): + row_gt = [] + row_post = [] + row = [] + for t in range(opt.n_past+opt.n_future): + row_gt.append(x_gray[t][i]) + row_post.append(gen_seq_post[t][i]) + row.append(gen_seq[t][i]) + to_plot.append(row_gt) + to_plot.append(row_post) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + posterior.zero_grad() + prior.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + + mse = 0 + mse_residual = 0 + h = encoder(x[0]) + for i in range(1, opt.n_past+opt.n_future): + h_target = encoder(x[i]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h = h[0] + + z_t = posterior(h_target[0]) + z_t_hat = prior(h) + h_pred = frame_predictor(torch.cat([h, z_t], 1)) + x_pred = decoder([h_pred, skip]) + gray_target_frame = utils.torch_rgb_img_to_gray(x[i]) + mse += mse_criterion(x_pred, gray_target_frame) + # penalize prior for being far from posterior + mse_residual += opt.gamma * torch.mean(torch.square(z_t.detach() - z_t_hat)) + h = h_target + + loss = mse + mse_residual + loss.backward() + + frame_predictor_optimizer.step() + posterior_optimizer.step() + prior_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + N = opt.n_past+opt.n_future + return mse.data.cpu().numpy()/N, mse_residual.data.cpu().numpy()/N + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + posterior.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_mse_residual = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + # opt.epoch_size = 10 + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse, mse_residual = train(x) + epoch_mse += mse + epoch_mse_residual += mse_residual + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + posterior.eval() + prior.eval() + encoder.eval() + decoder.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'posterior': posterior, + 'prior': prior, + 'opt': opt}, + '%s/model_e%02d.pth' % (opt.log_dir, epoch)) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + diff --git a/utils.py b/utils.py index 1d74bbc..6050384 100755 --- a/utils.py +++ b/utils.py @@ -7,72 +7,178 @@ from sklearn.manifold import TSNE import scipy.misc import matplotlib + matplotlib.use('agg') import matplotlib.pyplot as plt import functools -from skimage.measure import compare_psnr as psnr_metric -from skimage.measure import compare_ssim as ssim_metric +from skimage.metrics import peak_signal_noise_ratio as psnr_metric +from skimage.metrics import structural_similarity as ssim_metric from scipy import signal from scipy import ndimage from PIL import Image, ImageDraw - from torchvision import datasets, transforms from torch.autograd import Variable import imageio - hostname = socket.gethostname() -def load_dataset(opt): +RGB_weights = torch.tensor(np.array([0.299, 0.587, 0.114]), dtype=torch.float32, device=torch.device('cuda:0'), + requires_grad=False).detach() + + +def torch_rgb_img_to_gray(tensor): + # in: Bx3xHxW out: Bx1xHxW + if tensor.shape[1] == 1: + return tensor + # assert tensor.shape[1] == 3 # make sure input image has 3 (RGB) channels + tensor = torch.transpose(tensor, 1, 3) # B x W x H x 3 + tensor = torch.unsqueeze(torch.matmul(tensor, RGB_weights), -1) # B x W x H x 1 + tensor = torch.transpose(tensor, 3, 1) # B x 1 x H x W + return tensor + + +def torch_tensor_to_img(tensor): + image_array = tensor.numpy() + image_array -= np.min(image_array) + image_array = np.minimum(image_array, 1.0) + # print(image_array.shape) + image_array = np.transpose(image_array, (1, 2, 0)) + img = None + if image_array.shape[2] == 3: # 3-channel image + # array is grayscale, but we convert to RGB + img = Image.fromarray((image_array * 255).astype('uint8'), mode='RGB') + else: + img = Image.fromarray((image_array * 255).astype('uint8'), mode='L').convert('RGB') + return img + + +def load_dataset(opt, sequential=None, implausible=None): + train_data = None + test_data = None if opt.dataset == 'smmnist': from data.moving_mnist import MovingMNIST train_data = MovingMNIST( - train=True, - data_root=opt.data_root, - seq_len=opt.n_past+opt.n_future, - image_size=opt.image_width, - deterministic=False, - num_digits=opt.num_digits) + train=True, + data_root=opt.data_root, + seq_len=opt.n_past + opt.n_future, + image_size=opt.image_width, + deterministic=False, + num_digits=opt.num_digits) test_data = MovingMNIST( - train=False, - data_root=opt.data_root, - seq_len=opt.n_eval, - image_size=opt.image_width, - deterministic=False, - num_digits=opt.num_digits) + train=False, + data_root=opt.data_root, + seq_len=opt.n_eval, + image_size=opt.image_width, + deterministic=False, + num_digits=opt.num_digits) elif opt.dataset == 'bair': - from data.bair import RobotPush + from data.bair import RobotPush train_data = RobotPush( - data_root=opt.data_root, - train=True, - seq_len=opt.n_past+opt.n_future, - image_size=opt.image_width) + data_root=opt.data_root, + train=True, + seq_len=opt.n_past + opt.n_future, + image_size=opt.image_width) test_data = RobotPush( - data_root=opt.data_root, - train=False, - seq_len=opt.n_eval, - image_size=opt.image_width) + data_root=opt.data_root, + train=False, + seq_len=opt.n_eval, + image_size=opt.image_width) elif opt.dataset == 'kth': - from data.kth import KTH + from data.kth import KTH train_data = KTH( - train=True, - data_root=opt.data_root, - seq_len=opt.n_past+opt.n_future, - image_size=opt.image_width) + train=True, + data_root=opt.data_root, + seq_len=opt.n_past + opt.n_future, + image_size=opt.image_width) test_data = KTH( - train=False, - data_root=opt.data_root, - seq_len=opt.n_eval, - image_size=opt.image_width) - + train=False, + data_root=opt.data_root, + seq_len=opt.n_eval, + image_size=opt.image_width) + elif opt.dataset == 'mcs': + from data.mcs import MCS + is_shape_constancy = opt.is_shape_constancy if "is_shape_constancy" in opt.__dict__ else False + reduce_static_frames = opt.reduce_static_frames if "reduce_static_frames" in opt.__dict__ else False + train_data = MCS( + train=True, + data_root=opt.data_root, + seq_len=opt.n_past + opt.n_future, + image_size=opt.image_width, + task=opt.mcs_task, + sequential=sequential, + implausible=implausible, + im_channels=opt.channels, + use_edge_kernels=opt.use_edge_kernels, + start_min=opt.start_min, + start_max=opt.start_max, + sequence_stride=opt.sequence_stride, + reduce_static_frames=reduce_static_frames, + is_shape_constancy=is_shape_constancy,) + test_data = MCS( + train=False, + data_root=opt.data_root, + seq_len=opt.n_eval, + image_size=opt.image_width, + task=opt.mcs_task, + sequential=sequential, + implausible=implausible, + im_channels=opt.channels, + use_edge_kernels=opt.use_edge_kernels, + start_min=opt.start_min, + start_max=opt.start_max, + sequence_stride=opt.sequence_stride, + reduce_static_frames=reduce_static_frames, + is_shape_constancy=is_shape_constancy,) + elif opt.dataset == 'mcs_test': + is_shape_constancy = opt.is_shape_constancy if "is_shape_constancy" in opt.__dict__ else False + reduce_static_frames = opt.reduce_static_frames if "reduce_static_frames" in opt.__dict__ else False + from data.mcs import MCS + train_data = MCS( + train=True, + data_root=opt.data_root, + seq_len=opt.n_past + opt.n_future, + image_size=opt.image_width, + task=opt.mcs_task, + sequential=sequential, + implausible=implausible, + test_set=True, + im_channels=opt.channels, + use_edge_kernels=opt.use_edge_kernels, + start_min=opt.start_min, + start_max=opt.start_max, + sequence_stride=opt.sequence_stride, + reduce_static_frames=reduce_static_frames, + is_shape_constancy=is_shape_constancy, + labels=True) + test_data = MCS( + train=False, + data_root=opt.data_root, + seq_len=opt.n_eval, + image_size=opt.image_width, + task=opt.mcs_task, + sequential=sequential, + implausible=implausible, + test_set=True, + im_channels=opt.channels, + use_edge_kernels=opt.use_edge_kernels, + start_min=opt.start_min, + start_max=opt.start_max, + sequence_stride=opt.sequence_stride, + reduce_static_frames=reduce_static_frames, + is_shape_constancy=is_shape_constancy, + labels=True) + return train_data, test_data + def sequence_input(seq, dtype): return [Variable(x.type(dtype)) for x in seq] + def normalize_data(opt, dtype, sequence): - if opt.dataset == 'smmnist' or opt.dataset == 'kth' or opt.dataset == 'bair' : + if opt.dataset == 'smmnist' or opt.dataset == 'kth' or opt.dataset == 'bair' or opt.dataset == 'mcs'\ + or opt.dataset == 'mcs_test': sequence.transpose_(0, 1) sequence.transpose_(3, 4).transpose_(2, 3) else: @@ -80,12 +186,14 @@ def normalize_data(opt, dtype, sequence): return sequence_input(sequence, dtype) + def is_sequence(arg): return (not hasattr(arg, "strip") and not type(arg) is np.ndarray and not hasattr(arg, "dot") and (hasattr(arg, "__getitem__") or - hasattr(arg, "__iter__"))) + hasattr(arg, "__iter__"))) + def image_tensor(inputs, padding=1): # assert is_sequence(inputs) @@ -105,11 +213,11 @@ def image_tensor(inputs, padding=1): y_dim = images[0].size(1) result = torch.ones(c_dim, - x_dim * len(images) + padding * (len(images)-1), + x_dim * len(images) + padding * (len(images) - 1), y_dim) for i, image in enumerate(images): - result[:, i * x_dim + i * padding : - (i+1) * x_dim + i * padding, :].copy_(image) + result[:, i * x_dim + i * padding: + (i + 1) * x_dim + i * padding, :].copy_(image) return result @@ -129,82 +237,92 @@ def image_tensor(inputs, padding=1): result = torch.ones(c_dim, x_dim, - y_dim * len(images) + padding * (len(images)-1)) + y_dim * len(images) + padding * (len(images) - 1)) for i, image in enumerate(images): - result[:, :, i * y_dim + i * padding : - (i+1) * y_dim + i * padding].copy_(image) + result[:, :, i * y_dim + i * padding: + (i + 1) * y_dim + i * padding].copy_(image) return result + def save_np_img(fname, x): if x.shape[0] == 1: x = np.tile(x, (3, 1, 1)) - img = scipy.misc.toimage(x, - high=255*x.max(), - channel_axis=0) + img = torch_tensor_to_img(x) img.save(fname) + def make_image(tensor): tensor = tensor.cpu().clamp(0, 1) if tensor.size(0) == 1: tensor = tensor.expand(3, tensor.size(1), tensor.size(2)) # pdb.set_trace() - return scipy.misc.toimage(tensor.numpy(), - high=255*tensor.max(), - channel_axis=0) + return torch_tensor_to_img(tensor) + def draw_text_tensor(tensor, text): np_x = tensor.transpose(0, 1).transpose(1, 2).data.cpu().numpy() - pil = Image.fromarray(np.uint8(np_x*255)) + pil = Image.fromarray(np.uint8(np_x * 255)) draw = ImageDraw.Draw(pil) - draw.text((4, 64), text, (0,0,0)) + draw.text((4, 64), text, (0, 0, 0)) img = np.asarray(pil) return Variable(torch.Tensor(img / 255.)).transpose(1, 2).transpose(0, 1) + def save_gif(filename, inputs, duration=0.25): images = [] for tensor in inputs: img = image_tensor(tensor, padding=0) img = img.cpu() - img = img.transpose(0,1).transpose(1,2).clamp(0,1) - images.append(img.numpy()) + img = img.transpose(0, 1).transpose(1, 2).clamp(0, 1) + images.append((img.numpy() * 255).astype(np.uint8)) imageio.mimsave(filename, images, duration=duration) + def save_gif_with_text(filename, inputs, text, duration=0.25): images = [] for tensor, text in zip(inputs, text): img = image_tensor([draw_text_tensor(ti, texti) for ti, texti in zip(tensor, text)], padding=0) img = img.cpu() - img = img.transpose(0,1).transpose(1,2).clamp(0,1).numpy() + img = img.transpose(0, 1).transpose(1, 2).clamp(0, 1).numpy() images.append(img) imageio.mimsave(filename, images, duration=duration) + def save_image(filename, tensor): img = make_image(tensor) img.save(filename) + def save_tensors_image(filename, inputs, padding=1): images = image_tensor(inputs, padding) return save_image(filename, images) + def prod(l): return functools.reduce(lambda x, y: x * y, l) + def batch_flatten(x): return x.resize(x.size(0), prod(x.size()[1:])) + def clear_progressbar(): - # moves up 3 lines - print("\033[2A") - # deletes the whole line, regardless of character position - print("\033[2K") - # moves up two lines again - print("\033[2A") + # # moves up 3 lines + # print("\033[2A") + # # deletes the whole line, regardless of character position + # print("\033[2K") + # # moves up two lines again + # print("\033[2A") + print('\r') + print(' ' * 80) + print('\r') def mse_metric(x1, x2): err = np.sum((x1 - x2) ** 2) err /= float(x1.shape[0] * x1.shape[1] * x1.shape[2]) return err + def eval_seq(gt, pred): T = len(gt) bs = gt[0].shape[0] @@ -222,6 +340,7 @@ def eval_seq(gt, pred): return mse, ssim, psnr + # ssim function used in Babaeizadeh et al. (2017), Fin et al. (2016), etc. def finn_eval_seq(gt, pred): T = len(gt) @@ -246,21 +365,23 @@ def finn_eval_seq(gt, pred): def finn_psnr(x, y): - mse = ((x - y)**2).mean() - return 10*np.log(1/mse)/np.log(10) + mse = ((x - y) ** 2).mean() + return 10 * np.log(1 / mse) / np.log(10) def gaussian2(size, sigma): - A = 1/(2.0*np.pi*sigma**2) - x, y = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] - g = A*np.exp(-((x**2/(2.0*sigma**2))+(y**2/(2.0*sigma**2)))) + A = 1 / (2.0 * np.pi * sigma ** 2) + x, y = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1] + g = A * np.exp(-((x ** 2 / (2.0 * sigma ** 2)) + (y ** 2 / (2.0 * sigma ** 2)))) return g + def fspecial_gauss(size, sigma): - x, y = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] - g = np.exp(-((x**2 + y**2)/(2.0*sigma**2))) - return g/g.sum() - + x, y = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1] + g = np.exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2))) + return g / g.sum() + + def finn_ssim(img1, img2, cs_map=False): img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) @@ -269,24 +390,24 @@ def finn_ssim(img1, img2, cs_map=False): window = fspecial_gauss(size, sigma) K1 = 0.01 K2 = 0.03 - L = 1 #bitdepth of image - C1 = (K1*L)**2 - C2 = (K2*L)**2 + L = 1 # bitdepth of image + C1 = (K1 * L) ** 2 + C2 = (K2 * L) ** 2 mu1 = signal.fftconvolve(img1, window, mode='valid') mu2 = signal.fftconvolve(img2, window, mode='valid') - mu1_sq = mu1*mu1 - mu2_sq = mu2*mu2 - mu1_mu2 = mu1*mu2 - sigma1_sq = signal.fftconvolve(img1*img1, window, mode='valid') - mu1_sq - sigma2_sq = signal.fftconvolve(img2*img2, window, mode='valid') - mu2_sq - sigma12 = signal.fftconvolve(img1*img2, window, mode='valid') - mu1_mu2 + mu1_sq = mu1 * mu1 + mu2_sq = mu2 * mu2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = signal.fftconvolve(img1 * img1, window, mode='valid') - mu1_sq + sigma2_sq = signal.fftconvolve(img2 * img2, window, mode='valid') - mu2_sq + sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid') - mu1_mu2 if cs_map: - return (((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)* - (sigma1_sq + sigma2_sq + C2)), - (2.0*sigma12 + C2)/(sigma1_sq + sigma2_sq + C2)) + return (((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)), + (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)) else: - return ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)* - (sigma1_sq + sigma2_sq + C2)) + return ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) def init_weights(m): @@ -297,4 +418,3 @@ def init_weights(m): elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) -