diff --git a/acme/agents/jax/td3/learning.py b/acme/agents/jax/td3/learning.py index 4459c9d7d6..354afda5a6 100644 --- a/acme/agents/jax/td3/learning.py +++ b/acme/agents/jax/td3/learning.py @@ -111,7 +111,7 @@ def policy_loss( in_axes=(None, 0, 0)) dq_da = grad_critic(critic_params, transition.observation, action) batch_dpg_learning = jax.vmap(rlax.dpg_loss, in_axes=(0, 0)) - loss = jnp.mean(batch_dpg_learning(action, dq_da)) + loss = jnp.mean(jnp.sum(batch_dpg_learning(action, dq_da), axis=-1)) if bc_alpha is not None: # BC regularization for offline RL q_sa = networks.critic_network.apply(critic_params,