-
Notifications
You must be signed in to change notification settings - Fork 482
Description
🐛 Describe the bug
Summary
The GPTRewardModel.forward() method crashes with a division by zero error when processing batches with 0 or 1 samples.
Environment
- File:
examples/summarize_rlhf/reward_model/reward_model.py - Class:
GPTRewardModel - Method:
forward()
Problem Description
When the input batch has 0 or 1 samples, the batch size calculation results in bs = 0:
bs = input_ids.shape[0] // 2 # If shape[0] = 1, then bs = 0This causes two critical crashes:
Crash #1: Division by Zero (Line 100)
loss = loss / bs # ZeroDivisionError when bs = 0Crash #2: Empty Tensor Stack (Lines 103-104)
chosen_end_scores = torch.stack(chosen_end_scores) # RuntimeError: stack expects non-empty list
rejected_end_scores = torch.stack(rejected_end_scores)Reproduction Steps
from reward_model import GPTRewardModel
import torch
model = GPTRewardModel("EleutherAI/gpt-j-6B")
# Test with single sample (bs will be 0)
input_ids = torch.randint(0, 50000, (1, 512))
output = model(input_ids) # 💥 CRASHError output:
ZeroDivisionError: division by zero
at line 100: loss = loss / bs
Root Cause
The model expects paired inputs (chosen + rejected) and splits them:
- Input shape:
(batch_size, seq_len) - After split:
bs = batch_size // 2 - If batch_size ∈ {0, 1}:
bs = 0→ crash
The code doesn't handle the edge case where splitting results in zero samples per category.
Proposed Solution
Add an early return guard after the batch size calculation:
bs = input_ids.shape[0] // 2
# Handle empty batch edge case
if bs == 0:
return {
"loss": torch.tensor(0.0, device=input_ids.device),
"chosen_end_scores": torch.tensor([], device=input_ids.device),
"rejected_end_scores": torch.tensor([], device=input_ids.device),
}Benefits:
✅ Prevents division by zero
✅ Prevents empty tensor stack errors
✅ Returns consistent output format
✅ Maintains proper device placement
✅ Backward compatible (no breaking changes)
Impact
Severity: Medium-High
- Crashes training/inference on edge cases
- May occur during:
- Dataset iteration with uneven batches
- Distributed training with small batch sizes
- Testing/debugging with single samples
Affected Users: Anyone using GPTRewardModel with small or variable batch sizes
Severity: Medium-High
- Crashes training/inference on edge cases
- May occur during:
- Dataset iteration with uneven batches
- Distributed training with small batch sizes
- Testing/debugging with single samples
Affected Users: Anyone using GPTRewardModel with small or variable batch sizes
Additional Context
The fix is straightforward and adds only 10 lines of defensive code. I have the implementation ready and tested if you'd like me to submit a PR.
Affected code location:
examples/summarize_rlhf/reward_model/reward_model.py
Lines 51-57
Checklist
- Bug identified and root cause analyzed
- Reproduction steps provided
- Solution proposed with code snippet
- Impact assessment completed
Which trlX version are you using?
latest
Additional system and package information
Python=3.12.3, transformers=4.57.1,NVIDIA GB10, 580.95.05, CUDA=12.1