Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 120 additions & 1 deletion test/llm/test_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
GRPOLossOutput,
MCAdvantage,
)
from torchrl._utils import logger
from torchrl.objectives.llm.sft import SFTLoss

_has_transformers = importlib.util.find_spec("transformers") is not None
Expand Down Expand Up @@ -200,7 +201,7 @@ def test_grpo(self, mock_transformer_model, dapo):
)

# Create loss module
loss_fn = GRPOLoss(actor_network, eps=eps)
loss_fn = GRPOLoss(actor_network, clip_epsilon=eps)

# Create fake data
data = _mock_data_grpo(vocab_size=vocab_size, device=device)
Expand Down Expand Up @@ -245,6 +246,124 @@ def test_grpo(self, mock_transformer_model, dapo):
0 <= loss_vals.clip_fraction <= 1
), f"clip_fraction out of range: {loss_vals.clip_fraction}"

def test_kl_mask_threshold(self, mock_transformer_model):
"""Test that kl_mask_threshold properly filters out high-KL tokens."""
torch.manual_seed(42)
vocab_size = 1024
device = (
torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
)

# Create mock model and wrap it
model = mock_transformer_model(vocab_size=vocab_size, device=device)
actor_network = TransformersWrapper(
model,
generate=False,
pad_output=True,
input_mode="history",
)

# Create fake data
data = _mock_data_grpo(vocab_size=vocab_size, device=device)

# First, test that the data works without any threshold
loss_fn_baseline = GRPOLoss(
actor_network, clip_epsilon=0.2, kl_mask_threshold=None
)

data_baseline = data.clone()
loss_baseline = loss_fn_baseline(data_baseline)
logger.info(f"Baseline loss (no threshold): {loss_baseline.loss_objective}")
logger.info(f"Baseline ESS: {loss_baseline.ESS}")

# Check baseline is valid
if not torch.isfinite(loss_baseline.loss_objective):
raise ValueError(
f"Baseline loss is not finite: {loss_baseline.loss_objective}, skipping test"
)

# Now test with kl_mask_threshold enabled
# Use a very high threshold that should not mask any tokens
kl_threshold = 100.0 # Extremely high threshold to ensure no masking
loss_fn_with_threshold = GRPOLoss(
actor_network, clip_epsilon=0.2, kl_mask_threshold=kl_threshold
)

data_with_threshold = data.clone()
loss_with_threshold = loss_fn_with_threshold(data_with_threshold)

# Should produce valid output
assert isinstance(loss_with_threshold, GRPOLossOutput)

# Check that the loss is finite (with such a high threshold, it should be)
assert torch.isfinite(
loss_with_threshold.loss_objective
), f"loss_with_threshold is not finite: {loss_with_threshold.loss_objective}"
assert torch.isfinite(
loss_with_threshold.ESS
), f"ESS with threshold is not finite: {loss_with_threshold.ESS}"

logger.info(
f"Loss with high threshold (100.0): {loss_with_threshold.loss_objective}"
)
logger.info(f"ESS with high threshold: {loss_with_threshold.ESS}")

# The losses should be identical or very similar since we're not masking anything
# (the difference comes only from numerical precision)
assert torch.isclose(
loss_baseline.loss_objective, loss_with_threshold.loss_objective, rtol=1e-3
), f"Losses differ too much with high threshold: {loss_baseline.loss_objective} vs {loss_with_threshold.loss_objective}"

def test_failure_missing_entries(self, mock_transformer_model):
"""Test that GRPO fails when required keys are missing but works without optional keys."""
vocab_size = 1024
device = torch.device("cpu")

# Create mock model and wrap it
model = mock_transformer_model(vocab_size=vocab_size, device=device)
actor_network = TransformersWrapper(
model,
generate=False,
pad_output=True,
input_mode="history",
)

# Create loss module
loss_fn = GRPOLoss(actor_network, clip_epsilon=0.2)

# Create fake data
data = _mock_data_grpo(vocab_size=vocab_size, device=device)

# Test 1: Missing sample_log_prob (required) should fail
data_missing_sample_log_prob = data.clone()
data_missing_sample_log_prob.exclude(("log_probs", "full"), inplace=True)

with pytest.raises(KeyError, match="Couldn't find the log-prob"):
loss_fn(data_missing_sample_log_prob)

