diff --git a/draw.py b/draw.py index 0fa5c94..1a2a49c 100644 --- a/draw.py +++ b/draw.py @@ -17,6 +17,8 @@ tf.flags.DEFINE_string("data_dir", "", "") tf.flags.DEFINE_boolean("read_attn", True, "enable attention for reader") tf.flags.DEFINE_boolean("write_attn",True, "enable attention for writer") +tf.flags.DEFINE_integer("num_inter_threads", 0, "number of inter_threads") +tf.flags.DEFINE_integer("num_intra_threads", 0, "number of intra_threads") FLAGS = tf.flags.FLAGS ## MODEL PARAMETERS ## @@ -215,7 +217,8 @@ def binary_crossentropy(t,o): Lxs=[0]*train_iters Lzs=[0]*train_iters -sess=tf.InteractiveSession() +config = tf.ConfigProto(inter_op_parallelism_threads=FLAGS.num_inter_threads, intra_op_parallelism_threads=FLAGS.num_intra_threads) +sess=tf.InteractiveSession(config = config) saver = tf.train.Saver() # saves variables learned during training tf.global_variables_initializer().run()