From 145f5c11ff4e09f593bd2aa3dda84092371bfe33 Mon Sep 17 00:00:00 2001 From: pascalempucl Date: Thu, 1 Sep 2022 13:04:22 +0100 Subject: [PATCH 1/7] Initial commit --- tensorflow_mri/python/models/cyclegan.py | 360 ++++++++++++++++++ tensorflow_mri/python/models/cyclegan_test.py | 0 2 files changed, 360 insertions(+) create mode 100644 tensorflow_mri/python/models/cyclegan.py create mode 100644 tensorflow_mri/python/models/cyclegan_test.py diff --git a/tensorflow_mri/python/models/cyclegan.py b/tensorflow_mri/python/models/cyclegan.py new file mode 100644 index 00000000..f0d483c2 --- /dev/null +++ b/tensorflow_mri/python/models/cyclegan.py @@ -0,0 +1,360 @@ +# Copyright 2021 University College London. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""CycleGAN Layers and Composite Generator/Discriminator Model Implementation.""" + +import tensorflow as tf +from keras import layers +import tensorflow_addons as tfa + +from keras.optimizers import Adam +from keras.initializers import RandomNormal + +from keras.models import Model +from keras.layers import Layer +from keras.layers import Conv2D +from keras.layers import LeakyReLU +from keras.layers import Conv2DTranspose +from keras.layers import LeakyReLU +from keras.layers import Activation +from keras.layers import Concatenate +from tensorflow_addons.layers import InstanceNormalization + +from keras.activations import sigmoid +from keras.losses import MeanSquaredError, MeanAbsoluteError + +class CGResNetBlock(Layer): + """ResNet Block Layer specific to the CycleGAN Generator Architecture.""" + def __init__(self, n_filters): + """ + Args: + n_filters: An `int` denoting the number of convolutional filters + to use, and same dtype as `self`. + """ + super(CGResNetBlock, self).__init__() + self.init = RandomNormal(stddev=0.02) + self.conv2a = Conv2D(n_filters, (3,3), padding="same", kernel_initializer=self.init) + self.conv2b = Conv2D(256, (3,3), padding="same", kernel_initializer=self.init) + self.instance_norm1 = InstanceNormalization(axis=-1) + self.instance_norm2 = InstanceNormalization(axis=-1) + self.relu = Activation("relu") + self.concat = Concatenate() + + def call(self, input_tensor): + x = self.conv2a(input_tensor) + x = self.instance_norm1(x) + x = self.relu(x) + x = self.conv2b(x) + x = self.instance_norm2(x) + x = self.concat([x, input_tensor]) + return x + +class CGEncoder(Layer): + """Encoder Layer specific to the CycleGAN Generator Architecture.""" + + def __init__(self): + super(CGEncoder, self).__init__() + self.init = RandomNormal(stddev=0.02) + self.conv2a = Conv2D(64, (7,7), padding='same', kernel_initializer=self.init) + self.conv2b = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=self.init) + self.conv2c = Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=self.init) + self.instance_norm1 = InstanceNormalization(axis=-1) + self.instance_norm2 = InstanceNormalization(axis=-1) + self.instance_norm3 = InstanceNormalization(axis=-1) + self.relu = Activation('relu') + + def call(self, input_tensor): + x = self.conv2a(input_tensor) + x = self.instance_norm1(x) + x = self.relu(x) + + x = self.conv2b(x) + x = self.instance_norm2(x) + x = self.relu(x) + + x = self.conv2c(x) + x = self.instance_norm3(x) + x = self.relu(x) + + return x + +class CGDecoder(Layer): + """Decoder Layer specific to the CycleGAN Generator Architecture.""" + def __init__(self): + super(CGDecoder, self).__init__() + self.init = RandomNormal(stddev=0.02) + self.conv2ta = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=self.init) + self.conv2tb = Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', kernel_initializer=self.init) + self.conv2c = Conv2D(1, (7,7), padding='same', kernel_initializer=self.init) + self.instance_norm1 = InstanceNormalization(axis=-1) + self.instance_norm2 = InstanceNormalization(axis=-1) + self.instance_norm3 = InstanceNormalization(axis=-1) + self.tanh = Activation('tanh') + self.relu = Activation('relu') + + def call(self, input_tensor): + x = self.conv2ta(input_tensor) + x = self.instance_norm1(x) + x = self.relu(x) + + x = self.conv2tb(x) + x = self.instance_norm2(x) + x = self.relu(x) + + x = self.conv2c(x) + x = self.instance_norm3(x) + out_image = self.tanh(x) + + return out_image + +class CGConvEncoder(Layer): + """Convolutional Encoder Layer specific to the CycleGAN Discriminator Architecture.""" + def __init__(self): + super(CGConvEncoder, self).__init__() + self.init = RandomNormal(stddev=0.02) + + self.conv2a = Conv2D(64, (4,4), strides=(3,3), padding='same', kernel_initializer=self.init) + self.conv2b = Conv2D(128, (4,4), strides=(3,3), padding='same', kernel_initializer=self.init) + self.conv2c = Conv2D(256, (4,4), strides=(3,3), padding='same', kernel_initializer=self.init) + self.conv2d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=self.init) + self.conv2e = Conv2D(512, (4,4), padding='same', kernel_initializer=self.init) + self.conv2f = Conv2D(1, (4,4), padding='same', kernel_initializer=self.init) + + self.instance_norm1 = InstanceNormalization(axis=-1) + self.instance_norm2 = InstanceNormalization(axis=-1) + self.instance_norm3 = InstanceNormalization(axis=-1) + self.instance_norm4 = InstanceNormalization(axis=-1) + + self.leakyReLU = LeakyReLU(alpha=0.2) + self.sigmoid = sigmoid + + def call(self, input_tensor): + """ + Args: + input_tensor: A `tf.Tensor` of shape `[..., n]` and same dtype as `self`. + Returns: + A `tf.Tensor` of shape `[..., 1]` and same dtype as `self`. + """ + x = self.conv2a(input_tensor) + x = self.leakyReLU(x) + + x = self.conv2b(x) + x = self.instance_norm1(x) + x = self.leakyReLU(x) + + x = self.conv2c(x) + x = self.instance_norm2(x) + x = self.leakyReLU(x) + + x = self.conv2d(x) + x = self.instance_norm3(x) + x = self.leakyReLU(x) + + x = self.conv2e(x) + x = self.instance_norm4(x) + x = self.leakyReLU(x) + + out_pred = self.conv2f(x) + + return out_pred + +class CGGenerator(Model): + """CycleGAN Generator Model.""" + def __init__(self, image_shape, n_resnet=None): + super(CGGenerator, self).__init__() + + # If image size is less than 256x256, only use + # 6 resnet blocks + if image_shape[1] < 256 and n_resnet == None: + pass + + self.encoder = CGEncoder() + self.decoder = CGDecoder() + self.res_blocks = [] + + def call(self, inputs): + x = self.encoder(inputs) + for i in range(0, self.n_resnet): + self.res_blocks.append(CGResNetBlock(256)) + x = self.res_blocks[i](x) + x = self.decoder(x) + return x + +class CGDiscriminator(Model): + """CycleGAN Discriminator Model.""" + def __init__(self): + super(CGDiscriminator, self).__init__() + self.conv_encoder = CGConvEncoder() + + def compile(self, optimizer=Adam(lr=0.0002, beta_1=0.5), loss_weights=[0.5], **kwargs): + super().compile(optimizer=optimizer, losds_weights=loss_weights **kwargs) + + def call(self, input_image): + return self.conv_encoder(input_image) + +class CycleGAN(Model): + """ + This is the main composite CycleGAN Model used to concurrently train both generators + and discriminators with a custom training loop. + + Adapted from the following: https://keras.io/examples/generative/cyclegan/ + + By default, the image size for both Generators and Discriminators + is set to be 128x128 - but these can be changed to any square image + dimension, although > 256x256 and < 64x64 is not recommended. + """ + + def __init__( + self, + generator_G = CGGenerator(image_shape=(1,128,128)), + generator_F = CGGenerator(image_shape=(1,128,128)), + discriminator_X = CGDiscriminator(), + discriminator_Y = CGDiscriminator(), + lambda_cycle=10.0, + lambda_identity=1.0, + ): + super(CycleGAN, self).__init__() + self.gen_G = generator_G + self.gen_F = generator_F + self.disc_X = discriminator_X + self.disc_Y = discriminator_Y + self.lambda_cycle = lambda_cycle + self.lambda_identity = lambda_identity + + # Loss function for evaluating adversarial loss + self.adv_loss_fn = MeanSquaredError() + + # Define the loss function for the generators + def generator_loss_fn(self, fake): + fake_loss = self.adv_loss_fn(tf.ones_like(fake), fake) + return fake_loss + + # Define the loss function for the discriminators + def discriminator_loss_fn(self, real, fake): + real_loss = self.adv_loss_fn(tf.ones_like(real), real) + fake_loss = self.adv_loss_fn(tf.zeros_like(fake), fake) + return (real_loss + fake_loss) * 0.5 + + def compile( + self, + gen_G_optimizer=Adam(lr=0.0002, beta_1=0.5), + gen_F_optimizer=Adam(lr=0.0002, beta_1=0.5), + disc_X_optimizer=Adam(lr=0.0002, beta_1=0.5), + disc_Y_optimizer=Adam(lr=0.0002, beta_1=0.5), + gen_loss_fn=generator_loss_fn, + disc_loss_fn=discriminator_loss_fn, + ): + + super(CycleGAN, self).compile() + self.gen_G_optimizer = gen_G_optimizer + self.gen_F_optimizer = gen_F_optimizer + self.disc_X_optimizer = disc_X_optimizer + self.disc_Y_optimizer = disc_Y_optimizer + self.generator_loss_fn = gen_loss_fn + self.discriminator_loss_fn = disc_loss_fn + self.cycle_loss_fn = MeanAbsoluteError() + self.identity_loss_fn = MeanAbsoluteError() + + def train_step(self, batch_data): + # real_x is the ailiased data + # real_y is the clean data + real_x, real_y = batch_data + + # We need to set `persistent=True` since we need to use + # the tape to calculate the derivatives for both generators. + # If it didn't persist, we'd not be able to update both generators + # since the data would be erased after the first generator has + # completed its backpropagation pass. + with tf.GradientTape(persistent=True) as tape: + # Ailiased to clean + fake_y = self.gen_G(real_x, training=True) + # Clean to ailiased + fake_x = self.gen_F(real_y, training=True) + + # Cycle (Ailiased -> Fake Clean -> Fake Ailiased) + cycle_x = self.gen_F(fake_y, training=True) + # Cycle (Clean -> Fake Ailiased -> Fake Clean) + cycle_y = self.gen_G(fake_x, training=True) + + # Identity mapping + same_x = self.gen_F(real_x, training=True) + same_y = self.gen_G(real_y, training=True) + + # Discriminator outputs + disc_real_x = self.disc_X(real_x, training=True) + disc_fake_x = self.disc_X(fake_x, training=True) + + disc_real_y = self.disc_Y(real_y, training=True) + disc_fake_y = self.disc_Y(fake_y, training=True) + + # Adversarial Loss for Generators + gen_G_loss = self.generator_loss_fn(self, disc_fake_y) + gen_F_loss = self.generator_loss_fn(self, disc_fake_x) + + # Cycle Loss for Generators + cycle_loss_G = self.lambda_cycle * self.cycle_loss_fn(real_y, cycle_y) + cycle_loss_F = self.lambda_cycle * self.cycle_loss_fn(real_x, cycle_x) + + # Generator identity losses + id_loss_G = ( + self.identity_loss_fn(real_y, same_y) + * self.lambda_cycle + * self.lambda_identity + ) + + id_loss_F = ( + self.identity_loss_fn(real_x, same_x) + * self.lambda_cycle + * self.lambda_identity + ) + + # Total generator loss + total_loss_G = gen_G_loss + cycle_loss_G + id_loss_G + total_loss_F = gen_F_loss + cycle_loss_F + id_loss_F + + # Discriminator loss + disc_X_loss = self.discriminator_loss_fn(self, disc_real_x, disc_fake_x) + disc_Y_loss = self.discriminator_loss_fn(self, disc_real_y, disc_fake_y) + + # Get the generator gradients using tape: + grads_G = tape.gradient(total_loss_G, self.gen_G.trainable_variables) + grads_F = tape.gradient(total_loss_F, self.gen_F.trainable_variables) + + # Get the discriminator gradients using tape: + disc_X_grads = tape.gradient(disc_X_loss, self.disc_X.trainable_variables) + disc_Y_grads = tape.gradient(disc_Y_loss, self.disc_Y.trainable_variables) + + # Update the weights of the generators + self.gen_G_optimizer.apply_gradients( + zip(grads_G, self.gen_G.trainable_variables) + ) + self.gen_F_optimizer.apply_gradients( + zip(grads_F, self.gen_F.trainable_variables) + ) + + # Update the weights of the discriminators + self.disc_X_optimizer.apply_gradients( + zip(disc_X_grads, self.disc_X.trainable_variables) + ) + self.disc_Y_optimizer.apply_gradients( + zip(disc_Y_grads, self.disc_Y.trainable_variables) + ) + + return { + "G_loss": total_loss_G, + "F_loss": total_loss_F, + "D_X_loss": disc_X_loss, + "D_Y_loss": disc_Y_loss, + } \ No newline at end of file diff --git a/tensorflow_mri/python/models/cyclegan_test.py b/tensorflow_mri/python/models/cyclegan_test.py new file mode 100644 index 00000000..e69de29b From 81322125b529c97598bbeee386e2e59a0f9ac68f Mon Sep 17 00:00:00 2001 From: pascalempucl Date: Thu, 1 Sep 2022 13:21:02 +0100 Subject: [PATCH 2/7] Amended suggested changes to CycleGAN architecture --- tensorflow_mri/python/models/cyclegan.py | 42 +++++++++++++++--------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/tensorflow_mri/python/models/cyclegan.py b/tensorflow_mri/python/models/cyclegan.py index f0d483c2..3ad9c58b 100644 --- a/tensorflow_mri/python/models/cyclegan.py +++ b/tensorflow_mri/python/models/cyclegan.py @@ -179,16 +179,19 @@ def __init__(self, image_shape, n_resnet=None): # 6 resnet blocks if image_shape[1] < 256 and n_resnet == None: pass - + + self.n_resnet = n_resnet self.encoder = CGEncoder() self.decoder = CGDecoder() self.res_blocks = [] - + + for block in range(0, self.n_resnet): + self.res_blocks.append(CGResNetBlock(256)) + def call(self, inputs): x = self.encoder(inputs) - for i in range(0, self.n_resnet): - self.res_blocks.append(CGResNetBlock(256)) - x = self.res_blocks[i](x) + for block in self.res_blocks: + x = block(x) x = self.decoder(x) return x @@ -218,6 +221,10 @@ class CycleGAN(Model): def __init__( self, + G_loss_fn, + D_loss_fn, + adversarial_loss_fn, + identity_loss_fn, generator_G = CGGenerator(image_shape=(1,128,128)), generator_F = CGGenerator(image_shape=(1,128,128)), discriminator_X = CGDiscriminator(), @@ -233,9 +240,15 @@ def __init__( self.lambda_cycle = lambda_cycle self.lambda_identity = lambda_identity - # Loss function for evaluating adversarial loss - self.adv_loss_fn = MeanSquaredError() - + self.G_loss_fn = G_loss_fn + self.D_loss_fn = D_loss_fn + self.adv_loss_fn = adversarial_loss_fn + self.identity_loss_fn = identity_loss_fn + + # TODO - Move these to a tutorial notebook and add them + # as input to the cycleGAN as they are now not needed + # within the class here. + ''' # Define the loss function for the generators def generator_loss_fn(self, fake): fake_loss = self.adv_loss_fn(tf.ones_like(fake), fake) @@ -246,15 +259,14 @@ def discriminator_loss_fn(self, real, fake): real_loss = self.adv_loss_fn(tf.ones_like(real), real) fake_loss = self.adv_loss_fn(tf.zeros_like(fake), fake) return (real_loss + fake_loss) * 0.5 + ''' def compile( self, gen_G_optimizer=Adam(lr=0.0002, beta_1=0.5), gen_F_optimizer=Adam(lr=0.0002, beta_1=0.5), disc_X_optimizer=Adam(lr=0.0002, beta_1=0.5), - disc_Y_optimizer=Adam(lr=0.0002, beta_1=0.5), - gen_loss_fn=generator_loss_fn, - disc_loss_fn=discriminator_loss_fn, + disc_Y_optimizer=Adam(lr=0.0002, beta_1=0.5) ): super(CycleGAN, self).compile() @@ -262,10 +274,10 @@ def compile( self.gen_F_optimizer = gen_F_optimizer self.disc_X_optimizer = disc_X_optimizer self.disc_Y_optimizer = disc_Y_optimizer - self.generator_loss_fn = gen_loss_fn - self.discriminator_loss_fn = disc_loss_fn - self.cycle_loss_fn = MeanAbsoluteError() - self.identity_loss_fn = MeanAbsoluteError() + self.generator_loss_fn = self.G_loss_fn + self.discriminator_loss_fn = self.D_loss_fn + self.cycle_loss_fn = self.adv_loss_fn + self.identity_loss_fn = self.identity_loss_fn def train_step(self, batch_data): # real_x is the ailiased data From 88e347ed395184df1f05bab13ab4ed973fc5f2db Mon Sep 17 00:00:00 2001 From: pascalempucl Date: Thu, 1 Sep 2022 14:50:29 +0000 Subject: [PATCH 3/7] Added linting changes --- tensorflow_mri/python/models/cyclegan.py | 679 ++++++++++++----------- 1 file changed, 348 insertions(+), 331 deletions(-) diff --git a/tensorflow_mri/python/models/cyclegan.py b/tensorflow_mri/python/models/cyclegan.py index 3ad9c58b..3a295397 100644 --- a/tensorflow_mri/python/models/cyclegan.py +++ b/tensorflow_mri/python/models/cyclegan.py @@ -10,363 +10,380 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. +# limitations under the License.s # ============================================================================== -"""CycleGAN Layers and Composite Generator/Discriminator Model Implementation.""" +# pylint: disable=arguments-differ + +"""CycleGAN Layers and Composite Generator/Discriminator +Model Implementation.""" import tensorflow as tf -from keras import layers -import tensorflow_addons as tfa -from keras.optimizers import Adam +from keras.activations import sigmoid from keras.initializers import RandomNormal - -from keras.models import Model from keras.layers import Layer from keras.layers import Conv2D from keras.layers import LeakyReLU from keras.layers import Conv2DTranspose -from keras.layers import LeakyReLU from keras.layers import Activation from keras.layers import Concatenate -from tensorflow_addons.layers import InstanceNormalization +from keras.models import Model +from keras.optimizers import Adam -from keras.activations import sigmoid -from keras.losses import MeanSquaredError, MeanAbsoluteError +from tensorflow_addons.layers import InstanceNormalization class CGResNetBlock(Layer): - """ResNet Block Layer specific to the CycleGAN Generator Architecture.""" - def __init__(self, n_filters): - """ - Args: - n_filters: An `int` denoting the number of convolutional filters - to use, and same dtype as `self`. - """ - super(CGResNetBlock, self).__init__() - self.init = RandomNormal(stddev=0.02) - self.conv2a = Conv2D(n_filters, (3,3), padding="same", kernel_initializer=self.init) - self.conv2b = Conv2D(256, (3,3), padding="same", kernel_initializer=self.init) - self.instance_norm1 = InstanceNormalization(axis=-1) - self.instance_norm2 = InstanceNormalization(axis=-1) - self.relu = Activation("relu") - self.concat = Concatenate() - - def call(self, input_tensor): - x = self.conv2a(input_tensor) - x = self.instance_norm1(x) - x = self.relu(x) - x = self.conv2b(x) - x = self.instance_norm2(x) - x = self.concat([x, input_tensor]) - return x + """ResNet Block Layer specific to the CycleGAN Generator Architecture.""" + def __init__(self, n_filters): + """ + Args: + n_filters: An `int` denoting the number of convolutional filters + to use, and same dtype as `self`. + """ + super().__init__() + self.init = RandomNormal(stddev=0.02) + self.conv2a = Conv2D(n_filters, (3,3), padding="same", + kernel_initializer=self.init) + self.conv2b = Conv2D(256, (3,3), padding="same", + kernel_initializer=self.init) + self.instance_norm1 = InstanceNormalization(axis=-1) + self.instance_norm2 = InstanceNormalization(axis=-1) + self.relu = Activation("relu") + self.concat = Concatenate() + + def call(self, inputs): + x = self.conv2a(inputs) + x = self.instance_norm1(x) + x = self.relu(x) + x = self.conv2b(x) + x = self.instance_norm2(x) + x = self.concat([x, inputs]) + return x class CGEncoder(Layer): - """Encoder Layer specific to the CycleGAN Generator Architecture.""" - - def __init__(self): - super(CGEncoder, self).__init__() - self.init = RandomNormal(stddev=0.02) - self.conv2a = Conv2D(64, (7,7), padding='same', kernel_initializer=self.init) - self.conv2b = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=self.init) - self.conv2c = Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=self.init) - self.instance_norm1 = InstanceNormalization(axis=-1) - self.instance_norm2 = InstanceNormalization(axis=-1) - self.instance_norm3 = InstanceNormalization(axis=-1) - self.relu = Activation('relu') - - def call(self, input_tensor): - x = self.conv2a(input_tensor) - x = self.instance_norm1(x) - x = self.relu(x) - - x = self.conv2b(x) - x = self.instance_norm2(x) - x = self.relu(x) - - x = self.conv2c(x) - x = self.instance_norm3(x) - x = self.relu(x) - - return x + """Encoder Layer specific to the CycleGAN Generator Architecture.""" + + def __init__(self): + super().__init__() + self.init = RandomNormal(stddev=0.02) + self.conv2a = Conv2D(64, (7,7), padding='same', + kernel_initializer=self.init) + self.conv2b = Conv2D(128, (3,3), strides=(2,2), padding='same', + kernel_initializer=self.init) + self.conv2c = Conv2D(256, (3,3), strides=(2,2), padding='same', + kernel_initializer=self.init) + self.instance_norm1 = InstanceNormalization(axis=-1) + self.instance_norm2 = InstanceNormalization(axis=-1) + self.instance_norm3 = InstanceNormalization(axis=-1) + self.relu = Activation('relu') + + def call(self, inputs): + x = self.conv2a(inputs) + x = self.instance_norm1(x) + x = self.relu(x) + + x = self.conv2b(x) + x = self.instance_norm2(x) + x = self.relu(x) + + x = self.conv2c(x) + x = self.instance_norm3(x) + x = self.relu(x) + + return x class CGDecoder(Layer): - """Decoder Layer specific to the CycleGAN Generator Architecture.""" - def __init__(self): - super(CGDecoder, self).__init__() - self.init = RandomNormal(stddev=0.02) - self.conv2ta = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=self.init) - self.conv2tb = Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', kernel_initializer=self.init) - self.conv2c = Conv2D(1, (7,7), padding='same', kernel_initializer=self.init) - self.instance_norm1 = InstanceNormalization(axis=-1) - self.instance_norm2 = InstanceNormalization(axis=-1) - self.instance_norm3 = InstanceNormalization(axis=-1) - self.tanh = Activation('tanh') - self.relu = Activation('relu') - - def call(self, input_tensor): - x = self.conv2ta(input_tensor) - x = self.instance_norm1(x) - x = self.relu(x) - - x = self.conv2tb(x) - x = self.instance_norm2(x) - x = self.relu(x) - - x = self.conv2c(x) - x = self.instance_norm3(x) - out_image = self.tanh(x) - - return out_image + """Decoder Layer specific to the CycleGAN Generator Architecture.""" + def __init__(self): + super().__init__() + self.init = RandomNormal(stddev=0.02) + self.conv2ta = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', + kernel_initializer=self.init) + self.conv2tb = Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', + kernel_initializer=self.init) + self.conv2c = Conv2D(1, (7,7), padding='same', kernel_initializer=self.init) + self.instance_norm1 = InstanceNormalization(axis=-1) + self.instance_norm2 = InstanceNormalization(axis=-1) + self.instance_norm3 = InstanceNormalization(axis=-1) + self.tanh = Activation('tanh') + self.relu = Activation('relu') + + def call(self, inputs): + x = self.conv2ta(inputs) + x = self.instance_norm1(x) + x = self.relu(x) + + x = self.conv2tb(x) + x = self.instance_norm2(x) + x = self.relu(x) + + x = self.conv2c(x) + x = self.instance_norm3(x) + out_image = self.tanh(x) + + return out_image class CGConvEncoder(Layer): - """Convolutional Encoder Layer specific to the CycleGAN Discriminator Architecture.""" - def __init__(self): - super(CGConvEncoder, self).__init__() - self.init = RandomNormal(stddev=0.02) - - self.conv2a = Conv2D(64, (4,4), strides=(3,3), padding='same', kernel_initializer=self.init) - self.conv2b = Conv2D(128, (4,4), strides=(3,3), padding='same', kernel_initializer=self.init) - self.conv2c = Conv2D(256, (4,4), strides=(3,3), padding='same', kernel_initializer=self.init) - self.conv2d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=self.init) - self.conv2e = Conv2D(512, (4,4), padding='same', kernel_initializer=self.init) - self.conv2f = Conv2D(1, (4,4), padding='same', kernel_initializer=self.init) - - self.instance_norm1 = InstanceNormalization(axis=-1) - self.instance_norm2 = InstanceNormalization(axis=-1) - self.instance_norm3 = InstanceNormalization(axis=-1) - self.instance_norm4 = InstanceNormalization(axis=-1) - - self.leakyReLU = LeakyReLU(alpha=0.2) - self.sigmoid = sigmoid - - def call(self, input_tensor): - """ - Args: - input_tensor: A `tf.Tensor` of shape `[..., n]` and same dtype as `self`. - Returns: - A `tf.Tensor` of shape `[..., 1]` and same dtype as `self`. - """ - x = self.conv2a(input_tensor) - x = self.leakyReLU(x) - - x = self.conv2b(x) - x = self.instance_norm1(x) - x = self.leakyReLU(x) - - x = self.conv2c(x) - x = self.instance_norm2(x) - x = self.leakyReLU(x) - - x = self.conv2d(x) - x = self.instance_norm3(x) - x = self.leakyReLU(x) - - x = self.conv2e(x) - x = self.instance_norm4(x) - x = self.leakyReLU(x) - - out_pred = self.conv2f(x) - - return out_pred + """Convolutional Encoder Layer specific to the + CycleGAN Discriminator Architecture.""" + def __init__(self): + super().__init__() + self.init = RandomNormal(stddev=0.02) + + self.conv2a = Conv2D(64, (4,4), strides=(3,3), padding='same', + kernel_initializer=self.init) + self.conv2b = Conv2D(128, (4,4), strides=(3,3), padding='same', + kernel_initializer=self.init) + self.conv2c = Conv2D(256, (4,4), strides=(3,3), padding='same', + kernel_initializer=self.init) + self.conv2d = Conv2D(512, (4,4), strides=(2,2), padding='same', + kernel_initializer=self.init) + self.conv2e = Conv2D(512, (4,4), padding='same', + kernel_initializer=self.init) + self.conv2f = Conv2D(1, (4,4), padding='same', + kernel_initializer=self.init) + + self.instance_norm1 = InstanceNormalization(axis=-1) + self.instance_norm2 = InstanceNormalization(axis=-1) + self.instance_norm3 = InstanceNormalization(axis=-1) + self.instance_norm4 = InstanceNormalization(axis=-1) + + self.leakyrelu = LeakyReLU(alpha=0.2) + self.sigmoid = sigmoid + + def call(self, inputs): + """ + Args: + inputs: A `tf.Tensor` of shape `[..., n]` and same dtype as `self`. + Returns: + A `tf.Tensor` of shape `[..., 1]` and same dtype as `self`. + """ + x = self.conv2a(inputs) + x = self.leakyrelu(x) + + x = self.conv2b(x) + x = self.instance_norm1(x) + x = self.leakyrelu(x) + + x = self.conv2c(x) + x = self.instance_norm2(x) + x = self.leakyrelu(x) + + x = self.conv2d(x) + x = self.instance_norm3(x) + x = self.leakyReLU(x) + + x = self.conv2e(x) + x = self.instance_norm4(x) + x = self.leakyrelu(x) + + out_pred = self.conv2f(x) + + return out_pred class CGGenerator(Model): - """CycleGAN Generator Model.""" - def __init__(self, image_shape, n_resnet=None): - super(CGGenerator, self).__init__() - - # If image size is less than 256x256, only use - # 6 resnet blocks - if image_shape[1] < 256 and n_resnet == None: - pass - - self.n_resnet = n_resnet - self.encoder = CGEncoder() - self.decoder = CGDecoder() - self.res_blocks = [] - - for block in range(0, self.n_resnet): - self.res_blocks.append(CGResNetBlock(256)) - - def call(self, inputs): - x = self.encoder(inputs) - for block in self.res_blocks: - x = block(x) - x = self.decoder(x) - return x + """CycleGAN Generator Model.""" + def __init__(self, image_shape, n_resnet=None): + super().__init__() + + # If image size is less than 256x256, only use + # 6 resnet blocks + if image_shape[1] < 256 and n_resnet is None: + pass + + self.n_resnet = n_resnet + self.encoder = CGEncoder() + self.decoder = CGDecoder() + self.res_blocks = [] + + for _ in range(0, self.n_resnet): + self.res_blocks.append(CGResNetBlock(256)) + + def call(self, inputs): + x = self.encoder(inputs) + for block in self.res_blocks: + x = block(x) + x = self.decoder(x) + return x class CGDiscriminator(Model): - """CycleGAN Discriminator Model.""" - def __init__(self): - super(CGDiscriminator, self).__init__() - self.conv_encoder = CGConvEncoder() + """CycleGAN Discriminator Model.""" + def __init__(self): + super().__init__() + self.conv_encoder = CGConvEncoder() - def compile(self, optimizer=Adam(lr=0.0002, beta_1=0.5), loss_weights=[0.5], **kwargs): - super().compile(optimizer=optimizer, losds_weights=loss_weights **kwargs) + def compile(self, optimizer=Adam(lr=0.0002, beta_1=0.5), + loss_weights=(0.5), **kwargs): + super().compile(optimizer=optimizer, losds_weights=loss_weights **kwargs) - def call(self, input_image): - return self.conv_encoder(input_image) + def call(self, inputs): + return self.conv_encoder(inputs) class CycleGAN(Model): - """ - This is the main composite CycleGAN Model used to concurrently train both generators - and discriminators with a custom training loop. - - Adapted from the following: https://keras.io/examples/generative/cyclegan/ - - By default, the image size for both Generators and Discriminators - is set to be 128x128 - but these can be changed to any square image - dimension, although > 256x256 and < 64x64 is not recommended. - """ - - def __init__( - self, - G_loss_fn, - D_loss_fn, - adversarial_loss_fn, - identity_loss_fn, - generator_G = CGGenerator(image_shape=(1,128,128)), - generator_F = CGGenerator(image_shape=(1,128,128)), - discriminator_X = CGDiscriminator(), - discriminator_Y = CGDiscriminator(), - lambda_cycle=10.0, - lambda_identity=1.0, - ): - super(CycleGAN, self).__init__() - self.gen_G = generator_G - self.gen_F = generator_F - self.disc_X = discriminator_X - self.disc_Y = discriminator_Y - self.lambda_cycle = lambda_cycle - self.lambda_identity = lambda_identity - - self.G_loss_fn = G_loss_fn - self.D_loss_fn = D_loss_fn - self.adv_loss_fn = adversarial_loss_fn - self.identity_loss_fn = identity_loss_fn - - # TODO - Move these to a tutorial notebook and add them - # as input to the cycleGAN as they are now not needed - # within the class here. - ''' - # Define the loss function for the generators - def generator_loss_fn(self, fake): - fake_loss = self.adv_loss_fn(tf.ones_like(fake), fake) - return fake_loss - - # Define the loss function for the discriminators - def discriminator_loss_fn(self, real, fake): - real_loss = self.adv_loss_fn(tf.ones_like(real), real) - fake_loss = self.adv_loss_fn(tf.zeros_like(fake), fake) - return (real_loss + fake_loss) * 0.5 - ''' - - def compile( - self, - gen_G_optimizer=Adam(lr=0.0002, beta_1=0.5), - gen_F_optimizer=Adam(lr=0.0002, beta_1=0.5), - disc_X_optimizer=Adam(lr=0.0002, beta_1=0.5), - disc_Y_optimizer=Adam(lr=0.0002, beta_1=0.5) - ): - - super(CycleGAN, self).compile() - self.gen_G_optimizer = gen_G_optimizer - self.gen_F_optimizer = gen_F_optimizer - self.disc_X_optimizer = disc_X_optimizer - self.disc_Y_optimizer = disc_Y_optimizer - self.generator_loss_fn = self.G_loss_fn - self.discriminator_loss_fn = self.D_loss_fn - self.cycle_loss_fn = self.adv_loss_fn - self.identity_loss_fn = self.identity_loss_fn - - def train_step(self, batch_data): - # real_x is the ailiased data - # real_y is the clean data - real_x, real_y = batch_data - - # We need to set `persistent=True` since we need to use - # the tape to calculate the derivatives for both generators. - # If it didn't persist, we'd not be able to update both generators - # since the data would be erased after the first generator has - # completed its backpropagation pass. - with tf.GradientTape(persistent=True) as tape: - # Ailiased to clean - fake_y = self.gen_G(real_x, training=True) - # Clean to ailiased - fake_x = self.gen_F(real_y, training=True) - - # Cycle (Ailiased -> Fake Clean -> Fake Ailiased) - cycle_x = self.gen_F(fake_y, training=True) - # Cycle (Clean -> Fake Ailiased -> Fake Clean) - cycle_y = self.gen_G(fake_x, training=True) - - # Identity mapping - same_x = self.gen_F(real_x, training=True) - same_y = self.gen_G(real_y, training=True) - - # Discriminator outputs - disc_real_x = self.disc_X(real_x, training=True) - disc_fake_x = self.disc_X(fake_x, training=True) - - disc_real_y = self.disc_Y(real_y, training=True) - disc_fake_y = self.disc_Y(fake_y, training=True) - - # Adversarial Loss for Generators - gen_G_loss = self.generator_loss_fn(self, disc_fake_y) - gen_F_loss = self.generator_loss_fn(self, disc_fake_x) - - # Cycle Loss for Generators - cycle_loss_G = self.lambda_cycle * self.cycle_loss_fn(real_y, cycle_y) - cycle_loss_F = self.lambda_cycle * self.cycle_loss_fn(real_x, cycle_x) - - # Generator identity losses - id_loss_G = ( - self.identity_loss_fn(real_y, same_y) - * self.lambda_cycle - * self.lambda_identity - ) - - id_loss_F = ( - self.identity_loss_fn(real_x, same_x) - * self.lambda_cycle - * self.lambda_identity - ) - - # Total generator loss - total_loss_G = gen_G_loss + cycle_loss_G + id_loss_G - total_loss_F = gen_F_loss + cycle_loss_F + id_loss_F - - # Discriminator loss - disc_X_loss = self.discriminator_loss_fn(self, disc_real_x, disc_fake_x) - disc_Y_loss = self.discriminator_loss_fn(self, disc_real_y, disc_fake_y) - - # Get the generator gradients using tape: - grads_G = tape.gradient(total_loss_G, self.gen_G.trainable_variables) - grads_F = tape.gradient(total_loss_F, self.gen_F.trainable_variables) - - # Get the discriminator gradients using tape: - disc_X_grads = tape.gradient(disc_X_loss, self.disc_X.trainable_variables) - disc_Y_grads = tape.gradient(disc_Y_loss, self.disc_Y.trainable_variables) - - # Update the weights of the generators - self.gen_G_optimizer.apply_gradients( - zip(grads_G, self.gen_G.trainable_variables) - ) - self.gen_F_optimizer.apply_gradients( - zip(grads_F, self.gen_F.trainable_variables) - ) - - # Update the weights of the discriminators - self.disc_X_optimizer.apply_gradients( - zip(disc_X_grads, self.disc_X.trainable_variables) - ) - self.disc_Y_optimizer.apply_gradients( - zip(disc_Y_grads, self.disc_Y.trainable_variables) - ) - - return { - "G_loss": total_loss_G, - "F_loss": total_loss_F, - "D_X_loss": disc_X_loss, - "D_Y_loss": disc_Y_loss, - } \ No newline at end of file + """ + This is the main composite CycleGAN Model used to concurrently + train both generators and discriminators with a custom training loop. + + Adapted from the following: https://keras.io/examples/generative/cyclegan/ + + By default, the image size for both Generators and Discriminators + is set to be 128x128 - but these can be changed to any square image + dimension, although > 256x256 and < 64x64 is not recommended. + """ + + def __init__( + self, + g_loss_fn, + d_loss_fn, + adversarial_loss_fn, + identity_loss_fn, + generator_g = CGGenerator(image_shape=(1,128,128)), + generator_f = CGGenerator(image_shape=(1,128,128)), + discriminator_x = CGDiscriminator(), + discriminator_y = CGDiscriminator(), + lambda_cycle=10.0, + lambda_identity=1.0, + ): + super().__init__() + self.gen_g = generator_g + self.gen_f = generator_f + self.disc_x = discriminator_x + self.disc_y = discriminator_y + self.lambda_cycle = lambda_cycle + self.lambda_identity = lambda_identity + + self.g_loss_fn = g_loss_fn + self.d_loss_fn = d_loss_fn + self.adv_loss_fn = adversarial_loss_fn + self.identity_loss_fn = identity_loss_fn + + # TODO - Move these to a tutorial notebook and add them + # as input to the cycleGAN as they are now not needed + # within the class here. + + # Define the loss function for the generators + #def generator_loss_fn(self, fake): + # fake_loss = self.adv_loss_fn(tf.ones_like(fake), fake) + # return fake_loss + + # Define the loss function for the discriminators + #def discriminator_loss_fn(self, real, fake): + # real_loss = self.adv_loss_fn(tf.ones_like(real), real) + # fake_loss = self.adv_loss_fn(tf.zeros_like(fake), fake) + # return (real_loss + fake_loss) * 0.5 + + + def compile( + self, + gen_g_optimizer=Adam(lr=0.0002, beta_1=0.5), + gen_f_optimizer=Adam(lr=0.0002, beta_1=0.5), + disc_x_optimizer=Adam(lr=0.0002, beta_1=0.5), + disc_y_optimizer=Adam(lr=0.0002, beta_1=0.5) + ): + + super().compile() + self.gen_g_optimizer = gen_g_optimizer + self.gen_f_optimizer = gen_f_optimizer + self.disc_x_optimizer = disc_x_optimizer + self.disc_y_optimizer = disc_y_optimizer + self.generator_loss_fn = self.g_loss_fn + self.discriminator_loss_fn = self.d_loss_fn + self.cycle_loss_fn = self.adv_loss_fn + self.identity_loss_fn = self.identity_loss_fn + + def call(self): + raise NotImplementedError("Directly call either the Generator or " + "Discriminator model during inference.") + + def train_step(self, batch_data): + # real_x is the ailiased data + # real_y is the clean data + real_x, real_y = batch_data + + # We need to set `persistent=True` since we need to use + # the tape to calculate the derivatives for both generators. + # If it didn't persist, we'd not be able to update both generators + # since the data would be erased after the first generator has + # completed its backpropagation pass. + with tf.GradientTape(persistent=True) as tape: + # Ailiased to clean + fake_y = self.gen_g(real_x, training=True) + # Clean to ailiased + fake_x = self.gen_f(real_y, training=True) + + # Cycle (Ailiased -> Fake Clean -> Fake Ailiased) + cycle_x = self.gen_f(fake_y, training=True) + # Cycle (Clean -> Fake Ailiased -> Fake Clean) + cycle_y = self.gen_g(fake_x, training=True) + + # Identity mapping + same_x = self.gen_f(real_x, training=True) + same_y = self.gen_g(real_y, training=True) + + # Discriminator outputs + disc_real_x = self.disc_x(real_x, training=True) + disc_fake_x = self.disc_x(fake_x, training=True) + + disc_real_y = self.disc_y(real_y, training=True) + disc_fake_y = self.disc_y(fake_y, training=True) + + # Adversarial Loss for Generators + gen_g_loss = self.generator_loss_fn(self, disc_fake_y) + gen_f_loss = self.generator_loss_fn(self, disc_fake_x) + + # Cycle Loss for Generators + cycle_loss_g = self.lambda_cycle * self.cycle_loss_fn(real_y, cycle_y) + cycle_loss_f = self.lambda_cycle * self.cycle_loss_fn(real_x, cycle_x) + + # Generator identity losses + id_loss_g = ( + self.identity_loss_fn(real_y, same_y) + * self.lambda_cycle + * self.lambda_identity + ) + + id_loss_f = ( + self.identity_loss_fn(real_x, same_x) + * self.lambda_cycle + * self.lambda_identity + ) + + # Total generator loss + total_loss_g = gen_g_loss + cycle_loss_g + id_loss_g + total_loss_f = gen_f_loss + cycle_loss_f + id_loss_f + + # Discriminator loss + disc_x_loss = self.discriminator_loss_fn(self, disc_real_x, disc_fake_x) + disc_y_loss = self.discriminator_loss_fn(self, disc_real_y, disc_fake_y) + + # Get the generator gradients using tape: + grads_g = tape.gradient(total_loss_g, self.gen_G.trainable_variables) + grads_f = tape.gradient(total_loss_f, self.gen_F.trainable_variables) + + # Get the discriminator gradients using tape: + disc_x_grads = tape.gradient(disc_x_loss, self.disc_X.trainable_variables) + disc_y_grads = tape.gradient(disc_y_loss, self.disc_Y.trainable_variables) + + # Update the weights of the generators + self.gen_g_optimizer.apply_gradients( + zip(grads_g, self.gen_g.trainable_variables) + ) + self.gen_f_optimizer.apply_gradients( + zip(grads_f, self.gen_f.trainable_variables) + ) + + # Update the weights of the discriminators + self.disc_x_optimizer.apply_gradients( + zip(disc_x_grads, self.disc_x.trainable_variables) + ) + self.disc_y_optimizer.apply_gradients( + zip(disc_y_grads, self.disc_y.trainable_variables) + ) + + return { + "g_loss": total_loss_g, + "f_loss": total_loss_f, + "d_x_loss": disc_x_loss, + "d_y_loss": disc_y_loss, + } From 7f0370274e2f415a36261411c181980edb3aef43 Mon Sep 17 00:00:00 2001 From: pascalempucl Date: Thu, 1 Sep 2022 14:53:13 +0000 Subject: [PATCH 4/7] Amended requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 916dc593..2de91fca 100755 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ tensorflow-graphics tensorflow-io>=0.26.0 tensorflow-nufft>=0.8.0 tensorflow-probability>=0.16.0 +tensorflow-addons From 0167e11749ebe6f8934d3df4956e293e3ab14ae0 Mon Sep 17 00:00:00 2001 From: Javier Montalt Tordera Date: Fri, 2 Sep 2022 17:36:46 +0000 Subject: [PATCH 5/7] Canonicalizing implementation [skip ci] --- .../models/{cyclegan.py => cycle_gan.py} | 270 ++++++++++-------- .../{cyclegan_test.py => cycle_gan_test.py} | 0 2 files changed, 145 insertions(+), 125 deletions(-) rename tensorflow_mri/python/models/{cyclegan.py => cycle_gan.py} (56%) rename tensorflow_mri/python/models/{cyclegan_test.py => cycle_gan_test.py} (100%) diff --git a/tensorflow_mri/python/models/cyclegan.py b/tensorflow_mri/python/models/cycle_gan.py similarity index 56% rename from tensorflow_mri/python/models/cyclegan.py rename to tensorflow_mri/python/models/cycle_gan.py index 3a295397..3da51523 100644 --- a/tensorflow_mri/python/models/cyclegan.py +++ b/tensorflow_mri/python/models/cycle_gan.py @@ -12,45 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License.s # ============================================================================== - +"""CycleGAN model and composite generator/discriminator implementation.""" # pylint: disable=arguments-differ -"""CycleGAN Layers and Composite Generator/Discriminator -Model Implementation.""" - +import keras import tensorflow as tf +import tensorflow_addons as tfa + + +class CGResNetBlock(keras.layers.Layer): + """ResNet block layer specific to the CycleGAN generator architecture. -from keras.activations import sigmoid -from keras.initializers import RandomNormal -from keras.layers import Layer -from keras.layers import Conv2D -from keras.layers import LeakyReLU -from keras.layers import Conv2DTranspose -from keras.layers import Activation -from keras.layers import Concatenate -from keras.models import Model -from keras.optimizers import Adam - -from tensorflow_addons.layers import InstanceNormalization - -class CGResNetBlock(Layer): - """ResNet Block Layer specific to the CycleGAN Generator Architecture.""" + Args: + n_filters: An `int` denoting the number of convolutional filters + to use. + """ + # TODO: this layer needs a `get_config` method, or just remove the `n_filters` + # argument, since it's always set to the same value of 256. def __init__(self, n_filters): - """ - Args: - n_filters: An `int` denoting the number of convolutional filters - to use, and same dtype as `self`. - """ super().__init__() - self.init = RandomNormal(stddev=0.02) - self.conv2a = Conv2D(n_filters, (3,3), padding="same", - kernel_initializer=self.init) - self.conv2b = Conv2D(256, (3,3), padding="same", - kernel_initializer=self.init) - self.instance_norm1 = InstanceNormalization(axis=-1) - self.instance_norm2 = InstanceNormalization(axis=-1) - self.relu = Activation("relu") - self.concat = Concatenate() + self.conv2a = keras.layers.Conv2D( + n_filters, (3, 3), padding="same", + kernel_initializer=_make_default_initializer()) + self.conv2b = keras.layers.Conv2D( + 256, (3, 3), padding="same", + kernel_initializer=_make_default_initializer()) + self.instance_norm1 = tfa.layers.InstanceNormalization(axis=-1) + self.instance_norm2 = tfa.layers.InstanceNormalization(axis=-1) + self.relu = keras.layers.Activation("relu") + self.concat = keras.layers.Concatenate() def call(self, inputs): x = self.conv2a(inputs) @@ -61,22 +51,22 @@ def call(self, inputs): x = self.concat([x, inputs]) return x -class CGEncoder(Layer): - """Encoder Layer specific to the CycleGAN Generator Architecture.""" + +class CGEncoder(keras.layers.Layer): + """Encoder layer specific to the CycleGAN generator architecture.""" def __init__(self): super().__init__() - self.init = RandomNormal(stddev=0.02) - self.conv2a = Conv2D(64, (7,7), padding='same', - kernel_initializer=self.init) - self.conv2b = Conv2D(128, (3,3), strides=(2,2), padding='same', - kernel_initializer=self.init) - self.conv2c = Conv2D(256, (3,3), strides=(2,2), padding='same', - kernel_initializer=self.init) - self.instance_norm1 = InstanceNormalization(axis=-1) - self.instance_norm2 = InstanceNormalization(axis=-1) - self.instance_norm3 = InstanceNormalization(axis=-1) - self.relu = Activation('relu') + self.conv2a = keras.layers.Conv2D(64, (7,7), padding='same', + kernel_initializer=_make_default_initializer()) + self.conv2b = keras.layers.Conv2D(128, (3,3), strides=(2,2), padding='same', + kernel_initializer=_make_default_initializer()) + self.conv2c = keras.layers.Conv2D(256, (3,3), strides=(2,2), padding='same', + kernel_initializer=_make_default_initializer()) + self.instance_norm1 = tfa.layers.InstanceNormalization(axis=-1) + self.instance_norm2 = tfa.layers.InstanceNormalization(axis=-1) + self.instance_norm3 = tfa.layers.InstanceNormalization(axis=-1) + self.relu = keras.layers.Activation('relu') def call(self, inputs): x = self.conv2a(inputs) @@ -93,21 +83,25 @@ def call(self, inputs): return x -class CGDecoder(Layer): - """Decoder Layer specific to the CycleGAN Generator Architecture.""" + +class CGDecoder(keras.layers.Layer): + """Decoder layer specific to the CycleGAN generator architecture.""" def __init__(self): super().__init__() - self.init = RandomNormal(stddev=0.02) - self.conv2ta = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', - kernel_initializer=self.init) - self.conv2tb = Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', - kernel_initializer=self.init) - self.conv2c = Conv2D(1, (7,7), padding='same', kernel_initializer=self.init) - self.instance_norm1 = InstanceNormalization(axis=-1) - self.instance_norm2 = InstanceNormalization(axis=-1) - self.instance_norm3 = InstanceNormalization(axis=-1) - self.tanh = Activation('tanh') - self.relu = Activation('relu') + self.conv2ta = keras.layers.keras.layers.Conv2DTranspose( + 128, (3, 3), strides=(2, 2), padding='same', + kernel_initializer=_make_default_initializer()) + self.conv2tb = keras.layers.keras.layers.Conv2DTranspose( + 64, (3, 3), strides=(2, 2), padding='same', + kernel_initializer=_make_default_initializer()) + self.conv2c = keras.layers.Conv2D( + 1, (7,7), padding='same', + kernel_initializer=_make_default_initializer()) + self.instance_norm1 = tfa.layers.InstanceNormalization(axis=-1) + self.instance_norm2 = tfa.layers.InstanceNormalization(axis=-1) + self.instance_norm3 = tfa.layers.InstanceNormalization(axis=-1) + self.tanh = keras.layers.Activation('tanh') + self.relu = keras.layers.Activation('relu') def call(self, inputs): x = self.conv2ta(inputs) @@ -124,33 +118,35 @@ def call(self, inputs): return out_image -class CGConvEncoder(Layer): - """Convolutional Encoder Layer specific to the - CycleGAN Discriminator Architecture.""" + +class CGConvEncoder(keras.layers.Layer): + """Convolutional layer specific to the CycleGAN discriminator architecture.""" def __init__(self): super().__init__() - self.init = RandomNormal(stddev=0.02) - - self.conv2a = Conv2D(64, (4,4), strides=(3,3), padding='same', - kernel_initializer=self.init) - self.conv2b = Conv2D(128, (4,4), strides=(3,3), padding='same', - kernel_initializer=self.init) - self.conv2c = Conv2D(256, (4,4), strides=(3,3), padding='same', - kernel_initializer=self.init) - self.conv2d = Conv2D(512, (4,4), strides=(2,2), padding='same', - kernel_initializer=self.init) - self.conv2e = Conv2D(512, (4,4), padding='same', - kernel_initializer=self.init) - self.conv2f = Conv2D(1, (4,4), padding='same', - kernel_initializer=self.init) - - self.instance_norm1 = InstanceNormalization(axis=-1) - self.instance_norm2 = InstanceNormalization(axis=-1) - self.instance_norm3 = InstanceNormalization(axis=-1) - self.instance_norm4 = InstanceNormalization(axis=-1) - - self.leakyrelu = LeakyReLU(alpha=0.2) - self.sigmoid = sigmoid + + self.conv2a = keras.layers.Conv2D( + 64, (4, 4), strides=(3, 3), padding='same', + kernel_initializer=_make_default_initializer()) + self.conv2b = keras.layers.Conv2D( + 128, (4, 4), strides=(3, 3), padding='same', + kernel_initializer=_make_default_initializer()) + self.conv2c = keras.layers.Conv2D( + 256, (4, 4), strides=(3, 3), padding='same', + kernel_initializer=_make_default_initializer()) + self.conv2d = keras.layers.Conv2D( + 512, (4, 4), strides=(2, 2), padding='same', + kernel_initializer=_make_default_initializer()) + self.conv2e = keras.layers.Conv2D( + 512, (4, 4), padding='same', + kernel_initializer=_make_default_initializer()) + self.conv2f = keras.layers.Conv2D( + 1, (4, 4), padding='same', + kernel_initializer=_make_default_initializer()) + + self.instance_norm1 = tfa.layers.InstanceNormalization(axis=-1) + self.instance_norm2 = tfa.layers.InstanceNormalization(axis=-1) + self.instance_norm3 = tfa.layers.InstanceNormalization(axis=-1) + self.instance_norm4 = tfa.layers.InstanceNormalization(axis=-1) def call(self, inputs): """ @@ -182,8 +178,11 @@ def call(self, inputs): return out_pred -class CGGenerator(Model): + +class CGGenerator(keras.Model): """CycleGAN Generator Model.""" + # TODO: `image_shape` should not be an argument to the constructor. + # It is typically inferred from the input tensor. def __init__(self, image_shape, n_resnet=None): super().__init__() @@ -207,21 +206,28 @@ def call(self, inputs): x = self.decoder(x) return x -class CGDiscriminator(Model): - """CycleGAN Discriminator Model.""" + +class CGDiscriminator(keras.Model): + """CycleGAN discriminator Model.""" def __init__(self): super().__init__() self.conv_encoder = CGConvEncoder() - def compile(self, optimizer=Adam(lr=0.0002, beta_1=0.5), - loss_weights=(0.5), **kwargs): - super().compile(optimizer=optimizer, losds_weights=loss_weights **kwargs) + def compile(self, + optimizer=None, + loss_weights=(0.5,), + **kwargs): + super().compile(optimizer=optimizer or _make_default_optimizer(), + loss_weights=loss_weights, + **kwargs) def call(self, inputs): return self.conv_encoder(inputs) -class CycleGAN(Model): - """ + +class CycleGAN(keras.Model): + """CycleGAN model. + This is the main composite CycleGAN Model used to concurrently train both generators and discriminators with a custom training loop. @@ -233,23 +239,22 @@ class CycleGAN(Model): """ def __init__( - self, - g_loss_fn, - d_loss_fn, - adversarial_loss_fn, - identity_loss_fn, - generator_g = CGGenerator(image_shape=(1,128,128)), - generator_f = CGGenerator(image_shape=(1,128,128)), - discriminator_x = CGDiscriminator(), - discriminator_y = CGDiscriminator(), - lambda_cycle=10.0, - lambda_identity=1.0, - ): + self, + g_loss_fn, + d_loss_fn, + adversarial_loss_fn, + identity_loss_fn, + generator_g=None, + generator_f=None, + discriminator_x=None, + discriminator_y=None, + lambda_cycle=10.0, + lambda_identity=1.0): super().__init__() - self.gen_g = generator_g - self.gen_f = generator_f - self.disc_x = discriminator_x - self.disc_y = discriminator_y + self.gen_g = generator_g or _make_default_generator() + self.gen_f = generator_f or _make_default_generator() + self.disc_x = discriminator_x or _make_default_discriminator() + self.disc_y = discriminator_y or _make_default_discriminator() self.lambda_cycle = lambda_cycle self.lambda_identity = lambda_identity @@ -275,26 +280,25 @@ def __init__( def compile( - self, - gen_g_optimizer=Adam(lr=0.0002, beta_1=0.5), - gen_f_optimizer=Adam(lr=0.0002, beta_1=0.5), - disc_x_optimizer=Adam(lr=0.0002, beta_1=0.5), - disc_y_optimizer=Adam(lr=0.0002, beta_1=0.5) - ): - + self, + gen_g_optimizer=None, + gen_f_optimizer=None, + disc_x_optimizer=None, + disc_y_optimizer=None): super().compile() - self.gen_g_optimizer = gen_g_optimizer - self.gen_f_optimizer = gen_f_optimizer - self.disc_x_optimizer = disc_x_optimizer - self.disc_y_optimizer = disc_y_optimizer + self.gen_g_optimizer = gen_g_optimizer or _make_default_optimizer() + self.gen_f_optimizer = gen_f_optimizer or _make_default_optimizer() + self.disc_x_optimizer = disc_x_optimizer or _make_default_optimizer() + self.disc_y_optimizer = disc_y_optimizer or _make_default_optimizer() self.generator_loss_fn = self.g_loss_fn self.discriminator_loss_fn = self.d_loss_fn self.cycle_loss_fn = self.adv_loss_fn self.identity_loss_fn = self.identity_loss_fn def call(self): - raise NotImplementedError("Directly call either the Generator or " - "Discriminator model during inference.") + raise NotImplementedError( + "Directly call either the generator or " + "discriminator model during inference.") def train_step(self, batch_data): # real_x is the ailiased data @@ -382,8 +386,24 @@ def train_step(self, batch_data): ) return { - "g_loss": total_loss_g, - "f_loss": total_loss_f, - "d_x_loss": disc_x_loss, - "d_y_loss": disc_y_loss, + "g_loss": total_loss_g, + "f_loss": total_loss_f, + "d_x_loss": disc_x_loss, + "d_y_loss": disc_y_loss, } + + +def _make_default_optimizer(): + return keras.optimizers.Adam(lr=0.0002, beta_1=0.5) + + +def _make_default_generator(): + return CGGenerator(image_shape=[1, 128, 128]) + + +def _make_default_discriminator(): + return CGDiscriminator() + + +def _make_default_initializer(): + return keras.initializers.RandomNormal(stddev=0.02) diff --git a/tensorflow_mri/python/models/cyclegan_test.py b/tensorflow_mri/python/models/cycle_gan_test.py similarity index 100% rename from tensorflow_mri/python/models/cyclegan_test.py rename to tensorflow_mri/python/models/cycle_gan_test.py From 8c0445713abab16c7b666aba321b0d1d4df912d9 Mon Sep 17 00:00:00 2001 From: pascalempucl <111292309+pascalempucl@users.noreply.github.com> Date: Wed, 21 Sep 2022 15:02:54 +0100 Subject: [PATCH 6/7] Added CycleGAN 3D Architecture [skip ci] --- tensorflow_mri/python/models/cycle_gan_3D.py | 357 +++++++++++++++++++ 1 file changed, 357 insertions(+) create mode 100644 tensorflow_mri/python/models/cycle_gan_3D.py diff --git a/tensorflow_mri/python/models/cycle_gan_3D.py b/tensorflow_mri/python/models/cycle_gan_3D.py new file mode 100644 index 00000000..e963c1ba --- /dev/null +++ b/tensorflow_mri/python/models/cycle_gan_3D.py @@ -0,0 +1,357 @@ +# Copyright 2021 University College London. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""CycleGAN 3D Layers and Composite Generator/Discriminator Model Implementation.""" + +import tensorflow as tf + +from keras.optimizers import Adam +from keras.initializers import RandomNormal + +from keras.models import Model +from keras.layers import Layer +from keras.layers import Conv3D +from keras.layers import LeakyReLU +from keras.layers import Conv3DTranspose +from keras.layers import LeakyReLU +from keras.layers import Activation +from keras.layers import Concatenate +from tensorflow_addons.layers import InstanceNormalization + +from keras.activations import sigmoid +from keras.losses import MeanSquaredError, MeanAbsoluteError + + +class CG3DResNetBlock(Layer): + """ResNet Block Layer specific to the CycleGAN Generator Architecture.""" + def __init__(self, n_filters): + """ + Args: + n_filters: An `int` denoting the number of convolutional filters + to use, and same dtype as `self`. + """ + super(CG3DResNetBlock, self).__init__() + self.init = RandomNormal(stddev=0.02) + self.conv2a = Conv3D(n_filters, (3,3,3), padding="same", kernel_initializer=self.init) + self.conv2b = Conv3D(256, (3,3,3), padding="same", kernel_initializer=self.init) + self.instance_norm1 = InstanceNormalization(axis=-1) + self.instance_norm2 = InstanceNormalization(axis=-1) + self.relu = Activation("relu") + self.concat = Concatenate() + + def call(self, input_tensor): + x = self.conv2a(input_tensor) + x = self.instance_norm1(x) + x = self.relu(x) + x = self.conv2b(x) + x = self.instance_norm2(x) + x = self.concat([x, input_tensor]) + return x + +class CG3DEncoder(Layer): + """Encoder Layer specific to the CycleGAN 3D Generator Architecture.""" + + def __init__(self): + super(CG3DEncoder, self).__init__() + self.init = RandomNormal(stddev=0.02) + self.conv2a = Conv3D(64, (7,7,7), padding='same', kernel_initializer=self.init) + self.conv2b = Conv3D(128, (3,3,3), strides=(2,2,2), padding='same', kernel_initializer=self.init) + self.conv2c = Conv3D(256, (3,3,3), strides=(2,2,2), padding='same', kernel_initializer=self.init) + self.instance_norm1 = InstanceNormalization(axis=-1) + self.instance_norm2 = InstanceNormalization(axis=-1) + self.instance_norm3 = InstanceNormalization(axis=-1) + self.relu = Activation('relu') + + def call(self, input_tensor): + x = self.conv2a(input_tensor) + x = self.instance_norm1(x) + x = self.relu(x) + + x = self.conv2b(x) + x = self.instance_norm2(x) + x = self.relu(x) + + x = self.conv2c(x) + x = self.instance_norm3(x) + x = self.relu(x) + + return x + +class CG3DDecoder(Layer): + """Decoder Layer specific to the CycleGAN 3D Generator Architecture.""" + def __init__(self): + super(CG3DDecoder, self).__init__() + self.init = RandomNormal(stddev=0.02) + self.conv2ta = Conv3DTranspose(128, (3,3,3), strides=(2,2,2), padding='same', kernel_initializer=self.init) + self.conv2tb = Conv3DTranspose(64, (3,3,3), strides=(2,2,2), padding='same', kernel_initializer=self.init) + self.conv2c = Conv3D(1, (7,7,7), padding='same', kernel_initializer=self.init) + self.instance_norm1 = InstanceNormalization(axis=-1) + self.instance_norm2 = InstanceNormalization(axis=-1) + self.instance_norm3 = InstanceNormalization(axis=-1) + self.tanh = Activation('tanh') + self.relu = Activation('relu') + + def call(self, input_tensor): + x = self.conv2ta(input_tensor) + x = self.instance_norm1(x) + x = self.relu(x) + + x = self.conv2tb(x) + x = self.instance_norm2(x) + x = self.relu(x) + + x = self.conv2c(x) + x = self.instance_norm3(x) + out_image = self.tanh(x) + + return out_image + +class CG3DConvEncoder(Layer): + """Convolutional Encoder Layer specific to the CycleGAN 3D Discriminator Architecture.""" + def __init__(self): + super(CG3DConvEncoder, self).__init__() + self.init = RandomNormal(stddev=0.02) + + self.conv2a = Conv3D(64, (4,4,4), strides=(2,2,2), padding='same', kernel_initializer=self.init) + self.conv2b = Conv3D(128, (4,4,4), strides=(2,2,2), padding='same', kernel_initializer=self.init) + self.conv2c = Conv3D(128, (4,4,4), padding='same', kernel_initializer=self.init) + self.conv2d = Conv3D(1, (4,4,4), padding='same', kernel_initializer=self.init) + + self.instance_norm1 = InstanceNormalization(axis=-1) + self.instance_norm2 = InstanceNormalization(axis=-1) + + self.leakyReLU = LeakyReLU(alpha=0.2) + self.sigmoid = sigmoid + + def call(self, input_tensor): + """ + Args: + input_tensor: A `tf.Tensor` of shape `[..., n]` and same dtype as `self`. + Returns: + A `tf.Tensor` of shape `[..., 1]` and same dtype as `self`. + """ + x = self.conv2a(input_tensor) + x = self.leakyReLU(x) + + x = self.conv2b(x) + x = self.instance_norm1(x) + x = self.leakyReLU(x) + + x = self.conv2c(x) + x = self.instance_norm2(x) + x = self.leakyReLU(x) + + out_pred = self.conv2d(x) + + return out_pred + +class CG3DGenerator(Model): + """CycleGAN 3D Generator Model.""" + def __init__(self, image_shape, n_resnet=None): + super(CG3DGenerator, self).__init__() + self.n_resnet = n_resnet + + # If image size is less than 256x256, or there is no arg - only use + # 6 resnet blocks + if (image_shape[1] < 256 and n_resnet is None) or n_resnet is None: + self.n_resnet = 6 + + self.encoder = CG3DEncoder() + self.decoder = CG3DDecoder() + self.res_blocks = [] + + def call(self, inputs): + x = self.encoder(inputs) + for i in range(0, self.n_resnet): + self.res_blocks.append(CG3DResNetBlock(256)) + x = self.res_blocks[i](x) + x = self.decoder(x) + return x + +class CG3DDiscriminator(Model): + """CycleGAN Discriminator Model.""" + def __init__(self): + super(CG3DDiscriminator, self).__init__() + self.conv_encoder = CG3DConvEncoder() + + def compile(self, optimizer=Adam(lr=0.0002, beta_1=0.5), loss_weights=[0.5], **kwargs): + super().compile(optimizer=optimizer, losds_weights=loss_weights **kwargs) + + def call(self, input_image): + return self.conv_encoder(input_image) + +class CycleGAN3D(Model): + """ + This is the main composite CycleGAN Model used to concurrently train both generators + and discriminators with a custom training loop. + + Adapted from the following: https://keras.io/examples/generative/cyclegan/ + + By default, the image size for both Generators and Discriminators + is set to be 128x128 - but these can be changed to any square image + dimension, although > 256x256 and < 64x64 is not recommended. + """ + + def __init__( + self, + generator_G = CG3DGenerator(image_shape=(20,1,128,128)), + generator_F = CG3DGenerator(image_shape=(20,1,128,128)), + discriminator_X = CG3DDiscriminator(), + discriminator_Y = CG3DDiscriminator(), + lambda_cycle=10.0, + lambda_identity=1.0, + ): + super(CycleGAN3D, self).__init__() + self.gen_G = generator_G + self.gen_F = generator_F + self.disc_X = discriminator_X + self.disc_Y = discriminator_Y + self.lambda_cycle = lambda_cycle + self.lambda_identity = lambda_identity + # Loss function for evaluating adversarial loss + self.adv_loss_fn = MeanSquaredError() + + # Define the loss function for the generators + def generator_loss_fn(self, fake): + fake_loss = self.adv_loss_fn(tf.ones_like(fake), fake) + return fake_loss + + # Define the loss function for the discriminators + def discriminator_loss_fn(self, real, fake): + real_loss = self.adv_loss_fn(tf.ones_like(real), real) + fake_loss = self.adv_loss_fn(tf.zeros_like(fake), fake) + return (real_loss + fake_loss) * 0.5 + + def compile( + self, + gen_G_optimizer=Adam(lr=0.0002, beta_1=0.5), + gen_F_optimizer=Adam(lr=0.0002, beta_1=0.5), + disc_X_optimizer=Adam(lr=0.0002, beta_1=0.5), + disc_Y_optimizer=Adam(lr=0.0002, beta_1=0.5), + gen_loss_fn=generator_loss_fn, + disc_loss_fn=discriminator_loss_fn, + # Loss function for the Cycle Loss and Identity Loss (can be customised) + cycle_loss_fn=MeanAbsoluteError(), + identity_loss_fn=MeanAbsoluteError(), + **kwargs + ): + + super(CycleGAN3D, self).compile() + self.gen_G_optimizer = gen_G_optimizer + self.gen_F_optimizer = gen_F_optimizer + self.disc_X_optimizer = disc_X_optimizer + self.disc_Y_optimizer = disc_Y_optimizer + self.generator_loss_fn = gen_loss_fn + self.discriminator_loss_fn = disc_loss_fn + self.cycle_loss_fn = cycle_loss_fn + self.identity_loss_fn = identity_loss_fn + + def train_step(self, batch_data): + # real_x is the ailiased data + # real_y is the clean data + + real_x, real_y = batch_data + + # We need to set `persistent=True` since we need to use + # the tape to calculate the derivatives for both generators. + # If it didn't persist, we'd not be able to update both generators + # since the data would be erased after the first generator has + # completed its backpropagation pass. + with tf.GradientTape(persistent=True) as tape: + # Ailiased to clean + fake_y = self.gen_G(real_x, training=True) + # Clean to ailiased + fake_x = self.gen_F(real_y, training=True) + + # Cycle (Ailiased -> Fake Clean -> Fake Ailiased) + cycle_x = self.gen_F(fake_y, training=True) + # Cycle (Clean -> Fake Ailiased -> Fake Clean) + cycle_y = self.gen_G(fake_x, training=True) + + # Identity mapping + same_x = self.gen_F(real_x, training=True) + same_y = self.gen_G(real_y, training=True) + + # Discriminator outputs + disc_real_x = self.disc_X(real_x, training=True) + disc_fake_x = self.disc_X(fake_x, training=True) + + disc_real_y = self.disc_Y(real_y, training=True) + disc_fake_y = self.disc_Y(fake_y, training=True) + + # Adversarial Loss for Generators + gen_G_loss = self.generator_loss_fn(self, disc_fake_y) + gen_F_loss = self.generator_loss_fn(self, disc_fake_x) + + # Cycle Loss for Generators + cycle_loss_G = self.lambda_cycle * self.cycle_loss_fn(real_y, cycle_y) + cycle_loss_F = self.lambda_cycle * self.cycle_loss_fn(real_x, cycle_x) + + # Generator identity losses + id_loss_G = ( + self.identity_loss_fn(real_y, same_y) + * self.lambda_cycle + * self.lambda_identity + ) + + id_loss_F = ( + self.identity_loss_fn(real_x, same_x) + * self.lambda_cycle + * self.lambda_identity + ) + + # Total generator loss + total_loss_G = gen_G_loss + cycle_loss_G + id_loss_G + total_loss_F = gen_F_loss + cycle_loss_F + id_loss_F + + # Discriminator loss + disc_X_loss = self.discriminator_loss_fn(self, disc_real_x, disc_fake_x) + disc_Y_loss = self.discriminator_loss_fn(self, disc_real_y, disc_fake_y) + + # Get the generator gradients using tape: + grads_G = tape.gradient(total_loss_G, self.gen_G.trainable_variables) + grads_F = tape.gradient(total_loss_F, self.gen_F.trainable_variables) + + # Get the discriminator gradients using tape: + disc_X_grads = tape.gradient(disc_X_loss, self.disc_X.trainable_variables) + disc_Y_grads = tape.gradient(disc_Y_loss, self.disc_Y.trainable_variables) + + # Update the weights of the generators + self.gen_G_optimizer.apply_gradients( + zip(grads_G, self.gen_G.trainable_variables) + ) + self.gen_F_optimizer.apply_gradients( + zip(grads_F, self.gen_F.trainable_variables) + ) + + # Update the weights of the discriminators + self.disc_X_optimizer.apply_gradients( + zip(disc_X_grads, self.disc_X.trainable_variables) + ) + self.disc_Y_optimizer.apply_gradients( + zip(disc_Y_grads, self.disc_Y.trainable_variables) + ) + + return { + "G_loss": total_loss_G, + "F_loss": total_loss_F, + "D_X_loss": disc_X_loss, + "D_Y_loss": disc_Y_loss, + } + + def predict_step(self, data): + raise NotImplementedError("You can't call `model.predict()` on the \n" + "CycleGAN Model! You must call this on the Generator sub-model with \n" + "model.gen_G() or model.gen_F().") From 68e7b021a31ba815ec72771a6928d7dbf86910db Mon Sep 17 00:00:00 2001 From: pascalempucl Date: Wed, 21 Sep 2022 15:09:16 +0100 Subject: [PATCH 7/7] Updated n_resnet parameter in Generator `init` --- tensorflow_mri/python/models/cyclegan.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow_mri/python/models/cyclegan.py b/tensorflow_mri/python/models/cyclegan.py index 3a295397..c17b3d54 100644 --- a/tensorflow_mri/python/models/cyclegan.py +++ b/tensorflow_mri/python/models/cyclegan.py @@ -184,13 +184,13 @@ def call(self, inputs): class CGGenerator(Model): """CycleGAN Generator Model.""" - def __init__(self, image_shape, n_resnet=None): + def __init__(self, image_shape, n_resnet=9): super().__init__() # If image size is less than 256x256, only use # 6 resnet blocks - if image_shape[1] < 256 and n_resnet is None: - pass + if (image_shape[1] < 256 and n_resnet is None) or n_resnet is None: + n_resnet = 6 self.n_resnet = n_resnet self.encoder = CGEncoder()