diff --git a/DDQN/ddqn_agent.py b/DDQN/ddqn_agent.py index 66a15fe..a357e18 100644 --- a/DDQN/ddqn_agent.py +++ b/DDQN/ddqn_agent.py @@ -83,7 +83,7 @@ def learn(self): q_next = self.q_next.forward(states_) q_eval = self.q_eval.forward(states_) - max_actions = T.argmax(q_eval, dim=1) + max_actions = T.argmax(q_eval, dim=1).detach() q_next[dones] = 0.0 q_target = rewards + self.gamma*q_next[indices, max_actions]