diff --git a/train.py b/train.py index 454d6a9..79efe62 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,7 @@ import config import trainer_step + def main(): cfg = config.Config(filenamequeue="../data/dataset/layout_1205.tfrecords") t = trainer_step.Trainer(cfg) diff --git a/trainer_step.py b/trainer_step.py index 1f80a71..5fbe173 100644 --- a/trainer_step.py +++ b/trainer_step.py @@ -11,7 +11,7 @@ tf.compat.v1.disable_eager_execution() # slim = tf.contrib.slim -DATA_DIR = os.path.join('..','data') +DATA_DIR = os.path.join('..') class Trainer(object): def __init__(self, config): filenamequeue = tf.compat.v1.train.string_input_producer([config.filenamequeue]) @@ -58,7 +58,8 @@ def _build_model(self, filenamequeue, config): randomz = tf.random.normal([config.batch_size, config.z_dim]) # testing case testLayout, testImgfea, testSemvec, testTexfea = self.inputs(config) - randomz_val = np.load(os.path.join(DATA_DIR,'sample','noiseVector_128.npy')) + # randomz_val = np.load(os.path.join(DATA_DIR,'/drive/MyDrive/LayoutNetme/sample','noiseVector_128.npy')) + randomz_val = np.load('../drive/MyDrive/LayoutNetme/sample/noiseVector_128.npy') testSemvec, _ = model.embeddingSemvec(testSemvec, training_, reuse=True) testImgfea, _ = model.embeddingImg(testImgfea, training_, reuse=True) testTexfea, _ = model.embeddingTex(testTexfea, training_, reuse=True) @@ -223,8 +224,8 @@ def sample(self): return (inputdata + 1) / 2.0, (gen + 1) / 2.0, fea def testing(self): - new_saver = tf.compat.v1.train.import_meta_graph('./log/layoutNet-100.meta') - new_saver.restore(self.sess, './log/layoutNet-100') + new_saver = tf.compat.v1.train.import_meta_graph('/layoutNet-100.meta') + new_saver.restore(self.sess, '/layoutNet-100') gen, fea = self.sess.run([self.model["Gtest"], self.model["Etest"]], feed_dict={self.model["is_training"]: False}) @@ -232,12 +233,12 @@ def testing(self): return (gen + 1) / 2.0, fea def inputs(self, config): - layoutpath = DATA_DIR + '/sample/layout/' - imgfeapath = DATA_DIR + '/sample/visfea/' - texfeapath = DATA_DIR + '/sample/texfea/' - semvecpath = DATA_DIR + '/sample/semvec/' + layoutpath = DATA_DIR + '/drive/MyDrive/LayoutNetme/validation/layout/' + imgfeapath = DATA_DIR + '/drive/MyDrive/LayoutNetme/validation/visfea/' + texfeapath = DATA_DIR + '/drive/MyDrive/LayoutNetme/validation/texfea/' + semvecpath = DATA_DIR + '/drive/MyDrive/LayoutNetme/validation/semvec/' - f = open((DATA_DIR + '/sample/imgSel_128.txt'), 'r') + f = open((DATA_DIR + '/drive/MyDrive/LayoutNetme/validation/imgSel_60.txt'), 'r') name = f.read() namelist = name.split() n_samples = len(namelist)