Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions GRPO_BUG_FIX_SUMMARY.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# GRPO Bug Fix Summary

## Critical Bugs Found and Fixed

### Bug 1: Behavior Policy Logprobs Never Stored
**Problem:** The GRPO algorithm requires logprobs from the **behavior policy** (the policy that generated the responses), but these were never being stored or used.

**Impact:** The policy gradient had no signal - it was computing `exp(logprobs - logprobs) = exp(0) = 1`, making the importance sampling ratio always 1, which means the model wasn't actually learning from the policy gradient.

**Fixes Applied:**
1. Added `behavior_logprobs` field to `Episode` dataclass
2. Extracted behavior_logprobs from `completion.logprobs` when creating episodes
3. Added `behavior_logprobs_tensor` property to handle padding
4. Updated `collate()` function to include behavior_logprobs in batches
5. Updated loss functions to use behavior_logprobs instead of `logprobs.detach()`

### Bug 2: Logprobs Not Enabled in vLLM
**Problem:** The sampling_params configuration didn't request logprobs from vLLM, so `completion.logprobs` was always `None`.

**Fix Applied:**
- Added `logprobs: 1` to `sampling_params` in `qwen3_1_7b.yaml`

### Bug 3: Incorrect Importance Sampling Ratio
**Problem:** Both loss functions used:
```python
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
```

This is always `exp(0) = 1`, providing no learning signal!

**Correct Formula:**
```python
per_token_policy_loss = torch.exp(logprobs - behavior_logprobs.detach()) * advantages
```

**Files Modified:**
- `apps/grpo/main.py` - `simple_grpo_loss()` function
- `src/forge/losses/grpo_loss.py` - `SimpleGRPOLoss.forward()` method

## GRPO Algorithm Explanation

GRPO (Group Relative Policy Optimization) uses **three** sets of logprobs:

1. **Current Policy Logprobs** (`logprobs`): From the model being trained
- Used to compute gradients
- Computed on-the-fly during training

2. **Behavior Policy Logprobs** (`behavior_logprobs`): From the policy that generated the responses
- Used for importance sampling: `ratio = exp(current - behavior)`
- Must be stored when responses are generated
- In "off-by-n" setting, this is the policy from n steps ago

3. **Reference Policy Logprobs** (`ref_logprobs`): From a frozen reference model
- Used for KL regularization to prevent the policy from diverging too much
- Computed from a frozen copy of the initial model

## Testing the Fix

To verify the fix works:

```bash
# Train with learning rate 1e-5 (should learn)
python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml trainer.optimizer.lr=1e-5

# Train with learning rate 0 (should NOT learn - flat rewards)
python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml trainer.optimizer.lr=0
```

**Expected behavior after fix:**
- With lr=1e-5: Rewards should improve over time
- With lr=0: Rewards should stay flat (no learning)
- The two runs should have DIFFERENT reward trajectories

**Before the fix:**
- Both runs had identical reward patterns (no actual learning happening)

## Files Changed

1. `apps/grpo/main.py`
- Updated `Episode` dataclass
- Updated `collate()` function
- Updated `simple_grpo_loss()` function
- Added behavior_logprobs extraction in rollout loop
- Added `behavior_logprobs_tensor` property

2. `src/forge/losses/grpo_loss.py`
- Updated `SimpleGRPOLoss.forward()` signature and implementation
- Added documentation explaining the three logprob types

3. `apps/grpo/qwen3_1_7b.yaml`
- Added `logprobs: 1` to sampling_params

34 changes: 32 additions & 2 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml

import asyncio
import logging
import time
import uuid
from dataclasses import dataclass
Expand Down Expand Up @@ -49,7 +50,8 @@ class Episode:
target: Any | None = None
# Processed data
completion: Completion | None = None
ref_logprobs: torch.Tensor | None = None
behavior_logprobs: torch.Tensor | None = None # Logprobs from the policy that generated the response
ref_logprobs: torch.Tensor | None = None # Logprobs from the reference model
reward: float | None = None
advantage: float | None = None

Expand All @@ -72,6 +74,15 @@ def response_tensor(self) -> torch.Tensor:
diff = self.response_len - tensor.shape[0]
tensor = F.pad(tensor, (0, diff), value=self.pad_id)
return tensor

