Skip to content

SAC code doesn't appropriately implement target_q #93

@gautams3

Description

@gautams3

The computation of target Q in the SERL SAC code, critic_loss_fn() has a potential bug.

In this file, if you set config['backup_entropy']=True, the term temperature * next_action_log_probs is subtracted from target_q. This is mathematically equivalent to

$$y(r,s',d) = r + \gamma*(1-d) \left[ \min_{1,2} Q(s',a') \right] - \alpha * \log \pi_{\theta}(a'|s)$$

where y(r,s',d) = target_q, r = batch['rewards'], $\gamma$ = config['discount'], (1-d) = batch['masks'], $\min_{1,2} Q(s',a')$ = target_next_min_q, $\alpha$ = temperature, $\log \pi_{\theta}(a'|s)$ = next_actions_log_probs, $a' \sim \pi(\cdot | s)$

But the formula for target_q should be

$$y(r,s',d) = r + \gamma*(1-d) \left[ \min_{1,2} Q(s',a') - \alpha * \log \pi_{\theta}(a'|s) \right] $$

i.e. the $\alpha * \log \pi_{\theta}(a'|s)$ term should also be multiplied by\gamma*(1-d). This is so the entropy term is appropriately weighted by the discount factor so that your value function calculations are accurate.

Sources:
[1] SAC paper, see eq 3 for Value function
[2] Spinning up RL by OpenAI, SAC pseudocode, see line 12 for computing target q values.

The fix is quite simple. You subtract the $\alpha * \log \pi_{\theta}(a'|s)$ term before you multiply by $\gamma*(1-d)$.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions