From d1ad9aca9362936ae851bb49f33612f0e46f407d Mon Sep 17 00:00:00 2001 From: zhuboli <55901904+zhuboli@users.noreply.github.com> Date: Wed, 7 Oct 2020 16:34:03 -0700 Subject: [PATCH] fix retrace --- alf/algorithms/ddpg_algorithm.py | 3 ++- alf/algorithms/sac_algorithm.py | 3 ++- alf/algorithms/sarsa_algorithm.py | 2 +- alf/algorithms/td_loss.py | 35 ++++++++++++++++++++++++++----- alf/utils/value_ops.py | 30 ++++++++++++++++++++++++++ alf/utils/value_ops_test.py | 25 ++++++++++++++++++++++ 6 files changed, 90 insertions(+), 8 deletions(-) diff --git a/alf/algorithms/ddpg_algorithm.py b/alf/algorithms/ddpg_algorithm.py index 6c5ca7284..98ed6a7cc 100644 --- a/alf/algorithms/ddpg_algorithm.py +++ b/alf/algorithms/ddpg_algorithm.py @@ -326,7 +326,8 @@ def calc_loss(self, experience, train_info: DdpgInfo): critic_losses[i] = self._critic_losses[i]( experience=experience, value=train_info.critic.q_values[:, :, i, ...], - target_value=train_info.critic.target_q_values).loss + target_value=train_info.critic.target_q_values, + train_info = train_info).loss critic_loss = math_ops.add_n(critic_losses) diff --git a/alf/algorithms/sac_algorithm.py b/alf/algorithms/sac_algorithm.py index 65633d297..de44c4cf3 100644 --- a/alf/algorithms/sac_algorithm.py +++ b/alf/algorithms/sac_algorithm.py @@ -757,7 +757,8 @@ def _calc_critic_loss(self, experience, train_info: SacInfo): critic_losses.append( l(experience=experience, value=critic_info.critics[:, :, i, ...], - target_value=critic_info.target_critic).loss) + target_value=critic_info.target_critic, + train_info = train_info).loss) critic_loss = math_ops.add_n(critic_losses) diff --git a/alf/algorithms/sarsa_algorithm.py b/alf/algorithms/sarsa_algorithm.py index 86d07a74f..7c22fcb1b 100644 --- a/alf/algorithms/sarsa_algorithm.py +++ b/alf/algorithms/sarsa_algorithm.py @@ -435,7 +435,7 @@ def calc_loss(self, experience, info: SarsaInfo): target_critic = tensor_utils.tensor_prepend_zero( info.target_critics) loss_info = self._critic_losses[i](shifted_experience, critic, - target_critic) + target_critic,info) critic_losses.append(nest_map(lambda l: l[:-1], loss_info.loss)) critic_loss = math_ops.add_n(critic_losses) diff --git a/alf/algorithms/td_loss.py b/alf/algorithms/td_loss.py index b72dc592d..281a7ff1c 100644 --- a/alf/algorithms/td_loss.py +++ b/alf/algorithms/td_loss.py @@ -29,6 +29,7 @@ def __init__(self, gamma=0.99, td_error_loss_fn=element_wise_squared_loss, td_lambda=0.95, + use_retrace=0, debug_summaries=False, name="TDLoss"): r"""Create a TDLoss object. @@ -44,7 +45,7 @@ def __init__(self, :math:`G_t^\lambda = \hat{A}^{GAE}_t + V(s_t)` where the generalized advantage estimation is defined as: :math:`\hat{A}^{GAE}_t = \sum_{i=t}^{T-1}(\gamma\lambda)^{i-t}(R_{i+1} + \gamma V(s_{i+1}) - V(s_i))` - + use_retrace = 0 means one step or multi_step loss, use_retrace = 1 means retrace loss References: Schulman et al. `High-Dimensional Continuous Control Using Generalized Advantage Estimation @@ -69,8 +70,8 @@ def __init__(self, self._td_error_loss_fn = td_error_loss_fn self._lambda = td_lambda self._debug_summaries = debug_summaries - - def forward(self, experience, value, target_value): + self._use_retrace = use_retrace + def forward(self, experience, value, target_value, train_info): """Cacluate the loss. The first dimension of all the tensors is time dimension and the second @@ -84,6 +85,8 @@ def forward(self, experience, value, target_value): target_value (torch.Tensor): the time-major tensor for the value at each time step. This is used to calculate return. ``target_value`` can be same as ``value``. + train_info (sarsa info, sac info): information used to calcuate importance_ratio + or importance_ratio_clipped Returns: LossInfo: with the ``extra`` field same as ``loss``. """ @@ -99,7 +102,7 @@ def forward(self, experience, value, target_value): values=target_value, step_types=experience.step_type, discounts=experience.discount * self._gamma) - else: + elif self._use_retrace == 0: advantages = value_ops.generalized_advantage_estimation( rewards=experience.reward, values=target_value, @@ -107,7 +110,29 @@ def forward(self, experience, value, target_value): discounts=experience.discount * self._gamma, td_lambda=self._lambda) returns = advantages + target_value[:-1] - + else: + scope = alf.summary.scope(self.__class__.__name__) + importance_ratio,importance_ratio_clipped = value_ops.action_importance_ratio( + action_distribution=train_info.action_distribution, + collect_action_distribution=experience.rollout_info.action_distribution, + action=experience.action, + clipping_mode='capping', + importance_ratio_clipping= 0.0, + log_prob_clipping= 0.0, + scope=scope, + check_numerics=False, + debug_summaries=True) + advantages = value_ops.generalized_advantage_estimation_retrace( + importance_ratio = importance_ratio_clipped, + rewards=experience.reward, + values= value, + target_value = target_value, + step_types=experience.step_type, + discounts=experience.discount * self._gamma, + time_major = True, + td_lambda=self._lambda) + returns = advantages + value[:-1] + returns = returns.detach() value = value[:-1] if self._debug_summaries and alf.summary.should_record_summaries(): diff --git a/alf/utils/value_ops.py b/alf/utils/value_ops.py index a6bf85a23..a93d2cb4d 100644 --- a/alf/utils/value_ops.py +++ b/alf/utils/value_ops.py @@ -255,3 +255,33 @@ def generalized_advantage_estimation(rewards, advs = advs.transpose(0, 1) return advs.detach() + +####### add for the retrace method +def generalized_advantage_estimation_retrace(importance_ratio, discounts, rewards, td_lambda, time_major, values, target_value,step_types): + #importance_ratio = torch.min(importance_ratio, torch.tensor(1.)) + if not time_major: + discounts = discounts.transpose(0, 1) + rewards = rewards.transpose(0, 1) + values = values.transpose(0, 1) + step_types = step_types.transpose(0, 1) + importance_ratio = importance_ratio.transpose(0,1) + target_value = target_value.transpose(0,1) + + assert values.shape[0] >= 2, ("The sequence length needs to be " + "at least 2. Got {s}".format( + s=values.shape[0])) + advs = torch.zeros_like(values) + is_lasts = (step_types == StepType.LAST).to(dtype=torch.float32) + delta = (rewards[1:] + discounts[1:] * target_value[1:] - values[:-1]) + + + weighted_discounts = discounts[1:] * td_lambda * importance_ratio + with torch.no_grad(): + for t in reversed(range(rewards.shape[0] - 1)): + advs[t] = (1 - is_lasts[t]) * \ + (delta[t] + weighted_discounts[t] * advs[t + 1]) + advs = advs[:-1] + if not time_major: + advs = advs.transpose(0, 1) + + return advs.detach() \ No newline at end of file diff --git a/alf/utils/value_ops_test.py b/alf/utils/value_ops_test.py index ebd526127..024c12bca 100644 --- a/alf/utils/value_ops_test.py +++ b/alf/utils/value_ops_test.py @@ -170,7 +170,32 @@ def test_generalized_advantage_estimation(self): discounts=discounts, td_lambda=td_lambda, expected=expected) + +class GeneralizedAdvantage_retrace_Test(unittest.TestCase): + """Tests for alf.utils.value_ops + """ + def test_generalized_advantage_estimation_retrace(self): + values = torch.tensor([[2.] * 4], dtype=torch.float32) + step_types = torch.tensor([[StepType.MID] * 4], dtype=torch.int64) + rewards = torch.tensor([[3.] * 4], dtype=torch.float32) + discounts = torch.tensor([[0.9] * 4], dtype=torch.float32) + td_lambda = 0.6/0.9 + target_value = torch.tensor([[3.] * 4], dtype=torch.float32) + importance_ratio = torch.tensor([[0.8] * 3], dtype=torch.float32) + d = 3 * 0.9+ 3 - 2 + expected = torch.tensor([[ (d * 0.6 * 0.8 ) *0.6 * 0.8+ 0.6 * 0.8 * d + d, d * 0.6 * 0.8 + d, d]], + dtype=torch.float32) + np.testing.assert_array_almost_equal( + value_ops.generalized_advantage_estimation_retrace( + rewards=rewards, + values=values, + target_value = target_value, + step_types=step_types, + discounts=discounts, + td_lambda=td_lambda, + importance_ratio = importance_ratio, + time_major=False), expected) if __name__ == '__main__': unittest.main()