diff --git a/MPIGDriver.py b/MPIGDriver.py index 289f2e4..1f73586 100755 --- a/MPIGDriver.py +++ b/MPIGDriver.py @@ -156,7 +156,7 @@ mode='easgd', sync_every=args.sync_every, worker_optimizer=args.worker_optimizer, worker_optimizer_params=args.worker_optimizer_params, - elastic_force=args.elastic_force/(comm.Get_size()-1), + elastic_force=args.elastic_force/(min(1,comm.Get_size()-1)), elastic_lr=args.elastic_lr, elastic_momentum=args.elastic_momentum) else: diff --git a/models/get_3d.py b/models/get_3d.py index 53a4a7b..42a7938 100644 --- a/models/get_3d.py +++ b/models/get_3d.py @@ -7,7 +7,7 @@ print ("hum") import numpy as np import sys - +import keras def get_data(datafile): #get data for training #print ('Loading Data from .....', datafile) @@ -20,7 +20,11 @@ def get_data(datafile): X = X.astype(np.float32) y = y.astype(np.float32) y = y/100. - ecal = np.squeeze(np.sum(X, axis=(1, 2, 3))) + if keras.backend.image_data_format() !='channels_last': + X =np.moveaxis(X, -1, 1) + ecal = np.squeeze(np.sum(X, axis=(2, 3, 4))) + else: + ecal = np.squeeze(np.sum(X, axis=(1, 2, 3))) print (X.shape) print (y.shape) print (ecal.shape) diff --git a/mpi_learn/train/GanModel.py b/mpi_learn/train/GanModel.py index 53095da..cf17028 100644 --- a/mpi_learn/train/GanModel.py +++ b/mpi_learn/train/GanModel.py @@ -114,8 +114,16 @@ def _Model(**args): else: return Model(**args) def discriminator(fixed_bn = False, discr_drop_out=0.2): + if keras.backend.image_data_format() =='channels_last': + dshape=(25, 25, 25,1) + daxis=(1,2,3) + else: + dshape=(1, 25, 25, 25) + daxis=(2,3,4) + + image = Input(shape=dshape, name='image') + - image = Input(shape=( 25, 25, 25,1 ), name='image') bnm=2 if fixed_bn else 0 f=(5,5,5) @@ -164,19 +172,21 @@ def discriminator(fixed_bn = False, discr_drop_out=0.2): fake = _Dense(1, activation='sigmoid', name='classification')(dnn_out) aux = _Dense(1, activation='linear', name='energy')(dnn_out) - ecal = Lambda(lambda x: K.sum(x, axis=(1, 2, 3)), name='sum_cell')(image) + ecal = Lambda(lambda x: K.sum(x, daxis), name='sum_cell')(image) return _Model(output=[fake, aux, ecal], input=image, name='discriminator_model') def generator(latent_size=200, return_intermediate=False, with_bn=True): - + if keras.backend.image_data_format() =='channels_last': + dim = (7,7,8,8) + else: + dim = (8, 7, 7,8) latent = Input(shape=(latent_size, )) - bnm=0 x = _Dense(64 * 7* 7, init='glorot_normal', name='gen_dense1' )(latent) - x = Reshape((7, 7,8, 8))(x) + x = Reshape(dim)(x) x = _Conv3D(64, 6, 6, 8, border_mode='same', init='he_uniform', name='gen_c1' )(x) @@ -212,9 +222,14 @@ def generator(latent_size=200, return_intermediate=False, with_bn=True): return _Model(input=[latent], output=fake_image, name='generator_model') def get_sums(images): - sumsx = np.squeeze(np.sum(images, axis=(2,3))) - sumsy = np.squeeze(np.sum(images, axis=(1,3))) - sumsz = np.squeeze(np.sum(images, axis=(1,2))) + if keras.backend.image_data_format() =='channels_last': + sumsx = np.squeeze(np.sum(images, axis=(2,3))) + sumsy = np.squeeze(np.sum(images, axis=(1,3))) + sumsz = np.squeeze(np.sum(images, axis=(1,2))) + else: + sumsx = np.squeeze(np.sum(images, axis=(3,4))) + sumsy = np.squeeze(np.sum(images, axis=(2,4))) + sumsz = np.squeeze(np.sum(images, axis=(2,3))) return sumsx, sumsy, sumsz def get_moments(images, sumsx, sumsy, sumsz, totalE, m): @@ -535,6 +550,9 @@ def make_opt(**args): loss=['binary_crossentropy', 'mean_absolute_percentage_error', 'mean_absolute_percentage_error'], loss_weights=self.discr_loss_weights ) + if kv2: + self.discriminator.trainable = True #workaround for keras 2 bug + self.combined.metrics_names = self.discriminator.metrics_names print (self.discriminator.metrics_names) print (self.combined.metrics_names)