diff --git a/environments/sv-env-netlogs-judge/pyproject.toml b/environments/sv-env-netlogs-judge/pyproject.toml index a23fa76..947a424 100644 --- a/environments/sv-env-netlogs-judge/pyproject.toml +++ b/environments/sv-env-netlogs-judge/pyproject.toml @@ -15,7 +15,7 @@ tags = [ "train", "eval", ] -version = "0.2.17" +version = "0.2.18" requires-python = ">=3.12" dependencies = [ "verifiers>=0.1.9", diff --git a/environments/sv-env-netlogs-judge/sv_netlogs_judge_impl.py b/environments/sv-env-netlogs-judge/sv_netlogs_judge_impl.py index 5a7f40b..1445fc6 100644 --- a/environments/sv-env-netlogs-judge/sv_netlogs_judge_impl.py +++ b/environments/sv-env-netlogs-judge/sv_netlogs_judge_impl.py @@ -15,6 +15,7 @@ from __future__ import annotations +import json import logging as _logging import os import sys @@ -22,6 +23,8 @@ import verifiers as vf from datasets import Dataset +from openai import APIError, APITimeoutError, RateLimitError +from verifiers.utils.async_utils import maybe_await REPO_ROOT = str(Path(__file__).resolve().parents[2]) @@ -39,6 +42,7 @@ DatasetSource, JsonClassificationParser, RolloutLogger, + get_response_text, load_dataset_with_fallback, ) except ImportError: @@ -49,6 +53,7 @@ DatasetSource, JsonClassificationParser, RolloutLogger, + get_response_text, load_dataset_with_fallback, ) @@ -111,6 +116,128 @@ def __init__(self) -> None: super().__init__(allowed_labels=["Benign", "Malicious", "Abstain"]) +def format_completion_for_judge(parser: NetworkLogParser, completion) -> str: + """Format a completion for judge prompts with full structured context. + + JudgeRubric normally passes parser.parse_answer(completion) into the judge + prompt, which only yields the label (for example "Benign"). That erased the + confidence/format information the judge prompt explicitly asks about, causing + the judge to reject otherwise-correct responses as not valid JSON. + + We instead pass a canonical JSON rendering of the parsed response so the + judge can evaluate label, confidence, and structure together. If parsing + fails, fall back to the raw response text for debugging resilience. + """ + data = parser._parse_json(completion) + if data: + canonical = {key: data[key] for key in ("label", "confidence", "rationale") if key in data} + if canonical: + return json.dumps(canonical, sort_keys=True) + return get_response_text(completion) + + +class _JudgePromptParser: + """Thin parser wrapper used only when rendering the judge prompt.""" + + def __init__(self, base_parser: NetworkLogParser) -> None: + self.base_parser = base_parser + + def parse_answer(self, completion) -> str: + return format_completion_for_judge(self.base_parser, completion) + + def __getattr__(self, name: str): + return getattr(self.base_parser, name) + + +class StructuredResponseJudgeRubric(vf.JudgeRubric): + """JudgeRubric that feeds structured JSON to the judge prompt. + + Upstream JudgeRubric calls parser.parse_answer(completion), which is too + lossy for this environment because the judge prompt evaluates JSON validity + and confidence as well as the label. We compute the structured response + directly inside judge() so concurrent scoring cannot mutate shared rubric + state. + """ + + def __init__(self, parser: NetworkLogParser, **kwargs) -> None: + super().__init__(parser=parser, **kwargs) + self._prompt_parser = _JudgePromptParser(parser) + + async def judge(self, prompt, completion, answer, state=None) -> str: + if isinstance(prompt, list): + last_msg = prompt[-1] + if isinstance(last_msg, dict) and "content" in last_msg: + question = str(last_msg["content"]) + else: + question = "" + else: + question = str(prompt) + + response = self._prompt_parser.parse_answer(completion) + judge_prompt = self.judge_prompt.format(question=question, answer=answer, response=response) + cached = state.get("judge_response") if state else None + if isinstance(cached, dict) and judge_prompt in cached: + return cached[judge_prompt] + + judge_args = dict(self.judge_sampling_args or {}) + if "max_tokens" in judge_args: + if judge_args["max_tokens"] is None: + judge_args.pop("max_tokens") + else: + judge_args["max_completion_tokens"] = judge_args.pop("max_tokens") + if "max_completion_tokens" in judge_args and judge_args["max_completion_tokens"] is None: + judge_args.pop("max_completion_tokens") + judge_args = {k: v for k, v in judge_args.items() if v is not None} + + try: + judge_response = await maybe_await( + self.judge_client.chat.completions.create, + model=self.judge_model, + messages=[{"role": "user", "content": judge_prompt}], + **judge_args, + ) + judge_response = str(judge_response.choices[0].message.content) + except RateLimitError as e: + self.logger.warning( + f"Rate limit exceeded when calling judge model '{self.judge_model}'. " + f"Try reducing concurrency or waiting before retrying. Error: {str(e)}" + ) + raise RuntimeError( + f"Judge model rate limit exceeded. Try reducing concurrency or waiting before retrying. " + f"Model: {self.judge_model}, Error: {str(e)}" + ) from e + except APITimeoutError as e: + self.logger.warning( + f"Timeout when calling judge model '{self.judge_model}'. " + f"Increase timeout in judge_sampling_args or check model responsiveness. Error: {str(e)}" + ) + raise RuntimeError( + f"Judge model timeout. Increase timeout in judge_sampling_args or check model responsiveness. " + f"Model: {self.judge_model}, Error: {str(e)}" + ) from e + except APIError as e: + self.logger.warning( + f"API error when calling judge model '{self.judge_model}'. " + f"Check model availability and API key. Error: {str(e)}" + ) + raise RuntimeError( + f"Judge model API error. Check model availability and API key. " + f"Model: {self.judge_model}, Error: {str(e)}" + ) from e + except Exception as e: + self.logger.warning(f"Unexpected error when calling judge model '{self.judge_model}'. Error: {str(e)}") + raise RuntimeError( + f"Unexpected error when calling judge model '{self.judge_model}'. Error: {str(e)}" + ) from e + + if state: + if not isinstance(cached, dict): + cached = {} + cached[judge_prompt] = judge_response + state["judge_response"] = cached + return judge_response + + DEFAULT_ENV_NAME = "sv-env-network-logs-judge" @@ -260,7 +387,7 @@ def _create_synthetic_dataset(): # Use JudgeRubric instead of executable Rubric # The judge_reward function uses the judge callable injected by JudgeRubric - rubric = vf.JudgeRubric( + rubric = StructuredResponseJudgeRubric( parser=parser, judge_model=judge_model, judge_prompt=JUDGE_PROMPT, diff --git a/environments/sv-env-netlogs-judge/sv_netlogs_judge_test.py b/environments/sv-env-netlogs-judge/sv_netlogs_judge_test.py index a7795ee..fb3de72 100644 --- a/environments/sv-env-netlogs-judge/sv_netlogs_judge_test.py +++ b/environments/sv-env-netlogs-judge/sv_netlogs_judge_test.py @@ -157,6 +157,36 @@ def test_custom_judge_model(self) -> None: judge_rubric = self._get_judge_rubric(env) assert judge_rubric.judge_model == "gpt-4.1-mini" + def test_judge_prompt_receives_structured_json_response(self) -> None: + env = load_environment(dataset_name="synthetic", max_examples=1) + judge_rubric = self._get_judge_rubric(env) + + captured = {} + + async def fake_create(**kwargs): + captured["prompt"] = kwargs["messages"][0]["content"] + return Mock(choices=[Mock(message=Mock(content="yes"))]) + + judge_rubric.judge_client = Mock() + judge_rubric.judge_client.chat = Mock() + judge_rubric.judge_client.chat.completions = Mock() + judge_rubric.judge_client.chat.completions.create = fake_create + + result = asyncio.run( + judge_rubric.judge( + prompt=[{"role": "user", "content": "sample log"}], + completion='{"label": "Benign", "confidence": 0.95, "rationale": "normal traffic"}', + answer="Benign", + state={}, + ) + ) + + assert result == "yes" + prompt_text = captured["prompt"] + assert '"label": "Benign"' in prompt_text + assert '"confidence": 0.95' in prompt_text + assert "Model's parsed response: Benign" not in prompt_text + def test_custom_env_name_updates_logging_and_ids(self) -> None: logger = Mock() logger.enabled = True diff --git a/environments/sv-env-network-logs/pyproject.toml b/environments/sv-env-network-logs/pyproject.toml index c47ba1b..b1e362e 100644 --- a/environments/sv-env-network-logs/pyproject.toml +++ b/environments/sv-env-network-logs/pyproject.toml @@ -14,7 +14,7 @@ tags = [ "train", "eval", ] -version = "0.2.14" +version = "0.2.15" requires-python = ">=3.12" dependencies = [ "verifiers>=0.1.9", diff --git a/environments/sv-env-network-logs/sv_env_network_logs_judge.py b/environments/sv-env-network-logs/sv_env_network_logs_judge.py index 03ed745..29f5249 100644 --- a/environments/sv-env-network-logs/sv_env_network_logs_judge.py +++ b/environments/sv-env-network-logs/sv_env_network_logs_judge.py @@ -15,8 +15,12 @@ from __future__ import annotations +import json from pathlib import Path +from openai import APIError, APITimeoutError, RateLimitError +from verifiers.utils.async_utils import maybe_await + try: # Try importing from installed package first from sv_shared import weave_init # type: ignore # noqa: F401 @@ -38,6 +42,7 @@ DatasetSource, JsonClassificationParser, RolloutLogger, + get_response_text, load_dataset_with_fallback, ) except ImportError: @@ -49,6 +54,7 @@ DatasetSource, JsonClassificationParser, RolloutLogger, + get_response_text, load_dataset_with_fallback, ) @@ -111,6 +117,128 @@ def __init__(self) -> None: super().__init__(allowed_labels=["Benign", "Malicious", "Abstain"]) +def format_completion_for_judge(parser: NetworkLogParser, completion) -> str: + """Format a completion for judge prompts with full structured context. + + JudgeRubric normally passes parser.parse_answer(completion) into the judge + prompt, which only yields the label (for example "Benign"). That erased the + confidence/format information the judge prompt explicitly asks about, causing + the judge to reject otherwise-correct responses as not valid JSON. + + We instead pass a canonical JSON rendering of the parsed response so the + judge can evaluate label, confidence, and structure together. If parsing + fails, fall back to the raw response text for debugging resilience. + """ + data = parser._parse_json(completion) + if data: + canonical = {key: data[key] for key in ("label", "confidence", "rationale") if key in data} + if canonical: + return json.dumps(canonical, sort_keys=True) + return get_response_text(completion) + + +class _JudgePromptParser: + """Thin parser wrapper used only when rendering the judge prompt.""" + + def __init__(self, base_parser: NetworkLogParser) -> None: + self.base_parser = base_parser + + def parse_answer(self, completion) -> str: + return format_completion_for_judge(self.base_parser, completion) + + def __getattr__(self, name: str): + return getattr(self.base_parser, name) + + +class StructuredResponseJudgeRubric(vf.JudgeRubric): + """JudgeRubric that feeds structured JSON to the judge prompt. + + Upstream JudgeRubric calls parser.parse_answer(completion), which is too + lossy for this environment because the judge prompt evaluates JSON validity + and confidence as well as the label. We compute the structured response + directly inside judge() so concurrent scoring cannot mutate shared rubric + state. + """ + + def __init__(self, parser: NetworkLogParser, **kwargs) -> None: + super().__init__(parser=parser, **kwargs) + self._prompt_parser = _JudgePromptParser(parser) + + async def judge(self, prompt, completion, answer, state=None) -> str: + if isinstance(prompt, list): + last_msg = prompt[-1] + if isinstance(last_msg, dict) and "content" in last_msg: + question = str(last_msg["content"]) + else: + question = "" + else: + question = str(prompt) + + response = self._prompt_parser.parse_answer(completion) + judge_prompt = self.judge_prompt.format(question=question, answer=answer, response=response) + cached = state.get("judge_response") if state else None + if isinstance(cached, dict) and judge_prompt in cached: + return cached[judge_prompt] + + judge_args = dict(self.judge_sampling_args or {}) + if "max_tokens" in judge_args: + if judge_args["max_tokens"] is None: + judge_args.pop("max_tokens") + else: + judge_args["max_completion_tokens"] = judge_args.pop("max_tokens") + if "max_completion_tokens" in judge_args and judge_args["max_completion_tokens"] is None: + judge_args.pop("max_completion_tokens") + judge_args = {k: v for k, v in judge_args.items() if v is not None} + + try: + judge_response = await maybe_await( + self.judge_client.chat.completions.create, + model=self.judge_model, + messages=[{"role": "user", "content": judge_prompt}], + **judge_args, + ) + judge_response = str(judge_response.choices[0].message.content) + except RateLimitError as e: + self.logger.warning( + f"Rate limit exceeded when calling judge model '{self.judge_model}'. " + f"Try reducing concurrency or waiting before retrying. Error: {str(e)}" + ) + raise RuntimeError( + f"Judge model rate limit exceeded. Try reducing concurrency or waiting before retrying. " + f"Model: {self.judge_model}, Error: {str(e)}" + ) from e + except APITimeoutError as e: + self.logger.warning( + f"Timeout when calling judge model '{self.judge_model}'. " + f"Increase timeout in judge_sampling_args or check model responsiveness. Error: {str(e)}" + ) + raise RuntimeError( + f"Judge model timeout. Increase timeout in judge_sampling_args or check model responsiveness. " + f"Model: {self.judge_model}, Error: {str(e)}" + ) from e + except APIError as e: + self.logger.warning( + f"API error when calling judge model '{self.judge_model}'. " + f"Check model availability and API key. Error: {str(e)}" + ) + raise RuntimeError( + f"Judge model API error. Check model availability and API key. " + f"Model: {self.judge_model}, Error: {str(e)}" + ) from e + except Exception as e: + self.logger.warning(f"Unexpected error when calling judge model '{self.judge_model}'. Error: {str(e)}") + raise RuntimeError( + f"Unexpected error when calling judge model '{self.judge_model}'. Error: {str(e)}" + ) from e + + if state: + if not isinstance(cached, dict): + cached = {} + cached[judge_prompt] = judge_response + state["judge_response"] = cached + return judge_response + + def load_environment( dataset_name: str = "iot23-train-dev-test-v1.jsonl", dataset_source: DatasetSource = "auto", @@ -254,7 +382,7 @@ def _create_synthetic_dataset(): # Use JudgeRubric instead of executable Rubric # The judge_reward function uses the judge callable injected by JudgeRubric - rubric = vf.JudgeRubric( + rubric = StructuredResponseJudgeRubric( parser=parser, judge_model=judge_model, judge_prompt=JUDGE_PROMPT, diff --git a/environments/sv-env-network-logs/sv_env_network_logs_judge_test.py b/environments/sv-env-network-logs/sv_env_network_logs_judge_test.py index fede0d1..fb1bb4d 100644 --- a/environments/sv-env-network-logs/sv_env_network_logs_judge_test.py +++ b/environments/sv-env-network-logs/sv_env_network_logs_judge_test.py @@ -152,3 +152,33 @@ def test_custom_judge_model(self) -> None: env = load_environment(dataset_name="synthetic", max_examples=5, judge_model="gpt-4.1-mini") judge_rubric = self._get_judge_rubric(env) assert judge_rubric.judge_model == "gpt-4.1-mini" + + def test_judge_prompt_receives_structured_json_response(self) -> None: + env = load_environment(dataset_name="synthetic", max_examples=1) + judge_rubric = self._get_judge_rubric(env) + + captured = {} + + async def fake_create(**kwargs): + captured["prompt"] = kwargs["messages"][0]["content"] + return Mock(choices=[Mock(message=Mock(content="yes"))]) + + judge_rubric.judge_client = Mock() + judge_rubric.judge_client.chat = Mock() + judge_rubric.judge_client.chat.completions = Mock() + judge_rubric.judge_client.chat.completions.create = fake_create + + result = asyncio.run( + judge_rubric.judge( + prompt=[{"role": "user", "content": "sample log"}], + completion='{"label": "Benign", "confidence": 0.95, "rationale": "normal traffic"}', + answer="Benign", + state={}, + ) + ) + + assert result == "yes" + prompt_text = captured["prompt"] + assert '"label": "Benign"' in prompt_text + assert '"confidence": 0.95' in prompt_text + assert "Model's parsed response: Benign" not in prompt_text