From 7200ff2aa7b3868c0ec26567b47b3ae63f337dd8 Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Thu, 1 Jul 2021 14:39:40 -0400 Subject: [PATCH 01/26] Updated imports to match new skimage module names --- utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils.py b/utils.py index 1d74bbc..21f4014 100755 --- a/utils.py +++ b/utils.py @@ -10,8 +10,8 @@ matplotlib.use('agg') import matplotlib.pyplot as plt import functools -from skimage.measure import compare_psnr as psnr_metric -from skimage.measure import compare_ssim as ssim_metric +from skimage.metrics import peak_signal_noise_ratio as psnr_metric +from skimage.metrics import structural_similarity as ssim_metric from scipy import signal from scipy import ndimage from PIL import Image, ImageDraw From 9c0ee4ad2373d201335a29574ba4dbc7b681a961 Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Tue, 6 Jul 2021 22:01:15 -0400 Subject: [PATCH 02/26] MCS dataset --- data/convert_mcs.py | 54 ++++++ data/mcs.py | 66 ++++++++ train_svg_nonstochastic.py | 333 +++++++++++++++++++++++++++++++++++++ utils.py | 44 ++++- 4 files changed, 489 insertions(+), 8 deletions(-) create mode 100644 data/convert_mcs.py create mode 100644 data/mcs.py create mode 100644 train_svg_nonstochastic.py diff --git a/data/convert_mcs.py b/data/convert_mcs.py new file mode 100644 index 0000000..562ecdb --- /dev/null +++ b/data/convert_mcs.py @@ -0,0 +1,54 @@ +import os +import glob +import argparse +from os import path +import subprocess +from tqdm import tqdm +import threading + +parser = argparse.ArgumentParser() +parser.add_argument("-d", "--datadir", type=str, help="path to the mcs video dataset folder", + default="data/mcs_videos_1000") +parser.add_argument("-s", "--imsize", type=int, help="width of converted (square) images. default 64.", default=64) +args = parser.parse_args() + +DATA_ROOT = args.datadir +IMSIZE = args.imsize + +if not path.exists(DATA_ROOT): + print(f'directory "{DATA_ROOT}" does not exist! Check arguments') +elif not path.exists(path.join(DATA_ROOT, 'raw')): + print("Training videos must be in [datadir]/raw/(task)/*.mp4, where task is the task to which " + "the video belongs") + +def mp4_to_png_worker(path_to_task): + task_name = path.basename(path_to_task[:-1]) # remove the '/' at the end of path to a folder + vids = glob.glob(path.join(path_to_task, '*.mp4')) + vids_tqdm = tqdm(sorted(vids)) + vids_tqdm.set_description(f'Task {task_name}') + for vid in vids_tqdm: + sample_name = path.basename(vid) + sample_name = sample_name[:sample_name.rfind('.')] # get rid of extension name + out_folder = path.join(DATA_ROOT, 'processed', task_name, sample_name) + os.makedirs(out_folder, exist_ok=True) + ffmpeg_exe = 'ffmpeg' + ffmpeg_args = f'-i "{vid}" -s {IMSIZE}x{IMSIZE} "{path.join(out_folder, sample_name + "_%04d.png")}"' + + try: + subprocess.check_call(ffmpeg_exe + ' ' + ffmpeg_args, shell=True, stdout=subprocess.DEVNULL + , stderr=subprocess.DEVNULL) + except subprocess.CalledProcessError: + print('convert_mcs.py: ffmpeg execution failed. Check that you have installed the ffmpeg executable' + 'by typing "ffmpeg" in your terminal.') + quit() + + +tasks = glob.glob(path.join(DATA_ROOT, 'raw', '*/')) +threads = [] +for task in sorted(tasks): + t = threading.Thread(target=mp4_to_png_worker, args=[task], daemon=True) + t.start() + threads.append(t) + +for thread in threads: + thread.join() \ No newline at end of file diff --git a/data/mcs.py b/data/mcs.py new file mode 100644 index 0000000..a62afa4 --- /dev/null +++ b/data/mcs.py @@ -0,0 +1,66 @@ +import logging +import random +import os +import numpy as np +from glob import glob +import torch +import scipy.misc +import imageio +from os import path + +class MCS(object): + + def __init__(self, train, data_root, seq_len = 20, image_size=64, task='ALL'): + self.data_root = '%s/mcs_videos_1000/processed/' % data_root + if not os.path.exists(self.data_root): + raise os.error('data/mcs.py: Data directory not found!') + self.seq_len = seq_len + self.image_size = image_size + + # print('mcs.py: found tasks ', self.tasks) + self.video_folders = {} + if task == 'ALL': + self.tasks = [os.path.basename(folder) for folder in glob(path.join(self.data_root, '*'))] + else: + self.tasks = [task] + + for task in self.tasks: + self.video_folders[task] = [path.basename(folder) for folder in glob(path.join(self.data_root, task, '*'))] + + self.seed_set = False + + def get_sequence(self): + task = random.choice(self.tasks) + vid = random.choice(self.video_folders[task]) + num_frames = len(next(os.walk(path.join(self.data_root, task, vid)))[2]) # dir is your directory path as string + + frame_path = path.join(self.data_root, task, vid, vid + '_') + + start = random.randint(0, num_frames-self.seq_len) + seq = [] + for i in range(start, start+self.seq_len): + # i is 0-indexed so we need to add 1 to i + fname = frame_path + f'{i + 1:04d}.png' + im = imageio.imread(fname)/255. + gray = lambda rgb: np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) + im = gray(im)[..., np.newaxis] + seq.append(im) + return np.array(seq) + + def __getitem__(self, index): + if not self.seed_set: + self.seed_set = True + random.seed(index) + np.random.seed(index) + #torch.manual_seed(index) + return torch.from_numpy(self.get_sequence()) + + def __len__(self): + return 5*1000*200 # approximate + + +# if __name__ == '__main__': +# m = MCS(True, '/home/lol/Hub/svg/data/') +# s = m.__getitem__(0).cuda() +# print(s.device) + diff --git a/train_svg_nonstochastic.py b/train_svg_nonstochastic.py new file mode 100644 index 0000000..ce2adda --- /dev/null +++ b/train_svg_nonstochastic.py @@ -0,0 +1,333 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.004, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=168, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=400, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=10, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=8, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') + + +opt = parser.parse_args() +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + optimizer = opt.optimizer + model_dir = opt.model_dir + niter = opt.niter + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.optimizer = optimizer + opt.model_dir = model_dir + opt.log_dir = '%s/continued' % opt.log_dir +else: + name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + elif opt.dataset == 'mcs': + opt.log_dir = '%s/%s/%s/%s' % (opt.log_dir, opt.dataset, opt.mcs_task, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=True) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=True) + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] +else: + frame_predictor = lstm_models.lstm(opt.g_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, opt.channels) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i-1] + h = h.detach() + elif i < opt.n_past: + h, _ = h_seq[i-1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + x_in = x[i] + gen_seq[s].append(x_in) + else: + h = frame_predictor(h).detach() + x_in = decoder([h, skip]).detach() + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + gen_seq = [] + gen_seq.append(x[0]) + x_in = x[0] + h_seq = [encoder(x[i]) for i in range(opt.n_past+opt.n_future)] + for i in range(1, opt.n_past+opt.n_future): + h_target = h_seq[i][0].detach() + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i-1] + else: + h, _ = h_seq[i-1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + gen_seq.append(x[i]) + else: + h_pred = frame_predictor(h).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + + to_plot = [] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + row = [] + for t in range(opt.n_past+opt.n_future): + row.append(gen_seq[t][i]) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + + h_seq = [encoder(x[i]) for i in range(opt.n_past+opt.n_future)] + mse = 0 + kld = 0 + for i in range(1, opt.n_past+opt.n_future): + h_target = h_seq[i][0] + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i-1] + else: + h = h_seq[i-1][0] + h_pred = frame_predictor(torch.cat([h], 1)) + x_pred = decoder([h_pred, skip]) + mse += mse_criterion(x_pred, x[i]) + + loss = mse + loss.backward() + + frame_predictor_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + return mse.data.cpu().numpy()/(opt.n_past+opt.n_future) + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_kld = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse = train(x) + epoch_mse += mse + epoch_kld += 0 + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f | kld loss: %.5f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_kld/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + encoder.eval() + decoder.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'opt': opt}, + '%s/model_e%02d.pth' % (opt.log_dir, epoch)) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + diff --git a/utils.py b/utils.py index 21f4014..d684fe7 100755 --- a/utils.py +++ b/utils.py @@ -24,7 +24,25 @@ hostname = socket.gethostname() + +def torch_tensor_to_img(tensor): + image_array = tensor.numpy() + image_array -= np.min(image_array) + image_array = np.minimum(image_array, 1.0) + print(image_array.shape) + image_array = np.transpose(image_array, (1, 2, 0)) + img = None + if image_array.shape[2] == 3: # 3-channel image + # array is grayscale, but we convert to RGB + img = Image.fromarray((image_array * 255).astype('uint8'), mode='RGB') + else: + img = Image.fromarray((image_array * 255).astype('uint8'), mode='L').convert('RGB') + return img + + def load_dataset(opt): + train_data = None + test_data = None if opt.dataset == 'smmnist': from data.moving_mnist import MovingMNIST train_data = MovingMNIST( @@ -65,6 +83,20 @@ def load_dataset(opt): data_root=opt.data_root, seq_len=opt.n_eval, image_size=opt.image_width) + elif opt.dataset == 'mcs': + from data.mcs import MCS + train_data = MCS( + train=True, + data_root=opt.data_root, + seq_len=opt.n_past+opt.n_future, + image_size=opt.image_width, + task=opt.mcs_task) + test_data = MCS( + train=False, + data_root=opt.data_root, + seq_len=opt.n_eval, + image_size=opt.image_width, + task=opt.mcs_task) return train_data, test_data @@ -72,7 +104,7 @@ def sequence_input(seq, dtype): return [Variable(x.type(dtype)) for x in seq] def normalize_data(opt, dtype, sequence): - if opt.dataset == 'smmnist' or opt.dataset == 'kth' or opt.dataset == 'bair' : + if opt.dataset == 'smmnist' or opt.dataset == 'kth' or opt.dataset == 'bair' or opt.dataset == 'mcs': sequence.transpose_(0, 1) sequence.transpose_(3, 4).transpose_(2, 3) else: @@ -138,9 +170,7 @@ def image_tensor(inputs, padding=1): def save_np_img(fname, x): if x.shape[0] == 1: x = np.tile(x, (3, 1, 1)) - img = scipy.misc.toimage(x, - high=255*x.max(), - channel_axis=0) + img = torch_tensor_to_img(x) img.save(fname) def make_image(tensor): @@ -148,9 +178,7 @@ def make_image(tensor): if tensor.size(0) == 1: tensor = tensor.expand(3, tensor.size(1), tensor.size(2)) # pdb.set_trace() - return scipy.misc.toimage(tensor.numpy(), - high=255*tensor.max(), - channel_axis=0) + return torch_tensor_to_img(tensor) def draw_text_tensor(tensor, text): np_x = tensor.transpose(0, 1).transpose(1, 2).data.cpu().numpy() @@ -166,7 +194,7 @@ def save_gif(filename, inputs, duration=0.25): img = image_tensor(tensor, padding=0) img = img.cpu() img = img.transpose(0,1).transpose(1,2).clamp(0,1) - images.append(img.numpy()) + images.append((img.numpy()*255).astype(np.uint8)) imageio.mimsave(filename, images, duration=duration) def save_gif_with_text(filename, inputs, text, duration=0.25): From 04297ca59ba534cfd7f47d3c513ca0b3d5455b4e Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Fri, 9 Jul 2021 15:31:39 -0400 Subject: [PATCH 03/26] Testing predicting implausibility with hidden state differences in consecutive frames --- data/mcs.py | 48 ++++-- do_mcs_stats.py | 369 ++++++++++++++++++++++++++++++++++++++++++++ h_residual_mean.png | Bin 0 -> 22194 bytes utils.py | 8 +- 4 files changed, 407 insertions(+), 18 deletions(-) create mode 100644 do_mcs_stats.py create mode 100644 h_residual_mean.png diff --git a/data/mcs.py b/data/mcs.py index a62afa4..9414242 100644 --- a/data/mcs.py +++ b/data/mcs.py @@ -10,7 +10,7 @@ class MCS(object): - def __init__(self, train, data_root, seq_len = 20, image_size=64, task='ALL'): + def __init__(self, train, data_root, seq_len = 20, image_size=64, task='ALL', sequential=None): self.data_root = '%s/mcs_videos_1000/processed/' % data_root if not os.path.exists(self.data_root): raise os.error('data/mcs.py: Data directory not found!') @@ -18,30 +18,44 @@ def __init__(self, train, data_root, seq_len = 20, image_size=64, task='ALL'): self.image_size = image_size # print('mcs.py: found tasks ', self.tasks) - self.video_folders = {} + self.video_folder = {} + self.len_video_folder = {} if task == 'ALL': - self.tasks = [os.path.basename(folder) for folder in glob(path.join(self.data_root, '*'))] + self.tasks = [os.path.basename(folder) for folder in sorted(glob(path.join(self.data_root, '*')))] else: self.tasks = [task] for task in self.tasks: - self.video_folders[task] = [path.basename(folder) for folder in glob(path.join(self.data_root, task, '*'))] + self.video_folder[task] = [path.basename(folder) for folder in sorted(glob(path.join(self.data_root, task, '*')))] + self.len_video_folder[task] = len(self.video_folder[task]) self.seed_set = False + self.sequential = sequential # if set to true, return videos in sequence - def get_sequence(self): - task = random.choice(self.tasks) - vid = random.choice(self.video_folders[task]) - num_frames = len(next(os.walk(path.join(self.data_root, task, vid)))[2]) # dir is your directory path as string - - frame_path = path.join(self.data_root, task, vid, vid + '_') - - start = random.randint(0, num_frames-self.seq_len) + def get_sequence(self, idx=None): + if not self.sequential: + task = random.choice(self.tasks) + vid = random.choice(self.video_folder[task]) + num_frames = len(next(os.walk(path.join(self.data_root, task, vid)))[2]) + frame_path = path.join(self.data_root, task, vid, vid + '_') + else: + assert len(self.tasks) == 1 and idx is not None + # index = idx % self.len_video_folder[task] # loop over the videos under a task + task = self.tasks[0] + if idx >= self.len_video_folder[task]: + return None # we've run out of videos + vid = self.video_folder[task][idx] + num_frames = len(next(os.walk(path.join(self.data_root, task, vid)))[2]) + frame_path = path.join(self.data_root, task, vid, vid + '_') + if num_frames - self.seq_len < 0: + return None + start = random.randint(0, num_frames - self.seq_len) seq = [] - for i in range(start, start+self.seq_len): + # choose a random subsequence of frames in the selected video + for i in range(start, start + self.seq_len): # i is 0-indexed so we need to add 1 to i fname = frame_path + f'{i + 1:04d}.png' - im = imageio.imread(fname)/255. + im = imageio.imread(fname) / 255. gray = lambda rgb: np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) im = gray(im)[..., np.newaxis] seq.append(im) @@ -53,7 +67,11 @@ def __getitem__(self, index): random.seed(index) np.random.seed(index) #torch.manual_seed(index) - return torch.from_numpy(self.get_sequence()) + seq = self.get_sequence(index) + if seq is not None: + return torch.from_numpy(seq) + else: + return None def __len__(self): return 5*1000*200 # approximate diff --git a/do_mcs_stats.py b/do_mcs_stats.py new file mode 100644 index 0000000..9e64c56 --- /dev/null +++ b/do_mcs_stats.py @@ -0,0 +1,369 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import matplotlib.pyplot as plt + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.004, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=1, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=1000, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=195, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, + help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=8, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', + help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') +opt = parser.parse_args() +BATCH_SIZE = opt.batch_size +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + optimizer = opt.optimizer + model_dir = opt.model_dir + niter = opt.niter + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.optimizer = optimizer + opt.model_dir = model_dir + opt.log_dir = '%s/continued' % opt.log_dir +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models + +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = BATCH_SIZE +else: + frame_predictor = lstm_models.lstm(opt.g_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, opt.channels) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() + + +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +opt.batch_size = BATCH_SIZE +opt.epoch_size = 1000 +opt.n_future = 195 +print(opt) + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt, sequential=True) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + drop_last=True, + pin_memory=True) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +training_batch_generator = get_training_batch() +training_batch_generator_2 = get_training_batch() + + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +testing_batch_generator = get_testing_batch() + + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + h = h.detach() + elif i < opt.n_past: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + x_in = x[i] + gen_seq[s].append(x_in) + else: + h = frame_predictor(h).detach() + x_in = decoder([h, skip]).detach() + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [[] for t in range(opt.n_eval)] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + gen_seq = [] + gen_seq.append(x[0]) + x_in = x[0] + h_seq = [encoder(x[i]) for i in range(opt.n_past + opt.n_future)] + for i in range(1, opt.n_past + opt.n_future): + h_target = h_seq[i][0].detach() + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + else: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + gen_seq.append(x[i]) + else: + h_pred = frame_predictor(h).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + + to_plot = [] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + row = [] + for t in range(opt.n_past + opt.n_future): + row.append(gen_seq[t][i]) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +def do_stats(): + epoch_size = 1000 // opt.batch_size # we have a total of 1000 videos + + frame_predictor.eval() + encoder.eval() + decoder.eval() + progress = progressbar.ProgressBar(max_value=epoch_size).start() + h_residual_mean = torch.tensor(np.zeros((opt.n_future, 128), dtype=np.float32), requires_grad=False, + device=torch.device('cuda:0')) + i = 0 + for i in range(epoch_size): + progress.update(i + 1) + try: + x = next(training_batch_generator) + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + h_posterior = [encoder(x[j]) for j in range(opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) + last_pred = None + frame_predictor.hidden = frame_predictor.init_hidden() + for j in range(1, opt.n_past + opt.n_future): + if opt.last_frame_skip or j < opt.n_past: + h, skip = h_posterior[j - 1] + else: + h = h_posterior[j - 1][0].detach() + # we predict h_t from h_{t-1} + h_prior_pred = frame_predictor(torch.cat([h], 1)).detach() + + if j >= opt.n_past: + # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h + # h_res = h_prior_pred + # h_res = torch.mean(h_res, dim=0) # average errors at the same time j over the batch + # h_residual_mean[j - opt.n_past] += h_res + this_minus_last_pred = h_prior_pred - last_pred + this_minus_last_pred = torch.mean(this_minus_last_pred, dim=0) + h_residual_mean[j - opt.n_past] += this_minus_last_pred + last_pred = h_prior_pred + h_residual_mean /= epoch_size # get the mean error vector per time + + # restart training dataset + global train_loader + train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + drop_last=True, + pin_memory=True) + h_residual_var = torch.tensor(np.zeros(opt.n_future, dtype=np.float32), requires_grad=False, + device=torch.device('cuda:0')) + + for i in range(epoch_size): + progress.update(i + 1) + try: + x = next(training_batch_generator_2) + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + h_posterior = [encoder(x[j]) for j in range(opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) + last_pred = None + frame_predictor.hidden = frame_predictor.init_hidden() + for j in range(1, opt.n_past + opt.n_future): + if opt.last_frame_skip or j < opt.n_past: + h, skip = h_posterior[j - 1] + else: + h = h_posterior[j - 1][0].detach() + # we predict h_t from h_{t-1} + h_prior_pred = frame_predictor(torch.cat([h], 1)).detach() + + if j >= opt.n_past: + # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h + # squared_diff = torch.square(h_res - h_residual_mean[j - opt.n_past]) + # squared_diff = torch.mean(squared_diff, dim=1) # average squared residuals at time j over the dimensions of h + # squared_diff = torch.mean(squared_diff, dim=0) # average squared residuals at time j over the batch + # h_residual_var[j - opt.n_past] += squared_diff + + this_minus_last_pred = h_prior_pred - last_pred + squared_err = torch.square(this_minus_last_pred - h_residual_mean[j - opt.n_past]) + squared_err = torch.mean(squared_err, dim=1) # average errs at time j over the dimensions of h + squared_err = torch.mean(squared_err, dim=0) # average errs at time j over the batch + h_residual_var[j - opt.n_past] += squared_err.detach() + last_pred = h_prior_pred + h_residual_var /= epoch_size # get the mean error vector per time + h_residual_sd = torch.sqrt(h_residual_var) + print('Last i = {}'.format(i)) + print('sd of h residual: ', h_residual_sd) + print('var of h residual: ', h_residual_var) + print('norm(mean of h residual)', torch.norm(h_residual_mean, dim=1)) + + # H_err_cov = torch.tensor(np.zeros((opt.n_future, 128, 128), dtype=np.float32), requires_grad=False, + # device=torch.device('cuda:0')) + + # plot some stuff + h_residual_mean_norm = torch.norm(h_residual_mean, dim=1).cpu() + plt.subplot(2, 1, 1) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Norm of the residual mean") + plt.bar(np.arange(len(h_residual_mean_norm)), h_residual_mean_norm) + + plt.subplot(2, 1, 2) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Average dimensional sqrt(variance) of the residual") + plt.bar(np.arange(len(h_residual_sd.cpu())), h_residual_sd.cpu()) + plt.savefig('h_residual_mean.png') + + +do_stats() diff --git a/h_residual_mean.png b/h_residual_mean.png new file mode 100644 index 0000000000000000000000000000000000000000..c679777cdd4269e8522f0c7c316b244ae33bffb5 GIT binary patch literal 22194 zcmdVC2UJyCwk^6biwehppb`wAA_|g3iAIo&WR##HC?G*{kYGd+L=Yq=ksK9}91MUY zktC841SBg_^6%4Ar>btfx^>_C?X~vXucpTWHfygv*O+7U(R&~Bp@N*m2I}q96bfa7 zzP{^ zn40i%@^bPW+=qP+DC<4??$-6T$U?O8^BTcSw( z_~R2NoW7qcI(hCyk>~E`KTa%<$;rj=HiP(QiyG^DRP zXDa8_bSv+8sKaQ0U0-O3Mi6(&8r;vbb7gH*_&5Bz7iAfRqP~9ZP71}P_{smZFBIylL}hX@O5XuT0tmo30;8x{#UtA?MDj4T>RxR^_t76yfD-cXqA{2nevB z?T=|=5)b&*8)Q21J<>B?xAd`;tZYMXRS+E``semq1I^rbZEiCREwfk^~`vmwAD=UONVI_$=WBoPFdE;3dC6r zejXbeGq+Ioi+R!VQp)#$VOvf{ymFl4^NN)&UAw7Pw01t+7pkDH?qBkl>B5z^b@ZIq z=70Tkd-UjON=k}heQdUJedXUCt^PaAB%LJ?`ENl$2Dh*wT!=q@t!~@bx1>6&2nI zs!58EwjSE@#~$gRaI%Qayi&Hws=xM3O-&8;nl%cM;vRe^tw(XS&k^n6C+=+7vgL*2 z+*I6!OcQEqYQ?1U8GQPc40`(d+BMe8R6B=9C=?ESyw7#=Vo6?mBo6TNt03UaHL4jR zwe#%_ssj19i#g>rXBfq3s_ zs#)G&UUcS#p!HBJ0&B47)^hv#pQR0j&dzoY4heXQ=hLrwS&y{F?F+l9_~y+U`yO94 z>bCpezFiTS`Sx_6qp;BX=~Jt&_iNeM*y3b^jx^_3B@8qq#02nOHAFNA3fsLxpvT}b z#@kO0aQaPm-k}|s?r{FSzYb~TC}SPGUl^t2f7b3ppf~t)>yq_2BQ8FWl-XX}`R0yQ z6@TM*?44(p?2@^Qvx6#eifkftLuv8Yarj_~SZH$FfdZ!5USXaI)(1z->AAEMvG)C5 z63jK}26fm5Y{M;?@?!HpPDKm`i#ZpNV4WOlQsT;=@Rkb^jC%2c6-$gq%Q2_3fQyk} znVlSr`SeL4DkjG0bE1Y}L%grkudl1`xw`7qn6upcabmK+PT~_ESS~9y0DPau_jhybsLVUJA zE_!U^72k2>`X`HOk(eh>{vgYtmhRzF%1bG)%(EMJXHByjRDBu{V6fQdyjYrky{YzF zZ|`eQlbqqJPO|rK}g2ih8XhVYE+6@~HSqd%w>ZaZN#~&AQ*S!i(-#_i@?(4g{ zYCS`zdC(&ss{!RbWkVf>9R}yGUr+owJuR@BmO0%oM9jJU?*03LuFUoyw5PD61H!|1 zy4sg8@ESMs;TPKTb3ZXq8uSTwQ5_gGza4_TM&6{-*lb1%< zl`vRz7EyI~ck5o()m;}XP`$V?=fTQr+Mc(pF(@R&cyWGQelr8ZHEpMvp51-{I_+0y zKJ4*4>N7piP;r*0E=JyFKVF+{Q1krW!ou@2kFXwJU0DR{TM)utSM)yy_nYTuXJ_x_ z;HXH{DQrK>)8RCiGIMn*4=2@fldi$R1jM5e9Nk9sM7in4M;vu4w5J;=HKgct^t>ADIMIT1B&j$((iT$gj9)m(wH&vk z(N`U2Q}<+TgobL_yJVu*0Ki zerml{uHE(L_S!QPD?c294r-n3Kr6>Ic&1V%CtI6XSsEKqvwF4VZ z^W3>}G=3S*Gd=Vh7!DuDSExYVo}Qk~nWkKlaxYu1ADR4K^DZLMpe|bNX55a{Wn_g& zjzyY~l^2=xnVsnG=399GaP!#Ily%N^!?dO?baYXh_sJdg>C-x5#m2{1|201)#WbRa zo>NmnLE%Y-UQbeE!z?zIkGHqtMMkwf2M<=pRaRG5f8Fa5Dd9QrR9@5zcTl0XW#2A~ z46S0Vk^JfQ1guZ~OmEO^OUGh#YpZ64VT07uee$1?Uk1$E^QPS0+{^>F;+>bTrrEMB zjbr($^`B5>CMtQWV^tDVNy0~2LXcal2)x#m)5$UWdV67^CQ`y`?3<*+!u(kCk7}_c zVRJKCrncw$sg>QB_Uv?O%XLrO40uNuhUE;ChDpz%7W&LR3(unyEsSBvk}9lzenQnwYD6qsdHkI%uBXA!w&+SnOIjG%KZNMf@#AG$*0N*s)t++ zLWFESKjJEgUAb=S_m`#2X@1VXzPbuJOeG51jA(xTD&^grZV>6>LOGPQI2c8_4Iynm<2QvHhr7_|2O)>yc5Gu`mvo zV=)ou&ru)-ka<-AE-k;lUCDIR>^K6d9ywYiO~0z*aZ*+V6VH_g%a$#(L=Hb7D0m*Y zL`qIh_3YUPbq|a~fd>1t`$FsP(2B;k=Q|jEc(RKm1_0S$&UF%J&h$UKU~Qk)m0zkdCigP%>#k5>zMJn@P8 z*3uh zOQ_&+b#V@Z29trw5#%mG+y1+k*AV+P{p9z z%CrK_QaO$1Ei_H344aPQt4F)2@XdewMO z#US95DrRRYI8-ieF3PJYP>EOmeGANy2Md}@WOw3qTTd-RcXyGS)#cxRQGQRbTQ&A$ z7TpNVyG>^)_%H|WulR#PDZ2chzIb%2rIw1~vhUx1;3A$Pn>AFF<8}zQKi}H#mX|Pv z(z*6uE?L3ODX^U4Qv3JU{nIVBkI>+&@PiAJLQ|`wEp2pJN8YX3`ENQ>?OmaDt0|NN z2mX8n?DT6d;&Gychj6&7Ajr>FpKVcNTnU!NRntZ_|yEp_bIz`>B#cA!8pIybxFLhmCo3x!zoJF@; zURjx5DCRh8X>bqxKxa)W&ocjK&!X{^4d2{eiKN5zS-7mMjC$?bmxZY2D&^Oa! zLxcu%M++?W=PoVIHKblT<-sJRs;#XpH~wp8rmVU;CGjMJ@Q1pIB<0SXJ9Ox;6}BRW zquj-q`-J4_965G4DmQ2k6(!vEsEyXPOtICUlzuARz~%GITeOO1uuvb{c9BoKY>x~6 zK>(@WGSk>bq4=yIOWe7UVL!X*rlL)!OP5o^ub9al(qeU?QTe zzq7MdrvJglb21CB5cX;@UB+EtW7UOSvFkSO5x>@$tc4EZT*Tx*2ejT6gqX|K8-r|} z63lj|EBv?`(X=X~!k!$Q{WFf;cP~dAYvS{IaNB?Ll9)hMHUJg^i((;x=SO+=?b!;m zUq8N*Og{lOfaB;!d8j{CF1t@@#{j=v97$()k#HBaMI`(mi=pMVO_R%g9}kuoVWTlSL8Ed;*J#K zvue|Efime=XHzMZ{W^(1M2@55KnD05acH*oOI70G%Jta&^y2?RT`5Vjp{fo(ri5pt zVsD?bUBD!98?SyWIF*t0Ybgh#VqI(Lr(WD%u9QQ6tgXRbb79z!nG#iMytdPuO`6ZJ zet#8yS)koE5^b23)o`;4xQ9E2jmg8wsT3dGm24cvvb&SN(c30Zs5wX-c$pnMsbehU zBhF*Z=9zPVYi(B#YtqVwmE|0nY%Tznji5m$L7Yf=>sR_mREy5ZAS@M8RC3>*6lJVe z*yJX7sW(NfeyxCWTUX+^c{Kgikm9yzj({a8-GWvtRpnyYKeUWZZ-^g~jt@Di|6JHY zdYSFdt~16R1zO!D9v1sYu^o>XH-CHl#B7_JV61s+TJr7L65qjeRIkQQ2$mhxa+$?G_Dg?IG+fH|vmylif4x`{I!Ecp9BQNkrWM7PQ8A8 zJ(Epq1lqKW^_g7*bf0O!l5H2VP7LHTcwSU=y!ie5?*(%sb-+K2&uSu0t&VZ3P0=aj zY4=W0Qg+_|v7obm;e>ZZ_dRsu^>IqWR>BctTlzFL+C4Y2y4+KLb>$&ITU1P_y_As6 z2x;-_GmO;0+I?t?@C^jyH$t{=ex>(FQBlz+BqWJbE9X1i+s0-}j{yF4}9{)w%38K&x0Y zcROZ!wIoeTSZHW6`1^AGw0=|Pxu8?|ZtW~?7Rd*vOCKmMz09T_LT@L4s%5p`=u2wE zHUX31pxp-?@J_tmm0(PV^+2NVi8xGI>L}=4#LkQch0Gb#)7u+}C(o-_zQv*G-a2|2 zw59cjgVDrFU3af5qO;%k%*}O8+yYOt9UWzw?a#{&A@5vi zgV4+dpvmyn7-D007Rnk!`?f=}7x{7O)zolH5T5fh6EnRJDneT_xYRSmCWe~oEC=e# zA}Z@3~%^J9-#4VD%cNG|-o1KZ3H46DSn2bctcx9vM` zt!8v`ZlQMERftam5}E^SVuYNGjC)siw@6M{Fy8M-P*s17ASkX_LhG2cs)Me}D)o3K zeQ$J~>JTmRL4z9|2cdmq3_?G%jD>MaVW{$b|~Wow&2AsUq*3}o1p3>EJkySX zsX~yx4QQ#0OG*rj-RR_m?9Vh{U(R)aN;!v|shnw|l@GOSip%;!l4)fCuLKIW3Kl)T zlG-uiuv+bE;f$~=dU|m>g-+$@f$<8aXy`ld7j16dBqT!2vqKjQ5!6J&&~QC}AlXvq zNB!bXN$jcyk?Gbwgv>LSAEaFw!Vxv*=-#=>j0y}%5B}$Kk44D3l09cI$!-7TBnMYg4@RYwhc{}T*Dmp)T z0hz%N73d5z?+dp!7HUB?)w{YY)?TE#>FDVDr}pKq~7)ux!iF@JmXy{!m`dW7B%WHBKWf8tk+w1DF7+YA8}3=S1H%pnAfcHxCI29=~dcq zXQzgZK&5(sM+bGx4ONSqWCHo>sw)1?du0xzT$O-vZv`{hOWJ-E?!L5YQRIL0e2nEH z`CN%wxu1Iq4&$1-GCyaFpet4(%4DDguU)g|FfK}$?1&PzFgHzaTdDr1vV)Qm5M)}= z1Z`-$j*gBJcJMgF@<6bqSvK!Dl`ah{v^u`2n>!T>p15triWUAS4T4pp>{{53={?ZUu=I+n5*>D=mZ@*i*DIZ zHmLbxFQlfX^0-k`TG5obNK}0Ikh+u)J(uWj&1qLX^AGOad3@KdT_bBYP+ISURuDd~ ztn9Nuw`tSvpsHTwXG@!YfZMH4P~#F=njd#_cUMN&%zvZnG~xZjmgZwlo;pQ1zp4$(YC*}tcY$W)e4tQ zpjp>}4n2*OL=>CccT_GXe=;4X?5CP;weUoCCVZ)JYu35B>Cr}HFn;rIX9-a~LPJYo zgs0^A+E_?71aU(4Opp%~g(<>|*|w7kie>1=6OB zZPaSu$Tw`~`>^N)THsaFb>^QJJ7|~$V-ejWG^e~Rk+%SetHQ)O+9J;8Oe8Q%N=n{y z`uKFOY;8uPwvqFC8I822>)^i4D{@la%Zn|rVlNDTc^UcMFJHfEd&c!9#ap*-;lk%p zi!1D*2c1(?^e&j`-4SCzyeAOJ9+viU77w$4zV@rq=a>vTn0RfWf!cPx`@L@qCjH31 zh-97q`#Vg`EKbwO!)1|fH?FVcFq|<$(-V)bwXSWK&jqBhI_Pq2WqRH>zkS&F8?iz* zfA{WrMJ?of5~G=kb5aoVbja#5I}C9{8Yp=G%*ee47&i&)Ycw;~19+~C+tU0jV*c#` zQd%ss;L-W1w0-f{7MAOIZan_z%WLBx0tu`}+qJy;Pz!iG-!Mdyo1d|lR`flfQlDdG zA=ppUAFHWDz%F7i` zpS}z2P&;OZARe?SL{-wZMn@|G)%%}pGpI=Y1o4kIUq9l3_fNZ8tqM=ax4ym@B0qS; z*>k$h?A-?H4I5O*$3c`<0gud>N)eq6ZU-WHmiuxHq5^zI<`3bll2Q~Gzt3#{eHS3V z6530=8+jPMAt72<=73uIVIoLP>5C3BsYNavR<3Sf>&%muAcg_U7UN7}9zeMce{ma`2q5B80dfU0X}>< zYcM^ERBmTMUuVg*t&gHa zCnFvSWOFsN8c3@R(Xu-n;cSr`-};)ir|`uTkf*_y)Jr)im2~jjpmT+iBe_Gs#8SPt zXJU^eTCJ3NcpnHw(9xlT4y7UXTqNVbP+P7V-bVoe@|T?mmyoL<7>1a69Hy?v7$o2P zSxi6Bnw@t0&K;F(b3Fp5IaCu<<>jTMUhO*u0TG@uC4g{EP7=Gw*}km(&}}oFlV`Z3 zN8nY#ca@-c<0nzje2z{ok&a^QTbUucF1)XDzP(i&90Rk%nH_Ge(Rtm)SpOJQ_5Gk6 z0Y?fFd!nPGlf0keuy}Exw+636d`n91E47lV zoLdi>gq=F3*ZC8v|LVz;CuPvvyK|T1p+yz<4nP*j|ETLJq7S*TGEOOmu{Y$HMYPuy z0CYWf?p|z#XmGfDg@s#X!bIB=wey?vuw4nDB5fF4D9wGoSCCpS;Yk}pBo%oY7-&eK zqOLnPQUwpxMQ-%Id)e5Eg^^nJ?%Y`<6MVGh(gM(e05%X0_$uBo=f&&*Zr%HFpjJ>0 zxY2W_H%iOP3xMn8frX_PU7^1kLn`Eg6U|&08Qy zEtL?GN)S3-Ckbjf%L5yQm474t2NLhY?_0;6v|^>jy|vJ`boylK=+7}PwpnM*K+?sEKo2cY()YNT9^rty3*b}!& zAbGg&+X@^V>hO1N4w)|q=_?Q*ml$HkVWQyUkmofmZps3-k?8XE9%7#9N3ev-ET*v%2YI#1k(9yD?pB!sH@ znQ>V-1_Fd^lKYXh@;{=w$zr>cWw+}Kv8e=~D6kAjpCfeE0ILFZS~` zV2h_uo?Ha?ei{9qevW1TWrT-Dp6yGZWb{C3=!^E036Fi)qd(qTrH_7))6WKW30RDT z@%@E~x~XdZLx-vuy0cj1gN{^@Xcq<)EYVLp#^`5mZeEGqX#%OX3h7%G458g%Ql<=& zWF;x~*j{Sitv4_p{q*bCFMm`Gi>kxzc!T5k$^7Z8Wsq<+tKa4|l2x>9Ott!xg_o}eSZv1^oi6A5Gqn1AE z3wv8#SA1I5Zr8`Orwm}dg+UjM?^PM38H2{;ko=F}3E+VO-*~jrVPUc{=zk(;+})feN2gRA8%4>FupY$0xF zPubVw6MwULS)jvp&B2pHuYgCjPIC!N$2Hb~Qf7zBh?`&E-jf?--%sUpMHD_g2MuJT zIPLr#uWx221#fiSg+k5)i%-GOmvOmP^!eOSk*4ek@#l2~Jpd~~B;Y$Aqg5)A;SjESWN5Dbq|LAaUQ*{1d@VujSEM; zGDVp< z%cT9~6gDxtrO3>j1ey$4l!#M>Rk-&JJC5XatvFnGdsg#O#oalGCd+y&0|-+~2rqBq zPX#S0jY3%uWW^>)D3_BbZzBg>4I3riQlzJ()y%d<+EduP2f$2|B8;_Nu)@d8&J%i# zJLW9(MBa~wmVE~8(%*bzp^j4;AxCeV#6EkM?S8?S7Z!$lKs*pzU4`D(dWm)B0NiL2U^Er-p|@f2ezCOB zvD5&kKrDJT@(%FNpi16o+UvKu1NF86kP?Nw#V;4R!>{5{JPQ(&{48#ZAC#f%>-hv? z_QuT*wB=?ZR}6h}&V|hs-#n_N`uM?vC#+1c78DI*yCfFO4sf&~b<)`M+S}O?6H^0N z;&|}6TM}eBUD4$X?1OvHs#yowKOWTl(d(v2jdXN$Dnt8L%nzG7TiEKf_7Ha&e$nC$V>A?t2G5EsK*w#YUOa^bA^J1XfMA5Fl#V?InG z_R3*5#~vzaqGx@^8k4nnz*Gc$$C>1qov@)<3rfz;&K`g*F)kUTC#PnXS@Fw0Pgoq} zYn@xb6%hM%&hYD&3HWN`LCnP?e=$2woFfpa$oMF-*z@PlgA4Z2Z``;G9t#H{x-Amg z2c7dDUWRf_sCq&TIF6O=Bj7U~85e%`AvcpXgg6B%h=wE-i@+z44lN8o9FW?}MM*;5QpST*;$yStzCI);=%2ySg+L6%Qx5P_U+CkI`YsO|Of~<~b z&!;InwDO^cM@GjlV!Q2MnC?&mQi980y_5dOySHyA&{D%u_#+$HL55r{%=pGZ7#oS@ zFVuO_&UFkBI{{eUFJBnXo!m7_*=K||;YEp&o4=P|-f|v$UN9F~#Ynt37-vSoo!w`h)*jw+#x8K^hT+t?Y)GfMiJR%^`+h9Ke((3~X5xp!R=_HD5s@~7X6Pnv!G(NCIvfWN62xzQ zY#}AGYvql}0PT>e$XCu znr`dXUdxMIoEXo+lDTg5!E+|oIICD2Y-MS0;`~G1|570aup1-f0AF~9d5;uAY5Ell zm6y%Fp||(ef$+Y_1@J!o$ZGY^hMd4Jk0PPxkcpg0fbXrdf3yRI4dQt=zS66x&$Tgi zm=PMy(+zXV1`u?-!d{Io)biJlOT-lIm!DTiT2vTR<&ZJe(d^Y?fW&e*qj3N{QX`-; zDIjxV44^S;2xm5=*T=G9E&*M$l(*C3%+)!)IjpO-Y)pOel}D(G-}4gpi1U` z(7)si^}rqeyWechw&2GAL5!ZXpZF^>p@GIO9z-^F)gkk2!dwHqD0imKvXY*kY+$n4 z@d9?xkhM{{Mf3;N!+|T}@tOVdW$!sA7ZXD*_5x)5nA2fV_H1w*j@1(V2RCU51O_je z#I3a#QGzpW&X^+~);p=N;jv3UJ9s1h0@y7S+=_^Um^vYn<0%cv>0>^Jkh`axm!_Qs z!SWEG*IZPalIA~wKLBim+PafoL8w}|QVE5f8;Mqgy=i>_gF zdK?%5*fe)JA^^7Y>dSboP61Ff0@Jho@I{x!pPtU&0j1G82G-qr0ZEn2pOA<`F-I>d z*PQQ=Nxp>d=curYP3;Rws3hkQvXf-)3S9`XZqqOx4M#FJoEmCcsa;N3zD49TBlHQ- zQEColwmt`Iqz#nH@w9J2OLi!Es5wI#xa^abOm(vUn6|?( zKcKwAY8plbuz3^3^g7>Bfz^OTf@e?OVBNddkjx;Vc_$193ZD_Yr)u*zUlDm57>S7e z*a~DWq{^TB6na>&o1v_~eY?oS$;4?4$>JOo2(rRta7QA{Z4)GSICcrILC6Ns9CXag z$!I0MOh@60N7-B;>?)v5C2>IpVqjn;p3P8)kz){eS_+>!uwfrg5-%*YTB2QIRU##y z?gi&E;@O%YJW{tU%ifB!_-%C{-{TmXNXW z`Q%9*Ra8=9dX^1pi62u9ADpeps1RBpg5t=O5h0V&fDC{kYJ^Fa)8U(Hj^zv=f*0UK z8T*TXO@Dv?itT!%=^t@a=BJ;4_AfdMAEi8jX01VfO!yzDs3pYIqV# z(|2Ta6||D4@@v0a>_26azj9?1;hRgq3@}3eDH4P^lDECReM;yZ8a)|!N6}6xQh2V*^5x4zo;4AAA?t>2%ou{DF5h9g;&FazF4Lr+(9mOM5kt6W zt`L0*jjAl^1+fihmKF<_s=p>z&)ZP#u;}OPi8$pNi1EMf;bCJ8+f<>`s)CFk{JC!+ zt*S%YdH$#CHh7{RK6r2#D|iSBRgA4U0K zA;3;&BBMs2BCoXH9ESy7er6=Tva*toDCb1wBjZMxp*&3bW@1Rdewwl`9+t2JGQFN- zJ!}YG%RMYCj5J_%1)0uEPUA((Y5j3Q2BL_8T&n{c%nNSP)gIZLn|vrTD^_7#ZVCb_ z8h5bw@ZnEHhJxA=gbBL$25l>iJvqRq^Pq#`6BQMuX}R#pX=uO#Y%&k&d4Uai(c10Z zzyAYlIiG~l73+_-=if^%44>R$lN}^rdIXQ=2${fwymt5Ay#TPMLB9tUDjKfN0J{d_ zy>9~&(idGG%nPh0R-@!nT0a92peS?EFJJZTsYW35aCcuUpLO_5x zZ}m&`fTpyxbX~{Nq769RlC}uSllN>FA@M!>+M!DR`7fUDxKx2#pcw~n-{=Mkp0B^~YtkK|s+_wtl%Af>Gn35; z|5L3c_oswIH#QBU$m+JbC2gQ_M2;? zYmcekv6H3Oa7$T=tSmlQhnSJdjS(;s%Si=4{X8h zo^W5aRp%kA3qcs*=qHed4bUH<0EA$#6B3+GumkwO61qbiSQxgA5b`K)5vMm5VPwGw zOpQGn4AA>~NbN-c*CuMDQHP91xH8vbo75ved9Q(ylu~Tgu(lKMDZ_Ej8Z*4b5qTJt z0diHkHI7{&FSazFwpYgAEMr&r-NTkNlpkAy>!@uS1a2{kIcXj{b_{~A<$h17zR-6n z?BD}Q0E96>J{gT-UQ_WfFH)|)|USlF7n-#GB}1u!)kwnHS-9fTQX>dT zZXO=K!~jV@BXb<@*l4n*K2U&jDD4 zA9dg0hC7({&kg1C#fEb3<>7(Vl-Mh|FKE_fA!;>)QEIwE`*Tm_^6)N>mn+P`8j{H|_;MP6J;=^N?_i-}(_b5fp5z&+P-NyBTY$`Z z8E3YsDt@IL&mM2QQlZkN^5ty1?_R5WN&A*~##6#x{P9DR&3vsbCY4d6eM0RA*oRxQ zm2g2NR1VTH_)Q8xx+gH1%*LT|HlULsS|cIz9-2{#hTR+trp!6N=5(8L`T0{-EAQU( zAFb6&3n))V^U&t>O@f6T#Ee&sRil>Lrd?FWvlT))Uuw(l?lWw^EzUsYQQnQrK>JiAp#d(|TVDpU)kTE`ni}91^j8Sjc+Fzb^Ko6G@oFS`OuL5JQ0M=_3Gk z5V6?jI1U{eHaxZ3rQ$7Hd)JJSP>ufa8Ow&X;lOX3Hf{33)S5!*pEk0sM>~Ms0cUq& z)YbvO`}@+?1qhTxQ=_MZU*5TK1u}W|3So4zZSXZ4Hob}>%|5UT5;GI=sjz8n&lkG+ zE0rh{l78NWAadkkGI@1A)5MrXWIawvfH_+7?g|Rq?a^OUp;hs7w#fGXtp%EuP#RwY zT)61|28THWEj6@G~)h zLK7{?q^*=rseIk6q_k;W!l&S4mb}RNXxl@*aziv8A@f3gCXO|_pBFo;L@@$tl3zC^#(VC0L-a2O-FZ&vH z$fz@Gt>!jTwdt7E)LkktgXBUTml|fQ%eRI?U5cv~MIPS1W{U^1jf=#r(WVcvhm(C` zoT`2B4Z2l(geVF;WW0#k)0vqfp}B08!pMy-amlfIY5Src*kUN^n>se0c2b?~bAcP_ z&%_h1*(G#gzx0HF`i8aPL1wAQ@*`y4tw`qwtze?dia(BLzj;XCvvy%)sQaIl-gf=^ z;2+mk`0$HsD70(z|J<31>`9?;E-wr3OM5F8Mc!BcmmD3%ht{2vvu)og^;-!}ANNyJ zvS|Mr*uv%MIypAPe)z{QluN?SDHBC`J=XjmBOT-{7=`j*qC3lXkUJuf>U6}f6gmo}noLn0-`2%q-AabCnE@_KQJ-a5+uV~k2)<=iP@3xt@T<#x`h0mf_rp>n7!yd95ak4J_ct;x#r=b#&y%P|R z#DcsPY!2~W!S6+m6q(E{Tu5R)dGgo|w$%)j<1=r+zvyH4T46@W60}#$IHE$b7NYJ1 z+FApgJOG=%b-WlfF|g$q&?G)#1tolz@Od~Q<1FRSkm2Gft0>|5#_PWRh{p;Q=hZ)8 zJ#8*<%mstqrO*zqv~hEKS&uMU73+9ss5u7s?eg;SoaQH3V1)+>%v8co^I7PJw`|+i zuK-Ofj!<0VSFT*iN(@6IfF^(gR!W0Vw5iDb%^hLHZVi;(pU`%9chA2a8rge!um0uN za+N*Y_mV!g>h<7F68Dlx*g6rG)u63G+f}#mPchMd(g!%Ks z-ZHSFR@ecSU1TNmA@Hx}+1gF~&Tn;TyR|0&bxNe%Z{L9f<$8`F9YJhn2md!SBU(MG zTShSdOQt(8aDsH_tEi=BCoZIU@*jWZ>5Xr!9X2cLNx#D_jz01nB*nN=tAEa;&!I)N zRKVL(4||&5I39&KVNOmpkp>B>bUEy_J;E@S;~*J}{R!tY1{zj62mEghlPH~}y$yOP zosh4DN5>>{Bynny@r6arLKL7uj7Fq2vS6}=hFLfk*6%?g;1SvphFfBEV%tb&Ua8OR z_Po{Kl9@uh!Gs(V5z&H=?|im-6vzNJ$^QO+;`_^Y0w=AStQiMAS0*KYDiD(CL^(&w zk2RPoap6(8T)2VKXz?aUamtK>mh9$H@_Tp@`lx z5zYa2m}7`#87cvIu5N{^pTbYTKtz**#yXe1tJcfM=Uci51fQw?5(XS-!;U?TK6e1~ z)5RInDO(Yr9d(BxmKiep}Zk zd2)iut2Kr1w=}T4cs{*Sew5am@wFIY4=o(*Qx^7>7tV~*79Uf4Hm7?!Zt08WN^9a; zKvyB?1HJW=`woGRvKxK#(f@3d)JMBC^$L7_AY95hT(LZdFTGt+2cw!_9XMxM*mMJ6 zn9C=k0WmAV9+B(5+tL%dNFem8Kqzdw?)g-mP5d8L{}%6c!ra zD|n)VPXo+MjQI-t5cJsH%x&`I{ow4GPzoV@EJQUj21&d--QpEdZ>LvsOu0hFA`=*- z;MDP9AY!K4m+mXWxy!V-r}!Yls)4WGy=MD5+ ztyIFJC;l@J7BOvdz|R85%Fysw$2V>ZU(p(?LhDC)>Q)%RBWlQgwVM%}?$M7%FPi?^ z)ggtIN~D{{SXeDQ)n66bGfj2oI9YXwRzSTzuy=2Xw*I{W1%HnI<)O{;ND<_u9L!A_;uSP?_t$T`yd(sn0EJ%( ztY!vgx^URZ^Tbn!=+tt!Yz;B^fw8RtF#BFOBM8$8ldiWv`YO(dikq#?qWlf{m%sJI z7?V$p!}GsJoS|ZnLqtvvic3gv;1R&Oa47P7xi*wW;_3z=**v=HwPMFn+aKW=Wk_1h zax?XlJ=WLTtzk7u3~*j9pgmgo`R5jLJ{6#Q1|kmW=BWE(b}g|c;3->*l@)TS$BCV- zg<_HcC3CmuHpSu*GCqS*rrVgu`g4lXY`eGIpMp)>7L|V9l502uiev`p2z2!HYTtzS z$ptGwzrc7x#m*PCc%S!0pg&U{XKKxPf} z?j05PnKj-^_XTKzZ2#YOK!_iepq^x~ZiLs)J;lLKD>k1>d{V@>g$g;E01jQdNy#Fie5|u=5092Nz#&dQ}YhC=gRpb63T{wvvGt z=sAX{-9mNHz&A574WBc9vTxtM()whr7I!A0&lshedyD(3gt}RUzc>{O`8>^{Pj0R) zsknbRiVQ!j1$@LJiTZmOMU98h(~z!5niDHj)rDX)4nWm6}S=`@Z0glBv27KYLoa+*rK7N>+ zOMr^rzkly9>Nx99GC@oq1`Khwj6WvLh%!JV#c_zm+1y>V#7>9&4~dATC$ydZbF$W; zxef!a$sKL$*a;7zKB8|dMek+2!&zeMg!Re8!A+s0RN^>->k!y}D;Q7> z>@fUag$dl&+)I9Osn+{P$Rz2mYAEg|#2Jgy5YjvCv#3y?UF$#Q?4G(%xB&>o<;)L^ zB>bOHzkkt7&-{FRULdM-h1OH{^UiNtL-7PpP8~v;T2iP%9V)A#k4 zE>S$w$sts*RAJ6RN=8N*+$!-(7+|28%rgTxvq}O9D4~E66E*_rJ3u=*>Wd7}!uC*8 zbzwQ>F_{FphqL(d$I3a#(R^eqgP6t%h#+V00P-7Pv>Up31I**>lIKy8A&tMtY`2#t z{xssthGhuA_bhg!#b^oTxajZWT*@PnkP;N5mKHr3(XAj=%{hQH1@Q0x@>19P1rcf? zh(JiNR{sdWYXT~ppX?`!f()w*rSmod^qm7L zRGrknM@x{Y?-9cSQWB0YEN%nQ74eH^C&kCbVI*4u*?FKTE#`VAM+kvkeNMyc(dgkk*Y@cs_=SRX^D2V(z;7#P5?1{?Tv$hZA=G h*U_i{mz3oA(sCVL){s5<2mcf!&d8lkK6&x>{{tMr5g7mg literal 0 HcmV?d00001 diff --git a/utils.py b/utils.py index d684fe7..0635a37 100755 --- a/utils.py +++ b/utils.py @@ -40,7 +40,7 @@ def torch_tensor_to_img(tensor): return img -def load_dataset(opt): +def load_dataset(opt, sequential=None): train_data = None test_data = None if opt.dataset == 'smmnist': @@ -90,13 +90,15 @@ def load_dataset(opt): data_root=opt.data_root, seq_len=opt.n_past+opt.n_future, image_size=opt.image_width, - task=opt.mcs_task) + task=opt.mcs_task, + sequential=sequential) test_data = MCS( train=False, data_root=opt.data_root, seq_len=opt.n_eval, image_size=opt.image_width, - task=opt.mcs_task) + task=opt.mcs_task, + sequential=sequential) return train_data, test_data From 248237b8495457d752c02d56603db1ff21a2be3a Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Mon, 19 Jul 2021 15:06:16 -0400 Subject: [PATCH 04/26] Updated prior vs posterior model --- data/mcs.py | 41 ++- do_mcs_implausblility_test.py | 384 +++++++++++++++++++++++ do_mcs_implausblility_test_posterior.py | 396 ++++++++++++++++++++++++ do_mcs_stats.py | 19 +- do_mcs_stats_posterior.py | 377 ++++++++++++++++++++++ train_svg_nonstochastic_posterior.py | 373 ++++++++++++++++++++++ utils.py | 10 +- 7 files changed, 1585 insertions(+), 15 deletions(-) create mode 100644 do_mcs_implausblility_test.py create mode 100644 do_mcs_implausblility_test_posterior.py create mode 100644 do_mcs_stats_posterior.py create mode 100644 train_svg_nonstochastic_posterior.py diff --git a/data/mcs.py b/data/mcs.py index 9414242..9993996 100644 --- a/data/mcs.py +++ b/data/mcs.py @@ -8,14 +8,17 @@ import imageio from os import path + class MCS(object): - def __init__(self, train, data_root, seq_len = 20, image_size=64, task='ALL', sequential=None): + def __init__(self, train, data_root, seq_len=20, image_size=64, task='ALL', sequential=None, implausible=False): + # if implausible is set to True, generates "fake" images by cutting out or repeating frames + self.implausible = implausible self.data_root = '%s/mcs_videos_1000/processed/' % data_root if not os.path.exists(self.data_root): raise os.error('data/mcs.py: Data directory not found!') self.seq_len = seq_len - self.image_size = image_size + self.image_size = image_size # print('mcs.py: found tasks ', self.tasks) self.video_folder = {} @@ -26,7 +29,8 @@ def __init__(self, train, data_root, seq_len = 20, image_size=64, task='ALL', se self.tasks = [task] for task in self.tasks: - self.video_folder[task] = [path.basename(folder) for folder in sorted(glob(path.join(self.data_root, task, '*')))] + self.video_folder[task] = [path.basename(folder) for folder in + sorted(glob(path.join(self.data_root, task, '*')))] self.len_video_folder[task] = len(self.video_folder[task]) self.seed_set = False @@ -61,24 +65,47 @@ def get_sequence(self, idx=None): seq.append(im) return np.array(seq) + def abnormalize_sequence(self, seq): + """ + Takes a sequence and makes it implausible by cutting out and then repeating frames or + repeating frames and then cutting out frames + """ + implausibility_type = random.randint(1, 3) + # start = random.randint(100, 140) + start = 110 + vid_len = len(seq) + duration = 10 + implausibility_type = 0 + if implausibility_type == 1: # object is invisible when/where it shouldn't be + no_object_frame = seq[30] + seq[start:start + duration] = no_object_frame + elif implausibility_type == 2: # object suddenly freezes then teleports where it would be + seq[start:start + duration] = seq[start] + elif implausibility_type == 3: # object jumps forward 10 frames and keeps moving + seq[start:vid_len - duration] = \ + seq[start + duration: vid_len] + elif implausibility_type == 0: + pass # do nothing if type if 0 + return seq + def __getitem__(self, index): if not self.seed_set: self.seed_set = True random.seed(index) np.random.seed(index) - #torch.manual_seed(index) + # torch.manual_seed(index) seq = self.get_sequence(index) if seq is not None: + if self.implausible: + seq = self.abnormalize_sequence(seq) return torch.from_numpy(seq) else: return None def __len__(self): - return 5*1000*200 # approximate - + return 5 * 1000 * 200 # approximate # if __name__ == '__main__': # m = MCS(True, '/home/lol/Hub/svg/data/') # s = m.__getitem__(0).cuda() # print(s.device) - diff --git a/do_mcs_implausblility_test.py b/do_mcs_implausblility_test.py new file mode 100644 index 0000000..91ea12b --- /dev/null +++ b/do_mcs_implausblility_test.py @@ -0,0 +1,384 @@ +import glob + +import cv2 +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import matplotlib.pyplot as plt +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.004, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=1, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=1000, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=195, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, + help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=8, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', + help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') +opt = parser.parse_args() +BATCH_SIZE = opt.batch_size +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + optimizer = opt.optimizer + model_dir = opt.model_dir + niter = opt.niter + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.optimizer = optimizer + opt.model_dir = model_dir + opt.log_dir = '%s/continued' % opt.log_dir +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + +import models.lstm as lstm_models + +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = BATCH_SIZE +else: + raise ValueError('Please specify --model_dir') + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, opt.channels) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() + + +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +opt.batch_size = BATCH_SIZE +opt.epoch_size = 1000 +opt.n_future = 195 +print(opt) + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt, sequential=True, implausible=True) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + drop_last=True, + pin_memory=True,) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +training_batch_generator = get_training_batch() +training_batch_generator_2 = get_training_batch() + + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +testing_batch_generator = get_testing_batch() + + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + h = h.detach() + elif i < opt.n_past: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + x_in = x[i] + gen_seq[s].append(x_in) + else: + h = frame_predictor(h).detach() + x_in = decoder([h, skip]).detach() + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [[] for t in range(opt.n_eval)] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + gen_seq = [] + gen_seq.append(x[0]) + x_in = x[0] + h_seq = [encoder(x[i]) for i in range(opt.n_past + opt.n_future)] + for i in range(1, opt.n_past + opt.n_future): + h_target = h_seq[i][0].detach() + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + else: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + gen_seq.append(x[i]) + else: + h_pred = frame_predictor(h).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + + to_plot = [] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + row = [] + for t in range(opt.n_past + opt.n_future): + row.append(gen_seq[t][i]) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +def do_implasubility_test(h_residual_mean, h_residual_vars): + epoch_size = 1000 // opt.batch_size # we have a total of 1000 videos + cov_inv = [np.diag(1.0 / h_var)[np.newaxis, ...] for h_var in h_residual_vars] # we assume the covariance matrix is diagonal and do the inverse for each time + + frame_predictor.eval() + encoder.eval() + decoder.eval() + progress = progressbar.ProgressBar(max_value=epoch_size).start() + + + for i in range(epoch_size): + h_residual_var = torch.tensor(np.zeros(opt.n_future, dtype=np.float32), requires_grad=False, + device=torch.device('cpu')) + progress.update(i + 1) + try: + x = next(training_batch_generator) + frames = [frame.cpu() for frame in x] + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + h_posterior = [encoder(x[j]) for j in range(opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) + last_pred = None + frame_predictor.hidden = frame_predictor.init_hidden() + for j in range(1, opt.n_past + opt.n_future): + if opt.last_frame_skip or j < opt.n_past: + h, skip = h_posterior[j - 1] + else: + h = h_posterior[j - 1][0].detach() + # we predict h_t from h_{t-1} + h_prior_pred = frame_predictor(torch.cat([h], 1)).detach() + + if j >= opt.n_past: + # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h + # squared_diff = torch.square(h_res - h_residual_mean[j - opt.n_past]) + # squared_diff = torch.mean(squared_diff, dim=1) # average squared residuals at time j over the dimensions of h + # squared_diff = torch.mean(squared_diff, dim=0) # average squared residuals at time j over the batch + # h_residual_var[j - opt.n_past] += squared_diff + + residual = (h_prior_pred - last_pred).cpu() + # err = (residual - h_residual_mean[j - opt.n_past]) + err = residual + err = np.square(err) # [batch, dim_feature] + err = err / h_residual_vars[j - opt.n_past][np.newaxis, ...] # [batch, dim_feature] / [1, dim_feature] + err = torch.sum(err, axis=1) # -> [batch,] + # if len(err.shape) == 2: # [batch, dim_feature] + # err = err[..., np.newaxis] # make err into a vector [batch, dim_feature, 1] + # # print(cov_inv[j - opt.n_past].shape) + # # print(cov_inv[j - opt.n_past]) + # # print(err.shape) + # print(err) + # print(np.diag(cov_inv[j - opt.n_past][0])) + # quit() + # mahanlanobis_dist = np.matmul(cov_inv[j - opt.n_past], err).transpose(2, 1) + # # print(mahanlanobis_dist.shape) + # mahanlanobis_dist = np.matmul(mahanlanobis_dist, err) # [batch, 1, 1] + # # print(mahanlanobis_dist.shape) + # # print(np.sqrt(mahanlanobis_dist)) + # print(mahanlanobis_dist) + # quit() + # mahanlanobis_dist = torch.mean(mahanlanobis_dist) # scalar + + + # err = torch.mean(err, dim=1) # average errs at time j over the dimensions of h + # err = torch.mean(err, dim=0) # average errs at time j over the batch + # h_residual_var[j - opt.n_past] += err.detach() + h_residual_var[j - opt.n_past] += torch.mean(err, axis=0) + last_pred = h_prior_pred.detach() + + h_residual_sd = torch.sqrt(h_residual_var).cpu() + + h_residual_sd_filtered = - h_residual_sd[:-2] + 2 * h_residual_sd[1:-1] - h_residual_sd[2:] + + + print(h_residual_var) + for j in range(len(frames)): + frame = np.uint8(np.minimum(frames[j][0][0], 1) * 255) + cv2.imshow('frame', frame) + cv2.waitKey(10) + fig = plt.figure() + # plt.ylim(0, 10.0) + plt.xlabel('Time') + plt.title("sqrt of average squared dimensional error") + plt.bar(np.arange(len(h_residual_sd)), h_residual_sd) + fig.canvas.draw() + img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, + sep='') + img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + plt.xlabel('Time') + plt.title("filtered sqrt of average squared dimensional error") + plt.bar(np.arange(len(h_residual_sd_filtered)), h_residual_sd_filtered) + fig.canvas.draw() + img2 = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, + sep='') + img2 = img2.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2BGR) + + cv2.imshow("plot", img) + cv2.imshow("plot2", img2) + + k = cv2.waitKey(0) + if k == ord('q'): + quit() + # plt.savefig('implausibility_test.png') + # h_residual_var /= epoch_size # get the mean error vector per time + # h_residual_sd = torch.sqrt(h_residual_var) + print('Last i = {}'.format(i)) + + + # H_err_cov = torch.tensor(np.zeros((opt.n_future, 128, 128), dtype=np.float32), requires_grad=False, + # device=torch.device('cuda:0')) + + # plot some stuff + h_residual_mean_norm = torch.norm(h_residual_mean, dim=1).cpu() + plt.subplot(2, 1, 1) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Norm of the residual mean") + plt.bar(np.arange(len(h_residual_mean_norm)), h_residual_mean_norm) + + plt.subplot(2, 1, 2) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Average dimensional sqrt(variance) of the residual") + plt.bar(np.arange(len(h_residual_sd.cpu())), h_residual_sd.cpu()) + plt.savefig('h_residual_mean.png') + + +f = open('mcs_stats.json', 'r') +mcs_stats_dict = json.load(f) +do_implasubility_test(np.array(mcs_stats_dict['mean']), np.array(mcs_stats_dict['vars'])) diff --git a/do_mcs_implausblility_test_posterior.py b/do_mcs_implausblility_test_posterior.py new file mode 100644 index 0000000..141e92d --- /dev/null +++ b/do_mcs_implausblility_test_posterior.py @@ -0,0 +1,396 @@ +import glob + +import cv2 +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import matplotlib.pyplot as plt +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.004, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=1, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=1000, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=195, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, + help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=8, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', + help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') +opt = parser.parse_args() +BATCH_SIZE = opt.batch_size +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + optimizer = opt.optimizer + model_dir = opt.model_dir + niter = opt.niter + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.optimizer = optimizer + opt.model_dir = model_dir + opt.log_dir = '%s/continued' % opt.log_dir +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +import models.lstm as lstm_models + +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = BATCH_SIZE + posterior = saved_model['posterior'] + posterior.batch_size = BATCH_SIZE +else: + raise ValueError('Please specify --model_dir') + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, opt.channels) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() + + +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +opt.batch_size = BATCH_SIZE +opt.epoch_size = 1000 +opt.n_future = 195 +print(opt) + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt, sequential=True, implausible=True) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + drop_last=True, + pin_memory=True,) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +training_batch_generator = get_training_batch() +training_batch_generator_2 = get_training_batch() + + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +testing_batch_generator = get_testing_batch() + + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + h = h.detach() + elif i < opt.n_past: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + x_in = x[i] + gen_seq[s].append(x_in) + else: + h = frame_predictor(h).detach() + x_in = decoder([h, skip]).detach() + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [[] for t in range(opt.n_eval)] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq = [] + gen_seq.append(x[0]) + x_in = x[0] + h_seq = [encoder(x[i]) for i in range(opt.n_past + opt.n_future)] + for i in range(1, opt.n_past + opt.n_future): + h_target = h_seq[i][0].detach() + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + else: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + gen_seq.append(x[i]) + else: + h_pred = frame_predictor(h).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + + to_plot = [] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + row = [] + for t in range(opt.n_past + opt.n_future): + row.append(gen_seq[t][i]) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +def do_implasubility_test(h_residual_mean, h_residual_vars): + epoch_size = 1000 // opt.batch_size # we have a total of 1000 videos + cov_inv = [np.diag(1.0 / h_var)[np.newaxis, ...] for h_var in h_residual_vars] # we assume the covariance matrix is diagonal and do the inverse for each time + + frame_predictor.eval() + posterior.eval() + encoder.eval() + decoder.eval() + progress = progressbar.ProgressBar(max_value=epoch_size).start() + + + for i in range(epoch_size): + h_residual_var = torch.tensor(np.zeros(opt.n_future, dtype=np.float32), requires_grad=False, + device=torch.device('cpu')) + progress.update(i + 1) + try: + x = next(training_batch_generator) + frames = [frame.cpu() for frame in x] + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + h_posterior = [encoder(x[j]) for j in range(opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) + last_pred = None + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + start = 1 + for j in range(start, opt.n_past + opt.n_future): + h_target = h_posterior[j][0].detach() + if opt.last_frame_skip or j < opt.n_past + start - 1: + h, skip = h_posterior[j - 1] + else: + h = h_posterior[j - 1][0].detach() + # we predict h_t from h_{t-1} + h_prior_pred = frame_predictor(h.detach()) + h_posterior_pred = posterior(h_target.detach()) + + if j >= opt.n_past + start - 1: + # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h + # squared_diff = torch.square(h_res - h_residual_mean[j - opt.n_past]) + # squared_diff = torch.mean(squared_diff, dim=1) # average squared residuals at time j over the dimensions of h + # squared_diff = torch.mean(squared_diff, dim=0) # average squared residuals at time j over the batch + # h_residual_var[j - opt.n_past] += squared_diff + + # residual = (h_prior_pred - h_posterior_pred).cpu().detach() + residual = (last_post_pred - h_posterior_pred).cpu().detach() + # err = (residual - h_residual_mean[j - opt.n_past]) + err = residual + err = np.square(err) # [batch, dim_feature] + # err = err / h_residual_vars[j - opt.n_past][np.newaxis, ...] # [batch, dim_feature] / [1, dim_feature] + err = torch.mean(err, axis=1) # -> [batch,] + # if len(err.shape) == 2: # [batch, dim_feature] + # err = err[..., np.newaxis] # make err into a vector [batch, dim_feature, 1] + # # print(cov_inv[j - opt.n_past].shape) + # # print(cov_inv[j - opt.n_past]) + # # print(err.shape) + # print(err) + # print(np.diag(cov_inv[j - opt.n_past][0])) + # quit() + # mahanlanobis_dist = np.matmul(cov_inv[j - opt.n_past], err).transpose(2, 1) + # # print(mahanlanobis_dist.shape) + # mahanlanobis_dist = np.matmul(mahanlanobis_dist, err) # [batch, 1, 1] + # # print(mahanlanobis_dist.shape) + # # print(np.sqrt(mahanlanobis_dist)) + # print(mahanlanobis_dist) + # quit() + # mahanlanobis_dist = torch.mean(mahanlanobis_dist) # scalar + + + # err = torch.mean(err, dim=1) # average errs at time j over the dimensions of h + # err = torch.mean(err, dim=0) # average errs at time j over the batch + # h_residual_var[j - opt.n_past] += err.detach() + h_residual_var[j - opt.n_past] += torch.mean(err, axis=0) + last_post_pred = h_posterior_pred.detach() + + h_residual_sd = torch.sqrt(h_residual_var).cpu() + + h_residual_sd_filtered = - h_residual_sd[:-2] + 2 * h_residual_sd[1:-1] - h_residual_sd[2:] + + + # print(h_residual_var) + for j in range(len(frames)): + frame = np.uint8(np.minimum(frames[j][0][0], 1) * 255) + cv2.imshow('frame', frame) + cv2.waitKey(15) + fig = plt.figure() + plt.ylim(0, 2.0) + plt.xlabel('Time') + plt.title("sqrt of average squared dimensional error") + plt.bar(np.arange(len(h_residual_sd)), h_residual_sd) + fig.canvas.draw() + img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, + sep='') + img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + plt.xlabel('Time') + plt.title("filtered sqrt of average squared dimensional error") + plt.bar(np.arange(len(h_residual_sd_filtered)), h_residual_sd_filtered) + fig.canvas.draw() + img2 = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, + sep='') + img2 = img2.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2BGR) + + cv2.imshow("plot", img) + cv2.imshow("plot2", img2) + + k = cv2.waitKey(0) + if k == ord('q'): + quit() + # plt.savefig('implausibility_test.png') + # h_residual_var /= epoch_size # get the mean error vector per time + # h_residual_sd = torch.sqrt(h_residual_var) + print('Last i = {}'.format(i)) + + + # H_err_cov = torch.tensor(np.zeros((opt.n_future, 128, 128), dtype=np.float32), requires_grad=False, + # device=torch.device('cuda:0')) + + # plot some stuff + h_residual_mean_norm = torch.norm(h_residual_mean, dim=1).cpu() + plt.subplot(2, 1, 1) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Norm of the residual mean") + plt.bar(np.arange(len(h_residual_mean_norm)), h_residual_mean_norm) + + plt.subplot(2, 1, 2) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Average dimensional sqrt(variance) of the residual") + plt.bar(np.arange(len(h_residual_sd.cpu())), h_residual_sd.cpu()) + plt.savefig('h_residual_mean.png') + + +f = open('mcs_stats_post.json', 'r') +mcs_stats_dict = json.load(f) +do_implasubility_test(np.array(mcs_stats_dict['mean']), np.array(mcs_stats_dict['vars'])) diff --git a/do_mcs_stats.py b/do_mcs_stats.py index 9e64c56..3a74423 100644 --- a/do_mcs_stats.py +++ b/do_mcs_stats.py @@ -13,6 +13,7 @@ import progressbar import numpy as np import matplotlib.pyplot as plt +import json parser = argparse.ArgumentParser() parser.add_argument('--lr', default=0.004, type=float, help='learning rate') @@ -259,7 +260,7 @@ def plot_rec(x, epoch): def do_stats(): - epoch_size = 1000 // opt.batch_size # we have a total of 1000 videos + epoch_size = 1000 // opt.batch_size // 4 # we have a total of 1000 videos frame_predictor.eval() encoder.eval() @@ -307,6 +308,8 @@ def do_stats(): pin_memory=True) h_residual_var = torch.tensor(np.zeros(opt.n_future, dtype=np.float32), requires_grad=False, device=torch.device('cuda:0')) + h_residual_vars = torch.tensor(np.zeros((opt.n_future, 128), dtype=np.float32), requires_grad=False, + device=torch.device('cuda:0')) for i in range(epoch_size): progress.update(i + 1) @@ -336,16 +339,19 @@ def do_stats(): this_minus_last_pred = h_prior_pred - last_pred squared_err = torch.square(this_minus_last_pred - h_residual_mean[j - opt.n_past]) - squared_err = torch.mean(squared_err, dim=1) # average errs at time j over the dimensions of h squared_err = torch.mean(squared_err, dim=0) # average errs at time j over the batch + h_residual_vars[j - opt.n_past] += squared_err.detach() + squared_err = torch.mean(squared_err, dim=0) # average errs at time j over the dimensions of h h_residual_var[j - opt.n_past] += squared_err.detach() - last_pred = h_prior_pred + last_pred = h_prior_pred.detach() h_residual_var /= epoch_size # get the mean error vector per time + h_residual_vars /= epoch_size h_residual_sd = torch.sqrt(h_residual_var) print('Last i = {}'.format(i)) print('sd of h residual: ', h_residual_sd) print('var of h residual: ', h_residual_var) print('norm(mean of h residual)', torch.norm(h_residual_mean, dim=1)) + print('vars of h residual: ', h_residual_vars) # H_err_cov = torch.tensor(np.zeros((opt.n_future, 128, 128), dtype=np.float32), requires_grad=False, # device=torch.device('cuda:0')) @@ -363,7 +369,12 @@ def do_stats(): plt.tight_layout() plt.title("Average dimensional sqrt(variance) of the residual") plt.bar(np.arange(len(h_residual_sd.cpu())), h_residual_sd.cpu()) - plt.savefig('h_residual_mean.png') + plt.savefig('h_residual_mean_1.png') + + stats_dict = {'mean': h_residual_mean.cpu().tolist(), 'var': h_residual_var.cpu().tolist(), + 'vars': h_residual_vars.cpu().tolist()} + f = open('mcs_stats.json', 'w') + json.dump(stats_dict, f) do_stats() diff --git a/do_mcs_stats_posterior.py b/do_mcs_stats_posterior.py new file mode 100644 index 0000000..5ccec1c --- /dev/null +++ b/do_mcs_stats_posterior.py @@ -0,0 +1,377 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import matplotlib.pyplot as plt +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.004, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=1, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=1000, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=195, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, + help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=8, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', + help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') +opt = parser.parse_args() +BATCH_SIZE = opt.batch_size +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.log_dir = '%s/continued' % opt.log_dir +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +import models.lstm as lstm_models + +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = BATCH_SIZE + posterior = saved_model['posterior'] + posterior.batch_size = BATCH_SIZE +else: + frame_predictor = lstm_models.lstm(opt.g_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, opt.channels) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() + + +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +opt.batch_size = BATCH_SIZE +opt.epoch_size = 1000 +opt.n_future = 195 +print(opt) + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt, sequential=True) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + drop_last=True, + pin_memory=True) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +training_batch_generator = get_training_batch() +training_batch_generator_2 = get_training_batch() + + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +testing_batch_generator = get_testing_batch() + + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + h = h.detach() + elif i < opt.n_past: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + x_in = x[i] + gen_seq[s].append(x_in) + else: + h = frame_predictor(h).detach() + x_in = decoder([h, skip]).detach() + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [[] for t in range(opt.n_eval)] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq = [] + gen_seq.append(x[0]) + x_in = x[0] + h_seq = [encoder(x[i]) for i in range(opt.n_past + opt.n_future)] + for i in range(1, opt.n_past + opt.n_future): + h_target = h_seq[i][0].detach() + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + else: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + gen_seq.append(x[i]) + else: + h_pred = frame_predictor(h).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + + to_plot = [] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + row = [] + for t in range(opt.n_past + opt.n_future): + row.append(gen_seq[t][i]) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +def do_stats(): + epoch_size = 1000 // opt.batch_size // 1 # we have a total of 1000 videos + + frame_predictor.eval() + encoder.eval() + decoder.eval() + progress = progressbar.ProgressBar(max_value=epoch_size).start() + h_residual_mean = torch.tensor(np.zeros((opt.n_future, 128), dtype=np.float32), requires_grad=False, + device=torch.device('cuda:0')) + i = 0 + for i in range(epoch_size): + progress.update(i + 1) + try: + x = next(training_batch_generator) + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + h_posterior = [encoder(x[j]) for j in range(opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) + last_pred = None + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + for j in range(1, opt.n_past + opt.n_future): + h_target = h_posterior[j][0].detach() + if opt.last_frame_skip or j < opt.n_past: + h, skip = h_posterior[j - 1] + else: + h = h_posterior[j - 1][0].detach() + # we predict h_t from h_{t-1} + h_prior_pred = frame_predictor(h).detach() + h_posterior_pred = posterior(h_target).detach() + + if j >= opt.n_past: + # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h + # h_res = h_prior_pred + # h_res = torch.mean(h_res, dim=0) # average errors at the same time j over the batch + # h_residual_mean[j - opt.n_past] += h_res + residual = h_prior_pred - h_posterior_pred + residual = torch.mean(residual, dim=0) + h_residual_mean[j - opt.n_past] += residual + h_residual_mean /= epoch_size # get the mean error vector per time + + # restart training dataset + global train_loader + train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + drop_last=True, + pin_memory=True) + h_residual_var = torch.tensor(np.zeros(opt.n_future, dtype=np.float32), requires_grad=False, + device=torch.device('cuda:0')) + h_residual_vars = torch.tensor(np.zeros((opt.n_future, 128), dtype=np.float32), requires_grad=False, + device=torch.device('cuda:0')) + + for i in range(epoch_size): + progress.update(i + 1) + try: + x = next(training_batch_generator_2) + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + h_posterior = [encoder(x[j]) for j in range(opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + for j in range(1, opt.n_past + opt.n_future): + h_target = h_posterior[j][0].detach() + if opt.last_frame_skip or j < opt.n_past: + h, skip = h_posterior[j - 1] + else: + h = h_posterior[j - 1][0].detach() + # we predict h_t from h_{t-1} + h_prior_pred = frame_predictor(h).detach() + h_posterior_pred = posterior(h_target).detach() + + if j >= opt.n_past: + # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h + # squared_diff = torch.square(h_res - h_residual_mean[j - opt.n_past]) + # squared_diff = torch.mean(squared_diff, dim=1) # average squared residuals at time j over the dimensions of h + # squared_diff = torch.mean(squared_diff, dim=0) # average squared residuals at time j over the batch + # h_residual_var[j - opt.n_past] += squared_diff + + residual = h_prior_pred - h_posterior_pred + squared_err = torch.square(residual - h_residual_mean[j - opt.n_past]) + # squared_err = torch.square(residual) + squared_err = torch.mean(squared_err, dim=0) # average errs at time j over the batch + h_residual_vars[j - opt.n_past] += squared_err.detach() + squared_err = torch.mean(squared_err, dim=0) # average errs at time j over the dimensions of h + h_residual_var[j - opt.n_past] += squared_err.detach() + h_residual_var /= epoch_size + h_residual_vars /= epoch_size + h_residual_sd = torch.sqrt(h_residual_var) + print('Last i = {}'.format(i)) + print('sd of h residual: ', h_residual_sd) + print('var of h residual: ', h_residual_var) + print('norm(mean of h residual)', torch.norm(h_residual_mean, dim=1)) + print('vars of h residual: ', h_residual_vars) + + # H_err_cov = torch.tensor(np.zeros((opt.n_future, 128, 128), dtype=np.float32), requires_grad=False, + # device=torch.device('cuda:0')) + + # plot some stuff + h_residual_mean_norm = torch.norm(h_residual_mean, dim=1).cpu() + plt.subplot(2, 1, 1) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Norm of the residual mean") + plt.bar(np.arange(len(h_residual_mean_norm)), h_residual_mean_norm) + + plt.subplot(2, 1, 2) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Average dimensional sqrt(variance) of the residual") + plt.bar(np.arange(len(h_residual_sd.cpu())), h_residual_sd.cpu()) + plt.savefig('post_h_residual.png') + + stats_dict = {'mean': h_residual_mean.cpu().tolist(), 'var': h_residual_var.cpu().tolist(), + 'vars': h_residual_vars.cpu().tolist()} + f = open('mcs_stats_post.json', 'w') + json.dump(stats_dict, f) + + +do_stats() diff --git a/train_svg_nonstochastic_posterior.py b/train_svg_nonstochastic_posterior.py new file mode 100644 index 0000000..1dce6a7 --- /dev/null +++ b/train_svg_nonstochastic_posterior.py @@ -0,0 +1,373 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.008, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=128, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=400, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=15, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') + + +opt = parser.parse_args() +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + optimizer = opt.optimizer + model_dir = opt.model_dir + niter = opt.niter + lr = opt.lr + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.optimizer = optimizer + opt.model_dir = model_dir + opt.lr = lr + opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) +else: + name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + elif opt.dataset == 'mcs': + opt.log_dir = '%s/%s/%s/%s' % (opt.log_dir, opt.dataset, opt.mcs_task, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=True) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=True) + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + posterior = saved_model['posterior'] +else: + frame_predictor = lstm_models.lstm(opt.g_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + posterior = lstm_models.lstm(opt.g_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + posterior.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, opt.channels) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + with torch.no_grad(): + h = encoder(x_in) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h_seq[i-1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + x_in = x[i] + gen_seq[s].append(x_in) + else: + h = frame_predictor(h).detach() + x_in = decoder([h, skip]) + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + + gen_seq = [] + gen_seq.append(x[0]) + gen_seq_post = [] + gen_seq_post.append(x[0]) + x_in = x[0] + h_seq = [encoder(x[i]) for i in range(opt.n_past+opt.n_future)] + for i in range(1, opt.n_past+opt.n_future): + h_target = h_seq[i][0].detach() + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i-1] + else: + h, _ = h_seq[i-1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + gen_seq.append(x[i]) + gen_seq_post.append(x[i]) + else: + h_pred = frame_predictor(h).detach() + x_pred = decoder([h_pred, skip]).detach() + h_posterior = posterior(h_target).detach() + x_posterior = decoder([h_posterior, skip]).detach() + gen_seq.append(x_pred) + gen_seq_post.append(x_posterior) + + to_plot = [] + nrow = min(opt.batch_size * 3, 25 * 3) + for i in range(nrow): + row_gt = [] + row_post = [] + row = [] + for t in range(opt.n_past+opt.n_future): + row_gt.append(x[t][i]) + row_post.append(gen_seq_post[t][i]) + row.append(gen_seq[t][i]) + to_plot.append(row_gt) + to_plot.append(row_post) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + posterior.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + + h_seq = [encoder(x[i]) for i in range(opt.n_past+opt.n_future)] + mse = 0 + mse_post = 0 + mse_diff_post = 0 + for i in range(1, opt.n_past+opt.n_future): + h_target = h_seq[i][0] + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i-1] + else: + h = h_seq[i-1][0] + h_pred = frame_predictor(h) + x_pred = decoder([h_pred, skip]) + h_posterior = posterior(h_target) + x_posterior = decoder([h_posterior, skip]) + mse += mse_criterion(x_pred, x[i]) + mse_post += mse_criterion(x_posterior, x[i]) + mse_diff_post += opt.gamma * torch.mean(torch.square(h_posterior.detach() - h_pred)) + + loss = mse + mse_post + mse_diff_post + loss.backward() + + frame_predictor_optimizer.step() + posterior_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + N = opt.n_past+opt.n_future + return mse.data.cpu().numpy()/N, mse_post.data.cpu().numpy()/N, mse_diff_post.data.cpu().numpy()/N + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + posterior.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_mse_posterior = 0 + epoch_posterior_diff = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + # opt.epoch_size = 10 + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse, mse_posterior, posterior_diff = train(x) + epoch_mse += mse + epoch_mse_posterior += mse_posterior + epoch_posterior_diff += posterior_diff + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f, %.5f posterior | posterior diff loss: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_posterior/opt.epoch_size, epoch_posterior_diff/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + posterior.eval() + encoder.eval() + decoder.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'posterior': posterior, + 'opt': opt}, + '%s/model_e%02d.pth' % (opt.log_dir, epoch)) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + diff --git a/utils.py b/utils.py index 0635a37..371b9ab 100755 --- a/utils.py +++ b/utils.py @@ -29,7 +29,7 @@ def torch_tensor_to_img(tensor): image_array = tensor.numpy() image_array -= np.min(image_array) image_array = np.minimum(image_array, 1.0) - print(image_array.shape) + # print(image_array.shape) image_array = np.transpose(image_array, (1, 2, 0)) img = None if image_array.shape[2] == 3: # 3-channel image @@ -40,7 +40,7 @@ def torch_tensor_to_img(tensor): return img -def load_dataset(opt, sequential=None): +def load_dataset(opt, sequential=None, implausible=None): train_data = None test_data = None if opt.dataset == 'smmnist': @@ -91,14 +91,16 @@ def load_dataset(opt, sequential=None): seq_len=opt.n_past+opt.n_future, image_size=opt.image_width, task=opt.mcs_task, - sequential=sequential) + sequential=sequential, + implausible=implausible) test_data = MCS( train=False, data_root=opt.data_root, seq_len=opt.n_eval, image_size=opt.image_width, task=opt.mcs_task, - sequential=sequential) + sequential=sequential, + implausible=implausible) return train_data, test_data From 473c776fa33caec9b724c5d00ec62e0d4d4af2ba Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Fri, 23 Jul 2021 18:06:14 -0400 Subject: [PATCH 05/26] updating prior vs posterior model --- _do_mcs_implausblility_test_posterior.py | 393 +++++++++++++++++++++++ _do_mcs_stats_posterior.py | 377 ++++++++++++++++++++++ _train_svg_nonstochastic_posterior.py | 373 +++++++++++++++++++++ data/mcs.py | 11 +- do_mcs_implausblility_test_posterior.py | 204 +++++++----- do_mcs_stats_posterior.py | 114 ++++--- train_svg_nonstochastic_posterior.py | 146 ++++++--- utils.py | 28 ++ 8 files changed, 1456 insertions(+), 190 deletions(-) create mode 100644 _do_mcs_implausblility_test_posterior.py create mode 100644 _do_mcs_stats_posterior.py create mode 100644 _train_svg_nonstochastic_posterior.py diff --git a/_do_mcs_implausblility_test_posterior.py b/_do_mcs_implausblility_test_posterior.py new file mode 100644 index 0000000..35ec6cb --- /dev/null +++ b/_do_mcs_implausblility_test_posterior.py @@ -0,0 +1,393 @@ +import glob + +import cv2 +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import matplotlib.pyplot as plt +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.004, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=1, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=1000, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=195, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, + help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=8, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', + help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') +opt = parser.parse_args() +BATCH_SIZE = opt.batch_size +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + optimizer = opt.optimizer + model_dir = opt.model_dir + niter = opt.niter + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.optimizer = optimizer + opt.model_dir = model_dir + opt.log_dir = '%s/continued' % opt.log_dir +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +import models.lstm as lstm_models + +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = BATCH_SIZE + posterior = saved_model['posterior'] + posterior.batch_size = BATCH_SIZE +else: + raise ValueError('Please specify --model_dir') + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() + + +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +opt.batch_size = BATCH_SIZE +opt.epoch_size = 1000 +opt.n_future = 195 +print(opt) + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt, sequential=True, implausible=True) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + drop_last=True, + pin_memory=True,) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +training_batch_generator = get_training_batch() +training_batch_generator_2 = get_training_batch() + + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +testing_batch_generator = get_testing_batch() + + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + h = h.detach() + elif i < opt.n_past: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + x_in = x[i] + gen_seq[s].append(x_in) + else: + h = frame_predictor(h).detach() + x_in = decoder([h, skip]).detach() + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [[] for t in range(opt.n_eval)] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq = [] + gen_seq.append(x[0]) + x_in = x[0] + h_seq = [encoder(x[i]) for i in range(opt.n_past + opt.n_future)] + for i in range(1, opt.n_past + opt.n_future): + h_target = h_seq[i][0].detach() + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + else: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + gen_seq.append(x[i]) + else: + h_pred = frame_predictor(h).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + + to_plot = [] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + row = [] + for t in range(opt.n_past + opt.n_future): + row.append(gen_seq[t][i]) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +def do_implasubility_test(h_residual_mean, h_residual_vars): + epoch_size = 1000 // opt.batch_size # we have a total of 1000 videos + cov_inv = [np.diag(1.0 / h_var)[np.newaxis, ...] for h_var in h_residual_vars] # we assume the covariance matrix is diagonal and do the inverse for each time + + frame_predictor.eval() + posterior.eval() + encoder.eval() + decoder.eval() + progress = progressbar.ProgressBar(max_value=epoch_size).start() + + + for i in range(epoch_size): + h_residual_var = torch.tensor(np.zeros(opt.n_future, dtype=np.float32), requires_grad=False, + device=torch.device('cpu')) + progress.update(i + 1) + try: + x = next(training_batch_generator) + frames = [frame.cpu() for frame in x] + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + h_posterior = [encoder(x[j]) for j in range(opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) + last_pred = None + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + start = 1 + for j in range(start, opt.n_past + opt.n_future): + h_target = h_posterior[j][0].detach() + if opt.last_frame_skip or j < opt.n_past + start - 1: + h, skip = h_posterior[j - 1] + else: + h = h_posterior[j - 1][0].detach() + # we predict h_t from h_{t-1} + h_prior_pred = frame_predictor(h.detach()) + h_posterior_pred = posterior(h_target.detach()) + + if j >= opt.n_past + start - 1: + # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h + # squared_diff = torch.square(h_res - h_residual_mean[j - opt.n_past]) + # squared_diff = torch.mean(squared_diff, dim=1) # average squared residuals at time j over the dimensions of h + # squared_diff = torch.mean(squared_diff, dim=0) # average squared residuals at time j over the batch + # h_residual_var[j - opt.n_past] += squared_diff + + # residual = (h_prior_pred - h_posterior_pred).cpu().detach() + residual = (last_post_pred - h_posterior_pred).cpu().detach() + # err = (residual - h_residual_mean[j - opt.n_past]) + err = residual + err = np.square(err) # [batch, dim_feature] + # err = err / h_residual_vars[j - opt.n_past][np.newaxis, ...] # [batch, dim_feature] / [1, dim_feature] + err = torch.mean(err, axis=1) # -> [batch,] + # if len(err.shape) == 2: # [batch, dim_feature] + # err = err[..., np.newaxis] # make err into a vector [batch, dim_feature, 1] + # # print(cov_inv[j - opt.n_past].shape) + # # print(cov_inv[j - opt.n_past]) + # # print(err.shape) + # print(err) + # print(np.diag(cov_inv[j - opt.n_past][0])) + # quit() + # mahanlanobis_dist = np.matmul(cov_inv[j - opt.n_past], err).transpose(2, 1) + # # print(mahanlanobis_dist.shape) + # mahanlanobis_dist = np.matmul(mahanlanobis_dist, err) # [batch, 1, 1] + # # print(mahanlanobis_dist.shape) + # # print(np.sqrt(mahanlanobis_dist)) + # print(mahanlanobis_dist) + # quit() + # mahanlanobis_dist = torch.mean(mahanlanobis_dist) # scalar + + + # err = torch.mean(err, dim=1) # average errs at time j over the dimensions of h + # err = torch.mean(err, dim=0) # average errs at time j over the batch + # h_residual_var[j - opt.n_past] += err.detach() + h_residual_var[j - opt.n_past] += torch.mean(err, axis=0) + last_post_pred = h_posterior_pred.detach() + + h_residual_sd = torch.sqrt(h_residual_var).cpu() + + h_residual_sd_filtered = - h_residual_sd[:-2] + 2 * h_residual_sd[1:-1] - h_residual_sd[2:] + + + # print(h_residual_var) + for j in range(len(frames)): + frame = np.uint8(np.minimum(frames[j][0][0], 1) * 255) + cv2.imshow('frame', frame) + cv2.waitKey(15) + fig = plt.figure() + plt.ylim(0, 2.0) + plt.xlabel('Time') + plt.title("sqrt of average squared dimensional error") + plt.bar(np.arange(len(h_residual_sd)), h_residual_sd) + fig.canvas.draw() + img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, + sep='') + img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + plt.xlabel('Time') + plt.title("filtered sqrt of average squared dimensional error") + plt.bar(np.arange(len(h_residual_sd_filtered)), h_residual_sd_filtered) + fig.canvas.draw() + img2 = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, + sep='') + img2 = img2.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2BGR) + + cv2.imshow("plot", img) + cv2.imshow("plot2", img2) + + k = cv2.waitKey(0) + if k == ord('q'): + quit() + # plt.savefig('implausibility_test.png') + # h_residual_var /= epoch_size # get the mean error vector per time + # h_residual_sd = torch.sqrt(h_residual_var) + print('Last i = {}'.format(i)) + + + # H_err_cov = torch.tensor(np.zeros((opt.n_future, 128, 128), dtype=np.float32), requires_grad=False, + # device=torch.device('cuda:0')) + + # plot some stuff + h_residual_mean_norm = torch.norm(h_residual_mean, dim=1).cpu() + plt.subplot(2, 1, 1) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Norm of the residual mean") + plt.bar(np.arange(len(h_residual_mean_norm)), h_residual_mean_norm) + + plt.subplot(2, 1, 2) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Average dimensional sqrt(variance) of the residual") + plt.bar(np.arange(len(h_residual_sd.cpu())), h_residual_sd.cpu()) + plt.savefig('h_residual_mean.png') + + +f = open('mcs_stats_post.json', 'r') +mcs_stats_dict = json.load(f) +do_implasubility_test(np.array(mcs_stats_dict['mean']), np.array(mcs_stats_dict['vars'])) diff --git a/_do_mcs_stats_posterior.py b/_do_mcs_stats_posterior.py new file mode 100644 index 0000000..5ccec1c --- /dev/null +++ b/_do_mcs_stats_posterior.py @@ -0,0 +1,377 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import matplotlib.pyplot as plt +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.004, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=1, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=1000, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=195, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, + help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=8, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', + help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') +opt = parser.parse_args() +BATCH_SIZE = opt.batch_size +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.log_dir = '%s/continued' % opt.log_dir +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +import models.lstm as lstm_models + +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = BATCH_SIZE + posterior = saved_model['posterior'] + posterior.batch_size = BATCH_SIZE +else: + frame_predictor = lstm_models.lstm(opt.g_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, opt.channels) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() + + +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +opt.batch_size = BATCH_SIZE +opt.epoch_size = 1000 +opt.n_future = 195 +print(opt) + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt, sequential=True) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + drop_last=True, + pin_memory=True) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +training_batch_generator = get_training_batch() +training_batch_generator_2 = get_training_batch() + + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + +testing_batch_generator = get_testing_batch() + + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + h = h.detach() + elif i < opt.n_past: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + x_in = x[i] + gen_seq[s].append(x_in) + else: + h = frame_predictor(h).detach() + x_in = decoder([h, skip]).detach() + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [[] for t in range(opt.n_eval)] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq = [] + gen_seq.append(x[0]) + x_in = x[0] + h_seq = [encoder(x[i]) for i in range(opt.n_past + opt.n_future)] + for i in range(1, opt.n_past + opt.n_future): + h_target = h_seq[i][0].detach() + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + else: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + gen_seq.append(x[i]) + else: + h_pred = frame_predictor(h).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + + to_plot = [] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + row = [] + for t in range(opt.n_past + opt.n_future): + row.append(gen_seq[t][i]) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +def do_stats(): + epoch_size = 1000 // opt.batch_size // 1 # we have a total of 1000 videos + + frame_predictor.eval() + encoder.eval() + decoder.eval() + progress = progressbar.ProgressBar(max_value=epoch_size).start() + h_residual_mean = torch.tensor(np.zeros((opt.n_future, 128), dtype=np.float32), requires_grad=False, + device=torch.device('cuda:0')) + i = 0 + for i in range(epoch_size): + progress.update(i + 1) + try: + x = next(training_batch_generator) + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + h_posterior = [encoder(x[j]) for j in range(opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) + last_pred = None + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + for j in range(1, opt.n_past + opt.n_future): + h_target = h_posterior[j][0].detach() + if opt.last_frame_skip or j < opt.n_past: + h, skip = h_posterior[j - 1] + else: + h = h_posterior[j - 1][0].detach() + # we predict h_t from h_{t-1} + h_prior_pred = frame_predictor(h).detach() + h_posterior_pred = posterior(h_target).detach() + + if j >= opt.n_past: + # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h + # h_res = h_prior_pred + # h_res = torch.mean(h_res, dim=0) # average errors at the same time j over the batch + # h_residual_mean[j - opt.n_past] += h_res + residual = h_prior_pred - h_posterior_pred + residual = torch.mean(residual, dim=0) + h_residual_mean[j - opt.n_past] += residual + h_residual_mean /= epoch_size # get the mean error vector per time + + # restart training dataset + global train_loader + train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + drop_last=True, + pin_memory=True) + h_residual_var = torch.tensor(np.zeros(opt.n_future, dtype=np.float32), requires_grad=False, + device=torch.device('cuda:0')) + h_residual_vars = torch.tensor(np.zeros((opt.n_future, 128), dtype=np.float32), requires_grad=False, + device=torch.device('cuda:0')) + + for i in range(epoch_size): + progress.update(i + 1) + try: + x = next(training_batch_generator_2) + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + h_posterior = [encoder(x[j]) for j in range(opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + for j in range(1, opt.n_past + opt.n_future): + h_target = h_posterior[j][0].detach() + if opt.last_frame_skip or j < opt.n_past: + h, skip = h_posterior[j - 1] + else: + h = h_posterior[j - 1][0].detach() + # we predict h_t from h_{t-1} + h_prior_pred = frame_predictor(h).detach() + h_posterior_pred = posterior(h_target).detach() + + if j >= opt.n_past: + # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h + # squared_diff = torch.square(h_res - h_residual_mean[j - opt.n_past]) + # squared_diff = torch.mean(squared_diff, dim=1) # average squared residuals at time j over the dimensions of h + # squared_diff = torch.mean(squared_diff, dim=0) # average squared residuals at time j over the batch + # h_residual_var[j - opt.n_past] += squared_diff + + residual = h_prior_pred - h_posterior_pred + squared_err = torch.square(residual - h_residual_mean[j - opt.n_past]) + # squared_err = torch.square(residual) + squared_err = torch.mean(squared_err, dim=0) # average errs at time j over the batch + h_residual_vars[j - opt.n_past] += squared_err.detach() + squared_err = torch.mean(squared_err, dim=0) # average errs at time j over the dimensions of h + h_residual_var[j - opt.n_past] += squared_err.detach() + h_residual_var /= epoch_size + h_residual_vars /= epoch_size + h_residual_sd = torch.sqrt(h_residual_var) + print('Last i = {}'.format(i)) + print('sd of h residual: ', h_residual_sd) + print('var of h residual: ', h_residual_var) + print('norm(mean of h residual)', torch.norm(h_residual_mean, dim=1)) + print('vars of h residual: ', h_residual_vars) + + # H_err_cov = torch.tensor(np.zeros((opt.n_future, 128, 128), dtype=np.float32), requires_grad=False, + # device=torch.device('cuda:0')) + + # plot some stuff + h_residual_mean_norm = torch.norm(h_residual_mean, dim=1).cpu() + plt.subplot(2, 1, 1) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Norm of the residual mean") + plt.bar(np.arange(len(h_residual_mean_norm)), h_residual_mean_norm) + + plt.subplot(2, 1, 2) + plt.xlabel('Time') + plt.tight_layout() + plt.title("Average dimensional sqrt(variance) of the residual") + plt.bar(np.arange(len(h_residual_sd.cpu())), h_residual_sd.cpu()) + plt.savefig('post_h_residual.png') + + stats_dict = {'mean': h_residual_mean.cpu().tolist(), 'var': h_residual_var.cpu().tolist(), + 'vars': h_residual_vars.cpu().tolist()} + f = open('mcs_stats_post.json', 'w') + json.dump(stats_dict, f) + + +do_stats() diff --git a/_train_svg_nonstochastic_posterior.py b/_train_svg_nonstochastic_posterior.py new file mode 100644 index 0000000..1dce6a7 --- /dev/null +++ b/_train_svg_nonstochastic_posterior.py @@ -0,0 +1,373 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.008, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=128, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=400, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=15, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') + + +opt = parser.parse_args() +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + optimizer = opt.optimizer + model_dir = opt.model_dir + niter = opt.niter + lr = opt.lr + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.optimizer = optimizer + opt.model_dir = model_dir + opt.lr = lr + opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) +else: + name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + elif opt.dataset == 'mcs': + opt.log_dir = '%s/%s/%s/%s' % (opt.log_dir, opt.dataset, opt.mcs_task, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=True) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=True) + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + posterior = saved_model['posterior'] +else: + frame_predictor = lstm_models.lstm(opt.g_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + posterior = lstm_models.lstm(opt.g_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + posterior.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, opt.channels) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + with torch.no_grad(): + h = encoder(x_in) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h_seq[i-1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + x_in = x[i] + gen_seq[s].append(x_in) + else: + h = frame_predictor(h).detach() + x_in = decoder([h, skip]) + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + + gen_seq = [] + gen_seq.append(x[0]) + gen_seq_post = [] + gen_seq_post.append(x[0]) + x_in = x[0] + h_seq = [encoder(x[i]) for i in range(opt.n_past+opt.n_future)] + for i in range(1, opt.n_past+opt.n_future): + h_target = h_seq[i][0].detach() + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i-1] + else: + h, _ = h_seq[i-1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + gen_seq.append(x[i]) + gen_seq_post.append(x[i]) + else: + h_pred = frame_predictor(h).detach() + x_pred = decoder([h_pred, skip]).detach() + h_posterior = posterior(h_target).detach() + x_posterior = decoder([h_posterior, skip]).detach() + gen_seq.append(x_pred) + gen_seq_post.append(x_posterior) + + to_plot = [] + nrow = min(opt.batch_size * 3, 25 * 3) + for i in range(nrow): + row_gt = [] + row_post = [] + row = [] + for t in range(opt.n_past+opt.n_future): + row_gt.append(x[t][i]) + row_post.append(gen_seq_post[t][i]) + row.append(gen_seq[t][i]) + to_plot.append(row_gt) + to_plot.append(row_post) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + posterior.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + + h_seq = [encoder(x[i]) for i in range(opt.n_past+opt.n_future)] + mse = 0 + mse_post = 0 + mse_diff_post = 0 + for i in range(1, opt.n_past+opt.n_future): + h_target = h_seq[i][0] + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i-1] + else: + h = h_seq[i-1][0] + h_pred = frame_predictor(h) + x_pred = decoder([h_pred, skip]) + h_posterior = posterior(h_target) + x_posterior = decoder([h_posterior, skip]) + mse += mse_criterion(x_pred, x[i]) + mse_post += mse_criterion(x_posterior, x[i]) + mse_diff_post += opt.gamma * torch.mean(torch.square(h_posterior.detach() - h_pred)) + + loss = mse + mse_post + mse_diff_post + loss.backward() + + frame_predictor_optimizer.step() + posterior_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + N = opt.n_past+opt.n_future + return mse.data.cpu().numpy()/N, mse_post.data.cpu().numpy()/N, mse_diff_post.data.cpu().numpy()/N + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + posterior.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_mse_posterior = 0 + epoch_posterior_diff = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + # opt.epoch_size = 10 + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse, mse_posterior, posterior_diff = train(x) + epoch_mse += mse + epoch_mse_posterior += mse_posterior + epoch_posterior_diff += posterior_diff + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f, %.5f posterior | posterior diff loss: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_posterior/opt.epoch_size, epoch_posterior_diff/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + posterior.eval() + encoder.eval() + decoder.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'posterior': posterior, + 'opt': opt}, + '%s/model_e%02d.pth' % (opt.log_dir, epoch)) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + diff --git a/data/mcs.py b/data/mcs.py index 9993996..89d12f9 100644 --- a/data/mcs.py +++ b/data/mcs.py @@ -11,10 +11,11 @@ class MCS(object): - def __init__(self, train, data_root, seq_len=20, image_size=64, task='ALL', sequential=None, implausible=False): + def __init__(self, train, data_root, seq_len=20, image_size=64, task='ALL', sequential=None, implausible=False, test_set=False): # if implausible is set to True, generates "fake" images by cutting out or repeating frames self.implausible = implausible self.data_root = '%s/mcs_videos_1000/processed/' % data_root + self.data_root = '%s/mcs_videos_test/processed/' % data_root if not os.path.exists(self.data_root): raise os.error('data/mcs.py: Data directory not found!') self.seq_len = seq_len @@ -60,8 +61,8 @@ def get_sequence(self, idx=None): # i is 0-indexed so we need to add 1 to i fname = frame_path + f'{i + 1:04d}.png' im = imageio.imread(fname) / 255. - gray = lambda rgb: np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) - im = gray(im)[..., np.newaxis] + # gray = lambda rgb: np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) + # im = gray(im)[..., np.newaxis] seq.append(im) return np.array(seq) @@ -74,8 +75,8 @@ def abnormalize_sequence(self, seq): # start = random.randint(100, 140) start = 110 vid_len = len(seq) - duration = 10 - implausibility_type = 0 + duration = 7 + implausibility_type = 1 if implausibility_type == 1: # object is invisible when/where it shouldn't be no_object_frame = seq[30] seq[start:start + duration] = no_object_frame diff --git a/do_mcs_implausblility_test_posterior.py b/do_mcs_implausblility_test_posterior.py index 141e92d..d477d72 100644 --- a/do_mcs_implausblility_test_posterior.py +++ b/do_mcs_implausblility_test_posterior.py @@ -30,7 +30,7 @@ parser.add_argument('--epoch_size', type=int, default=1000, help='epoch size') parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') parser.add_argument('--channels', default=1, type=int) -parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--dataset', default='mcs_test', help='dataset to train with') parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') parser.add_argument('--n_future', type=int, default=195, help='number of frames to predict') @@ -66,13 +66,18 @@ else: raise ValueError("Please specify the model to load with the --model_dir argument") +font = cv2.FONT_HERSHEY_SIMPLEX +bottomLeftCornerOfText = (30, 30) +fontScale = 0.6 +fontColor = (0, 0, 0) +thickness = 1 + print("Random Seed: ", opt.seed) random.seed(opt.seed) torch.manual_seed(opt.seed) torch.cuda.manual_seed_all(opt.seed) dtype = torch.cuda.FloatTensor - import models.lstm as lstm_models if opt.model_dir != '': @@ -80,6 +85,8 @@ frame_predictor.batch_size = BATCH_SIZE posterior = saved_model['posterior'] posterior.batch_size = BATCH_SIZE + prior = saved_model['prior'] + prior.batch_size = BATCH_SIZE else: raise ValueError('Please specify --model_dir') @@ -100,10 +107,7 @@ decoder = saved_model['decoder'] encoder = saved_model['encoder'] else: - encoder = model.encoder(opt.g_dim, opt.channels) - decoder = model.decoder(opt.g_dim, opt.channels) - encoder.apply(utils.init_weights) - decoder.apply(utils.init_weights) + raise ValueError("Please specify the model to load with the --model_dir argument") # --------- loss functions ------------------------------------ mse_criterion = nn.MSELoss() @@ -129,13 +133,19 @@ def kl_criterion(mu, logvar): print(opt) # --------- load a dataset ------------------------------------ -train_data, test_data = utils.load_dataset(opt, sequential=True, implausible=True) - +train_data_implausible, test_data_implausible = utils.load_dataset(opt, sequential=True, implausible=True) +train_data, test_data = utils.load_dataset(opt, sequential=True, implausible=False) + +train_loader_implausible = DataLoader(train_data_implausible, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + drop_last=True, + pin_memory=True, ) train_loader = DataLoader(train_data, num_workers=opt.data_threads, batch_size=opt.batch_size, drop_last=True, - pin_memory=True,) + pin_memory=True, ) test_loader = DataLoader(test_data, num_workers=opt.data_threads, batch_size=opt.batch_size, @@ -151,8 +161,15 @@ def get_training_batch(): yield batch +def get_training_batch_implausible(): + while True: + for sequence in train_loader_implausible: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch + + training_batch_generator = get_training_batch() -training_batch_generator_2 = get_training_batch() +training_batch_generator_implausible = get_training_batch_implausible() def get_testing_batch(): @@ -255,23 +272,36 @@ def plot_rec(x, epoch): utils.save_tensors_image(fname, to_plot) -def do_implasubility_test(h_residual_mean, h_residual_vars): - epoch_size = 1000 // opt.batch_size # we have a total of 1000 videos - cov_inv = [np.diag(1.0 / h_var)[np.newaxis, ...] for h_var in h_residual_vars] # we assume the covariance matrix is diagonal and do the inverse for each time - +def do_implasubility_test(z_residual_mean, z_residual_cov, visualize=True): + epoch_size = 199 // opt.batch_size # we have a total of 1000 videos + cov_inv = [np.linalg.pinv(cov, hermitian=True) for cov in + z_residual_cov] # we assume the covariance matrix is diagonal and do the inverse for each time + # cov_inv = [np.linalg.inv(np.diag(np.diag(cov))) for cov in z_residual_cov] + # cov_inv = [np.eye(32) for cov in z_residual_cov] frame_predictor.eval() posterior.eval() encoder.eval() decoder.eval() progress = progressbar.ProgressBar(max_value=epoch_size).start() - - + confusion_matrix = [[0, 0], [0, 0]] + # for i in range(50): + # if i % 2 == 0: + # x = next(training_batch_generator) + # else: + # x = next(training_batch_generator_implausible) for i in range(epoch_size): h_residual_var = torch.tensor(np.zeros(opt.n_future, dtype=np.float32), requires_grad=False, device=torch.device('cpu')) + scores = np.array([0 for i in range(opt.n_future)], dtype=np.float32) progress.update(i + 1) + is_implausible = False try: - x = next(training_batch_generator) + if i % 2 == 0: + x = next(training_batch_generator) + else: + x = next(training_batch_generator_implausible) + is_implausible = True + frames = [frame.cpu() for frame in x] except TypeError: print('got None at i = {}, terminating'.format(i)) @@ -289,8 +319,8 @@ def do_implasubility_test(h_residual_mean, h_residual_vars): else: h = h_posterior[j - 1][0].detach() # we predict h_t from h_{t-1} - h_prior_pred = frame_predictor(h.detach()) - h_posterior_pred = posterior(h_target.detach()) + z_t = posterior(h_target[0].detach()).detach() + z_t_hat = prior(h).detach() if j >= opt.n_past + start - 1: # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h @@ -300,12 +330,19 @@ def do_implasubility_test(h_residual_mean, h_residual_vars): # h_residual_var[j - opt.n_past] += squared_diff # residual = (h_prior_pred - h_posterior_pred).cpu().detach() - residual = (last_post_pred - h_posterior_pred).cpu().detach() - # err = (residual - h_residual_mean[j - opt.n_past]) + residual = (z_t - z_t_hat).cpu().detach().numpy() + # err = (residual - z_residual_mean[j - opt.n_past]) err = residual err = np.square(err) # [batch, dim_feature] - # err = err / h_residual_vars[j - opt.n_past][np.newaxis, ...] # [batch, dim_feature] / [1, dim_feature] - err = torch.mean(err, axis=1) # -> [batch,] + score = np.sum(err) + # score = np.matmul(err[:, np.newaxis, :], cov_inv[j - opt.n_past][np.newaxis, ...]) # B*1*D * 1*D*D + # score = np.matmul(score, err[:, :, np.newaxis]) # * B*1*D * B*D*1 -> B*1*1 + # score = numpy.nan_to_num(score) + # score /= opt.z_dim + + scores[j - opt.n_past] = np.sqrt(score) + # scores[j - opt.n_past] = score + # note: scores[t]^2 ~ Chi^2_df=z_dim so E[scores[t]^2] = z_dim # if len(err.shape) == 2: # [batch, dim_feature] # err = err[..., np.newaxis] # make err into a vector [batch, dim_feature, 1] # # print(cov_inv[j - opt.n_past].shape) @@ -323,74 +360,83 @@ def do_implasubility_test(h_residual_mean, h_residual_vars): # quit() # mahanlanobis_dist = torch.mean(mahanlanobis_dist) # scalar - # err = torch.mean(err, dim=1) # average errs at time j over the dimensions of h # err = torch.mean(err, dim=0) # average errs at time j over the batch # h_residual_var[j - opt.n_past] += err.detach() - h_residual_var[j - opt.n_past] += torch.mean(err, axis=0) - last_post_pred = h_posterior_pred.detach() - - h_residual_sd = torch.sqrt(h_residual_var).cpu() - - h_residual_sd_filtered = - h_residual_sd[:-2] + 2 * h_residual_sd[1:-1] - h_residual_sd[2:] + # h_residual_var[j - opt.n_past] += torch.mean(err, axis=0) + z_residual_scores_filtered = -0.25 * scores[:-2] + (0.5+0.5) * scores[1:-1] - 0.25 *scores[2:] # print(h_residual_var) - for j in range(len(frames)): - frame = np.uint8(np.minimum(frames[j][0][0], 1) * 255) - cv2.imshow('frame', frame) - cv2.waitKey(15) - fig = plt.figure() - plt.ylim(0, 2.0) - plt.xlabel('Time') - plt.title("sqrt of average squared dimensional error") - plt.bar(np.arange(len(h_residual_sd)), h_residual_sd) - fig.canvas.draw() - img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, - sep='') - img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) - img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - - plt.xlabel('Time') - plt.title("filtered sqrt of average squared dimensional error") - plt.bar(np.arange(len(h_residual_sd_filtered)), h_residual_sd_filtered) - fig.canvas.draw() - img2 = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, - sep='') - img2 = img2.reshape(fig.canvas.get_width_height()[::-1] + (3,)) - img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2BGR) - - cv2.imshow("plot", img) - cv2.imshow("plot2", img2) - - k = cv2.waitKey(0) - if k == ord('q'): - quit() + if visualize: + for j in range(len(frames)): + frame = np.uint8(np.minimum(frames[j][0][0], 1) * 255) + cv2.imshow('frame', frame) + cv2.waitKey(15) + + percentile = np.percentile(scores[76:152], 85.0) + thresh = percentile * 1 + 0.5 + spikes_idx = np.argwhere(z_residual_scores_filtered > thresh) + spikes_idx = spikes_idx[ + (spikes_idx >= 75) & (spikes_idx <= 150)] # ignore spikes near the start and end of video + spikes = z_residual_scores_filtered[spikes_idx] + msg = '' + if len(spikes_idx) > 0: + # we add n_past because the first n_past frames are not counted. Add 1 because of the filtering + msg = 'thresh {:.1f} IMPLAUSIBLE spikes: '.format(thresh) + str(['{:.2f}@{}'.format(z_residual_scores_filtered[k], k + opt.n_past + 1) for k in spikes_idx]) + confusion_matrix[is_implausible][1] += 1 + else: + max_idx = np.argmax(z_residual_scores_filtered[75:151]) + 75 + msg = 'thresh {:.1f} PLAUSIBLE max {:.2f}@{}'.format(thresh, z_residual_scores_filtered[max_idx], max_idx + opt.n_past + 1) + confusion_matrix[is_implausible][0] += 1 + + print(msg) + + if visualize: + fig = plt.figure() + # plt.ylim(0, 2.0) + plt.xlabel('Time') + plt.title("sqrt of average squared dimensional error") + plt.bar(np.arange(len(scores)), scores) + fig.canvas.draw() + img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, + sep='') + img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + plt.xlabel('Time') + plt.title("filtered sqrt of average squared dimensional error") + plt.bar(np.arange(len(z_residual_scores_filtered)), z_residual_scores_filtered) + fig.canvas.draw() + img2 = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, + sep='') + img2 = img2.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2BGR) + + cv2.imshow("plot", img) + cv2.putText(img2, msg, bottomLeftCornerOfText, font, fontScale, fontColor, thickness, cv2.LINE_AA) + cv2.imshow("plot2", img2) + + k = cv2.waitKey(0) + if k == ord('q'): + quit() # plt.savefig('implausibility_test.png') # h_residual_var /= epoch_size # get the mean error vector per time # h_residual_sd = torch.sqrt(h_residual_var) print('Last i = {}'.format(i)) - # H_err_cov = torch.tensor(np.zeros((opt.n_future, 128, 128), dtype=np.float32), requires_grad=False, # device=torch.device('cuda:0')) # plot some stuff - h_residual_mean_norm = torch.norm(h_residual_mean, dim=1).cpu() - plt.subplot(2, 1, 1) - plt.xlabel('Time') - plt.tight_layout() - plt.title("Norm of the residual mean") - plt.bar(np.arange(len(h_residual_mean_norm)), h_residual_mean_norm) - - plt.subplot(2, 1, 2) - plt.xlabel('Time') - plt.tight_layout() - plt.title("Average dimensional sqrt(variance) of the residual") - plt.bar(np.arange(len(h_residual_sd.cpu())), h_residual_sd.cpu()) - plt.savefig('h_residual_mean.png') - - -f = open('mcs_stats_post.json', 'r') -mcs_stats_dict = json.load(f) -do_implasubility_test(np.array(mcs_stats_dict['mean']), np.array(mcs_stats_dict['vars'])) + return confusion_matrix + + +# f = open('new_mcs_stats_post.json', 'r') +# mcs_stats_dict = json.load(f) +mcs_stats_dict = {} +with open('new_mcs_stats_post.npy', 'rb') as f: + mcs_stats_dict['mean'] = np.load(f) + mcs_stats_dict['cov'] = np.load(f) +conf_mat = do_implasubility_test(np.array(mcs_stats_dict['mean']), np.array(mcs_stats_dict['cov']), visualize=True) +print(conf_mat) diff --git a/do_mcs_stats_posterior.py b/do_mcs_stats_posterior.py index 5ccec1c..621696a 100644 --- a/do_mcs_stats_posterior.py +++ b/do_mcs_stats_posterior.py @@ -77,9 +77,10 @@ frame_predictor.batch_size = BATCH_SIZE posterior = saved_model['posterior'] posterior.batch_size = BATCH_SIZE + prior = saved_model['prior'] + prior.batch_size = BATCH_SIZE else: - frame_predictor = lstm_models.lstm(opt.g_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) - frame_predictor.apply(utils.init_weights) + raise ValueError("Please specify the model to load with the --model_dir argument") if opt.model == 'dcgan': if opt.image_width == 64: @@ -98,10 +99,7 @@ decoder = saved_model['decoder'] encoder = saved_model['encoder'] else: - encoder = model.encoder(opt.g_dim, opt.channels) - decoder = model.decoder(opt.g_dim, opt.channels) - encoder.apply(utils.init_weights) - decoder.apply(utils.init_weights) + raise ValueError("Please specify the model to load with the --model_dir argument") # --------- loss functions ------------------------------------ mse_criterion = nn.MSELoss() @@ -256,10 +254,12 @@ def do_stats(): epoch_size = 1000 // opt.batch_size // 1 # we have a total of 1000 videos frame_predictor.eval() + prior.eval() + posterior.eval() encoder.eval() decoder.eval() progress = progressbar.ProgressBar(max_value=epoch_size).start() - h_residual_mean = torch.tensor(np.zeros((opt.n_future, 128), dtype=np.float32), requires_grad=False, + z_residual_mean = torch.tensor(np.zeros((opt.n_future, opt.z_dim), dtype=np.float64), requires_grad=False, device=torch.device('cuda:0')) i = 0 for i in range(epoch_size): @@ -269,30 +269,33 @@ def do_stats(): except TypeError: print('got None at i = {}, terminating'.format(i)) break - h_posterior = [encoder(x[j]) for j in range(opt.n_past + opt.n_future)] # print(h_posterior[0][0].size()) - last_pred = None frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() posterior.hidden = posterior.init_hidden() + + last_h = encoder(x[0]) for j in range(1, opt.n_past + opt.n_future): - h_target = h_posterior[j][0].detach() + h_target = encoder(x[j]) if opt.last_frame_skip or j < opt.n_past: - h, skip = h_posterior[j - 1] + h, skip = last_h + h = h.detach() else: - h = h_posterior[j - 1][0].detach() + h = last_h[0].detach() # we predict h_t from h_{t-1} - h_prior_pred = frame_predictor(h).detach() - h_posterior_pred = posterior(h_target).detach() + z_t = posterior(h_target[0].detach()).detach() + z_t_hat = prior(h).detach() if j >= opt.n_past: # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h # h_res = h_prior_pred # h_res = torch.mean(h_res, dim=0) # average errors at the same time j over the batch - # h_residual_mean[j - opt.n_past] += h_res - residual = h_prior_pred - h_posterior_pred + # z_residual_mean[j - opt.n_past] += h_res + residual = z_t - z_t_hat residual = torch.mean(residual, dim=0) - h_residual_mean[j - opt.n_past] += residual - h_residual_mean /= epoch_size # get the mean error vector per time + z_residual_mean[j - opt.n_past] += residual + last_h = h_target + z_residual_mean /= epoch_size # get the mean error vector per time # restart training dataset global train_loader @@ -301,11 +304,11 @@ def do_stats(): batch_size=opt.batch_size, drop_last=True, pin_memory=True) - h_residual_var = torch.tensor(np.zeros(opt.n_future, dtype=np.float32), requires_grad=False, - device=torch.device('cuda:0')) - h_residual_vars = torch.tensor(np.zeros((opt.n_future, 128), dtype=np.float32), requires_grad=False, + + z_cov = torch.tensor(np.zeros((opt.n_future, opt.z_dim, opt.z_dim), dtype=np.float64), requires_grad=False, device=torch.device('cuda:0')) + for i in range(epoch_size): progress.update(i + 1) try: @@ -313,65 +316,70 @@ def do_stats(): except TypeError: print('got None at i = {}, terminating'.format(i)) break - h_posterior = [encoder(x[j]) for j in range(opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() posterior.hidden = posterior.init_hidden() + last_h = encoder(x[0]) for j in range(1, opt.n_past + opt.n_future): - h_target = h_posterior[j][0].detach() + h_target = encoder(x[j]) if opt.last_frame_skip or j < opt.n_past: - h, skip = h_posterior[j - 1] + h, skip = last_h + h = h.detach() else: - h = h_posterior[j - 1][0].detach() + h = last_h[0].detach() # we predict h_t from h_{t-1} - h_prior_pred = frame_predictor(h).detach() - h_posterior_pred = posterior(h_target).detach() + z_t = posterior(h_target[0].detach()).detach() + z_t_hat = prior(h).detach() if j >= opt.n_past: # h_res = h_prior_pred - h_posterior[j][0].detach() # predicted h minus observed h - # squared_diff = torch.square(h_res - h_residual_mean[j - opt.n_past]) + # squared_diff = torch.square(h_res - z_residual_mean[j - opt.n_past]) # squared_diff = torch.mean(squared_diff, dim=1) # average squared residuals at time j over the dimensions of h # squared_diff = torch.mean(squared_diff, dim=0) # average squared residuals at time j over the batch # h_residual_var[j - opt.n_past] += squared_diff - residual = h_prior_pred - h_posterior_pred - squared_err = torch.square(residual - h_residual_mean[j - opt.n_past]) - # squared_err = torch.square(residual) - squared_err = torch.mean(squared_err, dim=0) # average errs at time j over the batch - h_residual_vars[j - opt.n_past] += squared_err.detach() - squared_err = torch.mean(squared_err, dim=0) # average errs at time j over the dimensions of h - h_residual_var[j - opt.n_past] += squared_err.detach() - h_residual_var /= epoch_size - h_residual_vars /= epoch_size - h_residual_sd = torch.sqrt(h_residual_var) + residual = z_t - z_t_hat # B x D + # B x D x 1 * B x 1 x D -> B x D x D + sample_cov = torch.matmul(residual[:, :, np.newaxis], residual[:, np.newaxis, :]) + sample_cov = torch.mean(sample_cov, axis=0) # mean over batch dimension + z_cov[j - opt.n_past] += sample_cov + last_h = h_target + z_cov /= epoch_size + z_cov = z_cov.cpu().numpy() + z_sd = [np.sqrt(np.diag(cov)) for cov in z_cov] + z_sd = np.array(z_sd) print('Last i = {}'.format(i)) - print('sd of h residual: ', h_residual_sd) - print('var of h residual: ', h_residual_var) - print('norm(mean of h residual)', torch.norm(h_residual_mean, dim=1)) - print('vars of h residual: ', h_residual_vars) + print('sd of z residual: ', z_sd) + print('norm(mean of z residual)', torch.norm(z_residual_mean, dim=1)) # H_err_cov = torch.tensor(np.zeros((opt.n_future, 128, 128), dtype=np.float32), requires_grad=False, # device=torch.device('cuda:0')) # plot some stuff - h_residual_mean_norm = torch.norm(h_residual_mean, dim=1).cpu() - plt.subplot(2, 1, 1) + z_residual_mean_norm = torch.norm(z_residual_mean, dim=1).cpu() + plt.subplot(3, 1, 1) plt.xlabel('Time') plt.tight_layout() plt.title("Norm of the residual mean") - plt.bar(np.arange(len(h_residual_mean_norm)), h_residual_mean_norm) + plt.bar(np.arange(len(z_residual_mean_norm)), z_residual_mean_norm) - plt.subplot(2, 1, 2) + plt.subplot(3, 1, 2) plt.xlabel('Time') plt.tight_layout() plt.title("Average dimensional sqrt(variance) of the residual") - plt.bar(np.arange(len(h_residual_sd.cpu())), h_residual_sd.cpu()) - plt.savefig('post_h_residual.png') - - stats_dict = {'mean': h_residual_mean.cpu().tolist(), 'var': h_residual_var.cpu().tolist(), - 'vars': h_residual_vars.cpu().tolist()} - f = open('mcs_stats_post.json', 'w') - json.dump(stats_dict, f) + plt.bar(np.arange(len(z_sd)), np.mean(z_sd, axis=1)) + plt.savefig('z_residual.png') + + stats_dict = {'mean': z_residual_mean.cpu().numpy(), 'cov': z_cov} + print(stats_dict['mean'].dtype) + print(stats_dict['cov'].dtype) + # f = open('new_mcs_stats_post.json', 'w') + # json.dump(stats_dict, f) + with open('new_mcs_stats_post.npy', 'wb') as f: + np.save(f, stats_dict['mean']) + np.save(f, stats_dict['cov']) do_stats() diff --git a/train_svg_nonstochastic_posterior.py b/train_svg_nonstochastic_posterior.py index 1dce6a7..20b994f 100644 --- a/train_svg_nonstochastic_posterior.py +++ b/train_svg_nonstochastic_posterior.py @@ -14,9 +14,9 @@ import numpy as np parser = argparse.ArgumentParser() -parser.add_argument('--lr', default=0.008, type=float, help='learning rate') +parser.add_argument('--lr', default=0.004, type=float, help='learning rate') parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') -parser.add_argument('--batch_size', default=128, type=int, help='batch size') +parser.add_argument('--batch_size', default=22, type=int, help='batch size') parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') parser.add_argument('--model_dir', default='', help='base directory to save logs') parser.add_argument('--name', default='', help='identifier for directory') @@ -24,22 +24,23 @@ parser.add_argument('--optimizer', default='adam', help='optimizer to train with') parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') parser.add_argument('--seed', default=1, type=int, help='manual seed') -parser.add_argument('--epoch_size', type=int, default=400, help='epoch size') +parser.add_argument('--epoch_size', type=int, default=500, help='epoch size') parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') -parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--channels', default=3, type=int, help='number of channels for input images. ') parser.add_argument('--dataset', default='mcs', help='dataset to train with') parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') parser.add_argument('--n_future', type=int, default=15, help='number of frames to predict') parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--prior_rnn_layers', type=int, default=1, help='number of layers') parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') -parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--z_dim', type=int, default=32, help='dimensionality of z_t') parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') -parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--model', default='vgg', help='model type (dcgan | vgg)') parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') @@ -63,7 +64,7 @@ opt.lr = lr opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) else: - name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + name = 'NEWmodel=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) if opt.dataset == 'smmnist': opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) elif opt.dataset == 'mcs': @@ -98,12 +99,15 @@ import models.lstm as lstm_models if opt.model_dir != '': frame_predictor = saved_model['frame_predictor'] + prior = saved_model['prior'] posterior = saved_model['posterior'] else: - frame_predictor = lstm_models.lstm(opt.g_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + frame_predictor = lstm_models.lstm(opt.g_dim + opt.z_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + posterior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.posterior_rnn_layers, opt.batch_size) + prior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.prior_rnn_layers, opt.batch_size) frame_predictor.apply(utils.init_weights) - posterior = lstm_models.lstm(opt.g_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) posterior.apply(utils.init_weights) + prior.apply(utils.init_weights) if opt.model == 'dcgan': if opt.image_width == 64: @@ -123,12 +127,13 @@ encoder = saved_model['encoder'] else: encoder = model.encoder(opt.g_dim, opt.channels) - decoder = model.decoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, 1) encoder.apply(utils.init_weights) decoder.apply(utils.init_weights) frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) @@ -144,6 +149,7 @@ def kl_criterion(mu, logvar): # --------- transfer to gpu ------------------------------------ frame_predictor.cuda() posterior.cuda() +prior.cuda() encoder.cuda() decoder.cuda() mse_criterion.cuda() @@ -185,10 +191,11 @@ def plot(x, epoch): gen_seq = [[] for _ in range(nsample)] gt_seq = [x[i] for i in range(len(x))] - h_seq = [encoder(x[i]) for i in range(opt.n_past)] + # h_seq = [encoder(x[i]) for i in range(opt.n_past)] for s in range(nsample): frame_predictor.hidden = frame_predictor.init_hidden() posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() gen_seq[s].append(x[0]) x_in = x[0] for i in range(1, opt.n_eval): @@ -197,17 +204,23 @@ def plot(x, epoch): if opt.last_frame_skip or i < opt.n_past: h, skip = h else: - h, _ = h_seq[i-1] + h, _ = h h = h.detach() if i < opt.n_past: - frame_predictor(h) + h_target = encoder(x[i]) + h_target = h_target[0].detach() + z_t = posterior(h_target) + prior(h) + frame_predictor(torch.cat([h, z_t], 1)) x_in = x[i] gen_seq[s].append(x_in) else: - h = frame_predictor(h).detach() + z_t_hat = prior(h) + h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() x_in = decoder([h, skip]) gen_seq[s].append(x_in) + to_plot = [] gifs = [ [] for t in range(opt.n_eval) ] nrow = min(opt.batch_size, 25) @@ -238,42 +251,63 @@ def plot(x, epoch): def plot_rec(x, epoch): + gen_seq = [utils.torch_rgb_img_to_gray(x[0])] + gen_seq_post = [utils.torch_rgb_img_to_gray(x[0])] + + # prediction using posterior Z frame_predictor.hidden = frame_predictor.init_hidden() posterior.hidden = posterior.init_hidden() - - gen_seq = [] - gen_seq.append(x[0]) - gen_seq_post = [] - gen_seq_post.append(x[0]) - x_in = x[0] - h_seq = [encoder(x[i]) for i in range(opt.n_past+opt.n_future)] + h = encoder(x[0]) + h = (h[0].detach(), h[1].detach()) for i in range(1, opt.n_past+opt.n_future): - h_target = h_seq[i][0].detach() + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1].detach()) if opt.last_frame_skip or i < opt.n_past: - h, skip = h_seq[i-1] + h, skip = h + else: + h, _ = h + z_t = posterior(h_target) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t], 1)) + gen_seq_post.append(x[i]) + else: + h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq_post.append(x_pred) + h = h_target + + # prediction using prior Z + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1].detach()) + for i in range(1, opt.n_past+opt.n_future): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1].detach()) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h else: - h, _ = h_seq[i-1] + h, _ = h h = h.detach() + z_t_hat = prior(h) if i < opt.n_past: - frame_predictor(h) + frame_predictor(torch.cat([h, z_t_hat], 1)) gen_seq.append(x[i]) - gen_seq_post.append(x[i]) else: - h_pred = frame_predictor(h).detach() + h_pred = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() x_pred = decoder([h_pred, skip]).detach() - h_posterior = posterior(h_target).detach() - x_posterior = decoder([h_posterior, skip]).detach() gen_seq.append(x_pred) - gen_seq_post.append(x_posterior) + h = h_target to_plot = [] - nrow = min(opt.batch_size * 3, 25 * 3) + nrow = min(opt.batch_size, 25) + x_gray = [utils.torch_rgb_img_to_gray(time) for time in range(opt.n_past+opt.n_future)] for i in range(nrow): row_gt = [] row_post = [] row = [] for t in range(opt.n_past+opt.n_future): - row_gt.append(x[t][i]) + row_gt.append(x_gray[t][i]) row_post.append(gen_seq_post[t][i]) row.append(gen_seq[t][i]) to_plot.append(row_gt) @@ -287,41 +321,47 @@ def plot_rec(x, epoch): def train(x): frame_predictor.zero_grad() posterior.zero_grad() + prior.zero_grad() encoder.zero_grad() decoder.zero_grad() # initialize the hidden state. frame_predictor.hidden = frame_predictor.init_hidden() posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() - h_seq = [encoder(x[i]) for i in range(opt.n_past+opt.n_future)] mse = 0 - mse_post = 0 - mse_diff_post = 0 + mse_residual = 0 + h = encoder(x[0]) for i in range(1, opt.n_past+opt.n_future): - h_target = h_seq[i][0] - if opt.last_frame_skip or i < opt.n_past: - h, skip = h_seq[i-1] + h_target = encoder(x[i]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h else: - h = h_seq[i-1][0] - h_pred = frame_predictor(h) + h = h[0] + + z_t = posterior(h_target[0]) + z_t_hat = prior(h) + h_pred = frame_predictor(torch.cat([h, z_t], 1)) x_pred = decoder([h_pred, skip]) - h_posterior = posterior(h_target) - x_posterior = decoder([h_posterior, skip]) - mse += mse_criterion(x_pred, x[i]) - mse_post += mse_criterion(x_posterior, x[i]) - mse_diff_post += opt.gamma * torch.mean(torch.square(h_posterior.detach() - h_pred)) - loss = mse + mse_post + mse_diff_post + gray_target_frame = utils.torch_rgb_img_to_gray(x[i]) + mse += mse_criterion(x_pred, gray_target_frame) + # penalize prior for being far from posterior + mse_residual += opt.gamma * torch.mean(torch.square(z_t.detach() - z_t_hat)) + h = h_target + + loss = mse + mse_residual loss.backward() frame_predictor_optimizer.step() posterior_optimizer.step() + prior_optimizer.step() encoder_optimizer.step() decoder_optimizer.step() N = opt.n_past+opt.n_future - return mse.data.cpu().numpy()/N, mse_post.data.cpu().numpy()/N, mse_diff_post.data.cpu().numpy()/N + return mse.data.cpu().numpy()/N, mse_residual.data.cpu().numpy()/N # --------- training loop ------------------------------------ for epoch in range(opt.niter): @@ -330,8 +370,7 @@ def train(x): encoder.train() decoder.train() epoch_mse = 0 - epoch_mse_posterior = 0 - epoch_posterior_diff = 0 + epoch_mse_residual = 0 progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() # opt.epoch_size = 10 for i in range(opt.epoch_size): @@ -339,20 +378,20 @@ def train(x): x = next(training_batch_generator) # train frame_predictor - mse, mse_posterior, posterior_diff = train(x) + mse, mse_residual = train(x) epoch_mse += mse - epoch_mse_posterior += mse_posterior - epoch_posterior_diff += posterior_diff + epoch_mse_residual += mse_residual progress.finish() utils.clear_progressbar() - print('[%02d] mse loss: %.5f, %.5f posterior | posterior diff loss: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_posterior/opt.epoch_size, epoch_posterior_diff/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) # plot some stuff frame_predictor.eval() posterior.eval() + prior.eval() encoder.eval() decoder.eval() x = next(testing_batch_generator) @@ -365,6 +404,7 @@ def train(x): 'decoder': decoder, 'frame_predictor': frame_predictor, 'posterior': posterior, + 'prior': posterior, 'opt': opt}, '%s/model_e%02d.pth' % (opt.log_dir, epoch)) if epoch % 10 == 0: diff --git a/utils.py b/utils.py index 371b9ab..6f6de51 100755 --- a/utils.py +++ b/utils.py @@ -24,6 +24,14 @@ hostname = socket.gethostname() +RGB_weights = torch.tensor(np.array([0.299, 0.587, 0.114]), dtype=torch.float32, device=torch.device('cuda:0')) + +def torch_rgb_img_to_gray(tensor): + # in: Bx3xHxW out: Bx1xHxW + tensor = torch.transpose(tensor, 1, 3) # B x W x H x 3 + tensor = torch.unsqueeze(torch.matmul(tensor, RGB_weights), -1) # B x W x H x 1 + tensor = torch.transpose(tensor, 3, 1) # B x 1 x H x W + return tensor def torch_tensor_to_img(tensor): image_array = tensor.numpy() @@ -101,6 +109,26 @@ def load_dataset(opt, sequential=None, implausible=None): task=opt.mcs_task, sequential=sequential, implausible=implausible) + elif opt.dataset == 'mcs_test': + from data.mcs import MCS + train_data = MCS( + train=True, + data_root=opt.data_root, + seq_len=opt.n_past + opt.n_future, + image_size=opt.image_width, + task=opt.mcs_task, + sequential=sequential, + implausible=implausible, + test_set=True) + test_data = MCS( + train=False, + data_root=opt.data_root, + seq_len=opt.n_eval, + image_size=opt.image_width, + task=opt.mcs_task, + sequential=sequential, + implausible=implausible, + test_set=True) return train_data, test_data From 8d455094d1246381d81546f60cb005afd2ab8583 Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Tue, 27 Jul 2021 19:39:08 -0400 Subject: [PATCH 06/26] gravity update --- data/mcs.py | 45 +++++- do_mcs_implausblility_test_posterior.py | 10 +- train_svg_nonstochastic_posterior.py | 59 ++++--- utils.py | 203 ++++++++++++++---------- 4 files changed, 201 insertions(+), 116 deletions(-) diff --git a/data/mcs.py b/data/mcs.py index 89d12f9..c00d80d 100644 --- a/data/mcs.py +++ b/data/mcs.py @@ -1,6 +1,8 @@ import logging import random import os + +import cv2 import numpy as np from glob import glob import torch @@ -11,15 +13,21 @@ class MCS(object): - def __init__(self, train, data_root, seq_len=20, image_size=64, task='ALL', sequential=None, implausible=False, test_set=False): + def __init__(self, train, data_root, seq_len=20, image_size=64, task='ALL', sequential=None, implausible=False, + test_set=False, im_channels=1, use_edge_kernels=True): # if implausible is set to True, generates "fake" images by cutting out or repeating frames self.implausible = implausible self.data_root = '%s/mcs_videos_1000/processed/' % data_root - self.data_root = '%s/mcs_videos_test/processed/' % data_root + # self.data_root = '%s/mcs_videos_test/processed/' % data_root if not os.path.exists(self.data_root): raise os.error('data/mcs.py: Data directory not found!') self.seq_len = seq_len self.image_size = image_size + self.im_channels = im_channels + if use_edge_kernels: + if im_channels != 1: + raise AssertionError('Using edge kernels implies the output images are grayscale! Set im_channels to 1!') + self.use_edge_kernels = use_edge_kernels # print('mcs.py: found tasks ', self.tasks) self.video_folder = {} @@ -60,9 +68,34 @@ def get_sequence(self, idx=None): for i in range(start, start + self.seq_len): # i is 0-indexed so we need to add 1 to i fname = frame_path + f'{i + 1:04d}.png' - im = imageio.imread(fname) / 255. - # gray = lambda rgb: np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) - # im = gray(im)[..., np.newaxis] + im = imageio.imread(fname) / np.float32(255.) + if self.im_channels == 1: # convert to grayscale if specified + if not self.use_edge_kernels: + # regular grayscale conversion + gray = lambda rgb: np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) + im = gray(im)[..., np.newaxis] + else: + edge_map = np.zeros((self.image_size, self.image_size), dtype=np.float32) + ddepth = cv2.CV_32F + scale = 1 + delta = 0 + for channel in range(3): + color = im[..., channel] + # grad_x = cv2.Sobel(color, ddepth, 1, 0, ksize=5, scale=scale, delta=delta, + # borderType=cv2.BORDER_DEFAULT) + # grad_y = cv2.Sobel(color, ddepth, 0, 1, ksize=5, scale=scale, delta=delta, + # borderType=cv2.BORDER_DEFAULT) + grad_x = cv2.Scharr(color, ddepth, 1, 0, scale=scale, delta=delta, + borderType=cv2.BORDER_DEFAULT) + grad_y = cv2.Scharr(color, ddepth, 0, 1, scale=scale, delta=delta, + borderType=cv2.BORDER_DEFAULT) + # abs_grad_x = np.abs(grad_x) + # abs_grad_y = np.abs(grad_y) + edge_map += np.sqrt(grad_x**2 + grad_y**2) + edge_map /= 3*2 # 3 channels and two directions per channel + edge_map /= 6 # to reduce magnitude + im = edge_map[..., np.newaxis] + seq.append(im) return np.array(seq) @@ -73,7 +106,7 @@ def abnormalize_sequence(self, seq): """ implausibility_type = random.randint(1, 3) # start = random.randint(100, 140) - start = 110 + start = 115 vid_len = len(seq) duration = 7 implausibility_type = 1 diff --git a/do_mcs_implausblility_test_posterior.py b/do_mcs_implausblility_test_posterior.py index d477d72..de00033 100644 --- a/do_mcs_implausblility_test_posterior.py +++ b/do_mcs_implausblility_test_posterior.py @@ -30,6 +30,7 @@ parser.add_argument('--epoch_size', type=int, default=1000, help='epoch size') parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--use_edge_kernels', action='store_true') parser.add_argument('--dataset', default='mcs_test', help='dataset to train with') parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') @@ -303,6 +304,9 @@ def do_implasubility_test(z_residual_mean, z_residual_cov, visualize=True): is_implausible = True frames = [frame.cpu() for frame in x] + # print(frames[0][0].numpy().dtype) + # print(frames[0][0]) + # quit() except TypeError: print('got None at i = {}, terminating'.format(i)) break @@ -370,8 +374,10 @@ def do_implasubility_test(z_residual_mean, z_residual_cov, visualize=True): # print(h_residual_var) if visualize: for j in range(len(frames)): - frame = np.uint8(np.minimum(frames[j][0][0], 1) * 255) - cv2.imshow('frame', frame) + frame_cv2 = frames[j][0][0].numpy() + frame_cv2 /= 3 + frame_cv2 = np.uint8(np.minimum(frame_cv2, 1.0) * 255.) + cv2.imshow('frame', frame_cv2) cv2.waitKey(15) percentile = np.percentile(scores[76:152], 85.0) diff --git a/train_svg_nonstochastic_posterior.py b/train_svg_nonstochastic_posterior.py index 20b994f..f35b9c2 100644 --- a/train_svg_nonstochastic_posterior.py +++ b/train_svg_nonstochastic_posterior.py @@ -12,11 +12,12 @@ import itertools import progressbar import numpy as np +import json parser = argparse.ArgumentParser() -parser.add_argument('--lr', default=0.004, type=float, help='learning rate') +parser.add_argument('--lr', default=0.0001, type=float, help='learning rate') parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') -parser.add_argument('--batch_size', default=22, type=int, help='batch size') +parser.add_argument('--batch_size', default=24, type=int, help='batch size') parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') parser.add_argument('--model_dir', default='', help='base directory to save logs') parser.add_argument('--name', default='', help='identifier for directory') @@ -27,6 +28,7 @@ parser.add_argument('--epoch_size', type=int, default=500, help='epoch size') parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') parser.add_argument('--channels', default=3, type=int, help='number of channels for input images. ') +parser.add_argument('--use_edge_kernels', action='store_true') parser.add_argument('--dataset', default='mcs', help='dataset to train with') parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') @@ -57,14 +59,20 @@ model_dir = opt.model_dir niter = opt.niter lr = opt.lr + batch_size = opt.batch_size + n_future = opt.n_future + n_eval = opt.n_eval opt = saved_model['opt'] opt.niter = niter # update number of epochs to train for opt.optimizer = optimizer opt.model_dir = model_dir + opt.n_future = n_future opt.lr = lr + opt.batch_size = batch_size + opt.n_eval = n_eval opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) else: - name = 'NEWmodel=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + name = 'RGBmodel=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) if opt.dataset == 'smmnist': opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) elif opt.dataset == 'mcs': @@ -72,8 +80,10 @@ else: opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) -os.makedirs('%s/gen/' % opt.log_dir, exist_ok=True) -os.makedirs('%s/plots/' % opt.log_dir, exist_ok=True) +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=False) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=False) +with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: + json.dump(opt.__dict__, f, indent=2) print("Random Seed: ", opt.seed) random.seed(opt.seed) @@ -99,8 +109,11 @@ import models.lstm as lstm_models if opt.model_dir != '': frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = opt.batch_size prior = saved_model['prior'] + prior.batch_size = opt.batch_size posterior = saved_model['posterior'] + posterior.batch_size = opt.batch_size else: frame_predictor = lstm_models.lstm(opt.g_dim + opt.z_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) posterior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.posterior_rnn_layers, opt.batch_size) @@ -162,7 +175,7 @@ def kl_criterion(mu, logvar): batch_size=opt.batch_size, shuffle=True, drop_last=True, - pin_memory=True) + pin_memory=True,) test_loader = DataLoader(test_data, num_workers=opt.data_threads, batch_size=opt.batch_size, @@ -189,18 +202,22 @@ def plot(x, epoch): nsample = 1 gen_seq = [[] for _ in range(nsample)] - gt_seq = [x[i] for i in range(len(x))] + gt_seq = [utils.torch_rgb_img_to_gray(x[t]) for t in range(len(x))] # h_seq = [encoder(x[i]) for i in range(opt.n_past)] for s in range(nsample): frame_predictor.hidden = frame_predictor.init_hidden() posterior.hidden = posterior.init_hidden() prior.hidden = prior.init_hidden() - gen_seq[s].append(x[0]) + gen_seq[s].append(utils.torch_rgb_img_to_gray(x[0])) x_in = x[0] for i in range(1, opt.n_eval): with torch.no_grad(): - h = encoder(x_in) + # if input is grayscale + if x_in.shape[1] == 1 and opt.channels == 3: + h = encoder(torch.cat(3*[x_in], dim=1)) # convert to RGB + else: + h = encoder(x_in) if opt.last_frame_skip or i < opt.n_past: h, skip = h else: @@ -208,12 +225,11 @@ def plot(x, epoch): h = h.detach() if i < opt.n_past: h_target = encoder(x[i]) - h_target = h_target[0].detach() - z_t = posterior(h_target) + z_t = posterior(h_target[0].detach()) prior(h) frame_predictor(torch.cat([h, z_t], 1)) x_in = x[i] - gen_seq[s].append(x_in) + gen_seq[s].append(utils.torch_rgb_img_to_gray(x_in)) else: z_t_hat = prior(h) h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() @@ -234,7 +250,7 @@ def plot(x, epoch): for s in range(nsample): row = [] for t in range(opt.n_eval): - row.append(gen_seq[s][t][i]) + row.append(gen_seq[s][t][i]) to_plot.append(row) for t in range(opt.n_eval): row = [] @@ -258,18 +274,18 @@ def plot_rec(x, epoch): frame_predictor.hidden = frame_predictor.init_hidden() posterior.hidden = posterior.init_hidden() h = encoder(x[0]) - h = (h[0].detach(), h[1].detach()) + h = (h[0].detach(), h[1]) for i in range(1, opt.n_past+opt.n_future): h_target = encoder(x[i]) - h_target = (h_target[0].detach(), h_target[1].detach()) + h_target = (h_target[0].detach(), h_target[1]) if opt.last_frame_skip or i < opt.n_past: h, skip = h else: h, _ = h - z_t = posterior(h_target) + z_t = posterior(h_target[0]) if i < opt.n_past: frame_predictor(torch.cat([h, z_t], 1)) - gen_seq_post.append(x[i]) + gen_seq_post.append(utils.torch_rgb_img_to_gray(x[i])) else: h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() x_pred = decoder([h_pred, skip]).detach() @@ -280,10 +296,10 @@ def plot_rec(x, epoch): frame_predictor.hidden = frame_predictor.init_hidden() prior.hidden = prior.init_hidden() h = encoder(x[0]) - h = (h[0].detach(), h[1].detach()) + h = (h[0].detach(), h[1]) for i in range(1, opt.n_past+opt.n_future): h_target = encoder(x[i]) - h_target = (h_target[0].detach(), h_target[1].detach()) + h_target = (h_target[0].detach(), h_target[1]) if opt.last_frame_skip or i < opt.n_past: h, skip = h else: @@ -292,7 +308,7 @@ def plot_rec(x, epoch): z_t_hat = prior(h) if i < opt.n_past: frame_predictor(torch.cat([h, z_t_hat], 1)) - gen_seq.append(x[i]) + gen_seq.append(utils.torch_rgb_img_to_gray(x[i])) else: h_pred = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() x_pred = decoder([h_pred, skip]).detach() @@ -301,7 +317,7 @@ def plot_rec(x, epoch): to_plot = [] nrow = min(opt.batch_size, 25) - x_gray = [utils.torch_rgb_img_to_gray(time) for time in range(opt.n_past+opt.n_future)] + x_gray = [utils.torch_rgb_img_to_gray(x[t]) for t in range(opt.n_past+opt.n_future)] for i in range(nrow): row_gt = [] row_post = [] @@ -344,7 +360,6 @@ def train(x): z_t_hat = prior(h) h_pred = frame_predictor(torch.cat([h, z_t], 1)) x_pred = decoder([h_pred, skip]) - gray_target_frame = utils.torch_rgb_img_to_gray(x[i]) mse += mse_criterion(x_pred, gray_target_frame) # penalize prior for being far from posterior diff --git a/utils.py b/utils.py index 6f6de51..22f704c 100755 --- a/utils.py +++ b/utils.py @@ -7,6 +7,7 @@ from sklearn.manifold import TSNE import scipy.misc import matplotlib + matplotlib.use('agg') import matplotlib.pyplot as plt import functools @@ -16,23 +17,27 @@ from scipy import ndimage from PIL import Image, ImageDraw - from torchvision import datasets, transforms from torch.autograd import Variable import imageio - hostname = socket.gethostname() -RGB_weights = torch.tensor(np.array([0.299, 0.587, 0.114]), dtype=torch.float32, device=torch.device('cuda:0')) +RGB_weights = torch.tensor(np.array([0.299, 0.587, 0.114]), dtype=torch.float32, device=torch.device('cuda:0'), + requires_grad=False).detach() + def torch_rgb_img_to_gray(tensor): # in: Bx3xHxW out: Bx1xHxW + if tensor.shape[1] == 1: + return tensor + # assert tensor.shape[1] == 3 # make sure input image has 3 (RGB) channels tensor = torch.transpose(tensor, 1, 3) # B x W x H x 3 tensor = torch.unsqueeze(torch.matmul(tensor, RGB_weights), -1) # B x W x H x 1 tensor = torch.transpose(tensor, 3, 1) # B x 1 x H x W return tensor + def torch_tensor_to_img(tensor): image_array = tensor.numpy() image_array -= np.min(image_array) @@ -54,61 +59,65 @@ def load_dataset(opt, sequential=None, implausible=None): if opt.dataset == 'smmnist': from data.moving_mnist import MovingMNIST train_data = MovingMNIST( - train=True, - data_root=opt.data_root, - seq_len=opt.n_past+opt.n_future, - image_size=opt.image_width, - deterministic=False, - num_digits=opt.num_digits) + train=True, + data_root=opt.data_root, + seq_len=opt.n_past + opt.n_future, + image_size=opt.image_width, + deterministic=False, + num_digits=opt.num_digits) test_data = MovingMNIST( - train=False, - data_root=opt.data_root, - seq_len=opt.n_eval, - image_size=opt.image_width, - deterministic=False, - num_digits=opt.num_digits) + train=False, + data_root=opt.data_root, + seq_len=opt.n_eval, + image_size=opt.image_width, + deterministic=False, + num_digits=opt.num_digits) elif opt.dataset == 'bair': - from data.bair import RobotPush + from data.bair import RobotPush train_data = RobotPush( - data_root=opt.data_root, - train=True, - seq_len=opt.n_past+opt.n_future, - image_size=opt.image_width) + data_root=opt.data_root, + train=True, + seq_len=opt.n_past + opt.n_future, + image_size=opt.image_width) test_data = RobotPush( - data_root=opt.data_root, - train=False, - seq_len=opt.n_eval, - image_size=opt.image_width) + data_root=opt.data_root, + train=False, + seq_len=opt.n_eval, + image_size=opt.image_width) elif opt.dataset == 'kth': - from data.kth import KTH + from data.kth import KTH train_data = KTH( - train=True, - data_root=opt.data_root, - seq_len=opt.n_past+opt.n_future, - image_size=opt.image_width) + train=True, + data_root=opt.data_root, + seq_len=opt.n_past + opt.n_future, + image_size=opt.image_width) test_data = KTH( - train=False, - data_root=opt.data_root, - seq_len=opt.n_eval, - image_size=opt.image_width) + train=False, + data_root=opt.data_root, + seq_len=opt.n_eval, + image_size=opt.image_width) elif opt.dataset == 'mcs': from data.mcs import MCS train_data = MCS( - train=True, - data_root=opt.data_root, - seq_len=opt.n_past+opt.n_future, - image_size=opt.image_width, - task=opt.mcs_task, - sequential=sequential, - implausible=implausible) + train=True, + data_root=opt.data_root, + seq_len=opt.n_past + opt.n_future, + image_size=opt.image_width, + task=opt.mcs_task, + sequential=sequential, + implausible=implausible, + im_channels=opt.channels, + use_edge_kernels=opt.use_edge_kernels) test_data = MCS( - train=False, - data_root=opt.data_root, - seq_len=opt.n_eval, - image_size=opt.image_width, - task=opt.mcs_task, - sequential=sequential, - implausible=implausible) + train=False, + data_root=opt.data_root, + seq_len=opt.n_eval, + image_size=opt.image_width, + task=opt.mcs_task, + sequential=sequential, + implausible=implausible, + im_channels=opt.channels, + use_edge_kernels=opt.use_edge_kernels) elif opt.dataset == 'mcs_test': from data.mcs import MCS train_data = MCS( @@ -119,7 +128,9 @@ def load_dataset(opt, sequential=None, implausible=None): task=opt.mcs_task, sequential=sequential, implausible=implausible, - test_set=True) + test_set=True, + im_channels=opt.channels, + use_edge_kernels=opt.use_edge_kernels) test_data = MCS( train=False, data_root=opt.data_root, @@ -128,13 +139,17 @@ def load_dataset(opt, sequential=None, implausible=None): task=opt.mcs_task, sequential=sequential, implausible=implausible, - test_set=True) - + test_set=True, + im_channels=opt.channels, + use_edge_kernels=opt.use_edge_kernels) + return train_data, test_data + def sequence_input(seq, dtype): return [Variable(x.type(dtype)) for x in seq] + def normalize_data(opt, dtype, sequence): if opt.dataset == 'smmnist' or opt.dataset == 'kth' or opt.dataset == 'bair' or opt.dataset == 'mcs': sequence.transpose_(0, 1) @@ -144,12 +159,14 @@ def normalize_data(opt, dtype, sequence): return sequence_input(sequence, dtype) + def is_sequence(arg): return (not hasattr(arg, "strip") and not type(arg) is np.ndarray and not hasattr(arg, "dot") and (hasattr(arg, "__getitem__") or - hasattr(arg, "__iter__"))) + hasattr(arg, "__iter__"))) + def image_tensor(inputs, padding=1): # assert is_sequence(inputs) @@ -169,11 +186,11 @@ def image_tensor(inputs, padding=1): y_dim = images[0].size(1) result = torch.ones(c_dim, - x_dim * len(images) + padding * (len(images)-1), + x_dim * len(images) + padding * (len(images) - 1), y_dim) for i, image in enumerate(images): - result[:, i * x_dim + i * padding : - (i+1) * x_dim + i * padding, :].copy_(image) + result[:, i * x_dim + i * padding: + (i + 1) * x_dim + i * padding, :].copy_(image) return result @@ -193,18 +210,20 @@ def image_tensor(inputs, padding=1): result = torch.ones(c_dim, x_dim, - y_dim * len(images) + padding * (len(images)-1)) + y_dim * len(images) + padding * (len(images) - 1)) for i, image in enumerate(images): - result[:, :, i * y_dim + i * padding : - (i+1) * y_dim + i * padding].copy_(image) + result[:, :, i * y_dim + i * padding: + (i + 1) * y_dim + i * padding].copy_(image) return result + def save_np_img(fname, x): if x.shape[0] == 1: x = np.tile(x, (3, 1, 1)) img = torch_tensor_to_img(x) img.save(fname) + def make_image(tensor): tensor = tensor.cpu().clamp(0, 1) if tensor.size(0) == 1: @@ -212,46 +231,54 @@ def make_image(tensor): # pdb.set_trace() return torch_tensor_to_img(tensor) + def draw_text_tensor(tensor, text): np_x = tensor.transpose(0, 1).transpose(1, 2).data.cpu().numpy() - pil = Image.fromarray(np.uint8(np_x*255)) + pil = Image.fromarray(np.uint8(np_x * 255)) draw = ImageDraw.Draw(pil) - draw.text((4, 64), text, (0,0,0)) + draw.text((4, 64), text, (0, 0, 0)) img = np.asarray(pil) return Variable(torch.Tensor(img / 255.)).transpose(1, 2).transpose(0, 1) + def save_gif(filename, inputs, duration=0.25): images = [] for tensor in inputs: img = image_tensor(tensor, padding=0) img = img.cpu() - img = img.transpose(0,1).transpose(1,2).clamp(0,1) - images.append((img.numpy()*255).astype(np.uint8)) + img = img.transpose(0, 1).transpose(1, 2).clamp(0, 1) + images.append((img.numpy() * 255).astype(np.uint8)) imageio.mimsave(filename, images, duration=duration) + def save_gif_with_text(filename, inputs, text, duration=0.25): images = [] for tensor, text in zip(inputs, text): img = image_tensor([draw_text_tensor(ti, texti) for ti, texti in zip(tensor, text)], padding=0) img = img.cpu() - img = img.transpose(0,1).transpose(1,2).clamp(0,1).numpy() + img = img.transpose(0, 1).transpose(1, 2).clamp(0, 1).numpy() images.append(img) imageio.mimsave(filename, images, duration=duration) + def save_image(filename, tensor): img = make_image(tensor) img.save(filename) + def save_tensors_image(filename, inputs, padding=1): images = image_tensor(inputs, padding) return save_image(filename, images) + def prod(l): return functools.reduce(lambda x, y: x * y, l) + def batch_flatten(x): return x.resize(x.size(0), prod(x.size()[1:])) + def clear_progressbar(): # moves up 3 lines print("\033[2A") @@ -260,11 +287,13 @@ def clear_progressbar(): # moves up two lines again print("\033[2A") + def mse_metric(x1, x2): err = np.sum((x1 - x2) ** 2) err /= float(x1.shape[0] * x1.shape[1] * x1.shape[2]) return err + def eval_seq(gt, pred): T = len(gt) bs = gt[0].shape[0] @@ -282,6 +311,7 @@ def eval_seq(gt, pred): return mse, ssim, psnr + # ssim function used in Babaeizadeh et al. (2017), Fin et al. (2016), etc. def finn_eval_seq(gt, pred): T = len(gt) @@ -306,21 +336,23 @@ def finn_eval_seq(gt, pred): def finn_psnr(x, y): - mse = ((x - y)**2).mean() - return 10*np.log(1/mse)/np.log(10) + mse = ((x - y) ** 2).mean() + return 10 * np.log(1 / mse) / np.log(10) def gaussian2(size, sigma): - A = 1/(2.0*np.pi*sigma**2) - x, y = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] - g = A*np.exp(-((x**2/(2.0*sigma**2))+(y**2/(2.0*sigma**2)))) + A = 1 / (2.0 * np.pi * sigma ** 2) + x, y = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1] + g = A * np.exp(-((x ** 2 / (2.0 * sigma ** 2)) + (y ** 2 / (2.0 * sigma ** 2)))) return g + def fspecial_gauss(size, sigma): - x, y = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] - g = np.exp(-((x**2 + y**2)/(2.0*sigma**2))) - return g/g.sum() - + x, y = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1] + g = np.exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2))) + return g / g.sum() + + def finn_ssim(img1, img2, cs_map=False): img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) @@ -329,24 +361,24 @@ def finn_ssim(img1, img2, cs_map=False): window = fspecial_gauss(size, sigma) K1 = 0.01 K2 = 0.03 - L = 1 #bitdepth of image - C1 = (K1*L)**2 - C2 = (K2*L)**2 + L = 1 # bitdepth of image + C1 = (K1 * L) ** 2 + C2 = (K2 * L) ** 2 mu1 = signal.fftconvolve(img1, window, mode='valid') mu2 = signal.fftconvolve(img2, window, mode='valid') - mu1_sq = mu1*mu1 - mu2_sq = mu2*mu2 - mu1_mu2 = mu1*mu2 - sigma1_sq = signal.fftconvolve(img1*img1, window, mode='valid') - mu1_sq - sigma2_sq = signal.fftconvolve(img2*img2, window, mode='valid') - mu2_sq - sigma12 = signal.fftconvolve(img1*img2, window, mode='valid') - mu1_mu2 + mu1_sq = mu1 * mu1 + mu2_sq = mu2 * mu2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = signal.fftconvolve(img1 * img1, window, mode='valid') - mu1_sq + sigma2_sq = signal.fftconvolve(img2 * img2, window, mode='valid') - mu2_sq + sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid') - mu1_mu2 if cs_map: - return (((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)* - (sigma1_sq + sigma2_sq + C2)), - (2.0*sigma12 + C2)/(sigma1_sq + sigma2_sq + C2)) + return (((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)), + (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)) else: - return ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)* - (sigma1_sq + sigma2_sq + C2)) + return ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) def init_weights(m): @@ -357,4 +389,3 @@ def init_weights(m): elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) - From f1152f67a39426920988a9652081d493314e22af Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Thu, 29 Jul 2021 16:54:23 -0400 Subject: [PATCH 07/26] gravity first version complete --- _train_svg_nonstochastic_posterior.py | 2 - data/convert_mcs.py | 4 +- data/mcs.py | 29 +- do_mcs_implausblility_test_gravity.py | 486 ++++++++++++++++++++++++ do_mcs_implausblility_test_posterior.py | 12 +- train_svg_nonstochastic.py | 2 - train_svg_nonstochastic_posterior.py | 2 - utils.py | 9 +- 8 files changed, 523 insertions(+), 23 deletions(-) create mode 100644 do_mcs_implausblility_test_gravity.py diff --git a/_train_svg_nonstochastic_posterior.py b/_train_svg_nonstochastic_posterior.py index 1dce6a7..fddd604 100644 --- a/_train_svg_nonstochastic_posterior.py +++ b/_train_svg_nonstochastic_posterior.py @@ -52,13 +52,11 @@ latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] print('Loading model ', latest_model) saved_model = torch.load(latest_model) - optimizer = opt.optimizer model_dir = opt.model_dir niter = opt.niter lr = opt.lr opt = saved_model['opt'] opt.niter = niter # update number of epochs to train for - opt.optimizer = optimizer opt.model_dir = model_dir opt.lr = lr opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) diff --git a/data/convert_mcs.py b/data/convert_mcs.py index 562ecdb..e436a8e 100644 --- a/data/convert_mcs.py +++ b/data/convert_mcs.py @@ -16,7 +16,7 @@ IMSIZE = args.imsize if not path.exists(DATA_ROOT): - print(f'directory "{DATA_ROOT}" does not exist! Check arguments') + print(f'directory "{DATA_ROOT}" does not exist! Check -d dataset argument') elif not path.exists(path.join(DATA_ROOT, 'raw')): print("Training videos must be in [datadir]/raw/(task)/*.mp4, where task is the task to which " "the video belongs") @@ -51,4 +51,4 @@ def mp4_to_png_worker(path_to_task): threads.append(t) for thread in threads: - thread.join() \ No newline at end of file + thread.join() diff --git a/data/mcs.py b/data/mcs.py index c00d80d..d801c42 100644 --- a/data/mcs.py +++ b/data/mcs.py @@ -14,16 +14,19 @@ class MCS(object): def __init__(self, train, data_root, seq_len=20, image_size=64, task='ALL', sequential=None, implausible=False, - test_set=False, im_channels=1, use_edge_kernels=True): + test_set=False, im_channels=1, use_edge_kernels=True, labels=False): # if implausible is set to True, generates "fake" images by cutting out or repeating frames self.implausible = implausible - self.data_root = '%s/mcs_videos_1000/processed/' % data_root - # self.data_root = '%s/mcs_videos_test/processed/' % data_root + if test_set: + self.data_root = '%s/mcs_videos_test/processed/' % data_root + else: + self.data_root = '%s/mcs_videos_1000/processed/' % data_root if not os.path.exists(self.data_root): raise os.error('data/mcs.py: Data directory not found!') self.seq_len = seq_len self.image_size = image_size self.im_channels = im_channels + self.labels = labels if use_edge_kernels: if im_channels != 1: raise AssertionError('Using edge kernels implies the output images are grayscale! Set im_channels to 1!') @@ -41,11 +44,11 @@ def __init__(self, train, data_root, seq_len=20, image_size=64, task='ALL', sequ self.video_folder[task] = [path.basename(folder) for folder in sorted(glob(path.join(self.data_root, task, '*')))] self.len_video_folder[task] = len(self.video_folder[task]) - self.seed_set = False self.sequential = sequential # if set to true, return videos in sequence def get_sequence(self, idx=None): + if not self.sequential: task = random.choice(self.tasks) vid = random.choice(self.video_folder[task]) @@ -60,6 +63,8 @@ def get_sequence(self, idx=None): vid = self.video_folder[task][idx] num_frames = len(next(os.walk(path.join(self.data_root, task, vid)))[2]) frame_path = path.join(self.data_root, task, vid, vid + '_') + label = str(os.path.basename(vid)) + label = label[label.rfind('_') + 1:] if num_frames - self.seq_len < 0: return None start = random.randint(0, num_frames - self.seq_len) @@ -97,7 +102,10 @@ def get_sequence(self, idx=None): im = edge_map[..., np.newaxis] seq.append(im) - return np.array(seq) + if self.labels: + return np.array(seq), label + else: + return np.array(seq) def abnormalize_sequence(self, seq): """ @@ -128,11 +136,18 @@ def __getitem__(self, index): random.seed(index) np.random.seed(index) # torch.manual_seed(index) - seq = self.get_sequence(index) + if self.labels: + seq, labels = self.get_sequence(index) + else: + seq = self.get_sequence(index) + if seq is not None: if self.implausible: seq = self.abnormalize_sequence(seq) - return torch.from_numpy(seq) + if self.labels: + return torch.from_numpy(seq), labels + else: + return torch.from_numpy(seq) else: return None diff --git a/do_mcs_implausblility_test_gravity.py b/do_mcs_implausblility_test_gravity.py new file mode 100644 index 0000000..fce2107 --- /dev/null +++ b/do_mcs_implausblility_test_gravity.py @@ -0,0 +1,486 @@ +import glob + +import cv2 +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import matplotlib.pyplot as plt +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.004, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=1, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=1000, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--use_edge_kernels', action='store_true') +parser.add_argument('--dataset', default='mcs_test', help='dataset to train with') +parser.add_argument('--mcs_task', default='GravitySupportEvaluation', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=55, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, + help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=8, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', + help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') +opt = parser.parse_args() +BATCH_SIZE = opt.batch_size +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + niter = opt.niter + dataset = opt.dataset + mcs_task = opt.mcs_task + n_future = opt.n_future + opt = saved_model['opt'] + opt.batch_size = BATCH_SIZE + opt.niter = niter # update number of epochs to train for + opt.dataset = dataset + opt.mcs_task = mcs_task + opt.n_future = n_future +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +font = cv2.FONT_HERSHEY_SIMPLEX +bottomLeftCornerOfText = (30, 30) +fontScale = 0.6 +fontColor = (0, 0, 0) +thickness = 1 + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + +import models.lstm as lstm_models + +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = BATCH_SIZE + posterior = saved_model['posterior'] + posterior.batch_size = BATCH_SIZE + prior = saved_model['prior'] + prior.batch_size = BATCH_SIZE +else: + raise ValueError('Please specify --model_dir') + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() + + +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +print(opt) + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt, sequential=True, implausible=False) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=False, + drop_last=True, + pin_memory=True, ) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=False, + drop_last=True, + pin_memory=True) + + +def get_training_batch(): + while True: + for sequence, labels in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch, labels + + +training_batch_generator = get_training_batch() + + +def get_testing_batch(): + while True: + for sequence, labels in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch, labels + + +testing_batch_generator = get_testing_batch() + + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + h = h.detach() + elif i < opt.n_past: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + x_in = x[i] + gen_seq[s].append(x_in) + else: + h = frame_predictor(h).detach() + x_in = decoder([h, skip]).detach() + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [[] for t in range(opt.n_eval)] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq = [] + gen_seq.append(x[0]) + x_in = x[0] + h_seq = [encoder(x[i]) for i in range(opt.n_past + opt.n_future)] + for i in range(1, opt.n_past + opt.n_future): + h_target = h_seq[i][0].detach() + if opt.last_frame_skip or i < opt.n_past: + h, skip = h_seq[i - 1] + else: + h, _ = h_seq[i - 1] + h = h.detach() + if i < opt.n_past: + frame_predictor(h) + gen_seq.append(x[i]) + else: + h_pred = frame_predictor(h).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + + to_plot = [] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + row = [] + for t in range(opt.n_past + opt.n_future): + row.append(gen_seq[t][i]) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +def get_center_maximal_contour(img, draw=False): + if len(img.shape) == 3: # bgr 2 gray + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + # Find contours + contours, hierarchy = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if len(contours) == 0: + return None + max_contour = max(contours, key=cv2.contourArea) + if cv2.contourArea(max_contour) == 0: + # print(max_contour) + # print(tuple(max_contour[0][0])) + if draw: + cv2.circle(img, tuple(max_contour[0][0]), 1, 150, 1) + return tuple(max_contour[0][0]) + M = cv2.moments(max_contour) + cx, cy = int(M['m10'] / M['m00']), int(M['m01'] / M['m00']) + if draw: + cv2.circle(img, (cx, cy), 1, 150, 1) + return (cx, cy) + + +def high_pass(frame, min=10): + fr = frame.copy() + fr[fr < min] = 0 # remove low brightness pixels + if np.sum(fr) != 0: + frame = fr + + frame = cv2.medianBlur(frame, 5) + # frame = get_maximal_contour(frame) + return frame + + +def do_implasubility_test(z_residual_mean, z_residual_cov, thresh, visualize=True): + MOTION_THRESH = 0.001 + epoch_size = 200 // opt.batch_size # we have a total of 1000 videos + cov_inv = [np.linalg.pinv(cov, hermitian=True) for cov in + z_residual_cov] # we assume the covariance matrix is diagonal and do the inverse for each time + # cov_inv = [np.linalg.inv(np.diag(np.diag(cov))) for cov in z_residual_cov] + # cov_inv = [np.eye(32) for cov in z_residual_cov] + frame_predictor.eval() + posterior.eval() + encoder.eval() + decoder.eval() + progress = progressbar.ProgressBar(max_value=epoch_size).start() + confusion_matrix = [[0, 0], [0, 0]] + # for i in range(50): + # if i % 2 == 0: + # x = next(training_batch_generator) + # else: + # x = next(training_batch_generator_implausible) + for i in range(epoch_size): + h_residual_var = torch.tensor(np.zeros(opt.n_future, dtype=np.float32), requires_grad=False, + device=torch.device('cpu')) + scores = np.array([0 for i in range(opt.n_future)], dtype=np.float32) + progress.update(i + 1) + is_implausible = False + try: + x, labels = next(training_batch_generator) + x_diff = [torch.abs(x[t] - x[t - 1]).detach() for t in range(1, len(x))] + x_diff = [torch.mean(frame, dim=(1, 2, 3)) for frame in x_diff] # mean along C, H, W + motion_start_time = [] + + # find when motion starts (object is about to be dropped) + for batch in range(len(x_diff[0])): + started = False + for t in range(30, len(x_diff)): + if x_diff[t][batch] > MOTION_THRESH: + motion_start_time.append( + t + 1 + 1) # add one because x_diff starts at t=1, then add another frame + started = True + break + if not started: + motion_start_time.append(33) # default to frame 33 if no motion is found + frames = [frame.cpu() for frame in x] + # print(frames[0][0].numpy().dtype) + # print(frames[0][0]) + # quit() + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + h_posterior = [encoder(x[j]) for j in range(0, opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) + last_pred = None + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + posterior.hidden = posterior.init_hidden() + + # start is the first frame that the prior model sees (so start + n_past is the first frame predicted) + if motion_start_time[0] is not None: + start = motion_start_time[0] - 4 + x_in = x[start] + x_out_seq = [x[start].cpu()] + for j in range(start + 1, opt.n_past + opt.n_future + 30): + h = encoder(x_in) + if opt.last_frame_skip or j < opt.n_past + start - 1: + h, skip = h + else: + h, _ = h + + if j < opt.n_past + start + 5: + z_t = posterior(h_posterior[j][0].detach()) + prior(h) + h_post = frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[j] + x_out_seq.append(decoder([h_post, skip]).detach().cpu()) + else: + z_t_hat = prior(h) + h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_in = decoder([h, skip]) + x_out_seq.append(x_in.detach().cpu()) + + if j >= opt.n_past + opt.n_future and torch.mean(torch.abs(x_out_seq[-1] - x_out_seq[-2]), + dim=(1, 2, 3)) <= MOTION_THRESH / 3: + # print('broke at j= ', j) + break + + # z_residual_scores_filtered = -0.25 * scores[:-2] + (0.5 + 0.5) * scores[1:-1] - 0.25 * scores[2:] + # print(motion_start_time[0], len(frames), len(x_out_seq)) + # print(h_residual_var) + background = frames[0][0][0].numpy().copy() + background = np.uint8(np.minimum(background / 1.2, 1.0) * 255.) + if visualize and motion_start_time[0] is not None: + k = 0 + j = start + source_center = None + pred_center = None + while not (j >= len(frames) and k >= len(x_out_seq)): + frame_cv2 = frames[min(j, len(frames) - 1)][0][0].numpy().copy() + frame_cv2 /= 1.2 + frame_cv2 = np.uint8(np.minimum(frame_cv2, 1.0) * 255.) + source_diff = high_pass(cv2.absdiff(frame_cv2, background)) + source_center = get_center_maximal_contour(source_diff, draw=True) + cv2.imshow('source', cv2.resize(frame_cv2, (384, 384), interpolation=cv2.INTER_NEAREST)) + cv2.imshow('source diff', cv2.resize(source_diff, (384, 384), interpolation=cv2.INTER_NEAREST)) + + out_cv2 = x_out_seq[min(k, len(x_out_seq) - 1)][0][0].numpy().copy() + out_cv2 /= 1.2 + # out_cv2 /= np.max(out_cv2) + out_cv2 = np.uint8(np.minimum(out_cv2, 1.0) * 255.) + pred_diff = high_pass(cv2.absdiff(out_cv2, background)) + pred_center = get_center_maximal_contour(pred_diff, draw=True) + cv2.imshow('prediction from first 5 frames', + cv2.resize(out_cv2, (384, 384), interpolation=cv2.INTER_NEAREST)) + cv2.imshow('pred diff', + cv2.resize(pred_diff, (384, 384), interpolation=cv2.INTER_NEAREST)) + cv2.waitKey(0) + j += 1 + k += 1 + if not visualize: + frame_cv2 = frames[-1][0][0].numpy().copy() + frame_cv2 = np.uint8(np.minimum(frame_cv2 / 1.2, 1.0) * 255.) + source_diff = high_pass(cv2.absdiff(frame_cv2, background)) + source_center = get_center_maximal_contour(source_diff, draw=True) + out_cv2 = x_out_seq[-1][0][0].numpy().copy() + out_cv2 = np.uint8(np.minimum(out_cv2 / 1.2, 1.0) * 255.) + pred_diff = high_pass(cv2.absdiff(out_cv2, background)) + pred_center = get_center_maximal_contour(pred_diff, draw=True) + + if (source_center is not None) and (pred_center is not None): + is_implausible = np.abs(source_center[-1] - pred_center[-1]) > thresh # a good thresh seems to be + + if is_implausible: + msg = 'implausible' + else: + msg = 'plausible' + # print('gt: ', labels, 'prediction: ', msg) + confusion_matrix[int(labels[0] == 'implausible')][int(is_implausible)] += 1 + if visualize: + cv2.waitKey(0) + + + # percentile = np.percentile(scores[76:152], 85.0) + # thresh = percentile * 1 + 0.5 + # spikes_idx = np.argwhere(z_residual_scores_filtered > thresh) + # spikes_idx = spikes_idx[ + # (spikes_idx >= 75) & (spikes_idx <= 150)] # ignore spikes near the start and end of video + # spikes = z_residual_scores_filtered[spikes_idx] + # msg = '' + # if len(spikes_idx) > 0: + # # we add n_past because the first n_past frames are not counted. Add 1 because of the filtering + # msg = 'thresh {:.1f} IMPLAUSIBLE spikes: '.format(thresh) + str( + # ['{:.2f}@{}'.format(z_residual_scores_filtered[k], k + opt.n_past + 1) for k in spikes_idx]) + # confusion_matrix[is_implausible][1] += 1 + # else: + # max_idx = np.argmax(z_residual_scores_filtered[75:151]) + 75 + # msg = 'thresh {:.1f} PLAUSIBLE max {:.2f}@{}'.format(thresh, z_residual_scores_filtered[max_idx], + # max_idx + opt.n_past + 1) + # confusion_matrix[is_implausible][0] += 1 + # + # print(msg) + # plt.savefig('implausibility_test.png') + # h_residual_var /= epoch_size # get the mean error vector per time + # h_residual_sd = torch.sqrt(h_residual_var) + print('Last i = {}'.format(i)) + + # H_err_cov = torch.tensor(np.zeros((opt.n_future, 128, 128), dtype=np.float32), requires_grad=False, + # device=torch.device('cuda:0')) + + # plot some stuff + return confusion_matrix + + +# f = open('new_mcs_stats_post.json', 'r') +# mcs_stats_dict = json.load(f) +mcs_stats_dict = {} +with open('new_mcs_stats_post.npy', 'rb') as f: + mcs_stats_dict['mean'] = np.load(f) + mcs_stats_dict['cov'] = np.load(f) + +ROC_curve = {} +for thr in range(2, 11): + train_data, test_data = utils.load_dataset(opt, sequential=True, implausible=False) + + train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=False, + drop_last=True, + pin_memory=True, ) + training_batch_generator = get_training_batch() + + conf_mat = do_implasubility_test(np.array(mcs_stats_dict['mean']), np.array(mcs_stats_dict['cov']), thr, visualize=True) + ROC_curve[thr] = conf_mat +print(ROC_curve) diff --git a/do_mcs_implausblility_test_posterior.py b/do_mcs_implausblility_test_posterior.py index de00033..a4eb013 100644 --- a/do_mcs_implausblility_test_posterior.py +++ b/do_mcs_implausblility_test_posterior.py @@ -56,14 +56,16 @@ latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] print('Loading model ', latest_model) saved_model = torch.load(latest_model) - optimizer = opt.optimizer - model_dir = opt.model_dir niter = opt.niter + dataset = opt.dataset + mcs_task = opt.mcs_task + n_future = opt.n_future opt = saved_model['opt'] + opt.batch_size = BATCH_SIZE opt.niter = niter # update number of epochs to train for - opt.optimizer = optimizer - opt.model_dir = model_dir - opt.log_dir = '%s/continued' % opt.log_dir + opt.dataset = dataset + opt.mcs_task = mcs_task + opt.n_future = n_future else: raise ValueError("Please specify the model to load with the --model_dir argument") diff --git a/train_svg_nonstochastic.py b/train_svg_nonstochastic.py index ce2adda..2ad20ab 100644 --- a/train_svg_nonstochastic.py +++ b/train_svg_nonstochastic.py @@ -51,12 +51,10 @@ latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] print('Loading model ', latest_model) saved_model = torch.load(latest_model) - optimizer = opt.optimizer model_dir = opt.model_dir niter = opt.niter opt = saved_model['opt'] opt.niter = niter # update number of epochs to train for - opt.optimizer = optimizer opt.model_dir = model_dir opt.log_dir = '%s/continued' % opt.log_dir else: diff --git a/train_svg_nonstochastic_posterior.py b/train_svg_nonstochastic_posterior.py index f35b9c2..bc8478e 100644 --- a/train_svg_nonstochastic_posterior.py +++ b/train_svg_nonstochastic_posterior.py @@ -55,7 +55,6 @@ latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] print('Loading model ', latest_model) saved_model = torch.load(latest_model) - optimizer = opt.optimizer model_dir = opt.model_dir niter = opt.niter lr = opt.lr @@ -64,7 +63,6 @@ n_eval = opt.n_eval opt = saved_model['opt'] opt.niter = niter # update number of epochs to train for - opt.optimizer = optimizer opt.model_dir = model_dir opt.n_future = n_future opt.lr = lr diff --git a/utils.py b/utils.py index 22f704c..6c29971 100755 --- a/utils.py +++ b/utils.py @@ -130,7 +130,8 @@ def load_dataset(opt, sequential=None, implausible=None): implausible=implausible, test_set=True, im_channels=opt.channels, - use_edge_kernels=opt.use_edge_kernels) + use_edge_kernels=opt.use_edge_kernels, + labels=True) test_data = MCS( train=False, data_root=opt.data_root, @@ -141,7 +142,8 @@ def load_dataset(opt, sequential=None, implausible=None): implausible=implausible, test_set=True, im_channels=opt.channels, - use_edge_kernels=opt.use_edge_kernels) + use_edge_kernels=opt.use_edge_kernels, + labels=True) return train_data, test_data @@ -151,7 +153,8 @@ def sequence_input(seq, dtype): def normalize_data(opt, dtype, sequence): - if opt.dataset == 'smmnist' or opt.dataset == 'kth' or opt.dataset == 'bair' or opt.dataset == 'mcs': + if opt.dataset == 'smmnist' or opt.dataset == 'kth' or opt.dataset == 'bair' or opt.dataset == 'mcs'\ + or opt.dataset == 'mcs_test': sequence.transpose_(0, 1) sequence.transpose_(3, 4).transpose_(2, 3) else: From defb8fe8ee63920f1c731714a7959762b98c1635 Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Tue, 17 Aug 2021 11:33:02 -0400 Subject: [PATCH 08/26] object perm and collision update --- ... => _do_mcs_implausblility_test_gravity.py | 0 data/mcs.py | 30 +- do_mcs_implausblility_test_gravity_v2.py | 461 ++++++++++++++++++ train_baseline_collision.py | 429 ++++++++++++++++ train_baseline_object_permanence.py | 445 +++++++++++++++++ utils.py | 10 +- 6 files changed, 1364 insertions(+), 11 deletions(-) rename do_mcs_implausblility_test_gravity.py => _do_mcs_implausblility_test_gravity.py (100%) create mode 100644 do_mcs_implausblility_test_gravity_v2.py create mode 100644 train_baseline_collision.py create mode 100644 train_baseline_object_permanence.py diff --git a/do_mcs_implausblility_test_gravity.py b/_do_mcs_implausblility_test_gravity.py similarity index 100% rename from do_mcs_implausblility_test_gravity.py rename to _do_mcs_implausblility_test_gravity.py diff --git a/data/mcs.py b/data/mcs.py index d801c42..3d0fe68 100644 --- a/data/mcs.py +++ b/data/mcs.py @@ -14,7 +14,7 @@ class MCS(object): def __init__(self, train, data_root, seq_len=20, image_size=64, task='ALL', sequential=None, implausible=False, - test_set=False, im_channels=1, use_edge_kernels=True, labels=False): + test_set=False, im_channels=1, use_edge_kernels=True, labels=False, start_min=None, start_max=None, sequence_stride=None): # if implausible is set to True, generates "fake" images by cutting out or repeating frames self.implausible = implausible if test_set: @@ -31,6 +31,9 @@ def __init__(self, train, data_root, seq_len=20, image_size=64, task='ALL', sequ if im_channels != 1: raise AssertionError('Using edge kernels implies the output images are grayscale! Set im_channels to 1!') self.use_edge_kernels = use_edge_kernels + self.start_min = start_min + self.start_max = start_max + self.sequence_stride = sequence_stride # print('mcs.py: found tasks ', self.tasks) self.video_folder = {} @@ -48,7 +51,7 @@ def __init__(self, train, data_root, seq_len=20, image_size=64, task='ALL', sequ self.sequential = sequential # if set to true, return videos in sequence def get_sequence(self, idx=None): - + stride = max(1, self.sequence_stride) if not self.sequential: task = random.choice(self.tasks) vid = random.choice(self.video_folder[task]) @@ -61,16 +64,25 @@ def get_sequence(self, idx=None): if idx >= self.len_video_folder[task]: return None # we've run out of videos vid = self.video_folder[task][idx] - num_frames = len(next(os.walk(path.join(self.data_root, task, vid)))[2]) + num_frames = len(next(os.walk(path.join(self.data_root, task, vid)))[2]) # how many frames we can use after considering the stride frame_path = path.join(self.data_root, task, vid, vid + '_') label = str(os.path.basename(vid)) label = label[label.rfind('_') + 1:] - if num_frames - self.seq_len < 0: - return None - start = random.randint(0, num_frames - self.seq_len) + + start_min = 0 + start_max = num_frames - 1 - (self.seq_len - 1) * stride + if start_max < 0: + raise ValueError("Number of frames in the dataset less than the desired sequence length!") + if self.start_min: + start_min = self.start_min + if self.start_max: + assert self.start_max <= start_max # so the sequence doesn't start too late + start_max = self.start_max + assert start_min <= start_max + start = random.randint(start_min, start_max) seq = [] # choose a random subsequence of frames in the selected video - for i in range(start, start + self.seq_len): + for i in range(start, start + self.seq_len * stride, stride): # i is 0-indexed so we need to add 1 to i fname = frame_path + f'{i + 1:04d}.png' im = imageio.imread(fname) / np.float32(255.) @@ -97,8 +109,8 @@ def get_sequence(self, idx=None): # abs_grad_x = np.abs(grad_x) # abs_grad_y = np.abs(grad_y) edge_map += np.sqrt(grad_x**2 + grad_y**2) - edge_map /= 3*2 # 3 channels and two directions per channel - edge_map /= 6 # to reduce magnitude + edge_map /= 3 # 3 channels + edge_map /= 12 # to reduce magnitude im = edge_map[..., np.newaxis] seq.append(im) diff --git a/do_mcs_implausblility_test_gravity_v2.py b/do_mcs_implausblility_test_gravity_v2.py new file mode 100644 index 0000000..6c7c1c6 --- /dev/null +++ b/do_mcs_implausblility_test_gravity_v2.py @@ -0,0 +1,461 @@ +import glob +from typing import List + +import cv2 +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random + +from shapely.geometry import Polygon +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import matplotlib.pyplot as plt +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.004, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=1, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=1000, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int) +parser.add_argument('--use_edge_kernels', action='store_true') +parser.add_argument('--dataset', default='mcs_test', help='dataset to train with') +parser.add_argument('--mcs_task', default='GravitySupportEvaluation', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=55, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, + help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--model', default='dcgan', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=8, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', + help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') +opt = parser.parse_args() +BATCH_SIZE = opt.batch_size +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + niter = opt.niter + dataset = opt.dataset + mcs_task = opt.mcs_task + n_future = opt.n_future + opt = saved_model['opt'] + opt.batch_size = BATCH_SIZE + opt.niter = niter # update number of epochs to train for + opt.dataset = dataset + opt.mcs_task = mcs_task + opt.n_future = n_future +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +font = cv2.FONT_HERSHEY_SIMPLEX +bottomLeftCornerOfText = (30, 30) +fontScale = 0.6 +fontColor = (0, 0, 0) +thickness = 1 + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + +import models.lstm as lstm_models + +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = BATCH_SIZE + posterior = saved_model['posterior'] + posterior.batch_size = BATCH_SIZE + prior = saved_model['prior'] + prior.batch_size = BATCH_SIZE +else: + raise ValueError('Please specify --model_dir') + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + raise ValueError("Please specify the model to load with the --model_dir argument") + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() + + +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +print(opt) + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt, sequential=True, implausible=False) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=False, + drop_last=True, + pin_memory=True, ) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=False, + drop_last=True, + pin_memory=True) + + +def get_training_batch(): + while True: + for sequence, labels in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch, labels + + +training_batch_generator = get_training_batch() + + +def get_testing_batch(): + while True: + for sequence, labels in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch, labels + + +testing_batch_generator = get_testing_batch() + + +def get_center_maximal_contour(img, draw=False): + if len(img.shape) == 3: # bgr 2 gray + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + # Find contours + contours, hierarchy = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if len(contours) == 0: + return None + max_contour = max(contours, key=cv2.contourArea) + if cv2.contourArea(max_contour) == 0: + # print(max_contour) + # print(tuple(max_contour[0][0])) + if draw: + cv2.circle(img, tuple(max_contour[0][0]), 1, 150, 1) + return tuple(max_contour[0][0]) + M = cv2.moments(max_contour) + cx, cy = int(M['m10'] / M['m00']), int(M['m01'] / M['m00']) + if draw: + cv2.circle(img, (cx, cy), 1, 150, 1) + return (cx, cy) + + +def get_polygons_from_img(img, top_n_contours=2): + img = high_pass(img.copy(), min=9, rad=7) + contours, hierarchy = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + contours = sorted(contours, key=cv2.contourArea, reverse=True) + contours = contours[:top_n_contours] + img = cv2.drawContours(img, contours, -1, 180, 1) + # contours = filter(lambda x: cv2.contourArea(x) > 0, contours) + polys = [] + for cnt in contours: + if cv2.contourArea(cnt) == 0: + continue + epsilon = 0.01 * cv2.arcLength(cnt, True) + poly_approx = cv2.approxPolyDP(cnt, epsilon, True) # shape: NUM_POINTS, 1, 2 (x and y locations) + poly_approx = np.squeeze(poly_approx, axis=1) # compress the second dimension -> (NUM_POINTS, 2) + try: + poly_approx = Polygon(poly_approx) + except ValueError: + print('too few vertices for polygon. original list: ', poly_approx) + idx = np.random.choice(list(range(len(poly_approx))), 3, replace=True) + poly_approx = Polygon(poly_approx[idx]) + print('resampled list: ', poly_approx) + # poly_approx -= get_center_of_mass(poly_approx) # center the COM at origin + c = poly_approx.centroid.coords[0] + cv2.circle(img, (round(c[0]), round(c[1])), 1, 150, 1) + polys.append(poly_approx) + return polys, img + + +def get_implausibility_score(poly_list1: List[Polygon], poly_list2: List[Polygon], thresh=5): + # thresh: when max of min distance > thresh, implausibility score > 0.5 + min_dists = [] + for p1 in poly_list1: + p1_centeroid = p1.centroid + min_dist = 100000000000 + for p2 in poly_list2: + min_dist = min(min_dist, p1_centeroid.distance(p2.centroid)) + min_dists.append(min_dist) + + poly_list1, poly_list2 = poly_list2, poly_list1 + for p1 in poly_list1: + p1_centeroid = p1.centroid + min_dist = 100000000000 + for p2 in poly_list2: + min_dist = min(min_dist, p1_centeroid.distance(p2.centroid)) + min_dists.append(min_dist) + if len(min_dists) == 0: + return 0 # if both the source and prediction has nothing going on, we say it's plausible + max_min_dist = max(min_dists) + # plausibility = 1 / (max(3, max_min_dist) - 2) # R+ -> [0,1] + # return 1 - plausibility + x = (max_min_dist - thresh) / 2 # this means that when MMD >= 4, we have implausibility score >= 0.5 and <0.5 if not. + implausibility = 1 / (1 + np.exp(-x)) + return implausibility + + +def high_pass(frame, min=10, rad=5): + fr = frame.copy() + fr[fr < min] = 0 # remove low brightness pixels + if np.sum(fr) != 0: + frame = fr + + frame = cv2.medianBlur(frame, rad) + # frame = get_maximal_contour(frame) + return frame + + +def do_implasubility_test(z_residual_mean, z_residual_cov, thresh, visualize=True): + MOTION_THRESH = 0.001 + epoch_size = 200 // opt.batch_size # we have a total of 1000 videos + cov_inv = [np.linalg.pinv(cov, hermitian=True) for cov in + z_residual_cov] # we assume the covariance matrix is diagonal and do the inverse for each time + # cov_inv = [np.linalg.inv(np.diag(np.diag(cov))) for cov in z_residual_cov] + # cov_inv = [np.eye(32) for cov in z_residual_cov] + frame_predictor.eval() + posterior.eval() + encoder.eval() + decoder.eval() + progress = progressbar.ProgressBar(max_value=epoch_size).start() + confusion_matrix = [[0, 0], [0, 0]] + # for i in range(50): + # if i % 2 == 0: + # x = next(training_batch_generator) + # else: + # x = next(training_batch_generator_implausible) + for i in range(epoch_size): + h_residual_var = torch.tensor(np.zeros(opt.n_future, dtype=np.float32), requires_grad=False, + device=torch.device('cpu')) + scores = np.array([0 for i in range(opt.n_future)], dtype=np.float32) + progress.update(i + 1) + is_implausible = False + try: + x, labels = next(training_batch_generator) + x_diff = [torch.abs(x[t] - x[t - 1]).detach() for t in range(1, len(x))] + x_diff = [torch.mean(frame, dim=(1, 2, 3)) for frame in x_diff] # mean along C, H, W + motion_start_time = [] + + # find when motion starts (object is about to be dropped) + for batch in range(len(x_diff[0])): + started = False + for t in range(30, len(x_diff)): + if x_diff[t][batch] > MOTION_THRESH: + motion_start_time.append( + t + 1 + 1) # add one because x_diff starts at t=1, then add another frame + started = True + break + if not started: + motion_start_time.append(33) # default to frame 33 if no motion is found + frames = [frame.cpu() for frame in x] + # print(frames[0][0].numpy().dtype) + # print(frames[0][0]) + # quit() + except TypeError: + print('got None at i = {}, terminating'.format(i)) + break + h_posterior = [encoder(x[j]) for j in range(0, opt.n_past + opt.n_future)] + # print(h_posterior[0][0].size()) + last_pred = None + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + posterior.hidden = posterior.init_hidden() + + # start is the first frame that the prior model sees (so start + n_past is the first frame predicted) + if motion_start_time[0] is not None: + start = motion_start_time[0] - 4 + x_in = x[start] + x_out_seq = [x[start].cpu()] + for j in range(start + 1, opt.n_past + opt.n_future + 30): + h = encoder(x_in) + if opt.last_frame_skip or j < opt.n_past + start - 1: + h, skip = h + else: + h, _ = h + + if j < opt.n_past + start + 5: + z_t = posterior(h_posterior[j][0].detach()) + prior(h) + h_post = frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[j] + x_out_seq.append(decoder([h_post, skip]).detach().cpu()) + else: + z_t_hat = prior(h) + h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_in = decoder([h, skip]) + x_out_seq.append(x_in.detach().cpu()) + + if j >= opt.n_past + opt.n_future and torch.mean(torch.abs(x_out_seq[-1] - x_out_seq[-2]), + dim=(1, 2, 3)) <= MOTION_THRESH / 3: + # print('broke at j= ', j) + break + + # z_residual_scores_filtered = -0.25 * scores[:-2] + (0.5 + 0.5) * scores[1:-1] - 0.25 * scores[2:] + # print(motion_start_time[0], len(frames), len(x_out_seq)) + # print(h_residual_var) + background = frames[0][0][0].numpy().copy() + background = np.uint8(np.minimum(background / 1.2, 1.0) * 255.) + if visualize and motion_start_time[0] is not None: + k = 0 + j = start + source_center = None + pred_center = None + while not (j >= len(frames) and k >= len(x_out_seq)): + frame_cv2 = frames[min(j, len(frames) - 1)][0][0].numpy().copy() + frame_cv2 /= 1.2 + frame_cv2 = np.uint8(np.minimum(frame_cv2, 1.0) * 255.) + source_diff = cv2.absdiff(frame_cv2, background) + polys_source, source_diff_marked = get_polygons_from_img(source_diff) + # source_center = get_center_maximal_contour(source_diff, draw=True) + cv2.imshow('source', cv2.resize(frame_cv2, (384, 384), interpolation=cv2.INTER_NEAREST)) + cv2.imshow('source diff', cv2.resize(source_diff_marked, (384, 384), interpolation=cv2.INTER_NEAREST)) + + out_cv2 = x_out_seq[min(k, len(x_out_seq) - 1)][0][0].numpy().copy() + out_cv2 /= 1.2 + # out_cv2 /= np.max(out_cv2) + out_cv2 = np.uint8(np.minimum(out_cv2, 1.0) * 255.) + pred_diff = cv2.absdiff(out_cv2, background) + polys_pred, pred_diff_marked = get_polygons_from_img(pred_diff) + # pred_center = get_center_maximal_contour(pred_diff, draw=True) + cv2.imshow('prediction from first 5 frames', + cv2.resize(out_cv2, (384, 384), interpolation=cv2.INTER_NEAREST)) + cv2.imshow('pred diff', + cv2.resize(pred_diff_marked, (384, 384), interpolation=cv2.INTER_NEAREST)) + src_pred_diff = cv2.absdiff(source_diff, pred_diff) + imp_score = get_implausibility_score(polys_source, polys_pred) + src_pred_diff = cv2.resize(src_pred_diff, (384, 384), interpolation=cv2.INTER_NEAREST) + cv2.putText(src_pred_diff, str(imp_score), (20, 20), font, fontScale, 128, thickness) + cv2.imshow('source-pred diff', src_pred_diff) + cv2.waitKey(0) + j += 1 + k += 1 + if not visualize: + frame_cv2 = frames[-1][0][0].numpy().copy() + frame_cv2 = np.uint8(np.minimum(frame_cv2 / 1.2, 1.0) * 255.) + source_diff = high_pass(cv2.absdiff(frame_cv2, background)) + source_center = get_center_maximal_contour(source_diff, draw=True) + out_cv2 = x_out_seq[-1][0][0].numpy().copy() + out_cv2 = np.uint8(np.minimum(out_cv2 / 1.2, 1.0) * 255.) + pred_diff = high_pass(cv2.absdiff(out_cv2, background)) + pred_center = get_center_maximal_contour(pred_diff, draw=True) + + if (source_center is not None) and (pred_center is not None): + is_implausible = np.abs(source_center[-1] - pred_center[-1]) > thresh # a good thresh seems to be + + if is_implausible: + msg = 'implausible' + else: + msg = 'plausible' + # print('gt: ', labels, 'prediction: ', msg) + confusion_matrix[int(labels[0] == 'implausible')][int(is_implausible)] += 1 + if visualize: + cv2.waitKey(0) + + + # percentile = np.percentile(scores[76:152], 85.0) + # thresh = percentile * 1 + 0.5 + # spikes_idx = np.argwhere(z_residual_scores_filtered > thresh) + # spikes_idx = spikes_idx[ + # (spikes_idx >= 75) & (spikes_idx <= 150)] # ignore spikes near the start and end of video + # spikes = z_residual_scores_filtered[spikes_idx] + # msg = '' + # if len(spikes_idx) > 0: + # # we add n_past because the first n_past frames are not counted. Add 1 because of the filtering + # msg = 'thresh {:.1f} IMPLAUSIBLE spikes: '.format(thresh) + str( + # ['{:.2f}@{}'.format(z_residual_scores_filtered[k], k + opt.n_past + 1) for k in spikes_idx]) + # confusion_matrix[is_implausible][1] += 1 + # else: + # max_idx = np.argmax(z_residual_scores_filtered[75:151]) + 75 + # msg = 'thresh {:.1f} PLAUSIBLE max {:.2f}@{}'.format(thresh, z_residual_scores_filtered[max_idx], + # max_idx + opt.n_past + 1) + # confusion_matrix[is_implausible][0] += 1 + # + # print(msg) + # plt.savefig('implausibility_test.png') + # h_residual_var /= epoch_size # get the mean error vector per time + # h_residual_sd = torch.sqrt(h_residual_var) + print('Last i = {}'.format(i)) + + # H_err_cov = torch.tensor(np.zeros((opt.n_future, 128, 128), dtype=np.float32), requires_grad=False, + # device=torch.device('cuda:0')) + + # plot some stuff + return confusion_matrix + + +# f = open('new_mcs_stats_post.json', 'r') +# mcs_stats_dict = json.load(f) +mcs_stats_dict = {} +with open('new_mcs_stats_post.npy', 'rb') as f: + mcs_stats_dict['mean'] = np.load(f) + mcs_stats_dict['cov'] = np.load(f) + +ROC_curve = {} +for thr in range(2, 11): + train_data, test_data = utils.load_dataset(opt, sequential=True, implausible=False) + + train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=False, + drop_last=True, + pin_memory=True, ) + training_batch_generator = get_training_batch() + + conf_mat = do_implasubility_test(np.array(mcs_stats_dict['mean']), np.array(mcs_stats_dict['cov']), thr, visualize=True) + ROC_curve[thr] = conf_mat +print(ROC_curve) diff --git a/train_baseline_collision.py b/train_baseline_collision.py new file mode 100644 index 0000000..5d60370 --- /dev/null +++ b/train_baseline_collision.py @@ -0,0 +1,429 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.0004, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=13, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=40, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=500, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int, help='number of channels for input images. ') +parser.add_argument('--use_edge_kernels', default=True, type=bool, help='whether to use edge kernels to reduce to 1 channel') +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='CollisionTraining', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=30, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=35, help='number of frames to predict at eval time') +parser.add_argument('--start_min', type=int, default=75, help='min starting time for sampling sequence (0-indexed)') +parser.add_argument('--start_max', type=int, default=100, help='max starting time for sampling sequence (0-indexed)') +parser.add_argument('--sequence_stride', type=int, default=2, help='factor for sequence temporal subsampling (int)') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--prior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=32, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') +parser.add_argument('--model', default='vgg', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') + + +opt = parser.parse_args() +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + lr = opt.lr + batch_size = opt.batch_size + n_future = opt.n_future + n_eval = opt.n_eval + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.n_future = n_future + opt.lr = lr + opt.batch_size = batch_size + opt.n_eval = n_eval + opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) +else: + name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + elif opt.dataset == 'mcs': + opt.log_dir = '%s/%s/%s/%s' % (opt.log_dir, opt.dataset, opt.mcs_task, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=False) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=False) +with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: + json.dump(opt.__dict__, f, indent=2) + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = opt.batch_size + prior = saved_model['prior'] + prior.batch_size = opt.batch_size + posterior = saved_model['posterior'] + posterior.batch_size = opt.batch_size +else: + frame_predictor = lstm_models.lstm(opt.g_dim + opt.z_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + posterior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.posterior_rnn_layers, opt.batch_size) + prior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.prior_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + posterior.apply(utils.init_weights) + prior.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, 1) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +prior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True,) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [utils.torch_rgb_img_to_gray(x[t]) for t in range(len(x))] + + # h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + gen_seq[s].append(utils.torch_rgb_img_to_gray(x[0])) + x_in = x[0] + for i in range(1, opt.n_eval): + with torch.no_grad(): + # if input is grayscale + if x_in.shape[1] == 1 and opt.channels == 3: + h = encoder(torch.cat(3*[x_in], dim=1)) # convert to RGB + else: + h = encoder(x_in) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + if i < opt.n_past: + h_target = encoder(x[i]) + z_t = posterior(h_target[0].detach()) + prior(h) + frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[i] + gen_seq[s].append(utils.torch_rgb_img_to_gray(x_in)) + else: + z_t_hat = prior(h) + h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_in = decoder([h, skip]) + gen_seq[s].append(x_in) + + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + gen_seq = [utils.torch_rgb_img_to_gray(x[0])] + gen_seq_post = [utils.torch_rgb_img_to_gray(x[0])] + + # prediction using posterior Z + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + z_t = posterior(h_target[0]) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t], 1)) + gen_seq_post.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq_post.append(x_pred) + h = h_target + + # prediction using prior Z + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + z_t_hat = prior(h) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t_hat], 1)) + gen_seq.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + h = h_target + + to_plot = [] + nrow = min(opt.batch_size, 25) + x_gray = [utils.torch_rgb_img_to_gray(x[t]) for t in range(min(opt.n_eval, opt.n_past+opt.n_future))] + for i in range(nrow): + row_gt = [] + row_post = [] + row = [] + for t in range(min(opt.n_eval, opt.n_past+opt.n_future)): + row_gt.append(x_gray[t][i]) + row_post.append(gen_seq_post[t][i]) + row.append(gen_seq[t][i]) + to_plot.append(row_gt) + to_plot.append(row_post) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + posterior.zero_grad() + prior.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + + mse = 0 + mse_residual = 0 + h = encoder(x[0]) + for i in range(1, opt.n_past+opt.n_future): + h_target = encoder(x[i]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h = h[0] + + z_t = posterior(h_target[0]) + z_t_hat = prior(h) + h_pred = frame_predictor(torch.cat([h, z_t], 1)) + x_pred = decoder([h_pred, skip]) + gray_target_frame = utils.torch_rgb_img_to_gray(x[i]) + mse += mse_criterion(x_pred, gray_target_frame) + # penalize prior for being far from posterior + mse_residual += opt.gamma * torch.mean(torch.square(z_t.detach() - z_t_hat)) + h = h_target + + loss = mse + mse_residual + loss.backward() + + frame_predictor_optimizer.step() + posterior_optimizer.step() + prior_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + N = opt.n_past+opt.n_future + return mse.data.cpu().numpy()/N, mse_residual.data.cpu().numpy()/N + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + posterior.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_mse_residual = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + # opt.epoch_size = 10 + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse, mse_residual = train(x) + epoch_mse += mse + epoch_mse_residual += mse_residual + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + posterior.eval() + prior.eval() + encoder.eval() + decoder.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'posterior': posterior, + 'prior': posterior, + 'opt': opt}, + '%s/model_e%02d.pth' % (opt.log_dir, epoch)) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + diff --git a/train_baseline_object_permanence.py b/train_baseline_object_permanence.py new file mode 100644 index 0000000..c60480f --- /dev/null +++ b/train_baseline_object_permanence.py @@ -0,0 +1,445 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.0004, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=10, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=40, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=500, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int, help='number of channels for input images. ') +parser.add_argument('--use_edge_kernels', default=True, type=bool, help='whether to use edge kernels to reduce to 1 channel') +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='ObjectPermanenceTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=45, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=50, help='number of frames to predict at eval time') +parser.add_argument('--start_min', type=int, default=65, help='min starting time for sampling sequence (0-indexed)') +parser.add_argument('--start_max', type=int, default=85, help='max starting time for sampling sequence (0-indexed)') +parser.add_argument('--sequence_stride', type=int, default=3, help='factor for sequence temporal subsampling (int)') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--prior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=32, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') +parser.add_argument('--model', default='vgg', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') + + +opt = parser.parse_args() +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + lr = opt.lr + batch_size = opt.batch_size + n_future = opt.n_future + n_eval = opt.n_eval + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.n_future = n_future + opt.lr = lr + opt.batch_size = batch_size + opt.n_eval = n_eval + opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) +else: + name = 'nmodel=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + elif opt.dataset == 'mcs': + opt.log_dir = '%s/%s/%s/%s' % (opt.log_dir, opt.dataset, opt.mcs_task, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=True) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=True) +with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: + opt2 = opt.__dict__.copy() + if isinstance(opt2['optimizer'], type): + opt2['optimizer'] = str(opt2['optimizer']) + json.dump(opt2, f, indent=2) + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +elif isinstance(opt.optimizer, type): + pass +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = opt.batch_size + prior = saved_model['prior'] + prior.batch_size = opt.batch_size + posterior = saved_model['posterior'] + posterior.batch_size = opt.batch_size +else: + frame_predictor = lstm_models.lstm(opt.g_dim + opt.z_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + posterior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.posterior_rnn_layers, opt.batch_size) + prior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.prior_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + posterior.apply(utils.init_weights) + prior.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, 1) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +prior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True,) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [utils.torch_rgb_img_to_gray(x[t]) for t in range(len(x))] + + # h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + gen_seq[s].append(utils.torch_rgb_img_to_gray(x[0])) + x_in = x[0] + for i in range(1, opt.n_eval): + with torch.no_grad(): + # if input is grayscale + if x_in.shape[1] == 1 and opt.channels == 3: + h = encoder(torch.cat(3*[x_in], dim=1)) # convert to RGB + else: + h = encoder(x_in) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + if i < opt.n_past: + h_target = encoder(x[i]) + z_t = posterior(h_target[0].detach()) + prior(h) + frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[i] + gen_seq[s].append(utils.torch_rgb_img_to_gray(x_in)) + else: + z_t_hat = prior(h) + h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_in = decoder([h, skip]) + gen_seq[s].append(x_in) + + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + gen_seq = [utils.torch_rgb_img_to_gray(x[0])] + gen_seq_post = [utils.torch_rgb_img_to_gray(x[0])] + + # prediction using posterior Z + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + z_t = posterior(h_target[0]) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t], 1)) + gen_seq_post.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq_post.append(x_pred) + h = h_target + + # prediction using prior Z + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + z_t_hat = prior(h) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t_hat], 1)) + gen_seq.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + h = h_target + + to_plot = [] + nrow = min(opt.batch_size, 25) + x_gray = [utils.torch_rgb_img_to_gray(x[t]) for t in range(min(opt.n_eval, opt.n_past+opt.n_future))] + for i in range(nrow): + row_gt = [] + row_post = [] + row = [] + for t in range(min(opt.n_eval, opt.n_past+opt.n_future)): + row_gt.append(x_gray[t][i]) + row_post.append(gen_seq_post[t][i]) + row.append(gen_seq[t][i]) + to_plot.append(row_gt) + to_plot.append(row_post) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + posterior.zero_grad() + prior.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + + mse = 0 + mse_residual = 0 + # x: T x B x C x H x W + x_diff = [1] + for i in range(1, len(x)): + diff = torch.abs(x[i] - x[i - 1]) + diff = torch.mean(diff, dim=(1, 2, 3)).detach() # mean over channels, width, and height + x_diff.append(diff) + + # print(i, x_diff > 1e-6) + h = encoder(x[0]) + for i in range(1, opt.n_past+opt.n_future): + h_target = encoder(x[i]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h = h[0] + + z_t = posterior(h_target[0]) + z_t_hat = prior(h) + h_pred = frame_predictor(torch.cat([h, z_t], 1)) + x_pred = decoder([h_pred, skip]) + gray_target_frame = utils.torch_rgb_img_to_gray(x[i]) + + still_frames = x_diff[i] <= 5e-6 + weights = torch.pow(0.05, still_frames).detach()[:, None, None, None] + mse += mse_criterion(weights * x_pred, weights * gray_target_frame) + # penalize prior for being far from posterior + mse_residual += opt.gamma * torch.mean(weights * torch.square(z_t.detach() - z_t_hat)) + h = h_target + + loss = mse + mse_residual + loss.backward() + + frame_predictor_optimizer.step() + posterior_optimizer.step() + prior_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + N = opt.n_past+opt.n_future + return mse.data.cpu().numpy()/N, mse_residual.data.cpu().numpy()/N + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + posterior.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_mse_residual = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + # opt.epoch_size = 10 + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse, mse_residual = train(x) + epoch_mse += mse + epoch_mse_residual += mse_residual + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + posterior.eval() + prior.eval() + encoder.eval() + decoder.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'posterior': posterior, + 'prior': posterior, + 'opt': opt}, + '%s/model_e%02d.pth' % (opt.log_dir, epoch)) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + diff --git a/utils.py b/utils.py index 6c29971..ec97c89 100755 --- a/utils.py +++ b/utils.py @@ -107,7 +107,10 @@ def load_dataset(opt, sequential=None, implausible=None): sequential=sequential, implausible=implausible, im_channels=opt.channels, - use_edge_kernels=opt.use_edge_kernels) + use_edge_kernels=opt.use_edge_kernels, + start_min=opt.start_min, + start_max=opt.start_max, + sequence_stride=opt.sequence_stride) test_data = MCS( train=False, data_root=opt.data_root, @@ -117,7 +120,10 @@ def load_dataset(opt, sequential=None, implausible=None): sequential=sequential, implausible=implausible, im_channels=opt.channels, - use_edge_kernels=opt.use_edge_kernels) + use_edge_kernels=opt.use_edge_kernels, + start_min=opt.start_min, + start_max=opt.start_max, + sequence_stride=opt.sequence_stride) elif opt.dataset == 'mcs_test': from data.mcs import MCS train_data = MCS( From e647d781c3e76e7aecd97f2ce0bf85e2d294b023 Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Thu, 26 Aug 2021 16:12:38 -0400 Subject: [PATCH 09/26] gravity update --- train_baseline_collision.py | 2 +- train_baseline_object_permanence.py | 2 +- train_svg_nonstochastic_posterior.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/train_baseline_collision.py b/train_baseline_collision.py index 5d60370..4cb527c 100644 --- a/train_baseline_collision.py +++ b/train_baseline_collision.py @@ -420,7 +420,7 @@ def train(x): 'decoder': decoder, 'frame_predictor': frame_predictor, 'posterior': posterior, - 'prior': posterior, + 'prior': prior, 'opt': opt}, '%s/model_e%02d.pth' % (opt.log_dir, epoch)) if epoch % 10 == 0: diff --git a/train_baseline_object_permanence.py b/train_baseline_object_permanence.py index c60480f..e3bbf73 100644 --- a/train_baseline_object_permanence.py +++ b/train_baseline_object_permanence.py @@ -436,7 +436,7 @@ def train(x): 'decoder': decoder, 'frame_predictor': frame_predictor, 'posterior': posterior, - 'prior': posterior, + 'prior': prior, 'opt': opt}, '%s/model_e%02d.pth' % (opt.log_dir, epoch)) if epoch % 10 == 0: diff --git a/train_svg_nonstochastic_posterior.py b/train_svg_nonstochastic_posterior.py index bc8478e..95a1f3b 100644 --- a/train_svg_nonstochastic_posterior.py +++ b/train_svg_nonstochastic_posterior.py @@ -417,7 +417,7 @@ def train(x): 'decoder': decoder, 'frame_predictor': frame_predictor, 'posterior': posterior, - 'prior': posterior, + 'prior': prior, 'opt': opt}, '%s/model_e%02d.pth' % (opt.log_dir, epoch)) if epoch % 10 == 0: From c0aef610b4b42f39c32c7b18f5e545410c075522 Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Thu, 26 Aug 2021 16:18:51 -0400 Subject: [PATCH 10/26] updated LR --- train_baseline_collision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_baseline_collision.py b/train_baseline_collision.py index 4cb527c..71e33ab 100644 --- a/train_baseline_collision.py +++ b/train_baseline_collision.py @@ -15,7 +15,7 @@ import json parser = argparse.ArgumentParser() -parser.add_argument('--lr', default=0.0004, type=float, help='learning rate') +parser.add_argument('--lr', default=0.0001, type=float, help='learning rate') parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') parser.add_argument('--batch_size', default=13, type=int, help='batch size') parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') From ab9d19fb6a0e7a99062fbf7002484c66dc1d7428 Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Thu, 26 Aug 2021 17:08:01 -0400 Subject: [PATCH 11/26] updated LR. Added requirements.txt --- requirements.txt | 3 +++ train_baseline_collision.py | 2 +- train_baseline_object_permanence.py | 2 +- train_svg_nonstochastic_posterior.py | 2 +- 4 files changed, 6 insertions(+), 3 deletions(-) create mode 100644 requirements.txt diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..d1f2f23 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +numpy +torch +shapely diff --git a/train_baseline_collision.py b/train_baseline_collision.py index 71e33ab..ce4b5e7 100644 --- a/train_baseline_collision.py +++ b/train_baseline_collision.py @@ -15,7 +15,7 @@ import json parser = argparse.ArgumentParser() -parser.add_argument('--lr', default=0.0001, type=float, help='learning rate') +parser.add_argument('--lr', default=0.0002, type=float, help='learning rate') parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') parser.add_argument('--batch_size', default=13, type=int, help='batch size') parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') diff --git a/train_baseline_object_permanence.py b/train_baseline_object_permanence.py index e3bbf73..0a314c2 100644 --- a/train_baseline_object_permanence.py +++ b/train_baseline_object_permanence.py @@ -15,7 +15,7 @@ import json parser = argparse.ArgumentParser() -parser.add_argument('--lr', default=0.0004, type=float, help='learning rate') +parser.add_argument('--lr', default=0.0002, type=float, help='learning rate') parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') parser.add_argument('--batch_size', default=10, type=int, help='batch size') parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') diff --git a/train_svg_nonstochastic_posterior.py b/train_svg_nonstochastic_posterior.py index 95a1f3b..032b931 100644 --- a/train_svg_nonstochastic_posterior.py +++ b/train_svg_nonstochastic_posterior.py @@ -15,7 +15,7 @@ import json parser = argparse.ArgumentParser() -parser.add_argument('--lr', default=0.0001, type=float, help='learning rate') +parser.add_argument('--lr', default=0.0002, type=float, help='learning rate') parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') parser.add_argument('--batch_size', default=24, type=int, help='batch size') parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') From 9d4fa6ad93a8f655cf3e6dfee3d4d89966a69930 Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Thu, 26 Aug 2021 18:17:16 -0400 Subject: [PATCH 12/26] Added gravity training --- requirements.txt | 5 + train_baseline_gravity.py | 429 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 434 insertions(+) create mode 100644 train_baseline_gravity.py diff --git a/requirements.txt b/requirements.txt index d1f2f23..4d88cc3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,8 @@ numpy torch shapely +sklearn +matplotlib +scikit-image +progressbar2 +opencv-python \ No newline at end of file diff --git a/train_baseline_gravity.py b/train_baseline_gravity.py new file mode 100644 index 0000000..86940ec --- /dev/null +++ b/train_baseline_gravity.py @@ -0,0 +1,429 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.0002, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=13, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=40, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=500, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int, help='number of channels for input images. ') +parser.add_argument('--use_edge_kernels', default=True, type=bool, help='whether to use edge kernels to reduce to 1 channel') +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='GravitySupportTraining', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=25, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--start_min', type=int, default=7, help='min starting time for sampling sequence (0-indexed)') +parser.add_argument('--start_max', type=int, default=29, help='max starting time for sampling sequence (0-indexed)') +parser.add_argument('--sequence_stride', type=int, default=1, help='factor for sequence temporal subsampling (int)') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--prior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=32, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') +parser.add_argument('--model', default='vgg', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') + + +opt = parser.parse_args() +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + lr = opt.lr + batch_size = opt.batch_size + n_future = opt.n_future + n_eval = opt.n_eval + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.n_future = n_future + opt.lr = lr + opt.batch_size = batch_size + opt.n_eval = n_eval + opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) +else: + name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + elif opt.dataset == 'mcs': + opt.log_dir = '%s/%s/%s/%s' % (opt.log_dir, opt.dataset, opt.mcs_task, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=False) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=False) +with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: + json.dump(opt.__dict__, f, indent=2) + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = opt.batch_size + prior = saved_model['prior'] + prior.batch_size = opt.batch_size + posterior = saved_model['posterior'] + posterior.batch_size = opt.batch_size +else: + frame_predictor = lstm_models.lstm(opt.g_dim + opt.z_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + posterior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.posterior_rnn_layers, opt.batch_size) + prior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.prior_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + posterior.apply(utils.init_weights) + prior.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, 1) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +prior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True,) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [utils.torch_rgb_img_to_gray(x[t]) for t in range(len(x))] + + # h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + gen_seq[s].append(utils.torch_rgb_img_to_gray(x[0])) + x_in = x[0] + for i in range(1, opt.n_eval): + with torch.no_grad(): + # if input is grayscale + if x_in.shape[1] == 1 and opt.channels == 3: + h = encoder(torch.cat(3*[x_in], dim=1)) # convert to RGB + else: + h = encoder(x_in) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + if i < opt.n_past: + h_target = encoder(x[i]) + z_t = posterior(h_target[0].detach()) + prior(h) + frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[i] + gen_seq[s].append(utils.torch_rgb_img_to_gray(x_in)) + else: + z_t_hat = prior(h) + h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_in = decoder([h, skip]) + gen_seq[s].append(x_in) + + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + gen_seq = [utils.torch_rgb_img_to_gray(x[0])] + gen_seq_post = [utils.torch_rgb_img_to_gray(x[0])] + + # prediction using posterior Z + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + z_t = posterior(h_target[0]) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t], 1)) + gen_seq_post.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq_post.append(x_pred) + h = h_target + + # prediction using prior Z + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + z_t_hat = prior(h) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t_hat], 1)) + gen_seq.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + h = h_target + + to_plot = [] + nrow = min(opt.batch_size, 25) + x_gray = [utils.torch_rgb_img_to_gray(x[t]) for t in range(min(opt.n_eval, opt.n_past+opt.n_future))] + for i in range(nrow): + row_gt = [] + row_post = [] + row = [] + for t in range(min(opt.n_eval, opt.n_past+opt.n_future)): + row_gt.append(x_gray[t][i]) + row_post.append(gen_seq_post[t][i]) + row.append(gen_seq[t][i]) + to_plot.append(row_gt) + to_plot.append(row_post) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + posterior.zero_grad() + prior.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + + mse = 0 + mse_residual = 0 + h = encoder(x[0]) + for i in range(1, opt.n_past+opt.n_future): + h_target = encoder(x[i]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h = h[0] + + z_t = posterior(h_target[0]) + z_t_hat = prior(h) + h_pred = frame_predictor(torch.cat([h, z_t], 1)) + x_pred = decoder([h_pred, skip]) + gray_target_frame = utils.torch_rgb_img_to_gray(x[i]) + mse += mse_criterion(x_pred, gray_target_frame) + # penalize prior for being far from posterior + mse_residual += opt.gamma * torch.mean(torch.square(z_t.detach() - z_t_hat)) + h = h_target + + loss = mse + mse_residual + loss.backward() + + frame_predictor_optimizer.step() + posterior_optimizer.step() + prior_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + N = opt.n_past+opt.n_future + return mse.data.cpu().numpy()/N, mse_residual.data.cpu().numpy()/N + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + posterior.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_mse_residual = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + # opt.epoch_size = 10 + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse, mse_residual = train(x) + epoch_mse += mse + epoch_mse_residual += mse_residual + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + posterior.eval() + prior.eval() + encoder.eval() + decoder.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'posterior': posterior, + 'prior': prior, + 'opt': opt}, + '%s/model_e%02d.pth' % (opt.log_dir, epoch)) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + From 321df206451fc3fdff8eb5de4e922dd4f93d8d4e Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Thu, 26 Aug 2021 22:03:31 -0400 Subject: [PATCH 13/26] Added spatial temporal continuity training --- train_baseline_object_permanence.py | 2 +- train_baseline_spatialTemporalContinuity.py | 429 ++++++++++++++++++++ 2 files changed, 430 insertions(+), 1 deletion(-) create mode 100644 train_baseline_spatialTemporalContinuity.py diff --git a/train_baseline_object_permanence.py b/train_baseline_object_permanence.py index 0a314c2..2ef3c2f 100644 --- a/train_baseline_object_permanence.py +++ b/train_baseline_object_permanence.py @@ -73,7 +73,7 @@ opt.n_eval = n_eval opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) else: - name = 'nmodel=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) if opt.dataset == 'smmnist': opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) elif opt.dataset == 'mcs': diff --git a/train_baseline_spatialTemporalContinuity.py b/train_baseline_spatialTemporalContinuity.py new file mode 100644 index 0000000..0bce3a9 --- /dev/null +++ b/train_baseline_spatialTemporalContinuity.py @@ -0,0 +1,429 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.0002, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=16, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=40, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=500, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int, help='number of channels for input images. ') +parser.add_argument('--use_edge_kernels', default=True, type=bool, help='whether to use edge kernels to reduce to 1 channel') +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=25, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--start_min', type=int, default=80, help='min starting time for sampling sequence (0-indexed)') +parser.add_argument('--start_max', type=int, default=135, help='max starting time for sampling sequence (0-indexed)') +parser.add_argument('--sequence_stride', type=int, default=1, help='factor for sequence temporal subsampling (int)') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--prior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=32, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') +parser.add_argument('--model', default='vgg', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') + + +opt = parser.parse_args() +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + lr = opt.lr + batch_size = opt.batch_size + n_future = opt.n_future + n_eval = opt.n_eval + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.n_future = n_future + opt.lr = lr + opt.batch_size = batch_size + opt.n_eval = n_eval + opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) +else: + name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + elif opt.dataset == 'mcs': + opt.log_dir = '%s/%s/%s/%s' % (opt.log_dir, opt.dataset, opt.mcs_task, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=False) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=False) +with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: + json.dump(opt.__dict__, f, indent=2) + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = opt.batch_size + prior = saved_model['prior'] + prior.batch_size = opt.batch_size + posterior = saved_model['posterior'] + posterior.batch_size = opt.batch_size +else: + frame_predictor = lstm_models.lstm(opt.g_dim + opt.z_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + posterior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.posterior_rnn_layers, opt.batch_size) + prior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.prior_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + posterior.apply(utils.init_weights) + prior.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, 1) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +prior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True,) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [utils.torch_rgb_img_to_gray(x[t]) for t in range(len(x))] + + # h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + gen_seq[s].append(utils.torch_rgb_img_to_gray(x[0])) + x_in = x[0] + for i in range(1, opt.n_eval): + with torch.no_grad(): + # if input is grayscale + if x_in.shape[1] == 1 and opt.channels == 3: + h = encoder(torch.cat(3*[x_in], dim=1)) # convert to RGB + else: + h = encoder(x_in) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + if i < opt.n_past: + h_target = encoder(x[i]) + z_t = posterior(h_target[0].detach()) + prior(h) + frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[i] + gen_seq[s].append(utils.torch_rgb_img_to_gray(x_in)) + else: + z_t_hat = prior(h) + h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_in = decoder([h, skip]) + gen_seq[s].append(x_in) + + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + gen_seq = [utils.torch_rgb_img_to_gray(x[0])] + gen_seq_post = [utils.torch_rgb_img_to_gray(x[0])] + + # prediction using posterior Z + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + z_t = posterior(h_target[0]) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t], 1)) + gen_seq_post.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq_post.append(x_pred) + h = h_target + + # prediction using prior Z + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + z_t_hat = prior(h) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t_hat], 1)) + gen_seq.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + h = h_target + + to_plot = [] + nrow = min(opt.batch_size, 25) + x_gray = [utils.torch_rgb_img_to_gray(x[t]) for t in range(min(opt.n_eval, opt.n_past+opt.n_future))] + for i in range(nrow): + row_gt = [] + row_post = [] + row = [] + for t in range(min(opt.n_eval, opt.n_past+opt.n_future)): + row_gt.append(x_gray[t][i]) + row_post.append(gen_seq_post[t][i]) + row.append(gen_seq[t][i]) + to_plot.append(row_gt) + to_plot.append(row_post) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + posterior.zero_grad() + prior.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + + mse = 0 + mse_residual = 0 + h = encoder(x[0]) + for i in range(1, opt.n_past+opt.n_future): + h_target = encoder(x[i]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h = h[0] + + z_t = posterior(h_target[0]) + z_t_hat = prior(h) + h_pred = frame_predictor(torch.cat([h, z_t], 1)) + x_pred = decoder([h_pred, skip]) + gray_target_frame = utils.torch_rgb_img_to_gray(x[i]) + mse += mse_criterion(x_pred, gray_target_frame) + # penalize prior for being far from posterior + mse_residual += opt.gamma * torch.mean(torch.square(z_t.detach() - z_t_hat)) + h = h_target + + loss = mse + mse_residual + loss.backward() + + frame_predictor_optimizer.step() + posterior_optimizer.step() + prior_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + N = opt.n_past+opt.n_future + return mse.data.cpu().numpy()/N, mse_residual.data.cpu().numpy()/N + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + posterior.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_mse_residual = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + # opt.epoch_size = 10 + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse, mse_residual = train(x) + epoch_mse += mse + epoch_mse_residual += mse_residual + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + posterior.eval() + prior.eval() + encoder.eval() + decoder.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'posterior': posterior, + 'prior': prior, + 'opt': opt}, + '%s/model_e%02d.pth' % (opt.log_dir, epoch)) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + From 5b76566cb619530249b39d2952b4523c0cb9d322 Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Fri, 27 Aug 2021 14:14:30 -0400 Subject: [PATCH 14/26] added training tasks --- train_baseline_collision.py | 6 +- train_baseline_gravity.py | 7 +- train_baseline_object_permanence.py | 1 + train_baseline_shapeconstancy.py | 433 ++++++++++++++++++++ train_baseline_spatialTemporalContinuity.py | 6 +- train_svg_nonstochastic_posterior.py | 6 +- 6 files changed, 454 insertions(+), 5 deletions(-) create mode 100644 train_baseline_shapeconstancy.py diff --git a/train_baseline_collision.py b/train_baseline_collision.py index ce4b5e7..e348ec9 100644 --- a/train_baseline_collision.py +++ b/train_baseline_collision.py @@ -84,7 +84,11 @@ os.makedirs('%s/gen/' % opt.log_dir, exist_ok=False) os.makedirs('%s/plots/' % opt.log_dir, exist_ok=False) with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: - json.dump(opt.__dict__, f, indent=2) + opt2 = opt.__dict__.copy() + if isinstance(opt2['optimizer'], type): + opt2['optimizer'] = str(opt2['optimizer']) + json.dump(opt2, f, indent=2) + del opt2 print("Random Seed: ", opt.seed) random.seed(opt.seed) diff --git a/train_baseline_gravity.py b/train_baseline_gravity.py index 86940ec..a2adea4 100644 --- a/train_baseline_gravity.py +++ b/train_baseline_gravity.py @@ -84,8 +84,11 @@ os.makedirs('%s/gen/' % opt.log_dir, exist_ok=False) os.makedirs('%s/plots/' % opt.log_dir, exist_ok=False) with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: - json.dump(opt.__dict__, f, indent=2) - + opt2 = opt.__dict__.copy() + if isinstance(opt2['optimizer'], type): + opt2['optimizer'] = str(opt2['optimizer']) + json.dump(opt2, f, indent=2) + del opt2 print("Random Seed: ", opt.seed) random.seed(opt.seed) torch.manual_seed(opt.seed) diff --git a/train_baseline_object_permanence.py b/train_baseline_object_permanence.py index 2ef3c2f..453aee2 100644 --- a/train_baseline_object_permanence.py +++ b/train_baseline_object_permanence.py @@ -88,6 +88,7 @@ if isinstance(opt2['optimizer'], type): opt2['optimizer'] = str(opt2['optimizer']) json.dump(opt2, f, indent=2) + del opt2 print("Random Seed: ", opt.seed) random.seed(opt.seed) diff --git a/train_baseline_shapeconstancy.py b/train_baseline_shapeconstancy.py new file mode 100644 index 0000000..8122238 --- /dev/null +++ b/train_baseline_shapeconstancy.py @@ -0,0 +1,433 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.0002, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=13, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=40, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=500, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int, help='number of channels for input images. ') +parser.add_argument('--use_edge_kernels', default=True, type=bool, help='whether to use edge kernels to reduce to 1 channel') +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='GravitySupportTraining', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=30, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=35, help='number of frames to predict at eval time') +parser.add_argument('--start_min', type=int, default=79, help='min starting time for sampling sequence (0-indexed)') +parser.add_argument('--start_max', type=int, default=89, help='max starting time for sampling sequence (0-indexed)') +parser.add_argument('--sequence_stride', type=int, default=2, help='factor for sequence temporal subsampling (int)') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--prior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=32, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') +parser.add_argument('--model', default='vgg', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') + + +opt = parser.parse_args() +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + lr = opt.lr + batch_size = opt.batch_size + n_future = opt.n_future + n_eval = opt.n_eval + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.n_future = n_future + opt.lr = lr + opt.batch_size = batch_size + opt.n_eval = n_eval + opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) +else: + name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + elif opt.dataset == 'mcs': + opt.log_dir = '%s/%s/%s/%s' % (opt.log_dir, opt.dataset, opt.mcs_task, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=False) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=False) +with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: + opt2 = opt.__dict__.copy() + if isinstance(opt2['optimizer'], type): + opt2['optimizer'] = str(opt2['optimizer']) + json.dump(opt2, f, indent=2) + del opt2 + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = opt.batch_size + prior = saved_model['prior'] + prior.batch_size = opt.batch_size + posterior = saved_model['posterior'] + posterior.batch_size = opt.batch_size +else: + frame_predictor = lstm_models.lstm(opt.g_dim + opt.z_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + posterior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.posterior_rnn_layers, opt.batch_size) + prior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.prior_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + posterior.apply(utils.init_weights) + prior.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, 1) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +prior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True,) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [utils.torch_rgb_img_to_gray(x[t]) for t in range(len(x))] + + # h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + gen_seq[s].append(utils.torch_rgb_img_to_gray(x[0])) + x_in = x[0] + for i in range(1, opt.n_eval): + with torch.no_grad(): + # if input is grayscale + if x_in.shape[1] == 1 and opt.channels == 3: + h = encoder(torch.cat(3*[x_in], dim=1)) # convert to RGB + else: + h = encoder(x_in) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + if i < opt.n_past: + h_target = encoder(x[i]) + z_t = posterior(h_target[0].detach()) + prior(h) + frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[i] + gen_seq[s].append(utils.torch_rgb_img_to_gray(x_in)) + else: + z_t_hat = prior(h) + h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_in = decoder([h, skip]) + gen_seq[s].append(x_in) + + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + gen_seq = [utils.torch_rgb_img_to_gray(x[0])] + gen_seq_post = [utils.torch_rgb_img_to_gray(x[0])] + + # prediction using posterior Z + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + z_t = posterior(h_target[0]) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t], 1)) + gen_seq_post.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq_post.append(x_pred) + h = h_target + + # prediction using prior Z + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + z_t_hat = prior(h) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t_hat], 1)) + gen_seq.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + h = h_target + + to_plot = [] + nrow = min(opt.batch_size, 25) + x_gray = [utils.torch_rgb_img_to_gray(x[t]) for t in range(min(opt.n_eval, opt.n_past+opt.n_future))] + for i in range(nrow): + row_gt = [] + row_post = [] + row = [] + for t in range(min(opt.n_eval, opt.n_past+opt.n_future)): + row_gt.append(x_gray[t][i]) + row_post.append(gen_seq_post[t][i]) + row.append(gen_seq[t][i]) + to_plot.append(row_gt) + to_plot.append(row_post) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + posterior.zero_grad() + prior.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + + mse = 0 + mse_residual = 0 + h = encoder(x[0]) + for i in range(1, opt.n_past+opt.n_future): + h_target = encoder(x[i]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h = h[0] + + z_t = posterior(h_target[0]) + z_t_hat = prior(h) + h_pred = frame_predictor(torch.cat([h, z_t], 1)) + x_pred = decoder([h_pred, skip]) + gray_target_frame = utils.torch_rgb_img_to_gray(x[i]) + mse += mse_criterion(x_pred, gray_target_frame) + # penalize prior for being far from posterior + mse_residual += opt.gamma * torch.mean(torch.square(z_t.detach() - z_t_hat)) + h = h_target + + loss = mse + mse_residual + loss.backward() + + frame_predictor_optimizer.step() + posterior_optimizer.step() + prior_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + N = opt.n_past+opt.n_future + return mse.data.cpu().numpy()/N, mse_residual.data.cpu().numpy()/N + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + posterior.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_mse_residual = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + # opt.epoch_size = 10 + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse, mse_residual = train(x) + epoch_mse += mse + epoch_mse_residual += mse_residual + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + posterior.eval() + prior.eval() + encoder.eval() + decoder.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'posterior': posterior, + 'prior': prior, + 'opt': opt}, + '%s/model_e%02d.pth' % (opt.log_dir, epoch)) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + diff --git a/train_baseline_spatialTemporalContinuity.py b/train_baseline_spatialTemporalContinuity.py index 0bce3a9..a04a986 100644 --- a/train_baseline_spatialTemporalContinuity.py +++ b/train_baseline_spatialTemporalContinuity.py @@ -84,7 +84,11 @@ os.makedirs('%s/gen/' % opt.log_dir, exist_ok=False) os.makedirs('%s/plots/' % opt.log_dir, exist_ok=False) with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: - json.dump(opt.__dict__, f, indent=2) + opt2 = opt.__dict__.copy() + if isinstance(opt2['optimizer'], type): + opt2['optimizer'] = str(opt2['optimizer']) + json.dump(opt2, f, indent=2) + del opt2 print("Random Seed: ", opt.seed) random.seed(opt.seed) diff --git a/train_svg_nonstochastic_posterior.py b/train_svg_nonstochastic_posterior.py index 032b931..5227607 100644 --- a/train_svg_nonstochastic_posterior.py +++ b/train_svg_nonstochastic_posterior.py @@ -81,7 +81,11 @@ os.makedirs('%s/gen/' % opt.log_dir, exist_ok=False) os.makedirs('%s/plots/' % opt.log_dir, exist_ok=False) with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: - json.dump(opt.__dict__, f, indent=2) + opt2 = opt.__dict__.copy() + if isinstance(opt2['optimizer'], type): + opt2['optimizer'] = str(opt2['optimizer']) + json.dump(opt2, f, indent=2) + del opt2 print("Random Seed: ", opt.seed) random.seed(opt.seed) From 36cb663ff20cce6d9ee1255d2c02e6b627f19132 Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Fri, 27 Aug 2021 14:42:14 -0400 Subject: [PATCH 15/26] updated clearing progress bar behavior --- utils.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/utils.py b/utils.py index ec97c89..7454037 100755 --- a/utils.py +++ b/utils.py @@ -289,13 +289,15 @@ def batch_flatten(x): def clear_progressbar(): - # moves up 3 lines - print("\033[2A") - # deletes the whole line, regardless of character position - print("\033[2K") - # moves up two lines again - print("\033[2A") - + # # moves up 3 lines + # print("\033[2A") + # # deletes the whole line, regardless of character position + # print("\033[2K") + # # moves up two lines again + # print("\033[2A") + print('\r') + print(' ' * 80) + print('\r') def mse_metric(x1, x2): err = np.sum((x1 - x2) ** 2) From 360b20a3ad7f857f57c41d182675406bc6e26ce6 Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Fri, 27 Aug 2021 15:02:08 -0400 Subject: [PATCH 16/26] updated clearing progress bar behavior --- train_baseline_collision.py | 4 ++++ train_baseline_gravity.py | 4 ++++ train_baseline_object_permanence.py | 2 ++ train_baseline_shapeconstancy.py | 4 ++++ train_baseline_spatialTemporalContinuity.py | 4 ++++ train_svg_nonstochastic_posterior.py | 4 ++++ 6 files changed, 22 insertions(+) diff --git a/train_baseline_collision.py b/train_baseline_collision.py index e348ec9..d70aa2a 100644 --- a/train_baseline_collision.py +++ b/train_baseline_collision.py @@ -64,6 +64,7 @@ batch_size = opt.batch_size n_future = opt.n_future n_eval = opt.n_eval + data_root = opt.data_root opt = saved_model['opt'] opt.niter = niter # update number of epochs to train for opt.model_dir = model_dir @@ -71,6 +72,7 @@ opt.lr = lr opt.batch_size = batch_size opt.n_eval = n_eval + opt.data_root = data_root opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) else: name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) @@ -108,6 +110,8 @@ opt.optimizer = optim.RMSprop elif opt.optimizer == 'sgd': opt.optimizer = optim.SGD +elif isinstance(opt.optimizer, type): + pass else: raise ValueError('Unknown optimizer: %s' % opt.optimizer) diff --git a/train_baseline_gravity.py b/train_baseline_gravity.py index a2adea4..a71c090 100644 --- a/train_baseline_gravity.py +++ b/train_baseline_gravity.py @@ -64,6 +64,7 @@ batch_size = opt.batch_size n_future = opt.n_future n_eval = opt.n_eval + data_root = opt.data_root opt = saved_model['opt'] opt.niter = niter # update number of epochs to train for opt.model_dir = model_dir @@ -71,6 +72,7 @@ opt.lr = lr opt.batch_size = batch_size opt.n_eval = n_eval + opt.data_root = data_root opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) else: name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) @@ -107,6 +109,8 @@ opt.optimizer = optim.RMSprop elif opt.optimizer == 'sgd': opt.optimizer = optim.SGD +elif isinstance(opt.optimizer, type): + pass else: raise ValueError('Unknown optimizer: %s' % opt.optimizer) diff --git a/train_baseline_object_permanence.py b/train_baseline_object_permanence.py index 453aee2..5795e29 100644 --- a/train_baseline_object_permanence.py +++ b/train_baseline_object_permanence.py @@ -64,6 +64,7 @@ batch_size = opt.batch_size n_future = opt.n_future n_eval = opt.n_eval + data_root = opt.data_root opt = saved_model['opt'] opt.niter = niter # update number of epochs to train for opt.model_dir = model_dir @@ -71,6 +72,7 @@ opt.lr = lr opt.batch_size = batch_size opt.n_eval = n_eval + opt.data_root = data_root opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) else: name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) diff --git a/train_baseline_shapeconstancy.py b/train_baseline_shapeconstancy.py index 8122238..8ad4480 100644 --- a/train_baseline_shapeconstancy.py +++ b/train_baseline_shapeconstancy.py @@ -64,6 +64,7 @@ batch_size = opt.batch_size n_future = opt.n_future n_eval = opt.n_eval + data_root = opt.data_root opt = saved_model['opt'] opt.niter = niter # update number of epochs to train for opt.model_dir = model_dir @@ -71,6 +72,7 @@ opt.lr = lr opt.batch_size = batch_size opt.n_eval = n_eval + opt.data_root = data_root opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) else: name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) @@ -108,6 +110,8 @@ opt.optimizer = optim.RMSprop elif opt.optimizer == 'sgd': opt.optimizer = optim.SGD +elif isinstance(opt.optimizer, type): + pass else: raise ValueError('Unknown optimizer: %s' % opt.optimizer) diff --git a/train_baseline_spatialTemporalContinuity.py b/train_baseline_spatialTemporalContinuity.py index a04a986..2e482fa 100644 --- a/train_baseline_spatialTemporalContinuity.py +++ b/train_baseline_spatialTemporalContinuity.py @@ -64,6 +64,7 @@ batch_size = opt.batch_size n_future = opt.n_future n_eval = opt.n_eval + data_root = opt.data_root opt = saved_model['opt'] opt.niter = niter # update number of epochs to train for opt.model_dir = model_dir @@ -71,6 +72,7 @@ opt.lr = lr opt.batch_size = batch_size opt.n_eval = n_eval + opt.data_root = data_root opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) else: name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) @@ -108,6 +110,8 @@ opt.optimizer = optim.RMSprop elif opt.optimizer == 'sgd': opt.optimizer = optim.SGD +elif isinstance(opt.optimizer, type): + pass else: raise ValueError('Unknown optimizer: %s' % opt.optimizer) diff --git a/train_svg_nonstochastic_posterior.py b/train_svg_nonstochastic_posterior.py index 5227607..5008401 100644 --- a/train_svg_nonstochastic_posterior.py +++ b/train_svg_nonstochastic_posterior.py @@ -61,6 +61,7 @@ batch_size = opt.batch_size n_future = opt.n_future n_eval = opt.n_eval + data_root = opt.data_root opt = saved_model['opt'] opt.niter = niter # update number of epochs to train for opt.model_dir = model_dir @@ -68,6 +69,7 @@ opt.lr = lr opt.batch_size = batch_size opt.n_eval = n_eval + opt.data_root = data_root opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) else: name = 'RGBmodel=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) @@ -105,6 +107,8 @@ opt.optimizer = optim.RMSprop elif opt.optimizer == 'sgd': opt.optimizer = optim.SGD +elif isinstance(opt.optimizer, type): + pass else: raise ValueError('Unknown optimizer: %s' % opt.optimizer) From 6416c5e3f9ee2a1d0e5cc8e30b9a5a4f9a552424 Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Sun, 29 Aug 2021 11:42:55 -0400 Subject: [PATCH 17/26] updated msc task --- train_baseline_shapeconstancy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_baseline_shapeconstancy.py b/train_baseline_shapeconstancy.py index 8ad4480..e7d0732 100644 --- a/train_baseline_shapeconstancy.py +++ b/train_baseline_shapeconstancy.py @@ -30,7 +30,7 @@ parser.add_argument('--channels', default=1, type=int, help='number of channels for input images. ') parser.add_argument('--use_edge_kernels', default=True, type=bool, help='whether to use edge kernels to reduce to 1 channel') parser.add_argument('--dataset', default='mcs', help='dataset to train with') -parser.add_argument('--mcs_task', default='GravitySupportTraining', help='mcs task') +parser.add_argument('--mcs_task', default='ShapeConstancyTraining', help='mcs task') parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') parser.add_argument('--n_future', type=int, default=30, help='number of frames to predict') parser.add_argument('--n_eval', type=int, default=35, help='number of frames to predict at eval time') From 42f710ca631a1e861cf557d5e5de3e7f57bdb8e1 Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Sun, 5 Sep 2021 20:06:03 -0400 Subject: [PATCH 18/26] Update README.md --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 934d2e3..e2177ff 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,6 @@ +# Fred Lu's fork of "Stochastic Video Generation with a Learned Prior" +I modified Emily Denton and Rob Fergus' paper [Stochastic Video Generation with a Learned Prior](https://arxiv.org/abs/1802.07687) to build a predictive model for the [Machine Common Sense](https://github.com/NextCenturyCorporation/MCS) project that observes plausible physics events and uses the prediction it makes to detect if an unseen physics event was plausible and generate a heatmap of implausible regions if the event is deemed implausible. This repo only contains the training code. The evaluation process has additional algorithms that are not yet included here. + # Stochastic Video Generation with a Learned Prior This is code for the paper [Stochastic Video Generation with a Learned Prior](https://arxiv.org/abs/1802.07687) by Emily Denton and Rob Fergus. See the [project page](https://sites.google.com/view/svglp/) for details and generated video sequences. From 3edb491fec0f36c67aa896bd29c7fc669fffc66d Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Wed, 8 Sep 2021 19:14:51 -0400 Subject: [PATCH 19/26] Gravity, collision, and continuity baselines done. Working on two remaining tasks --- _train_baseline_object_permanence.py | 448 ++++++++++++++++++++++++ _train_baseline_shapeconstancy.py | 437 +++++++++++++++++++++++ data/mcs.py | 42 ++- do_mcs_implausblility_test_posterior.py | 47 ++- train_baseline_object_permanence.py | 5 +- utils.py | 14 +- 6 files changed, 961 insertions(+), 32 deletions(-) create mode 100644 _train_baseline_object_permanence.py create mode 100644 _train_baseline_shapeconstancy.py diff --git a/_train_baseline_object_permanence.py b/_train_baseline_object_permanence.py new file mode 100644 index 0000000..5795e29 --- /dev/null +++ b/_train_baseline_object_permanence.py @@ -0,0 +1,448 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.0002, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=10, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=40, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=500, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int, help='number of channels for input images. ') +parser.add_argument('--use_edge_kernels', default=True, type=bool, help='whether to use edge kernels to reduce to 1 channel') +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='ObjectPermanenceTraining4', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=45, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=50, help='number of frames to predict at eval time') +parser.add_argument('--start_min', type=int, default=65, help='min starting time for sampling sequence (0-indexed)') +parser.add_argument('--start_max', type=int, default=85, help='max starting time for sampling sequence (0-indexed)') +parser.add_argument('--sequence_stride', type=int, default=3, help='factor for sequence temporal subsampling (int)') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--prior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=32, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') +parser.add_argument('--model', default='vgg', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') + + +opt = parser.parse_args() +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + lr = opt.lr + batch_size = opt.batch_size + n_future = opt.n_future + n_eval = opt.n_eval + data_root = opt.data_root + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.n_future = n_future + opt.lr = lr + opt.batch_size = batch_size + opt.n_eval = n_eval + opt.data_root = data_root + opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) +else: + name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + elif opt.dataset == 'mcs': + opt.log_dir = '%s/%s/%s/%s' % (opt.log_dir, opt.dataset, opt.mcs_task, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=True) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=True) +with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: + opt2 = opt.__dict__.copy() + if isinstance(opt2['optimizer'], type): + opt2['optimizer'] = str(opt2['optimizer']) + json.dump(opt2, f, indent=2) + del opt2 + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +elif isinstance(opt.optimizer, type): + pass +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = opt.batch_size + prior = saved_model['prior'] + prior.batch_size = opt.batch_size + posterior = saved_model['posterior'] + posterior.batch_size = opt.batch_size +else: + frame_predictor = lstm_models.lstm(opt.g_dim + opt.z_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + posterior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.posterior_rnn_layers, opt.batch_size) + prior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.prior_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + posterior.apply(utils.init_weights) + prior.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, 1) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +prior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True,) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [utils.torch_rgb_img_to_gray(x[t]) for t in range(len(x))] + + # h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + gen_seq[s].append(utils.torch_rgb_img_to_gray(x[0])) + x_in = x[0] + for i in range(1, opt.n_eval): + with torch.no_grad(): + # if input is grayscale + if x_in.shape[1] == 1 and opt.channels == 3: + h = encoder(torch.cat(3*[x_in], dim=1)) # convert to RGB + else: + h = encoder(x_in) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + if i < opt.n_past: + h_target = encoder(x[i]) + z_t = posterior(h_target[0].detach()) + prior(h) + frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[i] + gen_seq[s].append(utils.torch_rgb_img_to_gray(x_in)) + else: + z_t_hat = prior(h) + h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_in = decoder([h, skip]) + gen_seq[s].append(x_in) + + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + gen_seq = [utils.torch_rgb_img_to_gray(x[0])] + gen_seq_post = [utils.torch_rgb_img_to_gray(x[0])] + + # prediction using posterior Z + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + z_t = posterior(h_target[0]) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t], 1)) + gen_seq_post.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq_post.append(x_pred) + h = h_target + + # prediction using prior Z + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + z_t_hat = prior(h) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t_hat], 1)) + gen_seq.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + h = h_target + + to_plot = [] + nrow = min(opt.batch_size, 25) + x_gray = [utils.torch_rgb_img_to_gray(x[t]) for t in range(min(opt.n_eval, opt.n_past+opt.n_future))] + for i in range(nrow): + row_gt = [] + row_post = [] + row = [] + for t in range(min(opt.n_eval, opt.n_past+opt.n_future)): + row_gt.append(x_gray[t][i]) + row_post.append(gen_seq_post[t][i]) + row.append(gen_seq[t][i]) + to_plot.append(row_gt) + to_plot.append(row_post) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + posterior.zero_grad() + prior.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + + mse = 0 + mse_residual = 0 + # x: T x B x C x H x W + x_diff = [1] + for i in range(1, len(x)): + diff = torch.abs(x[i] - x[i - 1]) + diff = torch.mean(diff, dim=(1, 2, 3)).detach() # mean over channels, width, and height + x_diff.append(diff) + + # print(i, x_diff > 1e-6) + h = encoder(x[0]) + for i in range(1, opt.n_past+opt.n_future): + h_target = encoder(x[i]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h = h[0] + + z_t = posterior(h_target[0]) + z_t_hat = prior(h) + h_pred = frame_predictor(torch.cat([h, z_t], 1)) + x_pred = decoder([h_pred, skip]) + gray_target_frame = utils.torch_rgb_img_to_gray(x[i]) + + still_frames = x_diff[i] <= 5e-6 + weights = torch.pow(0.05, still_frames).detach()[:, None, None, None] + mse += mse_criterion(weights * x_pred, weights * gray_target_frame) + # penalize prior for being far from posterior + mse_residual += opt.gamma * torch.mean(weights * torch.square(z_t.detach() - z_t_hat)) + h = h_target + + loss = mse + mse_residual + loss.backward() + + frame_predictor_optimizer.step() + posterior_optimizer.step() + prior_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + N = opt.n_past+opt.n_future + return mse.data.cpu().numpy()/N, mse_residual.data.cpu().numpy()/N + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + posterior.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_mse_residual = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + # opt.epoch_size = 10 + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse, mse_residual = train(x) + epoch_mse += mse + epoch_mse_residual += mse_residual + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + posterior.eval() + prior.eval() + encoder.eval() + decoder.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'posterior': posterior, + 'prior': prior, + 'opt': opt}, + '%s/model_e%02d.pth' % (opt.log_dir, epoch)) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + diff --git a/_train_baseline_shapeconstancy.py b/_train_baseline_shapeconstancy.py new file mode 100644 index 0000000..8ad4480 --- /dev/null +++ b/_train_baseline_shapeconstancy.py @@ -0,0 +1,437 @@ +import glob + +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.0002, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=13, type=int, help='batch size') +parser.add_argument('--log_dir', default='logs/nonstochastic_posterior', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='data', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=40, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=500, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=1, type=int, help='number of channels for input images. ') +parser.add_argument('--use_edge_kernels', default=True, type=bool, help='whether to use edge kernels to reduce to 1 channel') +parser.add_argument('--dataset', default='mcs', help='dataset to train with') +parser.add_argument('--mcs_task', default='GravitySupportTraining', help='mcs task') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=30, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=35, help='number of frames to predict at eval time') +parser.add_argument('--start_min', type=int, default=79, help='min starting time for sampling sequence (0-indexed)') +parser.add_argument('--start_max', type=int, default=89, help='max starting time for sampling sequence (0-indexed)') +parser.add_argument('--sequence_stride', type=int, default=2, help='factor for sequence temporal subsampling (int)') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--prior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') +parser.add_argument('--predictor_rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=32, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') +parser.add_argument('--model', default='vgg', help='model type (dcgan | vgg)') +parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') + + +opt = parser.parse_args() +saved_model = None +if opt.model_dir != '': + models = glob.glob(f'{opt.model_dir}/model_*.pth') + latest_model = sorted(models, key=lambda s: int(s[s.rfind('_e') + 2: s.rfind('.pth')]), reverse=True)[0] + print('Loading model ', latest_model) + saved_model = torch.load(latest_model) + model_dir = opt.model_dir + niter = opt.niter + lr = opt.lr + batch_size = opt.batch_size + n_future = opt.n_future + n_eval = opt.n_eval + data_root = opt.data_root + opt = saved_model['opt'] + opt.niter = niter # update number of epochs to train for + opt.model_dir = model_dir + opt.n_future = n_future + opt.lr = lr + opt.batch_size = batch_size + opt.n_eval = n_eval + opt.data_root = data_root + opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) +else: + name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + elif opt.dataset == 'mcs': + opt.log_dir = '%s/%s/%s/%s' % (opt.log_dir, opt.dataset, opt.mcs_task, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=False) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=False) +with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: + opt2 = opt.__dict__.copy() + if isinstance(opt2['optimizer'], type): + opt2['optimizer'] = str(opt2['optimizer']) + json.dump(opt2, f, indent=2) + del opt2 + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +elif isinstance(opt.optimizer, type): + pass +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + frame_predictor.batch_size = opt.batch_size + prior = saved_model['prior'] + prior.batch_size = opt.batch_size + posterior = saved_model['posterior'] + posterior.batch_size = opt.batch_size +else: + frame_predictor = lstm_models.lstm(opt.g_dim + opt.z_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, opt.batch_size) + posterior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.posterior_rnn_layers, opt.batch_size) + prior = lstm_models.lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.prior_rnn_layers, opt.batch_size) + frame_predictor.apply(utils.init_weights) + posterior.apply(utils.init_weights) + prior.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, 1) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +prior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True,) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + + nsample = 1 + gen_seq = [[] for _ in range(nsample)] + gt_seq = [utils.torch_rgb_img_to_gray(x[t]) for t in range(len(x))] + + # h_seq = [encoder(x[i]) for i in range(opt.n_past)] + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + gen_seq[s].append(utils.torch_rgb_img_to_gray(x[0])) + x_in = x[0] + for i in range(1, opt.n_eval): + with torch.no_grad(): + # if input is grayscale + if x_in.shape[1] == 1 and opt.channels == 3: + h = encoder(torch.cat(3*[x_in], dim=1)) # convert to RGB + else: + h = encoder(x_in) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + if i < opt.n_past: + h_target = encoder(x[i]) + z_t = posterior(h_target[0].detach()) + prior(h) + frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[i] + gen_seq[s].append(utils.torch_rgb_img_to_gray(x_in)) + else: + z_t_hat = prior(h) + h = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_in = decoder([h, skip]) + gen_seq[s].append(x_in) + + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 25) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + for s in range(nsample): + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for s in range(nsample): + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + gen_seq = [utils.torch_rgb_img_to_gray(x[0])] + gen_seq_post = [utils.torch_rgb_img_to_gray(x[0])] + + # prediction using posterior Z + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + z_t = posterior(h_target[0]) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t], 1)) + gen_seq_post.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq_post.append(x_pred) + h = h_target + + # prediction using prior Z + frame_predictor.hidden = frame_predictor.init_hidden() + prior.hidden = prior.init_hidden() + h = encoder(x[0]) + h = (h[0].detach(), h[1]) + for i in range(1, min(opt.n_eval, opt.n_past+opt.n_future)): + h_target = encoder(x[i]) + h_target = (h_target[0].detach(), h_target[1]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + z_t_hat = prior(h) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t_hat], 1)) + gen_seq.append(utils.torch_rgb_img_to_gray(x[i])) + else: + h_pred = frame_predictor(torch.cat([h, z_t_hat], 1)).detach() + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + h = h_target + + to_plot = [] + nrow = min(opt.batch_size, 25) + x_gray = [utils.torch_rgb_img_to_gray(x[t]) for t in range(min(opt.n_eval, opt.n_past+opt.n_future))] + for i in range(nrow): + row_gt = [] + row_post = [] + row = [] + for t in range(min(opt.n_eval, opt.n_past+opt.n_future)): + row_gt.append(x_gray[t][i]) + row_post.append(gen_seq_post[t][i]) + row.append(gen_seq[t][i]) + to_plot.append(row_gt) + to_plot.append(row_post) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + posterior.zero_grad() + prior.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + + mse = 0 + mse_residual = 0 + h = encoder(x[0]) + for i in range(1, opt.n_past+opt.n_future): + h_target = encoder(x[i]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h = h[0] + + z_t = posterior(h_target[0]) + z_t_hat = prior(h) + h_pred = frame_predictor(torch.cat([h, z_t], 1)) + x_pred = decoder([h_pred, skip]) + gray_target_frame = utils.torch_rgb_img_to_gray(x[i]) + mse += mse_criterion(x_pred, gray_target_frame) + # penalize prior for being far from posterior + mse_residual += opt.gamma * torch.mean(torch.square(z_t.detach() - z_t_hat)) + h = h_target + + loss = mse + mse_residual + loss.backward() + + frame_predictor_optimizer.step() + posterior_optimizer.step() + prior_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + N = opt.n_past+opt.n_future + return mse.data.cpu().numpy()/N, mse_residual.data.cpu().numpy()/N + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + posterior.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_mse_residual = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + # opt.epoch_size = 10 + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse, mse_residual = train(x) + epoch_mse += mse + epoch_mse_residual += mse_residual + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + posterior.eval() + prior.eval() + encoder.eval() + decoder.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'posterior': posterior, + 'prior': prior, + 'opt': opt}, + '%s/model_e%02d.pth' % (opt.log_dir, epoch)) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + diff --git a/data/mcs.py b/data/mcs.py index 3d0fe68..edb9f8c 100644 --- a/data/mcs.py +++ b/data/mcs.py @@ -14,7 +14,8 @@ class MCS(object): def __init__(self, train, data_root, seq_len=20, image_size=64, task='ALL', sequential=None, implausible=False, - test_set=False, im_channels=1, use_edge_kernels=True, labels=False, start_min=None, start_max=None, sequence_stride=None): + test_set=False, im_channels=1, use_edge_kernels=True, labels=False, start_min=None, start_max=None, sequence_stride=None, + reduce_static_frames=False): # if implausible is set to True, generates "fake" images by cutting out or repeating frames self.implausible = implausible if test_set: @@ -34,6 +35,8 @@ def __init__(self, train, data_root, seq_len=20, image_size=64, task='ALL', sequ self.start_min = start_min self.start_max = start_max self.sequence_stride = sequence_stride + self.reduce_static_frames = reduce_static_frames + self.motion_threshold = 0.001 # print('mcs.py: found tasks ', self.tasks) self.video_folder = {} @@ -81,8 +84,16 @@ def get_sequence(self, idx=None): assert start_min <= start_max start = random.randint(start_min, start_max) seq = [] + last_im = None + first_movement = None + first_static = None + second_movement = None # choose a random subsequence of frames in the selected video - for i in range(start, start + self.seq_len * stride, stride): + if not self.reduce_static_frames: + end = start + self.seq_len * stride + else: + end = num_frames + for i in range(start, end, stride): # i is 0-indexed so we need to add 1 to i fname = frame_path + f'{i + 1:04d}.png' im = imageio.imread(fname) / np.float32(255.) @@ -112,8 +123,33 @@ def get_sequence(self, idx=None): edge_map /= 3 # 3 channels edge_map /= 12 # to reduce magnitude im = edge_map[..., np.newaxis] - + if self.reduce_static_frames and last_im is not None: + motion_magnitude = np.mean(cv2.absdiff(im, last_im)) * 256 + print(motion_magnitude) + if first_movement is None and motion_magnitude > self.motion_threshold: + first_movement = i + elif first_movement is not None and motion_magnitude <= self.motion_threshold: + first_static = i + elif first_static is not None and motion_magnitude > self.motion_threshold: + second_movement = i + last_im = im seq.append(im) + if self.reduce_static_frames: + new_seq = [] + + # keep the first and last static frames, then add the frames before and after to fill the sequence + len_before = (self.seq_len - 2) // 2 + len_after = (self.seq_len - 2) - len_before + new_seq += seq[first_static - start - len_before: first_static - start + 1] + new_seq += seq[second_movement - 1 - start: second_movement - start + len_after] + + print(first_static, second_movement) + print(first_static - start - len_before, first_static - start + 1) + print(second_movement - 1 - start, second_movement - start + len_after) + print(len(new_seq), len(seq)) + for img in new_seq: + cv2.imshow('img', img[:, :, 0]) + cv2.waitKey(0) if self.labels: return np.array(seq), label else: diff --git a/do_mcs_implausblility_test_posterior.py b/do_mcs_implausblility_test_posterior.py index a4eb013..27a8ce4 100644 --- a/do_mcs_implausblility_test_posterior.py +++ b/do_mcs_implausblility_test_posterior.py @@ -31,7 +31,7 @@ parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') parser.add_argument('--channels', default=1, type=int) parser.add_argument('--use_edge_kernels', action='store_true') -parser.add_argument('--dataset', default='mcs_test', help='dataset to train with') +parser.add_argument('--dataset', default='mcs', help='dataset to train with') parser.add_argument('--mcs_task', default='SpatioTemporalContinuityTraining4', help='mcs task') parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') parser.add_argument('--n_future', type=int, default=195, help='number of frames to predict') @@ -49,7 +49,7 @@ parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') opt = parser.parse_args() -BATCH_SIZE = opt.batch_size +BATCH_SIZE = 1 saved_model = None if opt.model_dir != '': models = glob.glob(f'{opt.model_dir}/model_*.pth') @@ -60,12 +60,24 @@ dataset = opt.dataset mcs_task = opt.mcs_task n_future = opt.n_future + data_root = opt.data_root opt = saved_model['opt'] opt.batch_size = BATCH_SIZE opt.niter = niter # update number of epochs to train for opt.dataset = dataset opt.mcs_task = mcs_task opt.n_future = n_future + opt.data_root = data_root + opt.start_min = 0 + opt.start_max = None + frame_predictor = saved_model['frame_predictor'].cuda() + frame_predictor.batch_size = BATCH_SIZE + posterior = saved_model['posterior'].cuda() + posterior.batch_size = BATCH_SIZE + prior = saved_model['prior'].cuda() + prior.batch_size = BATCH_SIZE + decoder = saved_model['decoder'].cuda() + encoder = saved_model['encoder'].cuda() else: raise ValueError("Please specify the model to load with the --model_dir argument") @@ -83,16 +95,6 @@ import models.lstm as lstm_models -if opt.model_dir != '': - frame_predictor = saved_model['frame_predictor'] - frame_predictor.batch_size = BATCH_SIZE - posterior = saved_model['posterior'] - posterior.batch_size = BATCH_SIZE - prior = saved_model['prior'] - prior.batch_size = BATCH_SIZE -else: - raise ValueError('Please specify --model_dir') - if opt.model == 'dcgan': if opt.image_width == 64: import models.dcgan_64 as model @@ -106,12 +108,6 @@ else: raise ValueError('Unknown model: %s' % opt.model) -if opt.model_dir != '': - decoder = saved_model['decoder'] - encoder = saved_model['encoder'] -else: - raise ValueError("Please specify the model to load with the --model_dir argument") - # --------- loss functions ------------------------------------ mse_criterion = nn.MSELoss() @@ -123,12 +119,12 @@ def kl_criterion(mu, logvar): return KLD -# --------- transfer to gpu ------------------------------------ -frame_predictor.cuda() -posterior.cuda() -encoder.cuda() -decoder.cuda() -mse_criterion.cuda() +# # --------- transfer to gpu ------------------------------------ +# frame_predictor.cuda() +# posterior.cuda() +# encoder.cuda() +# decoder.cuda() +# mse_criterion.cuda() opt.batch_size = BATCH_SIZE opt.epoch_size = 1000 @@ -317,6 +313,7 @@ def do_implasubility_test(z_residual_mean, z_residual_cov, visualize=True): last_pred = None frame_predictor.hidden = frame_predictor.init_hidden() posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() start = 1 for j in range(start, opt.n_past + opt.n_future): h_target = h_posterior[j][0].detach() @@ -377,7 +374,7 @@ def do_implasubility_test(z_residual_mean, z_residual_cov, visualize=True): if visualize: for j in range(len(frames)): frame_cv2 = frames[j][0][0].numpy() - frame_cv2 /= 3 + # frame_cv2 /= 3 frame_cv2 = np.uint8(np.minimum(frame_cv2, 1.0) * 255.) cv2.imshow('frame', frame_cv2) cv2.waitKey(15) diff --git a/train_baseline_object_permanence.py b/train_baseline_object_permanence.py index 5795e29..7af3e82 100644 --- a/train_baseline_object_permanence.py +++ b/train_baseline_object_permanence.py @@ -36,7 +36,8 @@ parser.add_argument('--n_eval', type=int, default=50, help='number of frames to predict at eval time') parser.add_argument('--start_min', type=int, default=65, help='min starting time for sampling sequence (0-indexed)') parser.add_argument('--start_max', type=int, default=85, help='max starting time for sampling sequence (0-indexed)') -parser.add_argument('--sequence_stride', type=int, default=3, help='factor for sequence temporal subsampling (int)') +parser.add_argument('--sequence_stride', type=int, default=1, help='factor for sequence temporal subsampling (int)') +parser.add_argument('--reduce_static_frames', type=bool, default=True, help='reduce number of static frames') parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') parser.add_argument('--prior_rnn_layers', type=int, default=1, help='number of layers') parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') @@ -75,7 +76,7 @@ opt.data_root = data_root opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) else: - name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + name = 'newmodel=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) if opt.dataset == 'smmnist': opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) elif opt.dataset == 'mcs': diff --git a/utils.py b/utils.py index 7454037..4a3a84e 100755 --- a/utils.py +++ b/utils.py @@ -110,7 +110,8 @@ def load_dataset(opt, sequential=None, implausible=None): use_edge_kernels=opt.use_edge_kernels, start_min=opt.start_min, start_max=opt.start_max, - sequence_stride=opt.sequence_stride) + sequence_stride=opt.sequence_stride, + reduce_static_frames=opt.reduce_static_frames) test_data = MCS( train=False, data_root=opt.data_root, @@ -123,7 +124,8 @@ def load_dataset(opt, sequential=None, implausible=None): use_edge_kernels=opt.use_edge_kernels, start_min=opt.start_min, start_max=opt.start_max, - sequence_stride=opt.sequence_stride) + sequence_stride=opt.sequence_stride, + reduce_static_frames=opt.reduce_static_frames) elif opt.dataset == 'mcs_test': from data.mcs import MCS train_data = MCS( @@ -137,6 +139,10 @@ def load_dataset(opt, sequential=None, implausible=None): test_set=True, im_channels=opt.channels, use_edge_kernels=opt.use_edge_kernels, + start_min=opt.start_min, + start_max=opt.start_max, + sequence_stride=opt.sequence_stride, + reduce_static_frames=opt.reduce_static_frames, labels=True) test_data = MCS( train=False, @@ -149,6 +155,10 @@ def load_dataset(opt, sequential=None, implausible=None): test_set=True, im_channels=opt.channels, use_edge_kernels=opt.use_edge_kernels, + start_min=opt.start_min, + start_max=opt.start_max, + sequence_stride=opt.sequence_stride, + reduce_static_frames=opt.reduce_static_frames, labels=True) return train_data, test_data From da198ecd7d21d3df5a85ca84fc24e0a893498bc9 Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Thu, 16 Sep 2021 22:37:06 -0400 Subject: [PATCH 20/26] Gravity, collision, and continuity baselines done. Working on two remaining tasks --- data/mcs.py | 45 +++++++++++++++++------------ train_baseline_object_permanence.py | 19 ++++++------ utils.py | 8 +++-- 3 files changed, 42 insertions(+), 30 deletions(-) diff --git a/data/mcs.py b/data/mcs.py index edb9f8c..41863e7 100644 --- a/data/mcs.py +++ b/data/mcs.py @@ -15,7 +15,7 @@ class MCS(object): def __init__(self, train, data_root, seq_len=20, image_size=64, task='ALL', sequential=None, implausible=False, test_set=False, im_channels=1, use_edge_kernels=True, labels=False, start_min=None, start_max=None, sequence_stride=None, - reduce_static_frames=False): + reduce_static_frames=False, object_exiting_frame_offset=20, lifting_frame_index=None,): # if implausible is set to True, generates "fake" images by cutting out or repeating frames self.implausible = implausible if test_set: @@ -36,7 +36,7 @@ def __init__(self, train, data_root, seq_len=20, image_size=64, task='ALL', sequ self.start_max = start_max self.sequence_stride = sequence_stride self.reduce_static_frames = reduce_static_frames - self.motion_threshold = 0.001 + self.motion_threshold = 0.1 # print('mcs.py: found tasks ', self.tasks) self.video_folder = {} @@ -125,31 +125,40 @@ def get_sequence(self, idx=None): im = edge_map[..., np.newaxis] if self.reduce_static_frames and last_im is not None: motion_magnitude = np.mean(cv2.absdiff(im, last_im)) * 256 - print(motion_magnitude) + # print(motion_magnitude) if first_movement is None and motion_magnitude > self.motion_threshold: first_movement = i - elif first_movement is not None and motion_magnitude <= self.motion_threshold: + elif first_movement is not None and first_static is None and motion_magnitude <= self.motion_threshold: first_static = i - elif first_static is not None and motion_magnitude > self.motion_threshold: - second_movement = i + # elif first_static is not None and second_movement is None and motion_magnitude > self.motion_threshold: + # second_movement = i last_im = im seq.append(im) if self.reduce_static_frames: new_seq = [] # keep the first and last static frames, then add the frames before and after to fill the sequence - len_before = (self.seq_len - 2) // 2 - len_after = (self.seq_len - 2) - len_before - new_seq += seq[first_static - start - len_before: first_static - start + 1] - new_seq += seq[second_movement - 1 - start: second_movement - start + len_after] - - print(first_static, second_movement) - print(first_static - start - len_before, first_static - start + 1) - print(second_movement - 1 - start, second_movement - start + len_after) - print(len(new_seq), len(seq)) - for img in new_seq: - cv2.imshow('img', img[:, :, 0]) - cv2.waitKey(0) + # len_before = (self.seq_len - 1) // 2 + # len_after = (self.seq_len - 1) - len_before + # new_seq += seq[first_movement - start: first_movement - start + 5] + first_static = min(165, max(start + 7, first_static)) # so don't run into negative indices + new_seq += seq[first_static - start - 7: first_static - start + 3] + assert len(new_seq) == 10 + new_seq += seq[first_static - start + 3 + 12: first_static - start + 3 + 22] + assert len(new_seq) == 20 + new_seq += seq[200 - start: 200 - start + (self.seq_len - 20)] + assert len(new_seq) == 40 + # new_seq = seq[first_movement - start:] + # new_seq += seq[first_static - start:] + # new_seq += seq[second_movement - start:] + # print(first_movement, first_static, len(new_seq), self.seq_len) + # print(first_static - start - len_before, first_static - start + 1) + # print(second_movement - 1 - start, second_movement - start + len_after) + # print(len(new_seq), len(seq)) + # for img in new_seq: + # cv2.imshow('img', img[:, :, 0]) + # cv2.waitKey(1) + seq = new_seq if self.labels: return np.array(seq), label else: diff --git a/train_baseline_object_permanence.py b/train_baseline_object_permanence.py index 7af3e82..8cc7f8e 100644 --- a/train_baseline_object_permanence.py +++ b/train_baseline_object_permanence.py @@ -31,13 +31,14 @@ parser.add_argument('--use_edge_kernels', default=True, type=bool, help='whether to use edge kernels to reduce to 1 channel') parser.add_argument('--dataset', default='mcs', help='dataset to train with') parser.add_argument('--mcs_task', default='ObjectPermanenceTraining4', help='mcs task') -parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') -parser.add_argument('--n_future', type=int, default=45, help='number of frames to predict') -parser.add_argument('--n_eval', type=int, default=50, help='number of frames to predict at eval time') -parser.add_argument('--start_min', type=int, default=65, help='min starting time for sampling sequence (0-indexed)') -parser.add_argument('--start_max', type=int, default=85, help='max starting time for sampling sequence (0-indexed)') +parser.add_argument('--n_past', type=int, default=20, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=20, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=40, help='number of frames to predict at eval time') +parser.add_argument('--start_min', type=int, default=75, help='min starting time for sampling sequence (0-indexed)') +parser.add_argument('--start_max', type=int, default=75, help='max starting time for sampling sequence (0-indexed)') parser.add_argument('--sequence_stride', type=int, default=1, help='factor for sequence temporal subsampling (int)') parser.add_argument('--reduce_static_frames', type=bool, default=True, help='reduce number of static frames') +parser.add_argument('--lifting_frame_index', type=int, default=200, help='index of frame when panels are lifted') parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') parser.add_argument('--prior_rnn_layers', type=int, default=1, help='number of layers') parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') @@ -47,7 +48,7 @@ parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') parser.add_argument('--model', default='vgg', help='model type (dcgan | vgg)') -parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') +parser.add_argument('--data_threads', type=int, default=1, help='number of data loading threads') parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') @@ -380,11 +381,9 @@ def train(x): x_pred = decoder([h_pred, skip]) gray_target_frame = utils.torch_rgb_img_to_gray(x[i]) - still_frames = x_diff[i] <= 5e-6 - weights = torch.pow(0.05, still_frames).detach()[:, None, None, None] - mse += mse_criterion(weights * x_pred, weights * gray_target_frame) + mse += mse_criterion(x_pred, gray_target_frame) # penalize prior for being far from posterior - mse_residual += opt.gamma * torch.mean(weights * torch.square(z_t.detach() - z_t_hat)) + mse_residual += opt.gamma * torch.mean(torch.square(z_t.detach() - z_t_hat)) h = h_target loss = mse + mse_residual diff --git a/utils.py b/utils.py index 4a3a84e..a3fb598 100755 --- a/utils.py +++ b/utils.py @@ -111,7 +111,8 @@ def load_dataset(opt, sequential=None, implausible=None): start_min=opt.start_min, start_max=opt.start_max, sequence_stride=opt.sequence_stride, - reduce_static_frames=opt.reduce_static_frames) + reduce_static_frames=opt.reduce_static_frames, + lifting_frame_index=opt.lifting_frame_index,) test_data = MCS( train=False, data_root=opt.data_root, @@ -125,7 +126,8 @@ def load_dataset(opt, sequential=None, implausible=None): start_min=opt.start_min, start_max=opt.start_max, sequence_stride=opt.sequence_stride, - reduce_static_frames=opt.reduce_static_frames) + reduce_static_frames=opt.reduce_static_frames, + lifting_frame_index=opt.lifting_frame_index,) elif opt.dataset == 'mcs_test': from data.mcs import MCS train_data = MCS( @@ -143,6 +145,7 @@ def load_dataset(opt, sequential=None, implausible=None): start_max=opt.start_max, sequence_stride=opt.sequence_stride, reduce_static_frames=opt.reduce_static_frames, + lifting_frame_index=opt.lifting_frame_index, labels=True) test_data = MCS( train=False, @@ -159,6 +162,7 @@ def load_dataset(opt, sequential=None, implausible=None): start_max=opt.start_max, sequence_stride=opt.sequence_stride, reduce_static_frames=opt.reduce_static_frames, + lifting_frame_index=opt.lifting_frame_index, labels=True) return train_data, test_data From fedf685b9f905814c3d6c2f2012414eb021020e5 Mon Sep 17 00:00:00 2001 From: Fred Date: Fri, 17 Sep 2021 21:11:38 -0400 Subject: [PATCH 21/26] updating obj permanence training --- _train_baseline_object_permanence.py | 3 +++ _train_baseline_shapeconstancy.py | 3 +++ train_baseline_collision.py | 3 +++ train_baseline_gravity.py | 3 +++ train_baseline_object_permanence.py | 23 +++++++-------------- train_baseline_shapeconstancy.py | 3 +++ train_baseline_spatialTemporalContinuity.py | 3 +++ 7 files changed, 26 insertions(+), 15 deletions(-) diff --git a/_train_baseline_object_permanence.py b/_train_baseline_object_permanence.py index 5795e29..e4117f0 100644 --- a/_train_baseline_object_permanence.py +++ b/_train_baseline_object_permanence.py @@ -402,6 +402,7 @@ def train(x): for epoch in range(opt.niter): frame_predictor.train() posterior.train() + prior.train() encoder.train() decoder.train() epoch_mse = 0 @@ -422,6 +423,8 @@ def train(x): utils.clear_progressbar() print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) # plot some stuff frame_predictor.eval() diff --git a/_train_baseline_shapeconstancy.py b/_train_baseline_shapeconstancy.py index 8ad4480..424ec0b 100644 --- a/_train_baseline_shapeconstancy.py +++ b/_train_baseline_shapeconstancy.py @@ -391,6 +391,7 @@ def train(x): for epoch in range(opt.niter): frame_predictor.train() posterior.train() + prior.train() encoder.train() decoder.train() epoch_mse = 0 @@ -411,6 +412,8 @@ def train(x): utils.clear_progressbar() print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) # plot some stuff frame_predictor.eval() diff --git a/train_baseline_collision.py b/train_baseline_collision.py index d70aa2a..a521217 100644 --- a/train_baseline_collision.py +++ b/train_baseline_collision.py @@ -391,6 +391,7 @@ def train(x): for epoch in range(opt.niter): frame_predictor.train() posterior.train() + prior.train() encoder.train() decoder.train() epoch_mse = 0 @@ -411,6 +412,8 @@ def train(x): utils.clear_progressbar() print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) # plot some stuff frame_predictor.eval() diff --git a/train_baseline_gravity.py b/train_baseline_gravity.py index a71c090..e712556 100644 --- a/train_baseline_gravity.py +++ b/train_baseline_gravity.py @@ -390,6 +390,7 @@ def train(x): for epoch in range(opt.niter): frame_predictor.train() posterior.train() + prior.train() encoder.train() decoder.train() epoch_mse = 0 @@ -410,6 +411,8 @@ def train(x): utils.clear_progressbar() print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) # plot some stuff frame_predictor.eval() diff --git a/train_baseline_object_permanence.py b/train_baseline_object_permanence.py index 8cc7f8e..6d6e00d 100644 --- a/train_baseline_object_permanence.py +++ b/train_baseline_object_permanence.py @@ -35,7 +35,7 @@ parser.add_argument('--n_future', type=int, default=20, help='number of frames to predict') parser.add_argument('--n_eval', type=int, default=40, help='number of frames to predict at eval time') parser.add_argument('--start_min', type=int, default=75, help='min starting time for sampling sequence (0-indexed)') -parser.add_argument('--start_max', type=int, default=75, help='max starting time for sampling sequence (0-indexed)') +parser.add_argument('--start_max', type=int, default=77, help='max starting time for sampling sequence (0-indexed)') parser.add_argument('--sequence_stride', type=int, default=1, help='factor for sequence temporal subsampling (int)') parser.add_argument('--reduce_static_frames', type=bool, default=True, help='reduce number of static frames') parser.add_argument('--lifting_frame_index', type=int, default=200, help='index of frame when panels are lifted') @@ -48,7 +48,7 @@ parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') parser.add_argument('--model', default='vgg', help='model type (dcgan | vgg)') -parser.add_argument('--data_threads', type=int, default=1, help='number of data loading threads') +parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') @@ -77,7 +77,7 @@ opt.data_root = data_root opt.log_dir = '%s/continued_lr%s' % (opt.log_dir, opt.lr) else: - name = 'newmodel=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) + name = 'model=%s%dx%d-rnn_size=%d-predictor-posterior-rnn_layers=%d-%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f-gamma=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.predictor_rnn_layers, opt.posterior_rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.gamma, opt.name) if opt.dataset == 'smmnist': opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) elif opt.dataset == 'mcs': @@ -85,8 +85,8 @@ else: opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) -os.makedirs('%s/gen/' % opt.log_dir, exist_ok=True) -os.makedirs('%s/plots/' % opt.log_dir, exist_ok=True) +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=False) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=False) with open(os.path.join(opt.log_dir, 'opt.json'), 'w') as f: opt2 = opt.__dict__.copy() if isinstance(opt2['optimizer'], type): @@ -359,14 +359,6 @@ def train(x): mse = 0 mse_residual = 0 - # x: T x B x C x H x W - x_diff = [1] - for i in range(1, len(x)): - diff = torch.abs(x[i] - x[i - 1]) - diff = torch.mean(diff, dim=(1, 2, 3)).detach() # mean over channels, width, and height - x_diff.append(diff) - - # print(i, x_diff > 1e-6) h = encoder(x[0]) for i in range(1, opt.n_past+opt.n_future): h_target = encoder(x[i]) @@ -380,7 +372,6 @@ def train(x): h_pred = frame_predictor(torch.cat([h, z_t], 1)) x_pred = decoder([h_pred, skip]) gray_target_frame = utils.torch_rgb_img_to_gray(x[i]) - mse += mse_criterion(x_pred, gray_target_frame) # penalize prior for being far from posterior mse_residual += opt.gamma * torch.mean(torch.square(z_t.detach() - z_t_hat)) @@ -402,6 +393,7 @@ def train(x): for epoch in range(opt.niter): frame_predictor.train() posterior.train() + prior.train() encoder.train() decoder.train() epoch_mse = 0 @@ -422,7 +414,8 @@ def train(x): utils.clear_progressbar() print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) - + with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) # plot some stuff frame_predictor.eval() posterior.eval() diff --git a/train_baseline_shapeconstancy.py b/train_baseline_shapeconstancy.py index e7d0732..e392a8f 100644 --- a/train_baseline_shapeconstancy.py +++ b/train_baseline_shapeconstancy.py @@ -391,6 +391,7 @@ def train(x): for epoch in range(opt.niter): frame_predictor.train() posterior.train() + prior.train() encoder.train() decoder.train() epoch_mse = 0 @@ -411,6 +412,8 @@ def train(x): utils.clear_progressbar() print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) # plot some stuff frame_predictor.eval() diff --git a/train_baseline_spatialTemporalContinuity.py b/train_baseline_spatialTemporalContinuity.py index 2e482fa..9dfcf6f 100644 --- a/train_baseline_spatialTemporalContinuity.py +++ b/train_baseline_spatialTemporalContinuity.py @@ -391,6 +391,7 @@ def train(x): for epoch in range(opt.niter): frame_predictor.train() posterior.train() + prior.train() encoder.train() decoder.train() epoch_mse = 0 @@ -411,6 +412,8 @@ def train(x): utils.clear_progressbar() print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) # plot some stuff frame_predictor.eval() From 1bf52e99a54f287a6bc0547b9af766a965a0a805 Mon Sep 17 00:00:00 2001 From: Fred Date: Fri, 17 Sep 2021 21:50:15 -0400 Subject: [PATCH 22/26] updating obj permanence training --- _train_baseline_object_permanence.py | 2 +- _train_baseline_shapeconstancy.py | 2 +- train_baseline_collision.py | 2 +- train_baseline_gravity.py | 2 +- train_baseline_object_permanence.py | 2 +- train_baseline_shapeconstancy.py | 2 +- train_baseline_spatialTemporalContinuity.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/_train_baseline_object_permanence.py b/_train_baseline_object_permanence.py index e4117f0..c6f4a3b 100644 --- a/_train_baseline_object_permanence.py +++ b/_train_baseline_object_permanence.py @@ -424,7 +424,7 @@ def train(x): print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: - f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)\n' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) # plot some stuff frame_predictor.eval() diff --git a/_train_baseline_shapeconstancy.py b/_train_baseline_shapeconstancy.py index 424ec0b..7d20951 100644 --- a/_train_baseline_shapeconstancy.py +++ b/_train_baseline_shapeconstancy.py @@ -413,7 +413,7 @@ def train(x): print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: - f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)\n' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) # plot some stuff frame_predictor.eval() diff --git a/train_baseline_collision.py b/train_baseline_collision.py index a521217..9d28b00 100644 --- a/train_baseline_collision.py +++ b/train_baseline_collision.py @@ -413,7 +413,7 @@ def train(x): print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: - f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)\n' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) # plot some stuff frame_predictor.eval() diff --git a/train_baseline_gravity.py b/train_baseline_gravity.py index e712556..42eeef0 100644 --- a/train_baseline_gravity.py +++ b/train_baseline_gravity.py @@ -412,7 +412,7 @@ def train(x): print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: - f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)\n' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) # plot some stuff frame_predictor.eval() diff --git a/train_baseline_object_permanence.py b/train_baseline_object_permanence.py index 6d6e00d..752f3ba 100644 --- a/train_baseline_object_permanence.py +++ b/train_baseline_object_permanence.py @@ -415,7 +415,7 @@ def train(x): print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: - f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)\n' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) # plot some stuff frame_predictor.eval() posterior.eval() diff --git a/train_baseline_shapeconstancy.py b/train_baseline_shapeconstancy.py index e392a8f..96f87fb 100644 --- a/train_baseline_shapeconstancy.py +++ b/train_baseline_shapeconstancy.py @@ -413,7 +413,7 @@ def train(x): print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: - f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)\n' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) # plot some stuff frame_predictor.eval() diff --git a/train_baseline_spatialTemporalContinuity.py b/train_baseline_spatialTemporalContinuity.py index 9dfcf6f..e8aa5f5 100644 --- a/train_baseline_spatialTemporalContinuity.py +++ b/train_baseline_spatialTemporalContinuity.py @@ -413,7 +413,7 @@ def train(x): print('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) with open(os.path.join(opt.log_dir, 'loss.txt'), 'a') as f: - f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + f.write('[%02d] mse loss: %.5f | residual mse: %.20f (%d)\n' % (epoch, epoch_mse/opt.epoch_size, epoch_mse_residual/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) # plot some stuff frame_predictor.eval() From 119a892351614d898658ed6f1a3dbc22d48f2796 Mon Sep 17 00:00:00 2001 From: Fred Date: Fri, 17 Sep 2021 22:39:56 -0400 Subject: [PATCH 23/26] updating obj permanence training --- train_baseline_object_permanence.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_baseline_object_permanence.py b/train_baseline_object_permanence.py index 752f3ba..916c65e 100644 --- a/train_baseline_object_permanence.py +++ b/train_baseline_object_permanence.py @@ -31,8 +31,8 @@ parser.add_argument('--use_edge_kernels', default=True, type=bool, help='whether to use edge kernels to reduce to 1 channel') parser.add_argument('--dataset', default='mcs', help='dataset to train with') parser.add_argument('--mcs_task', default='ObjectPermanenceTraining4', help='mcs task') -parser.add_argument('--n_past', type=int, default=20, help='number of frames to condition on') -parser.add_argument('--n_future', type=int, default=20, help='number of frames to predict') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=35, help='number of frames to predict') parser.add_argument('--n_eval', type=int, default=40, help='number of frames to predict at eval time') parser.add_argument('--start_min', type=int, default=75, help='min starting time for sampling sequence (0-indexed)') parser.add_argument('--start_max', type=int, default=77, help='max starting time for sampling sequence (0-indexed)') From d439f76cb63de13c5a5f3bc10ae980f69c33826c Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Sat, 18 Sep 2021 10:49:11 -0400 Subject: [PATCH 24/26] added SGD optimizer --- train_baseline_object_permanence.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/train_baseline_object_permanence.py b/train_baseline_object_permanence.py index 916c65e..870615f 100644 --- a/train_baseline_object_permanence.py +++ b/train_baseline_object_permanence.py @@ -155,11 +155,18 @@ encoder.apply(utils.init_weights) decoder.apply(utils.init_weights) -frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +if opt.optimizer == optim.SGD: + frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, momentum=opt.beta1) + posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, momentum=opt.beta1) + prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, momentum=opt.beta1) + encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, momentum=opt.beta1) + decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, momentum=opt.beta1) +else: + frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) # --------- loss functions ------------------------------------ mse_criterion = nn.MSELoss() From ef5398c70ccbb04fcf562dce82839e1c87488a2a Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Thu, 23 Sep 2021 18:53:13 -0400 Subject: [PATCH 25/26] Updated shape constancy training --- data/mcs.py | 28 +++++++++++++++++++--------- train_baseline_object_permanence.py | 2 +- train_baseline_shapeconstancy.py | 14 ++++++++------ utils.py | 20 ++++++++++++-------- 4 files changed, 40 insertions(+), 24 deletions(-) diff --git a/data/mcs.py b/data/mcs.py index 41863e7..abf952e 100644 --- a/data/mcs.py +++ b/data/mcs.py @@ -15,7 +15,7 @@ class MCS(object): def __init__(self, train, data_root, seq_len=20, image_size=64, task='ALL', sequential=None, implausible=False, test_set=False, im_channels=1, use_edge_kernels=True, labels=False, start_min=None, start_max=None, sequence_stride=None, - reduce_static_frames=False, object_exiting_frame_offset=20, lifting_frame_index=None,): + reduce_static_frames=False, is_shape_constancy=False): # if implausible is set to True, generates "fake" images by cutting out or repeating frames self.implausible = implausible if test_set: @@ -37,6 +37,7 @@ def __init__(self, train, data_root, seq_len=20, image_size=64, task='ALL', sequ self.sequence_stride = sequence_stride self.reduce_static_frames = reduce_static_frames self.motion_threshold = 0.1 + self.is_shape_constancy = is_shape_constancy # print('mcs.py: found tasks ', self.tasks) self.video_folder = {} @@ -141,13 +142,22 @@ def get_sequence(self, idx=None): # len_before = (self.seq_len - 1) // 2 # len_after = (self.seq_len - 1) - len_before # new_seq += seq[first_movement - start: first_movement - start + 5] - first_static = min(165, max(start + 7, first_static)) # so don't run into negative indices - new_seq += seq[first_static - start - 7: first_static - start + 3] - assert len(new_seq) == 10 - new_seq += seq[first_static - start + 3 + 12: first_static - start + 3 + 22] - assert len(new_seq) == 20 - new_seq += seq[200 - start: 200 - start + (self.seq_len - 20)] - assert len(new_seq) == 40 + if self.is_shape_constancy: + first_static = min(105, max(start + 10, first_static)) # so we don't run into negative indices + new_seq += seq[first_static - start - 10: first_static - start + 3] + assert len(new_seq) == 13 + else: + first_static = min(165, max(start + 7, first_static)) # so we don't run into negative indices + new_seq += seq[first_static - start - 7: first_static - start + 3] + assert len(new_seq) == 10 + + if self.is_shape_constancy: + new_seq += seq[144 - start: 144 - start + (self.seq_len - 13)] + else: + new_seq += seq[first_static - start + 3 + 12: first_static - start + 3 + 22] + assert len(new_seq) == 20 + new_seq += seq[200 - start: 200 - start + (self.seq_len - 20)] + assert len(new_seq) == self.seq_len # new_seq = seq[first_movement - start:] # new_seq += seq[first_static - start:] # new_seq += seq[second_movement - start:] @@ -157,7 +167,7 @@ def get_sequence(self, idx=None): # print(len(new_seq), len(seq)) # for img in new_seq: # cv2.imshow('img', img[:, :, 0]) - # cv2.waitKey(1) + # cv2.waitKey(0) seq = new_seq if self.labels: return np.array(seq), label diff --git a/train_baseline_object_permanence.py b/train_baseline_object_permanence.py index 870615f..085838e 100644 --- a/train_baseline_object_permanence.py +++ b/train_baseline_object_permanence.py @@ -38,7 +38,7 @@ parser.add_argument('--start_max', type=int, default=77, help='max starting time for sampling sequence (0-indexed)') parser.add_argument('--sequence_stride', type=int, default=1, help='factor for sequence temporal subsampling (int)') parser.add_argument('--reduce_static_frames', type=bool, default=True, help='reduce number of static frames') -parser.add_argument('--lifting_frame_index', type=int, default=200, help='index of frame when panels are lifted') +# parser.add_argument('--lifting_frame_index', type=int, default=200, help='index of frame when panels are lifted') parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') parser.add_argument('--prior_rnn_layers', type=int, default=1, help='number of layers') parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') diff --git a/train_baseline_shapeconstancy.py b/train_baseline_shapeconstancy.py index 96f87fb..efb2638 100644 --- a/train_baseline_shapeconstancy.py +++ b/train_baseline_shapeconstancy.py @@ -32,11 +32,13 @@ parser.add_argument('--dataset', default='mcs', help='dataset to train with') parser.add_argument('--mcs_task', default='ShapeConstancyTraining', help='mcs task') parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') -parser.add_argument('--n_future', type=int, default=30, help='number of frames to predict') -parser.add_argument('--n_eval', type=int, default=35, help='number of frames to predict at eval time') -parser.add_argument('--start_min', type=int, default=79, help='min starting time for sampling sequence (0-indexed)') -parser.add_argument('--start_max', type=int, default=89, help='max starting time for sampling sequence (0-indexed)') -parser.add_argument('--sequence_stride', type=int, default=2, help='factor for sequence temporal subsampling (int)') +parser.add_argument('--n_future', type=int, default=21, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=26, help='number of frames to predict at eval time') +parser.add_argument('--start_min', type=int, default=70, help='min starting time for sampling sequence (0-indexed)') +parser.add_argument('--start_max', type=int, default=75, help='max starting time for sampling sequence (0-indexed)') +parser.add_argument('--sequence_stride', type=int, default=1, help='factor for sequence temporal subsampling (int)') +parser.add_argument('--reduce_static_frames', type=bool, default=True, help='reduce number of static frames') +parser.add_argument('--is_shape_constancy', type=bool, default=True, help='flag indicating using shape constancy dataset') parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') parser.add_argument('--prior_rnn_layers', type=int, default=1, help='number of layers') parser.add_argument('--posterior_rnn_layers', type=int, default=1, help='number of layers') @@ -46,7 +48,7 @@ parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') parser.add_argument('--model', default='vgg', help='model type (dcgan | vgg)') -parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') +parser.add_argument('--data_threads', type=int, default=1, help='number of data loading threads') parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') diff --git a/utils.py b/utils.py index a3fb598..6050384 100755 --- a/utils.py +++ b/utils.py @@ -98,6 +98,8 @@ def load_dataset(opt, sequential=None, implausible=None): image_size=opt.image_width) elif opt.dataset == 'mcs': from data.mcs import MCS + is_shape_constancy = opt.is_shape_constancy if "is_shape_constancy" in opt.__dict__ else False + reduce_static_frames = opt.reduce_static_frames if "reduce_static_frames" in opt.__dict__ else False train_data = MCS( train=True, data_root=opt.data_root, @@ -111,8 +113,8 @@ def load_dataset(opt, sequential=None, implausible=None): start_min=opt.start_min, start_max=opt.start_max, sequence_stride=opt.sequence_stride, - reduce_static_frames=opt.reduce_static_frames, - lifting_frame_index=opt.lifting_frame_index,) + reduce_static_frames=reduce_static_frames, + is_shape_constancy=is_shape_constancy,) test_data = MCS( train=False, data_root=opt.data_root, @@ -126,9 +128,11 @@ def load_dataset(opt, sequential=None, implausible=None): start_min=opt.start_min, start_max=opt.start_max, sequence_stride=opt.sequence_stride, - reduce_static_frames=opt.reduce_static_frames, - lifting_frame_index=opt.lifting_frame_index,) + reduce_static_frames=reduce_static_frames, + is_shape_constancy=is_shape_constancy,) elif opt.dataset == 'mcs_test': + is_shape_constancy = opt.is_shape_constancy if "is_shape_constancy" in opt.__dict__ else False + reduce_static_frames = opt.reduce_static_frames if "reduce_static_frames" in opt.__dict__ else False from data.mcs import MCS train_data = MCS( train=True, @@ -144,8 +148,8 @@ def load_dataset(opt, sequential=None, implausible=None): start_min=opt.start_min, start_max=opt.start_max, sequence_stride=opt.sequence_stride, - reduce_static_frames=opt.reduce_static_frames, - lifting_frame_index=opt.lifting_frame_index, + reduce_static_frames=reduce_static_frames, + is_shape_constancy=is_shape_constancy, labels=True) test_data = MCS( train=False, @@ -161,8 +165,8 @@ def load_dataset(opt, sequential=None, implausible=None): start_min=opt.start_min, start_max=opt.start_max, sequence_stride=opt.sequence_stride, - reduce_static_frames=opt.reduce_static_frames, - lifting_frame_index=opt.lifting_frame_index, + reduce_static_frames=reduce_static_frames, + is_shape_constancy=is_shape_constancy, labels=True) return train_data, test_data From 5a36e57c359ba2238efc8e596055934503376331 Mon Sep 17 00:00:00 2001 From: Fred Lu Date: Thu, 23 Sep 2021 19:07:32 -0400 Subject: [PATCH 26/26] Updated shape constancy training --- train_baseline_collision.py | 17 ++++++++++++----- train_baseline_gravity.py | 17 ++++++++++++----- train_baseline_shapeconstancy.py | 19 +++++++++++++------ train_baseline_spatialTemporalContinuity.py | 17 ++++++++++++----- 4 files changed, 49 insertions(+), 21 deletions(-) diff --git a/train_baseline_collision.py b/train_baseline_collision.py index 9d28b00..1c900b2 100644 --- a/train_baseline_collision.py +++ b/train_baseline_collision.py @@ -153,11 +153,18 @@ encoder.apply(utils.init_weights) decoder.apply(utils.init_weights) -frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +if opt.optimizer == optim.SGD: + frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, momentum=opt.beta1) + posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, momentum=opt.beta1) + prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, momentum=opt.beta1) + encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, momentum=opt.beta1) + decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, momentum=opt.beta1) +else: + frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) # --------- loss functions ------------------------------------ mse_criterion = nn.MSELoss() diff --git a/train_baseline_gravity.py b/train_baseline_gravity.py index 42eeef0..f237590 100644 --- a/train_baseline_gravity.py +++ b/train_baseline_gravity.py @@ -152,11 +152,18 @@ encoder.apply(utils.init_weights) decoder.apply(utils.init_weights) -frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +if opt.optimizer == optim.SGD: + frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, momentum=opt.beta1) + posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, momentum=opt.beta1) + prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, momentum=opt.beta1) + encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, momentum=opt.beta1) + decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, momentum=opt.beta1) +else: + frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) # --------- loss functions ------------------------------------ mse_criterion = nn.MSELoss() diff --git a/train_baseline_shapeconstancy.py b/train_baseline_shapeconstancy.py index efb2638..4e7d44f 100644 --- a/train_baseline_shapeconstancy.py +++ b/train_baseline_shapeconstancy.py @@ -48,7 +48,7 @@ parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') parser.add_argument('--gamma', type=float, default=0.0001, help='weighting on h vs h posterior') parser.add_argument('--model', default='vgg', help='model type (dcgan | vgg)') -parser.add_argument('--data_threads', type=int, default=1, help='number of data loading threads') +parser.add_argument('--data_threads', type=int, default=12, help='number of data loading threads') parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+1 rather than last ground truth frame') @@ -155,11 +155,18 @@ encoder.apply(utils.init_weights) decoder.apply(utils.init_weights) -frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +if opt.optimizer == optim.SGD: + frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, momentum=opt.beta1) + posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, momentum=opt.beta1) + prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, momentum=opt.beta1) + encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, momentum=opt.beta1) + decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, momentum=opt.beta1) +else: + frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) # --------- loss functions ------------------------------------ mse_criterion = nn.MSELoss() diff --git a/train_baseline_spatialTemporalContinuity.py b/train_baseline_spatialTemporalContinuity.py index e8aa5f5..26af3b9 100644 --- a/train_baseline_spatialTemporalContinuity.py +++ b/train_baseline_spatialTemporalContinuity.py @@ -153,11 +153,18 @@ encoder.apply(utils.init_weights) decoder.apply(utils.init_weights) -frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +if opt.optimizer == optim.SGD: + frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, momentum=opt.beta1) + posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, momentum=opt.beta1) + prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, momentum=opt.beta1) + encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, momentum=opt.beta1) + decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, momentum=opt.beta1) +else: + frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) # --------- loss functions ------------------------------------ mse_criterion = nn.MSELoss()