Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion alf/algorithms/ddpg_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,8 @@ def calc_loss(self, experience, train_info: DdpgInfo):
critic_losses[i] = self._critic_losses[i](
experience=experience,
value=train_info.critic.q_values[:, :, i, ...],
target_value=train_info.critic.target_q_values).loss
target_value=train_info.critic.target_q_values,
train_info = train_info).loss

critic_loss = math_ops.add_n(critic_losses)

Expand Down
3 changes: 2 additions & 1 deletion alf/algorithms/sac_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,8 @@ def _calc_critic_loss(self, experience, train_info: SacInfo):
critic_losses.append(
l(experience=experience,
value=critic_info.critics[:, :, i, ...],
target_value=critic_info.target_critic).loss)
target_value=critic_info.target_critic,
train_info = train_info).loss)

critic_loss = math_ops.add_n(critic_losses)

Expand Down
2 changes: 1 addition & 1 deletion alf/algorithms/sarsa_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def calc_loss(self, experience, info: SarsaInfo):
target_critic = tensor_utils.tensor_prepend_zero(
info.target_critics)
loss_info = self._critic_losses[i](shifted_experience, critic,
target_critic)
target_critic,info)
critic_losses.append(nest_map(lambda l: l[:-1], loss_info.loss))

critic_loss = math_ops.add_n(critic_losses)
Expand Down
35 changes: 30 additions & 5 deletions alf/algorithms/td_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self,
gamma=0.99,
td_error_loss_fn=element_wise_squared_loss,
td_lambda=0.95,
use_retrace=0,
debug_summaries=False,
name="TDLoss"):
r"""Create a TDLoss object.
Expand All @@ -44,7 +45,7 @@ def __init__(self,
:math:`G_t^\lambda = \hat{A}^{GAE}_t + V(s_t)`
where the generalized advantage estimation is defined as:
:math:`\hat{A}^{GAE}_t = \sum_{i=t}^{T-1}(\gamma\lambda)^{i-t}(R_{i+1} + \gamma V(s_{i+1}) - V(s_i))`

use_retrace = 0 means one step or multi_step loss, use_retrace = 1 means retrace loss
References:

Schulman et al. `High-Dimensional Continuous Control Using Generalized Advantage Estimation
Expand All @@ -69,8 +70,8 @@ def __init__(self,
self._td_error_loss_fn = td_error_loss_fn
self._lambda = td_lambda
self._debug_summaries = debug_summaries

def forward(self, experience, value, target_value):
self._use_retrace = use_retrace
def forward(self, experience, value, target_value, train_info):
"""Cacluate the loss.

The first dimension of all the tensors is time dimension and the second
Expand All @@ -84,6 +85,8 @@ def forward(self, experience, value, target_value):
target_value (torch.Tensor): the time-major tensor for the value at
each time step. This is used to calculate return. ``target_value``
can be same as ``value``.
train_info (sarsa info, sac info): information used to calcuate importance_ratio
or importance_ratio_clipped
Returns:
LossInfo: with the ``extra`` field same as ``loss``.
"""
Expand All @@ -99,15 +102,37 @@ def forward(self, experience, value, target_value):
values=target_value,
step_types=experience.step_type,
discounts=experience.discount * self._gamma)
else:
elif self._use_retrace == 0:
advantages = value_ops.generalized_advantage_estimation(
rewards=experience.reward,
values=target_value,
step_types=experience.step_type,
discounts=experience.discount * self._gamma,
td_lambda=self._lambda)
returns = advantages + target_value[:-1]

else:
scope = alf.summary.scope(self.__class__.__name__)
importance_ratio,importance_ratio_clipped = value_ops.action_importance_ratio(
action_distribution=train_info.action_distribution,
collect_action_distribution=experience.rollout_info.action_distribution,
action=experience.action,
clipping_mode='capping',
importance_ratio_clipping= 0.0,
log_prob_clipping= 0.0,
scope=scope,
check_numerics=False,
debug_summaries=True)
advantages = value_ops.generalized_advantage_estimation_retrace(
importance_ratio = importance_ratio_clipped,
rewards=experience.reward,
values= value,
target_value = target_value,
step_types=experience.step_type,
discounts=experience.discount * self._gamma,
time_major = True,
td_lambda=self._lambda)
returns = advantages + value[:-1]
returns = returns.detach()
value = value[:-1]

if self._debug_summaries and alf.summary.should_record_summaries():
Expand Down
30 changes: 30 additions & 0 deletions alf/utils/value_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,33 @@ def generalized_advantage_estimation(rewards,
advs = advs.transpose(0, 1)

return advs.detach()

####### add for the retrace method
def generalized_advantage_estimation_retrace(importance_ratio, discounts, rewards, td_lambda, time_major, values, target_value,step_types):
#importance_ratio = torch.min(importance_ratio, torch.tensor(1.))
if not time_major:
discounts = discounts.transpose(0, 1)
rewards = rewards.transpose(0, 1)
values = values.transpose(0, 1)
step_types = step_types.transpose(0, 1)
importance_ratio = importance_ratio.transpose(0,1)
target_value = target_value.transpose(0,1)

assert values.shape[0] >= 2, ("The sequence length needs to be "
"at least 2. Got {s}".format(
s=values.shape[0]))
advs = torch.zeros_like(values)
is_lasts = (step_types == StepType.LAST).to(dtype=torch.float32)
delta = (rewards[1:] + discounts[1:] * target_value[1:] - values[:-1])


weighted_discounts = discounts[1:] * td_lambda * importance_ratio
with torch.no_grad():
for t in reversed(range(rewards.shape[0] - 1)):
advs[t] = (1 - is_lasts[t]) * \
(delta[t] + weighted_discounts[t] * advs[t + 1])
advs = advs[:-1]
if not time_major:
advs = advs.transpose(0, 1)

return advs.detach()
25 changes: 25 additions & 0 deletions alf/utils/value_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,32 @@ def test_generalized_advantage_estimation(self):
discounts=discounts,
td_lambda=td_lambda,
expected=expected)

class GeneralizedAdvantage_retrace_Test(unittest.TestCase):
"""Tests for alf.utils.value_ops
"""

def test_generalized_advantage_estimation_retrace(self):
values = torch.tensor([[2.] * 4], dtype=torch.float32)
step_types = torch.tensor([[StepType.MID] * 4], dtype=torch.int64)
rewards = torch.tensor([[3.] * 4], dtype=torch.float32)
discounts = torch.tensor([[0.9] * 4], dtype=torch.float32)
td_lambda = 0.6/0.9
target_value = torch.tensor([[3.] * 4], dtype=torch.float32)
importance_ratio = torch.tensor([[0.8] * 3], dtype=torch.float32)
d = 3 * 0.9+ 3 - 2
expected = torch.tensor([[ (d * 0.6 * 0.8 ) *0.6 * 0.8+ 0.6 * 0.8 * d + d, d * 0.6 * 0.8 + d, d]],
dtype=torch.float32)
np.testing.assert_array_almost_equal(
value_ops.generalized_advantage_estimation_retrace(
rewards=rewards,
values=values,
target_value = target_value,
step_types=step_types,
discounts=discounts,
td_lambda=td_lambda,
importance_ratio = importance_ratio,
time_major=False), expected)

if __name__ == '__main__':
unittest.main()