From 8e4f63fb7c9b7a945603ae68c9e7f649510257aa Mon Sep 17 00:00:00 2001 From: Yicheng Luo Date: Thu, 7 Jul 2022 00:33:41 +0100 Subject: [PATCH] Fix policy loss gradient in TD3 The gradients dq_da is currently incorrect. The gradients for each dimension from the action should be summed instead of averaged as per https://github.com/deepmind/rlax/blob/master/rlax/_src/policy_gradients_test.py#L55 For example, the D4PG agent also doesn't sum over the action dimension. In the case where someone may wish to write a D4PG-BC agent, this may cause a similar issue. The current policy loss is fine if using this version to run online TD3. When using an optimizer such as Adam, which normalizes the gradients based on the magnitude, the constant should not affect the computed policy updates. However, the current policy loss computation can be problematic if the user wants to use Acme's version of TD3 to reproduce results from the TD3-BC paper using the TD3-BC paper's default `bc_alpha` hyperparameter (which is 2.5). Without the sum, the relative magnitude of the gradient from the critic and the bc loss is different compared to the original implementation. I have noticed that this version of TD3 performs badly on some of the D4RL locomotion datasets (e.g., hopper-medium-replay-v2). I have found that without summing over the action dimension, the evaluation return is very unstable. --- acme/agents/jax/td3/learning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/acme/agents/jax/td3/learning.py b/acme/agents/jax/td3/learning.py index 4459c9d7d6..354afda5a6 100644 --- a/acme/agents/jax/td3/learning.py +++ b/acme/agents/jax/td3/learning.py @@ -111,7 +111,7 @@ def policy_loss( in_axes=(None, 0, 0)) dq_da = grad_critic(critic_params, transition.observation, action) batch_dpg_learning = jax.vmap(rlax.dpg_loss, in_axes=(0, 0)) - loss = jnp.mean(batch_dpg_learning(action, dq_da)) + loss = jnp.mean(jnp.sum(batch_dpg_learning(action, dq_da), axis=-1)) if bc_alpha is not None: # BC regularization for offline RL q_sa = networks.critic_network.apply(critic_params,