Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 97 additions & 7 deletions verifiers/trainers/grpo_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py

import os
os.environ["WEAVE_PRINT_CALL_LINK"] = os.environ.get("WEAVE_PRINT_CALL_LINK", "false")
os.environ["WEAVE_CAPTURE_CODE"] = os.environ.get("WEAVE_CAPTURE_CODE","false")
os.environ["WEAVE_TRACE_LANGCHAIN"] = os.environ.get("WEAVE_TRACE_LANGCHAIN","false")
os.environ["WEAVE_DEBUG_HTTP"] = os.environ.get("WEAVE_DEBUG_HTTP","false")
os.environ["WEAVE_IMPLICITLY_PATCH_INTEGRATIONS"] = os.environ.get(
"WEAVE_IMPLICITLY_PATCH_INTEGRATIONS","false"
)

import logging
import time
from collections import defaultdict, deque
Expand Down Expand Up @@ -52,6 +61,60 @@
from verifiers.utils.logging_utils import print_prompt_completions_sample


try:
import weave # type: ignore[unresolved-import]
WEAVE_AVAILABLE = True
except Exception:
WEAVE_AVAILABLE = False

# For tracing the rollouts using weave!
if WEAVE_AVAILABLE:
def _fmt_chat_block(msg: str):
"""Convert chat/message structures into a single readable block.

Args:
msg: Chat message with system/user kind format.

Returns:
Formatted message for better logging
"""

if isinstance(msg, str):
return msg
if isinstance(msg, list):
lines = []
for m in msg:
role = m.get("role", "user")
content = m.get("content", "")
if isinstance(content, list):
content = " ".join(
(part.get("text", str(part)) if isinstance(part, dict) else str(part))
for part in content
)
lines.append(f"{role}: {content}")
return "\n".join(lines)
return str(msg)


@weave.op()
def trace_rollouts(step: int, mode: str, _inputs: dict):
"""Trace the rollouts during training and evaluation phases.

Args:
step: Current train/eval step.
mode: 'train' or 'eval' phase.
_inputs: Dictionary containing prompt, completion, and rewards.

Returns:
Dictionary containing same values as the input dictionary, but nicely
formatted for logging and displaying.
"""

prompt = _fmt_chat_block(_inputs.pop("prompt"))
completion = _fmt_chat_block(_inputs.pop("completion"))
return {"prompt": prompt, "completion": completion, **_inputs}


class RepeatSampler(Sampler):
"""
Sampler that repeats the indices of a dataset in a structured manner.
Expand Down Expand Up @@ -1455,6 +1518,16 @@ def evaluate(
):
import pandas as pd

if WEAVE_AVAILABLE:
_full_inputs = {
"prompt": prompts, "completion": completions, **reward_dict
}
for idx in range(len(completions)):
_inputs = {k: v[idx] for k,v in _full_inputs.items()}
_ = trace_rollouts(
int(self.state.global_step), "eval", _inputs
)

table_data = {
"step": [str(self.state.global_step)] * len(prompts),
"prompt": prompts,
Expand Down Expand Up @@ -1509,14 +1582,31 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
):
import pandas as pd

_prompts = self._textual_logs["prompt"]
_completions = [
self._sanitize_tool_calls(c)
for c in self._textual_logs["completion"]
]
_rewards = {
k: list(v) for k, v in self._textual_logs["rewards"].items()
}
_full_inputs = {
"prompt": _prompts, "completion": _completions, **_rewards
}

# If user has weave installed in their environment, then use
# weave to trace rollouts, otherwise just stick to the original
if WEAVE_AVAILABLE:
for idx in range(len(_completions)):
_inputs = {k: v[idx] for k,v in _full_inputs.items()}
_ = trace_rollouts(
int(self.state.global_step), "train", _inputs
)

table = {
"step": [str(self.state.global_step)]
* len(self._textual_logs["prompt"]),
"prompt": list(self._textual_logs["prompt"]),
"completion": [
self._sanitize_tool_calls(c)
for c in self._textual_logs["completion"]
],
"step": [str(self.state.global_step)] * len(_prompts),
"prompt": list(_prompts),
"completion": _completions,
**{k: list(v) for k, v in self._textual_logs["rewards"].items()},
}
if len(table["prompt"]) > 0:
Expand Down