diff --git a/GRPO_BUG_FIX_SUMMARY.md b/GRPO_BUG_FIX_SUMMARY.md new file mode 100644 index 000000000..52a104420 --- /dev/null +++ b/GRPO_BUG_FIX_SUMMARY.md @@ -0,0 +1,92 @@ +# GRPO Bug Fix Summary + +## Critical Bugs Found and Fixed + +### Bug 1: Behavior Policy Logprobs Never Stored +**Problem:** The GRPO algorithm requires logprobs from the **behavior policy** (the policy that generated the responses), but these were never being stored or used. + +**Impact:** The policy gradient had no signal - it was computing `exp(logprobs - logprobs) = exp(0) = 1`, making the importance sampling ratio always 1, which means the model wasn't actually learning from the policy gradient. + +**Fixes Applied:** +1. Added `behavior_logprobs` field to `Episode` dataclass +2. Extracted behavior_logprobs from `completion.logprobs` when creating episodes +3. Added `behavior_logprobs_tensor` property to handle padding +4. Updated `collate()` function to include behavior_logprobs in batches +5. Updated loss functions to use behavior_logprobs instead of `logprobs.detach()` + +### Bug 2: Logprobs Not Enabled in vLLM +**Problem:** The sampling_params configuration didn't request logprobs from vLLM, so `completion.logprobs` was always `None`. + +**Fix Applied:** +- Added `logprobs: 1` to `sampling_params` in `qwen3_1_7b.yaml` + +### Bug 3: Incorrect Importance Sampling Ratio +**Problem:** Both loss functions used: +```python +per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages +``` + +This is always `exp(0) = 1`, providing no learning signal! + +**Correct Formula:** +```python +per_token_policy_loss = torch.exp(logprobs - behavior_logprobs.detach()) * advantages +``` + +**Files Modified:** +- `apps/grpo/main.py` - `simple_grpo_loss()` function +- `src/forge/losses/grpo_loss.py` - `SimpleGRPOLoss.forward()` method + +## GRPO Algorithm Explanation + +GRPO (Group Relative Policy Optimization) uses **three** sets of logprobs: + +1. **Current Policy Logprobs** (`logprobs`): From the model being trained + - Used to compute gradients + - Computed on-the-fly during training + +2. **Behavior Policy Logprobs** (`behavior_logprobs`): From the policy that generated the responses + - Used for importance sampling: `ratio = exp(current - behavior)` + - Must be stored when responses are generated + - In "off-by-n" setting, this is the policy from n steps ago + +3. **Reference Policy Logprobs** (`ref_logprobs`): From a frozen reference model + - Used for KL regularization to prevent the policy from diverging too much + - Computed from a frozen copy of the initial model + +## Testing the Fix + +To verify the fix works: + +```bash +# Train with learning rate 1e-5 (should learn) +python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml trainer.optimizer.lr=1e-5 + +# Train with learning rate 0 (should NOT learn - flat rewards) +python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml trainer.optimizer.lr=0 +``` + +**Expected behavior after fix:** +- With lr=1e-5: Rewards should improve over time +- With lr=0: Rewards should stay flat (no learning) +- The two runs should have DIFFERENT reward trajectories + +**Before the fix:** +- Both runs had identical reward patterns (no actual learning happening) + +## Files Changed + +1. `apps/grpo/main.py` + - Updated `Episode` dataclass + - Updated `collate()` function + - Updated `simple_grpo_loss()` function + - Added behavior_logprobs extraction in rollout loop + - Added `behavior_logprobs_tensor` property + +2. `src/forge/losses/grpo_loss.py` + - Updated `SimpleGRPOLoss.forward()` signature and implementation + - Added documentation explaining the three logprob types + +3. `apps/grpo/qwen3_1_7b.yaml` + - Added `logprobs: 1` to sampling_params + diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 693dc8d81..a6e086627 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -7,6 +7,7 @@ # Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml import asyncio +import logging import time import uuid from dataclasses import dataclass @@ -49,7 +50,8 @@ class Episode: target: Any | None = None # Processed data completion: Completion | None = None - ref_logprobs: torch.Tensor | None = None + behavior_logprobs: torch.Tensor | None = None # Logprobs from the policy that generated the response + ref_logprobs: torch.Tensor | None = None # Logprobs from the reference model reward: float | None = None advantage: float | None = None @@ -72,6 +74,15 @@ def response_tensor(self) -> torch.Tensor: diff = self.response_len - tensor.shape[0] tensor = F.pad(tensor, (0, diff), value=self.pad_id) return tensor + + @property + def behavior_logprobs_tensor(self) -> torch.Tensor: + """Get behavior logprobs padded to response_len.""" + tensor: torch.Tensor = self.behavior_logprobs + if tensor.shape[0] < self.response_len: # right pad with zeros (will be masked) + diff = self.response_len - tensor.shape[0] + tensor = F.pad(tensor, (0, diff), value=0.0) + return tensor # Represents the group (G) of episodes in GRPO @@ -97,6 +108,9 @@ def collate( response = [e.response_tensor for e in batch] response = torch.stack(response) # [b x s] + behavior_logprobs = [e.behavior_logprobs_tensor for e in batch] + behavior_logprobs = torch.stack(behavior_logprobs).squeeze() # [b x s] + ref_logprobs = [e.ref_logprobs for e in batch] ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s] @@ -109,6 +123,7 @@ def collate( input = {"tokens": torch.cat([request, response], dim=1)} target = { "response": response, + "behavior_logprobs": behavior_logprobs, "ref_logprobs": ref_logprobs, "advantages": advantages, "padding_mask": mask, @@ -119,17 +134,23 @@ def collate( # Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss` +_grpo_logger = logging.getLogger(__name__) + def simple_grpo_loss( logits: torch.Tensor, response: torch.Tensor, + behavior_logprobs: torch.Tensor, ref_logprobs: torch.Tensor, advantages: torch.Tensor, padding_mask: torch.Tensor, beta: float = 0.1, ) -> torch.Tensor: + # Sanity check: behavior_logprobs should not be None + assert behavior_logprobs is not None, "behavior_logprobs is None!" + logprobs: torch.Tensor = compute_logprobs(logits, response) kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1 - per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages + per_token_policy_loss = torch.exp(logprobs - behavior_logprobs.detach()) * advantages per_token_loss = -(per_token_policy_loss - beta * kl) loss = ( ((per_token_loss * padding_mask).sum(dim=1)) @@ -380,6 +401,14 @@ async def continuous_rollouts(): dtype=torch.long, ) for i, response in enumerate(responses): + # Extract behavior policy logprobs from the completion + behavior_logprobs = response.logprobs + if behavior_logprobs is None: + raise ValueError( + "Behavior logprobs are not available in the completion. " + "Please enable logprobs in the sampling_params configuration." + ) + episode = Episode( episode_id=str(uuid.uuid4()), pad_id=pad_id, @@ -387,6 +416,7 @@ async def continuous_rollouts(): response_len=max_res_tokens, target=target, completion=response, + behavior_logprobs=behavior_logprobs, ) episode.reward = await reward_actor.evaluate_response.route( prompt=prompt, response=response.text, target=target diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index c6fc1613b..ddc8ccbce 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -42,6 +42,7 @@ policy: max_tokens: ${max_res_tokens} temperature: 1.0 top_p: 1.0 + logprobs: 1 # Enable logprobs for behavior policy # Trainer configuration trainer: diff --git a/src/forge/losses/grpo_loss.py b/src/forge/losses/grpo_loss.py index 220367b47..3e9eadd5e 100644 --- a/src/forge/losses/grpo_loss.py +++ b/src/forge/losses/grpo_loss.py @@ -12,15 +12,25 @@ class SimpleGRPOLoss(nn.Module): """Simplified GRPO Loss for simplified single step updates Inspired by the Hugging Face TRL implementation: https://github.com/huggingface/trl/blob/417915a3e4d3e3bc8d7b196594308b8eabf928be/trl/trainer/grpo_trainer.py#L1624. + + Args: + beta: The KL divergence coefficient for the loss. + + Forward args: + logprobs: Log probabilities from the current policy being trained + behavior_logprobs: Log probabilities from the policy that generated the responses (behavior policy) + ref_logprobs: Log probabilities from the reference model (for KL regularization) + advantages: Computed advantages for each token + padding_mask: Mask indicating valid tokens (1) vs padding (0) """ def __init__(self, beta: float = 0.1): super().__init__() self.beta = beta - def forward(self, logprobs, ref_logprobs, advantages, padding_mask): + def forward(self, logprobs, behavior_logprobs, ref_logprobs, advantages, padding_mask): kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1 - per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages + per_token_policy_loss = torch.exp(logprobs - behavior_logprobs.detach()) * advantages per_token_loss = -(per_token_policy_loss - self.beta * kl) loss = ( ((per_token_loss * padding_mask).sum(dim=1))