diff --git a/data/moving_mnist.py b/data/moving_mnist.py index 9cad5fa..c206034 100644 --- a/data/moving_mnist.py +++ b/data/moving_mnist.py @@ -22,7 +22,7 @@ def __init__(self, train, data_root, seq_len=20, num_digits=2, image_size=64, de train=train, download=True, transform=transforms.Compose( - [transforms.Scale(self.digit_size), + [transforms.Resize(self.digit_size), transforms.ToTensor()])) self.N = len(self.data) diff --git a/utils.py b/utils.py index 1d74bbc..cd853e0 100755 --- a/utils.py +++ b/utils.py @@ -148,7 +148,8 @@ def make_image(tensor): if tensor.size(0) == 1: tensor = tensor.expand(3, tensor.size(1), tensor.size(2)) # pdb.set_trace() - return scipy.misc.toimage(tensor.numpy(), + tensor = tensor.detach().numpy() + return scipy.misc.toimage(tensor, high=255*tensor.max(), channel_axis=0)