From 9cb849390284efb0dd85a5c22033452715f3b833 Mon Sep 17 00:00:00 2001 From: Devin Flake Date: Sat, 25 Feb 2017 09:54:13 -0700 Subject: [PATCH] updated tensorflow functions --- dqn.py | 8 ++++---- train.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dqn.py b/dqn.py index 994cd61..91e79c3 100755 --- a/dqn.py +++ b/dqn.py @@ -45,8 +45,8 @@ def __init__(self, env, params): self.actions = tf.placeholder(tf.float32, [None, self.num_actions]) self.q_target = tf.placeholder(tf.float32, [None]) - self.q_train = tf.reduce_max(tf.mul(self.train_net.y, self.actions), reduction_indices=1) - self.diff = tf.sub(self.q_target, self.q_train) + self.q_train = tf.reduce_max(tf.multiply(self.train_net.y, self.actions), reduction_indices=1) + self.diff = tf.subtract(self.q_target, self.q_train) half = tf.constant(0.5) if params.clip_delta > 0: @@ -54,9 +54,9 @@ def __init__(self, env, params): clipped_diff = tf.clip_by_value(abs_diff, 0, 1) linear_part = abs_diff - clipped_diff quadratic_part = tf.square(clipped_diff) - self.diff_square = tf.mul(half, tf.add(quadratic_part, linear_part)) + self.diff_square = tf.multiply(half, tf.add(quadratic_part, linear_part)) else: - self.diff_square = tf.mul(half, tf.square(self.diff)) + self.diff_square = tf.multiply(half, tf.square(self.diff)) if params.accumulator == 'sum': self.loss = tf.reduce_sum(self.diff_square) diff --git a/train.py b/train.py index 8cfbdaa..21c2a92 100755 --- a/train.py +++ b/train.py @@ -11,7 +11,7 @@ def __init__(self, agent): def run(self): with tf.Session() as sess: - sess.run(tf.initialize_all_variables()) + sess.run(tf.global_variables_initializer()) self.agent.randomRestart() successes = 0