From dc9dfcb57c46a8031c238314c6b4c06cd1bb588f Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 16 Oct 2025 21:41:12 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/objectives/llm/grpo.py | 48 +++++++++++++++++++++++++++++----- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/torchrl/objectives/llm/grpo.py b/torchrl/objectives/llm/grpo.py index ad3dd5188ca..a57613dc9e2 100644 --- a/torchrl/objectives/llm/grpo.py +++ b/torchrl/objectives/llm/grpo.py @@ -86,6 +86,10 @@ class GRPOLoss(LossModule): When set, tokens with 0.5 * (log(pi_theta/pi_ref))^2 > kl_mask_threshold are masked out from the loss. This stabilizes updates by skipping tokens that drifted too far from the reference distribution (see table and description; enables per-token trust region). + aggregation (str, optional): loss aggregation strategy for the policy objective. + - "token_mean": global masked token mean (weights long sequences more). Default. + - "prompt_mean": per-sample masked mean over tokens, then mean across samples (equal sample weight). + - "none": return per-token loss (mask applied, no aggregation). Useful for downstream custom reductions. entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the loss to favour exploratory policies. samples_mc_entropy (int, optional): if the distribution retrieved from the policy @@ -150,6 +154,7 @@ def __init__( *, clip_epsilon: float | tuple[float, float] = 0.2, kl_mask_threshold: float | None = None, + aggregation: str | None = "token_mean", entropy_bonus: bool = True, samples_mc_entropy: int = 1, entropy_coeff: float = 0.01, @@ -170,6 +175,7 @@ def __init__( self.entropy_coeff = entropy_coeff self.reduction = reduction self.kl_mask_threshold = kl_mask_threshold + self.aggregation = aggregation or "token_mean" # Determine device and register clip epsilon as buffer if device is None: @@ -396,13 +402,13 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput: td_out.set("loss_entropy", -self.entropy_coeff * entropy) td_out.set("ESS", _reduce(ess / batch, self.reduction)) - td_out = td_out.named_apply( - lambda name, value: _reduce( - value, reduction=self.reduction, mask=mask - ).squeeze(-1) - if name.startswith("loss_") - else value, - ) + # Aggregate loss terms according to aggregation strategy + for key in list(td_out.keys()): + if isinstance(key, tuple) or not isinstance(key, str): + continue + if key.startswith("loss_"): + val = td_out.get(key) + td_out.set(key, self._aggregate_loss_value(val, mask)) if self.kl_to_ref_coeff is not None and self.kl_to_ref_coeff > 0: # FIXME: parameterize this loss_kl, kl_penalty = self._kl_to_ref( @@ -446,6 +452,34 @@ def _compute_policy_objective( gain = torch.stack([gain1, gain2], -1).min(dim=-1).values return -gain, clip_fraction + def _aggregate_loss_value( + self, value: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: + """Aggregate a per-token loss tensor using the configured strategy. + + Supports: + - token_mean: masked mean across all tokens (default) + - prompt_mean: per-sample masked mean over tokens, then mean across batch + - none: return per-token loss with masked-out tokens set to 0 + + The input `value` is expected to have shape [..., T, 1] where T is the token dimension, + and `mask` has shape [..., T]. + """ + if self.aggregation == "none" or self.reduction == "none": + mask_exp = expand_as_right(mask, value) + return torch.where(mask_exp, value, value.new_zeros(()).expand_as(value)) + + if self.aggregation == "prompt_mean": + # Mean over valid tokens per sample, then mean across batch + mask_exp = expand_as_right(mask, value).to(value.dtype) + token_sum = (value * mask_exp).sum(dim=-2, keepdim=False) + token_count = mask_exp.sum(dim=-2, keepdim=False).clamp_min(1.0) + sample_mean = token_sum / token_count + return sample_mean.mean(dim=0, keepdim=False) + + # token_mean (global masked mean) + return _reduce(value, reduction="mean", mask=mask).squeeze(-1) + def _get_entropy( self, dist: d.Distribution, adv_shape: torch.Size ) -> torch.Tensor | TensorDict: