From 9dc55f6c61be0667e7f5a055f767c6d19d33c90f Mon Sep 17 00:00:00 2001 From: AakashKumarNain Date: Fri, 10 Oct 2025 21:44:25 +0530 Subject: [PATCH] add tracing with weave if available --- verifiers/trainers/grpo_trainer.py | 104 +++++++++++++++++++++++++++-- 1 file changed, 97 insertions(+), 7 deletions(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index c8887e45e..2aafe1e5f 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -1,5 +1,14 @@ # adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py +import os +os.environ["WEAVE_PRINT_CALL_LINK"] = os.environ.get("WEAVE_PRINT_CALL_LINK", "false") +os.environ["WEAVE_CAPTURE_CODE"] = os.environ.get("WEAVE_CAPTURE_CODE","false") +os.environ["WEAVE_TRACE_LANGCHAIN"] = os.environ.get("WEAVE_TRACE_LANGCHAIN","false") +os.environ["WEAVE_DEBUG_HTTP"] = os.environ.get("WEAVE_DEBUG_HTTP","false") +os.environ["WEAVE_IMPLICITLY_PATCH_INTEGRATIONS"] = os.environ.get( + "WEAVE_IMPLICITLY_PATCH_INTEGRATIONS","false" +) + import logging import time from collections import defaultdict, deque @@ -52,6 +61,60 @@ from verifiers.utils.logging_utils import print_prompt_completions_sample +try: + import weave # type: ignore[unresolved-import] + WEAVE_AVAILABLE = True +except Exception: + WEAVE_AVAILABLE = False + +# For tracing the rollouts using weave! +if WEAVE_AVAILABLE: + def _fmt_chat_block(msg: str): + """Convert chat/message structures into a single readable block. + + Args: + msg: Chat message with system/user kind format. + + Returns: + Formatted message for better logging + """ + + if isinstance(msg, str): + return msg + if isinstance(msg, list): + lines = [] + for m in msg: + role = m.get("role", "user") + content = m.get("content", "") + if isinstance(content, list): + content = " ".join( + (part.get("text", str(part)) if isinstance(part, dict) else str(part)) + for part in content + ) + lines.append(f"{role}: {content}") + return "\n".join(lines) + return str(msg) + + + @weave.op() + def trace_rollouts(step: int, mode: str, _inputs: dict): + """Trace the rollouts during training and evaluation phases. + + Args: + step: Current train/eval step. + mode: 'train' or 'eval' phase. + _inputs: Dictionary containing prompt, completion, and rewards. + + Returns: + Dictionary containing same values as the input dictionary, but nicely + formatted for logging and displaying. + """ + + prompt = _fmt_chat_block(_inputs.pop("prompt")) + completion = _fmt_chat_block(_inputs.pop("completion")) + return {"prompt": prompt, "completion": completion, **_inputs} + + class RepeatSampler(Sampler): """ Sampler that repeats the indices of a dataset in a structured manner. @@ -1455,6 +1518,16 @@ def evaluate( ): import pandas as pd + if WEAVE_AVAILABLE: + _full_inputs = { + "prompt": prompts, "completion": completions, **reward_dict + } + for idx in range(len(completions)): + _inputs = {k: v[idx] for k,v in _full_inputs.items()} + _ = trace_rollouts( + int(self.state.global_step), "eval", _inputs + ) + table_data = { "step": [str(self.state.global_step)] * len(prompts), "prompt": prompts, @@ -1509,14 +1582,31 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: ): import pandas as pd + _prompts = self._textual_logs["prompt"] + _completions = [ + self._sanitize_tool_calls(c) + for c in self._textual_logs["completion"] + ] + _rewards = { + k: list(v) for k, v in self._textual_logs["rewards"].items() + } + _full_inputs = { + "prompt": _prompts, "completion": _completions, **_rewards + } + + # If user has weave installed in their environment, then use + # weave to trace rollouts, otherwise just stick to the original + if WEAVE_AVAILABLE: + for idx in range(len(_completions)): + _inputs = {k: v[idx] for k,v in _full_inputs.items()} + _ = trace_rollouts( + int(self.state.global_step), "train", _inputs + ) + table = { - "step": [str(self.state.global_step)] - * len(self._textual_logs["prompt"]), - "prompt": list(self._textual_logs["prompt"]), - "completion": [ - self._sanitize_tool_calls(c) - for c in self._textual_logs["completion"] - ], + "step": [str(self.state.global_step)] * len(_prompts), + "prompt": list(_prompts), + "completion": _completions, **{k: list(v) for k, v in self._textual_logs["rewards"].items()}, } if len(table["prompt"]) > 0: