@@ -86,6 +86,10 @@ class GRPOLoss(LossModule):
8686 When set, tokens with 0.5 * (log(pi_theta/pi_ref))^2 > kl_mask_threshold are masked out from the loss.
8787 This stabilizes updates by skipping tokens that drifted too far from the reference distribution
8888 (see table and description; enables per-token trust region).
89+ aggregation (str, optional): loss aggregation strategy for the policy objective.
90+ - "token_mean": global masked token mean (weights long sequences more). Default.
91+ - "prompt_mean": per-sample masked mean over tokens, then mean across samples (equal sample weight).
92+ - "none": return per-token loss (mask applied, no aggregation). Useful for downstream custom reductions.
8993 entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
9094 loss to favour exploratory policies.
9195 samples_mc_entropy (int, optional): if the distribution retrieved from the policy
@@ -147,6 +151,7 @@ def __init__(
147151 * ,
148152 clip_epsilon : float | tuple [float , float ] = 0.2 ,
149153 kl_mask_threshold : float | None = None ,
154+ aggregation : str | None = "token_mean" ,
150155 entropy_bonus : bool = True ,
151156 samples_mc_entropy : int = 1 ,
152157 entropy_coeff : float = 0.01 ,
@@ -167,6 +172,7 @@ def __init__(
167172 self .entropy_coeff = entropy_coeff
168173 self .reduction = reduction if reduction is not None else "mean"
169174 self .kl_mask_threshold = kl_mask_threshold
175+ self .aggregation = aggregation or "token_mean"
170176
171177 # Determine device and register clip epsilon as buffer
172178 if device is None :
@@ -397,13 +403,13 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
397403 td_out .set ("loss_entropy" , - self .entropy_coeff * entropy )
398404
399405 td_out .set ("ESS" , _reduce (ess / batch , self .reduction ))
400- td_out = td_out . named_apply (
401- lambda name , value : _reduce (
402- value , reduction = self . reduction , mask = mask
403- ). squeeze ( - 1 )
404- if name .startswith ("loss_" )
405- else value ,
406- )
406+ # Aggregate loss terms according to aggregation strategy
407+ for key in list ( td_out . keys ()):
408+ if isinstance ( key , tuple ) or not isinstance ( key , str ):
409+ continue
410+ if key .startswith ("loss_" ):
411+ val = td_out . get ( key )
412+ td_out . set ( key , self . _aggregate_loss_value ( val , mask ) )
407413 if self .kl_to_ref_coeff is not None and self .kl_to_ref_coeff > 0 :
408414 # FIXME: parameterize this
409415 loss_kl , kl_penalty = self ._kl_to_ref (
@@ -447,6 +453,34 @@ def _compute_policy_objective(
447453 gain = torch .stack ([gain1 , gain2 ], - 1 ).min (dim = - 1 ).values
448454 return - gain , clip_fraction
449455
456+ def _aggregate_loss_value (
457+ self , value : torch .Tensor , mask : torch .Tensor
458+ ) -> torch .Tensor :
459+ """Aggregate a per-token loss tensor using the configured strategy.
460+
461+ Supports:
462+ - token_mean: masked mean across all tokens (default)
463+ - prompt_mean: per-sample masked mean over tokens, then mean across batch
464+ - none: return per-token loss with masked-out tokens set to 0
465+
466+ The input `value` is expected to have shape [..., T, 1] where T is the token dimension,
467+ and `mask` has shape [..., T].
468+ """
469+ if self .aggregation == "none" or self .reduction == "none" :
470+ mask_exp = expand_as_right (mask , value )
471+ return torch .where (mask_exp , value , value .new_zeros (()).expand_as (value ))
472+
473+ if self .aggregation == "prompt_mean" :
474+ # Mean over valid tokens per sample, then mean across batch
475+ mask_exp = expand_as_right (mask , value ).to (value .dtype )
476+ token_sum = (value * mask_exp ).sum (dim = - 2 , keepdim = False )
477+ token_count = mask_exp .sum (dim = - 2 , keepdim = False ).clamp_min (1.0 )
478+ sample_mean = token_sum / token_count
479+ return sample_mean .mean (dim = 0 , keepdim = False )
480+
481+ # token_mean (global masked mean)
482+ return _reduce (value , reduction = "mean" , mask = mask ).squeeze (- 1 )
483+
450484 def _get_entropy (
451485 self , dist : d .Distribution , adv_shape : torch .Size
452486 ) -> torch .Tensor | TensorDict :
0 commit comments