@@ -82,6 +82,10 @@ class GRPOLoss(LossModule):
8282 - float x: symmetric clipping [1 - x, 1 + x] (default: 0.2)
8383 - tuple (eps_low, eps_high): asymmetric clipping [1 - eps_low, 1 + eps_high] as in DAPO Clip-Higher
8484 recommended defaults from DAPO: (0.20, 0.28); see Eq. (10) in the paper.
85+ kl_mask_threshold (float | None, optional): enable token-wise trust-region filtering (KL-Mask).
86+ When set, tokens with 0.5 * (log(pi_theta/pi_ref))^2 > kl_mask_threshold are masked out from the loss.
87+ This stabilizes updates by skipping tokens that drifted too far from the reference distribution
88+ (see table and description; enables per-token trust region).
8589 entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
8690 loss to favour exploratory policies.
8791 samples_mc_entropy (int, optional): if the distribution retrieved from the policy
@@ -142,6 +146,7 @@ def __init__(
142146 actor_network : LLMWrapperBase | None = None ,
143147 * ,
144148 clip_epsilon : float | tuple [float , float ] = 0.2 ,
149+ kl_mask_threshold : float | None = None ,
145150 entropy_bonus : bool = True ,
146151 samples_mc_entropy : int = 1 ,
147152 entropy_coeff : float = 0.01 ,
@@ -161,6 +166,7 @@ def __init__(
161166 self .samples_mc_entropy = samples_mc_entropy
162167 self .entropy_coeff = entropy_coeff
163168 self .reduction = reduction if reduction is not None else "mean"
169+ self .kl_mask_threshold = kl_mask_threshold
164170
165171 # Determine device and register clip epsilon as buffer
166172 if device is None :
@@ -335,6 +341,32 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
335341 tensordict , adv_shape = advantage .shape [:- 1 ]
336342 )
337343 mask = dist .mask
344+
345+ # Optional per-token trust-region filtering (KL-Mask) vs reference policy
346+ if self .kl_mask_threshold is not None and self .kl_mask_threshold > 0 :
347+ try :
348+ ref_log_prob = tensordict .get (
349+ self .tensor_keys .ref_log_probs ,
350+ as_padded_tensor = True ,
351+ padding_side = "left" ,
352+ padding_value = 0.0 ,
353+ )
354+ except KeyError :
355+ ref_log_prob = None
356+ cur_log_prob = tensordict .get ("_cur_log_prob" , None )
357+ if (ref_log_prob is not None ) and (cur_log_prob is not None ):
358+ # Align to valid tokens only (safety)
359+ cur_log_prob_masked = torch .where (
360+ expand_as_right (mask , cur_log_prob ), cur_log_prob , 0.0
361+ )
362+ ref_log_prob_masked = torch .where (
363+ expand_as_right (mask , ref_log_prob ), ref_log_prob , 0.0
364+ )
365+ log_is_ref = cur_log_prob_masked - ref_log_prob_masked
366+ kl_token = 0.5 * (log_is_ref ** 2 )
367+ tr_mask = kl_token <= self .kl_mask_threshold
368+ # Combine with attention mask
369+ mask = mask & tr_mask
338370 # ESS for logging
339371 with torch .no_grad ():
340372 # In theory, ESS should be computed on particles sampled from the same source. Here we sample according
0 commit comments