Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 15 additions & 21 deletions src/forge/losses/reinforce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
# LICENSE file in the root directory of this source tree.

import torch

from forge.util.ops import compute_logprobs
from torch import nn


Expand All @@ -23,28 +21,24 @@ class ReinforceLoss(nn.Module):
numerical noise. GRPO is more resilient in this case.
"""

def __init__(self):
def __init__(
self, prob_ratio_min: float | None = None, prob_ratio_max: float | None = None
):
super().__init__()
self.prob_ratio_min = prob_ratio_min
self.prob_ratio_max = prob_ratio_max

def forward(
self, trainer_logits, target_ids, target_mask, target_weights, target_log_probs
):
trainer_log_probs = compute_logprobs(trainer_logits, target_ids, align=False)
target_mask = target_mask.detach()
target_weights = target_weights
target_mask_sum = target_mask.sum()
target_mask_sum = torch.maximum(
target_mask_sum, torch.ones_like(target_mask_sum)
def forward(self, logprobs, sampling_logprobs, advantages, padding_mask):
prob_ratio = torch.exp(logprobs - sampling_logprobs)
prob_ratio = torch.clamp(
prob_ratio, min=self.prob_ratio_min, max=self.prob_ratio_max
)
sampler_log_probs = target_log_probs
advantages = advantages * prob_ratio

# Importance sampling ratio
logp_diff = trainer_log_probs - sampler_log_probs.detach()
importance_weights = torch.exp(logp_diff).detach()
importance_weights = torch.clamp(importance_weights, min=0.1, max=10.0)
weighted_advantages = target_weights * importance_weights
per_token_loss = -logprobs * advantages
sequence_length = padding_mask.sum(dim=1).clamp(min=1.0)
per_sequence_loss = (per_token_loss * padding_mask).sum(dim=1) / sequence_length

numerator = (-trainer_log_probs * weighted_advantages * target_mask).sum()
loss = per_sequence_loss.mean()

denominator = target_mask_sum
return numerator / denominator
return loss