diff --git a/ML/Pytorch/GANs/4. WGAN-GP/train.py b/ML/Pytorch/GANs/4. WGAN-GP/train.py index 1b23c77c..f1fe4f5c 100644 --- a/ML/Pytorch/GANs/4. WGAN-GP/train.py +++ b/ML/Pytorch/GANs/4. WGAN-GP/train.py @@ -33,7 +33,7 @@ transforms = transforms.Compose( [ - transforms.Resize(IMAGE_SIZE), + transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize( [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]