From 928259f2efbb7a193f39c0ffe0dd03f9498d98a8 Mon Sep 17 00:00:00 2001 From: Bohdan Naida Date: Mon, 17 Nov 2025 19:10:36 +0100 Subject: [PATCH 1/7] removed redundant variable assignments --- src/forge/losses/reinforce_loss.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/forge/losses/reinforce_loss.py b/src/forge/losses/reinforce_loss.py index f1c92b667..046ecf1bf 100644 --- a/src/forge/losses/reinforce_loss.py +++ b/src/forge/losses/reinforce_loss.py @@ -30,16 +30,14 @@ def forward( self, trainer_logits, target_ids, target_mask, target_weights, target_log_probs ): trainer_log_probs = compute_logprobs(trainer_logits, target_ids, align=False) - target_mask = target_mask.detach() target_weights = target_weights target_mask_sum = target_mask.sum() target_mask_sum = torch.maximum( target_mask_sum, torch.ones_like(target_mask_sum) ) - sampler_log_probs = target_log_probs # Importance sampling ratio - logp_diff = trainer_log_probs - sampler_log_probs.detach() + logp_diff = trainer_log_probs - target_log_probs.detach() importance_weights = torch.exp(logp_diff).detach() importance_weights = torch.clamp(importance_weights, min=0.1, max=10.0) weighted_advantages = target_weights * importance_weights From 512970d4c32f6600594c3ca3ff17410eb526f3a5 Mon Sep 17 00:00:00 2001 From: Bohdan Naida Date: Mon, 17 Nov 2025 19:12:11 +0100 Subject: [PATCH 2/7] remove duplicate logprobs computation --- src/forge/losses/reinforce_loss.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/forge/losses/reinforce_loss.py b/src/forge/losses/reinforce_loss.py index 046ecf1bf..91a4f7fc2 100644 --- a/src/forge/losses/reinforce_loss.py +++ b/src/forge/losses/reinforce_loss.py @@ -5,8 +5,6 @@ # LICENSE file in the root directory of this source tree. import torch - -from forge.util.ops import compute_logprobs from torch import nn @@ -27,9 +25,13 @@ def __init__(self): super().__init__() def forward( - self, trainer_logits, target_ids, target_mask, target_weights, target_log_probs + self, + trainer_log_probs, + target_ids, + target_mask, + target_weights, + target_log_probs, ): - trainer_log_probs = compute_logprobs(trainer_logits, target_ids, align=False) target_weights = target_weights target_mask_sum = target_mask.sum() target_mask_sum = torch.maximum( From 44f5beee420d754c2ae198ffdeca3be49d4be87e Mon Sep 17 00:00:00 2001 From: Bohdan Naida Date: Mon, 17 Nov 2025 19:14:31 +0100 Subject: [PATCH 3/7] rename parameters to follow RL terminology --- src/forge/losses/reinforce_loss.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/forge/losses/reinforce_loss.py b/src/forge/losses/reinforce_loss.py index 91a4f7fc2..61e7ddf7b 100644 --- a/src/forge/losses/reinforce_loss.py +++ b/src/forge/losses/reinforce_loss.py @@ -26,25 +26,24 @@ def __init__(self): def forward( self, - trainer_log_probs, + logprobs, target_ids, - target_mask, - target_weights, - target_log_probs, + padding_mask, + advantages, + sampling_logprobs, ): - target_weights = target_weights - target_mask_sum = target_mask.sum() + target_mask_sum = padding_mask.sum() target_mask_sum = torch.maximum( target_mask_sum, torch.ones_like(target_mask_sum) ) # Importance sampling ratio - logp_diff = trainer_log_probs - target_log_probs.detach() - importance_weights = torch.exp(logp_diff).detach() - importance_weights = torch.clamp(importance_weights, min=0.1, max=10.0) - weighted_advantages = target_weights * importance_weights + logp_diff = logprobs - sampling_logprobs.detach() + prob_ratio = torch.exp(logp_diff).detach() + prob_ratio = torch.clamp(prob_ratio, min=0.1, max=10.0) + advantages = advantages * prob_ratio - numerator = (-trainer_log_probs * weighted_advantages * target_mask).sum() + numerator = (-logprobs * advantages * padding_mask).sum() denominator = target_mask_sum return numerator / denominator From f23f2af4b675342c91047345633e1e7bacc9a034 Mon Sep 17 00:00:00 2001 From: Bohdan Naida Date: Mon, 17 Nov 2025 19:17:47 +0100 Subject: [PATCH 4/7] make clamping optional and moved parameters to the constructor --- src/forge/losses/reinforce_loss.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/forge/losses/reinforce_loss.py b/src/forge/losses/reinforce_loss.py index 61e7ddf7b..1b22303dc 100644 --- a/src/forge/losses/reinforce_loss.py +++ b/src/forge/losses/reinforce_loss.py @@ -21,8 +21,12 @@ class ReinforceLoss(nn.Module): numerical noise. GRPO is more resilient in this case. """ - def __init__(self): + def __init__( + self, prob_ratio_min: float | None = None, prob_ratio_max: float | None = None + ): super().__init__() + self.prob_ratio_min = prob_ratio_min + self.prob_ratio_max = prob_ratio_max def forward( self, @@ -40,7 +44,9 @@ def forward( # Importance sampling ratio logp_diff = logprobs - sampling_logprobs.detach() prob_ratio = torch.exp(logp_diff).detach() - prob_ratio = torch.clamp(prob_ratio, min=0.1, max=10.0) + prob_ratio = torch.clamp( + prob_ratio, min=self.prob_ratio_min, max=self.prob_ratio_max + ) advantages = advantages * prob_ratio numerator = (-logprobs * advantages * padding_mask).sum() From e21955006156c2eb993a3a245ae98c231811f306 Mon Sep 17 00:00:00 2001 From: Bohdan Naida Date: Mon, 17 Nov 2025 19:21:10 +0100 Subject: [PATCH 5/7] normalize by sequence length instead of total tokens --- src/forge/losses/reinforce_loss.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/forge/losses/reinforce_loss.py b/src/forge/losses/reinforce_loss.py index 1b22303dc..54db71611 100644 --- a/src/forge/losses/reinforce_loss.py +++ b/src/forge/losses/reinforce_loss.py @@ -36,12 +36,6 @@ def forward( advantages, sampling_logprobs, ): - target_mask_sum = padding_mask.sum() - target_mask_sum = torch.maximum( - target_mask_sum, torch.ones_like(target_mask_sum) - ) - - # Importance sampling ratio logp_diff = logprobs - sampling_logprobs.detach() prob_ratio = torch.exp(logp_diff).detach() prob_ratio = torch.clamp( @@ -49,7 +43,10 @@ def forward( ) advantages = advantages * prob_ratio - numerator = (-logprobs * advantages * padding_mask).sum() + per_token_loss = -logprobs * advantages + sequence_length = padding_mask.sum(dim=1).clamp(min=1.0) + per_sequence_loss = (per_token_loss * padding_mask).sum(dim=1) / sequence_length + + loss = per_sequence_loss.mean() - denominator = target_mask_sum - return numerator / denominator + return loss From ee56fd8150cfa3e7a8cdde7c27719c92533357fd Mon Sep 17 00:00:00 2001 From: Bohdan Naida Date: Mon, 17 Nov 2025 19:22:22 +0100 Subject: [PATCH 6/7] remove redundant detach as we get sampling_logprobs from vllm --- src/forge/losses/reinforce_loss.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/forge/losses/reinforce_loss.py b/src/forge/losses/reinforce_loss.py index 54db71611..92d4dab6b 100644 --- a/src/forge/losses/reinforce_loss.py +++ b/src/forge/losses/reinforce_loss.py @@ -36,8 +36,7 @@ def forward( advantages, sampling_logprobs, ): - logp_diff = logprobs - sampling_logprobs.detach() - prob_ratio = torch.exp(logp_diff).detach() + prob_ratio = torch.exp(logprobs - sampling_logprobs) prob_ratio = torch.clamp( prob_ratio, min=self.prob_ratio_min, max=self.prob_ratio_max ) From f3e49e43a035c157cea7877ea7eb4bd1b4b5337f Mon Sep 17 00:00:00 2001 From: Bohdan Naida Date: Mon, 17 Nov 2025 19:39:11 +0100 Subject: [PATCH 7/7] remove unused func parameters and align API with SimpleGRPOLoss --- src/forge/losses/reinforce_loss.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/forge/losses/reinforce_loss.py b/src/forge/losses/reinforce_loss.py index 92d4dab6b..8f8b6e846 100644 --- a/src/forge/losses/reinforce_loss.py +++ b/src/forge/losses/reinforce_loss.py @@ -28,14 +28,7 @@ def __init__( self.prob_ratio_min = prob_ratio_min self.prob_ratio_max = prob_ratio_max - def forward( - self, - logprobs, - target_ids, - padding_mask, - advantages, - sampling_logprobs, - ): + def forward(self, logprobs, sampling_logprobs, advantages, padding_mask): prob_ratio = torch.exp(logprobs - sampling_logprobs) prob_ratio = torch.clamp( prob_ratio, min=self.prob_ratio_min, max=self.prob_ratio_max