@property
def behavior_logprobs_tensor(self) -> torch.Tensor:
"""Get behavior logprobs padded to response_len."""
tensor: torch.Tensor = self.behavior_logprobs
if tensor.shape[0] < self.response_len: # right pad with zeros (will be masked)
diff = self.response_len - tensor.shape[0]
tensor = F.pad(tensor, (0, diff), value=0.0)
return tensor


# Represents the group (G) of episodes in GRPO
Expand All @@ -97,6 +108,9 @@ def collate(
response = [e.response_tensor for e in batch]
response = torch.stack(response) # [b x s]

behavior_logprobs = [e.behavior_logprobs_tensor for e in batch]
behavior_logprobs = torch.stack(behavior_logprobs).squeeze() # [b x s]

ref_logprobs = [e.ref_logprobs for e in batch]
ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s]

Expand All @@ -109,6 +123,7 @@ def collate(
input = {"tokens": torch.cat([request, response], dim=1)}
target = {
"response": response,
"behavior_logprobs": behavior_logprobs,
"ref_logprobs": ref_logprobs,
"advantages": advantages,
"padding_mask": mask,
Expand All @@ -119,17 +134,23 @@ def collate(


# Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`
_grpo_logger = logging.getLogger(__name__)

def simple_grpo_loss(
logits: torch.Tensor,
response: torch.Tensor,
behavior_logprobs: torch.Tensor,
ref_logprobs: torch.Tensor,
advantages: torch.Tensor,
padding_mask: torch.Tensor,
beta: float = 0.1,
) -> torch.Tensor:
# Sanity check: behavior_logprobs should not be None
assert behavior_logprobs is not None, "behavior_logprobs is None!"

logprobs: torch.Tensor = compute_logprobs(logits, response)
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
per_token_policy_loss = torch.exp(logprobs - behavior_logprobs.detach()) * advantages
per_token_loss = -(per_token_policy_loss - beta * kl)
loss = (
((per_token_loss * padding_mask).sum(dim=1))
Expand Down Expand Up @@ -380,13 +401,22 @@ async def continuous_rollouts():
dtype=torch.long,
)
for i, response in enumerate(responses):
# Extract behavior policy logprobs from the completion
behavior_logprobs = response.logprobs
if behavior_logprobs is None:
raise ValueError(
"Behavior logprobs are not available in the completion. "
"Please enable logprobs in the sampling_params configuration."
)

episode = Episode(
episode_id=str(uuid.uuid4()),
pad_id=pad_id,
request_len=max_req_tokens,
response_len=max_res_tokens,
target=target,
completion=response,
behavior_logprobs=behavior_logprobs,
)
episode.reward = await reward_actor.evaluate_response.route(
prompt=prompt, response=response.text, target=target
Expand Down
1 change: 1 addition & 0 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ policy:
max_tokens: ${max_res_tokens}
temperature: 1.0
top_p: 1.0
logprobs: 1 # Enable logprobs for behavior policy

# Trainer configuration
trainer:
Expand Down
14 changes: 12 additions & 2 deletions src/forge/losses/grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,25 @@ class SimpleGRPOLoss(nn.Module):
"""Simplified GRPO Loss for simplified single step updates
Inspired by the Hugging Face TRL implementation:
https://github.com/huggingface/trl/blob/417915a3e4d3e3bc8d7b196594308b8eabf928be/trl/trainer/grpo_trainer.py#L1624.

Args:
beta: The KL divergence coefficient for the loss.

Forward args:
logprobs: Log probabilities from the current policy being trained
behavior_logprobs: Log probabilities from the policy that generated the responses (behavior policy)
ref_logprobs: Log probabilities from the reference model (for KL regularization)
advantages: Computed advantages for each token
padding_mask: Mask indicating valid tokens (1) vs padding (0)
"""

def __init__(self, beta: float = 0.1):
super().__init__()
self.beta = beta

def forward(self, logprobs, ref_logprobs, advantages, padding_mask):
def forward(self, logprobs, behavior_logprobs, ref_logprobs, advantages, padding_mask):
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
per_token_policy_loss = torch.exp(logprobs - behavior_logprobs.detach()) * advantages
per_token_loss = -(per_token_policy_loss - self.beta * kl)
loss = (
((per_token_loss * padding_mask).sum(dim=1))
Expand Down