From 70442c1bb8a79e114f7f5b8d4f8625f5e73ce7d7 Mon Sep 17 00:00:00 2001 From: Michelangelo Conserva Date: Mon, 19 Sep 2022 08:46:25 +0100 Subject: [PATCH] _total_steps assignment position bug The _total_steps should be increased after the agent takes an action as in the jax implementation of this agent. For values of the sgd_period higher than one, this bugs prevents the agent from training. --- bsuite/baselines/tf/boot_dqn/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bsuite/baselines/tf/boot_dqn/agent.py b/bsuite/baselines/tf/boot_dqn/agent.py index e423e984..095b9c7c 100644 --- a/bsuite/baselines/tf/boot_dqn/agent.py +++ b/bsuite/baselines/tf/boot_dqn/agent.py @@ -120,7 +120,6 @@ def _step(self, transitions: Sequence[tf.Tensor]): loss = tf.reduce_mean(tf.stack(losses)) gradients = tape.gradient(loss, variables) - self._total_steps.assign_add(1) self._optimizer.apply(gradients, variables) # Periodically update the target network. @@ -132,6 +131,7 @@ def _step(self, transitions: Sequence[tf.Tensor]): def select_action(self, timestep: dm_env.TimeStep) -> base.Action: """Select values via Thompson sampling, then use epsilon-greedy policy.""" + self._total_steps.assign_add(1) if self._rng.rand() < self._epsilon_fn(self._total_steps.numpy()): return self._rng.randint(self._num_actions)