# Test 2: Missing ref_log_probs (optional when kl_to_ref_coeff is None) should work
data_missing_ref = data.clone()
# Remove the ref_log_probs key if it exists
if ("next", "ref_log_probs", "full") in data_missing_ref.keys(True):
data_missing_ref.exclude(("next", "ref_log_probs", "full"), inplace=True)

# Should work fine without ref_log_probs when kl_to_ref_coeff is None
loss_vals = loss_fn(data_missing_ref)
assert isinstance(loss_vals, GRPOLossOutput)
assert torch.isfinite(loss_vals.loss_objective)

# Test 3: Missing ref_log_probs when kl_to_ref_coeff is set should fail
loss_fn_with_kl = GRPOLoss(actor_network, clip_epsilon=0.2, kl_to_ref_coeff=0.1)

data_missing_ref_for_kl = data.clone()
if ("next", "ref_log_probs", "full") in data_missing_ref_for_kl.keys(True):
data_missing_ref_for_kl.exclude(
("next", "ref_log_probs", "full"), inplace=True
)

with pytest.raises(KeyError, match="Couldn't find the ref log-prob"):
loss_fn_with_kl(data_missing_ref_for_kl)

def test_cispo(self, mock_transformer_model):
"""Test CISPO loss computation with mock models."""
vocab_size = 1024
Expand Down
32 changes: 32 additions & 0 deletions torchrl/objectives/llm/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ class GRPOLoss(LossModule):
- float x: symmetric clipping [1 - x, 1 + x] (default: 0.2)
- tuple (eps_low, eps_high): asymmetric clipping [1 - eps_low, 1 + eps_high] as in DAPO Clip-Higher
recommended defaults from DAPO: (0.20, 0.28); see Eq. (10) in the paper.
kl_mask_threshold (float | None, optional): enable token-wise trust-region filtering (KL-Mask).
When set, tokens with 0.5 * (log(pi_theta/pi_ref))^2 > kl_mask_threshold are masked out from the loss.
This stabilizes updates by skipping tokens that drifted too far from the reference distribution
(see table and description; enables per-token trust region).
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
loss to favour exploratory policies.
samples_mc_entropy (int, optional): if the distribution retrieved from the policy
Expand Down Expand Up @@ -189,6 +193,7 @@ def __init__(
actor_network: LLMWrapperBase | None = None,
*,
clip_epsilon: float | tuple[float, float] = 0.2,
kl_mask_threshold: float | None = None,
entropy_bonus: bool = True,
samples_mc_entropy: int = 1,
entropy_coeff: float = 0.01,
Expand All @@ -208,6 +213,7 @@ def __init__(
self.samples_mc_entropy = samples_mc_entropy
self.entropy_coeff = entropy_coeff
self.reduction = reduction if reduction is not None else "mean"
self.kl_mask_threshold = kl_mask_threshold

# Determine device and register clip epsilon as buffer
if device is None:
Expand Down Expand Up @@ -382,6 +388,32 @@ def forward(self, tensordict: TensorDictBase) -> LLMOutputType:
tensordict, adv_shape=advantage.shape[:-1]
)
mask = dist.mask

# Optional per-token trust-region filtering (KL-Mask) vs reference policy
if self.kl_mask_threshold is not None and self.kl_mask_threshold > 0:
try:
inference_log_prob = tensordict.get(
self.tensor_keys.sample_log_prob,
as_padded_tensor=True,
padding_side="left",
padding_value=0.0,
)
except KeyError:
inference_log_prob = None
cur_log_prob = tensordict.get("_cur_log_prob", None)
if (inference_log_prob is not None) and (cur_log_prob is not None):
# Align to valid tokens only (safety)
cur_log_prob_masked = torch.where(
expand_as_right(mask, cur_log_prob), cur_log_prob, 0.0
)
inference_log_prob_masked = torch.where(
expand_as_right(mask, inference_log_prob), inference_log_prob, 0.0
)
log_is_ref = cur_log_prob_masked - inference_log_prob_masked
kl_token = 0.5 * (log_is_ref**2)
tr_mask = kl_token <= self.kl_mask_threshold
# Combine with attention mask
mask = mask & tr_mask
# ESS for logging
with torch.no_grad():
# In theory, ESS should be computed on particles sampled from the same source. Here we sample according
Expand Down
Loading