Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environments/sv-env-netlogs-judge/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ tags = [
"train",
"eval",
]
version = "0.2.17"
version = "0.2.18"
requires-python = ">=3.12"
dependencies = [
"verifiers>=0.1.9",
Expand Down
129 changes: 128 additions & 1 deletion environments/sv-env-netlogs-judge/sv_netlogs_judge_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@

from __future__ import annotations

import json
import logging as _logging
import os
import sys
from pathlib import Path

import verifiers as vf
from datasets import Dataset
from openai import APIError, APITimeoutError, RateLimitError
from verifiers.utils.async_utils import maybe_await

REPO_ROOT = str(Path(__file__).resolve().parents[2])

Expand All @@ -39,6 +42,7 @@
DatasetSource,
JsonClassificationParser,
RolloutLogger,
get_response_text,
load_dataset_with_fallback,
)
except ImportError:
Expand All @@ -49,6 +53,7 @@
DatasetSource,
JsonClassificationParser,
RolloutLogger,
get_response_text,
load_dataset_with_fallback,
)

Expand Down Expand Up @@ -111,6 +116,128 @@ def __init__(self) -> None:
super().__init__(allowed_labels=["Benign", "Malicious", "Abstain"])


def format_completion_for_judge(parser: NetworkLogParser, completion) -> str:
"""Format a completion for judge prompts with full structured context.

JudgeRubric normally passes parser.parse_answer(completion) into the judge
prompt, which only yields the label (for example "Benign"). That erased the
confidence/format information the judge prompt explicitly asks about, causing
the judge to reject otherwise-correct responses as not valid JSON.

We instead pass a canonical JSON rendering of the parsed response so the
judge can evaluate label, confidence, and structure together. If parsing
fails, fall back to the raw response text for debugging resilience.
"""
data = parser._parse_json(completion)
if data:
canonical = {key: data[key] for key in ("label", "confidence", "rationale") if key in data}
if canonical:
return json.dumps(canonical, sort_keys=True)
return get_response_text(completion)


class _JudgePromptParser:
"""Thin parser wrapper used only when rendering the judge prompt."""

def __init__(self, base_parser: NetworkLogParser) -> None:
self.base_parser = base_parser

def parse_answer(self, completion) -> str:
return format_completion_for_judge(self.base_parser, completion)

def __getattr__(self, name: str):
return getattr(self.base_parser, name)


class StructuredResponseJudgeRubric(vf.JudgeRubric):
"""JudgeRubric that feeds structured JSON to the judge prompt.

Upstream JudgeRubric calls parser.parse_answer(completion), which is too
lossy for this environment because the judge prompt evaluates JSON validity
and confidence as well as the label. We compute the structured response
directly inside judge() so concurrent scoring cannot mutate shared rubric
state.
"""

def __init__(self, parser: NetworkLogParser, **kwargs) -> None:
super().__init__(parser=parser, **kwargs)
self._prompt_parser = _JudgePromptParser(parser)

async def judge(self, prompt, completion, answer, state=None) -> str:
if isinstance(prompt, list):
last_msg = prompt[-1]
if isinstance(last_msg, dict) and "content" in last_msg:
question = str(last_msg["content"])
else:
question = ""
else:
question = str(prompt)

response = self._prompt_parser.parse_answer(completion)
judge_prompt = self.judge_prompt.format(question=question, answer=answer, response=response)
cached = state.get("judge_response") if state else None
if isinstance(cached, dict) and judge_prompt in cached:
return cached[judge_prompt]

judge_args = dict(self.judge_sampling_args or {})
if "max_tokens" in judge_args:
if judge_args["max_tokens"] is None:
judge_args.pop("max_tokens")
else:
judge_args["max_completion_tokens"] = judge_args.pop("max_tokens")
if "max_completion_tokens" in judge_args and judge_args["max_completion_tokens"] is None:
judge_args.pop("max_completion_tokens")
judge_args = {k: v for k, v in judge_args.items() if v is not None}

try:
judge_response = await maybe_await(
self.judge_client.chat.completions.create,
model=self.judge_model,
messages=[{"role": "user", "content": judge_prompt}],
**judge_args,
)
judge_response = str(judge_response.choices[0].message.content)
except RateLimitError as e:
self.logger.warning(
f"Rate limit exceeded when calling judge model '{self.judge_model}'. "
f"Try reducing concurrency or waiting before retrying. Error: {str(e)}"
)
raise RuntimeError(
f"Judge model rate limit exceeded. Try reducing concurrency or waiting before retrying. "
f"Model: {self.judge_model}, Error: {str(e)}"
) from e
except APITimeoutError as e:
self.logger.warning(
f"Timeout when calling judge model '{self.judge_model}'. "
f"Increase timeout in judge_sampling_args or check model responsiveness. Error: {str(e)}"
)
raise RuntimeError(
f"Judge model timeout. Increase timeout in judge_sampling_args or check model responsiveness. "
f"Model: {self.judge_model}, Error: {str(e)}"
) from e
except APIError as e:
self.logger.warning(
f"API error when calling judge model '{self.judge_model}'. "
f"Check model availability and API key. Error: {str(e)}"
)
raise RuntimeError(
f"Judge model API error. Check model availability and API key. "
f"Model: {self.judge_model}, Error: {str(e)}"
) from e
except Exception as e:
self.logger.warning(f"Unexpected error when calling judge model '{self.judge_model}'. Error: {str(e)}")
raise RuntimeError(
f"Unexpected error when calling judge model '{self.judge_model}'. Error: {str(e)}"
) from e

if state:
if not isinstance(cached, dict):
cached = {}
cached[judge_prompt] = judge_response
state["judge_response"] = cached
return judge_response


DEFAULT_ENV_NAME = "sv-env-network-logs-judge"


Expand Down Expand Up @@ -260,7 +387,7 @@ def _create_synthetic_dataset():

# Use JudgeRubric instead of executable Rubric
# The judge_reward function uses the judge callable injected by JudgeRubric
rubric = vf.JudgeRubric(
rubric = StructuredResponseJudgeRubric(
parser=parser,
judge_model=judge_model,
judge_prompt=JUDGE_PROMPT,
Expand Down
30 changes: 30 additions & 0 deletions environments/sv-env-netlogs-judge/sv_netlogs_judge_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,36 @@ def test_custom_judge_model(self) -> None:
judge_rubric = self._get_judge_rubric(env)
assert judge_rubric.judge_model == "gpt-4.1-mini"

def test_judge_prompt_receives_structured_json_response(self) -> None:
env = load_environment(dataset_name="synthetic", max_examples=1)
judge_rubric = self._get_judge_rubric(env)

captured = {}

async def fake_create(**kwargs):
captured["prompt"] = kwargs["messages"][0]["content"]
return Mock(choices=[Mock(message=Mock(content="yes"))])

judge_rubric.judge_client = Mock()
judge_rubric.judge_client.chat = Mock()
judge_rubric.judge_client.chat.completions = Mock()
judge_rubric.judge_client.chat.completions.create = fake_create

result = asyncio.run(
judge_rubric.judge(
prompt=[{"role": "user", "content": "sample log"}],
completion='{"label": "Benign", "confidence": 0.95, "rationale": "normal traffic"}',
answer="Benign",
state={},
)
)

assert result == "yes"
prompt_text = captured["prompt"]
assert '"label": "Benign"' in prompt_text
assert '"confidence": 0.95' in prompt_text
assert "Model's parsed response: Benign" not in prompt_text

def test_custom_env_name_updates_logging_and_ids(self) -> None:
logger = Mock()
logger.enabled = True
Expand Down
2 changes: 1 addition & 1 deletion environments/sv-env-network-logs/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ tags = [
"train",
"eval",
]
version = "0.2.14"
version = "0.2.15"
requires-python = ">=3.12"
dependencies = [
"verifiers>=0.1.9",
Expand Down
130 changes: 129 additions & 1 deletion environments/sv-env-network-logs/sv_env_network_logs_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@

from __future__ import annotations

import json
from pathlib import Path

from openai import APIError, APITimeoutError, RateLimitError
from verifiers.utils.async_utils import maybe_await

try:
# Try importing from installed package first
from sv_shared import weave_init # type: ignore # noqa: F401
Expand All @@ -38,6 +42,7 @@
DatasetSource,
JsonClassificationParser,
RolloutLogger,
get_response_text,
load_dataset_with_fallback,
)
except ImportError:
Expand All @@ -49,6 +54,7 @@
DatasetSource,
JsonClassificationParser,
RolloutLogger,
get_response_text,
load_dataset_with_fallback,
)

Expand Down Expand Up @@ -111,6 +117,128 @@ def __init__(self) -> None:
super().__init__(allowed_labels=["Benign", "Malicious", "Abstain"])


def format_completion_for_judge(parser: NetworkLogParser, completion) -> str:
"""Format a completion for judge prompts with full structured context.

JudgeRubric normally passes parser.parse_answer(completion) into the judge
prompt, which only yields the label (for example "Benign"). That erased the
confidence/format information the judge prompt explicitly asks about, causing
the judge to reject otherwise-correct responses as not valid JSON.

We instead pass a canonical JSON rendering of the parsed response so the
judge can evaluate label, confidence, and structure together. If parsing
fails, fall back to the raw response text for debugging resilience.
"""
data = parser._parse_json(completion)
if data:
canonical = {key: data[key] for key in ("label", "confidence", "rationale") if key in data}
if canonical:
return json.dumps(canonical, sort_keys=True)
return get_response_text(completion)


class _JudgePromptParser:
"""Thin parser wrapper used only when rendering the judge prompt."""

def __init__(self, base_parser: NetworkLogParser) -> None:
self.base_parser = base_parser

def parse_answer(self, completion) -> str:
return format_completion_for_judge(self.base_parser, completion)

def __getattr__(self, name: str):
return getattr(self.base_parser, name)


class StructuredResponseJudgeRubric(vf.JudgeRubric):
"""JudgeRubric that feeds structured JSON to the judge prompt.

Upstream JudgeRubric calls parser.parse_answer(completion), which is too
lossy for this environment because the judge prompt evaluates JSON validity
and confidence as well as the label. We compute the structured response
directly inside judge() so concurrent scoring cannot mutate shared rubric
state.
"""

def __init__(self, parser: NetworkLogParser, **kwargs) -> None:
super().__init__(parser=parser, **kwargs)
self._prompt_parser = _JudgePromptParser(parser)

async def judge(self, prompt, completion, answer, state=None) -> str:
if isinstance(prompt, list):
last_msg = prompt[-1]
if isinstance(last_msg, dict) and "content" in last_msg:
question = str(last_msg["content"])
else:
question = ""
else:
question = str(prompt)

response = self._prompt_parser.parse_answer(completion)
judge_prompt = self.judge_prompt.format(question=question, answer=answer, response=response)
cached = state.get("judge_response") if state else None
if isinstance(cached, dict) and judge_prompt in cached:
return cached[judge_prompt]

judge_args = dict(self.judge_sampling_args or {})
if "max_tokens" in judge_args:
if judge_args["max_tokens"] is None:
judge_args.pop("max_tokens")
else:
judge_args["max_completion_tokens"] = judge_args.pop("max_tokens")
if "max_completion_tokens" in judge_args and judge_args["max_completion_tokens"] is None:
judge_args.pop("max_completion_tokens")
judge_args = {k: v for k, v in judge_args.items() if v is not None}

try:
judge_response = await maybe_await(
self.judge_client.chat.completions.create,
model=self.judge_model,
messages=[{"role": "user", "content": judge_prompt}],
**judge_args,
)
judge_response = str(judge_response.choices[0].message.content)
except RateLimitError as e:
self.logger.warning(
f"Rate limit exceeded when calling judge model '{self.judge_model}'. "
f"Try reducing concurrency or waiting before retrying. Error: {str(e)}"
)
raise RuntimeError(
f"Judge model rate limit exceeded. Try reducing concurrency or waiting before retrying. "
f"Model: {self.judge_model}, Error: {str(e)}"
) from e
except APITimeoutError as e:
self.logger.warning(
f"Timeout when calling judge model '{self.judge_model}'. "
f"Increase timeout in judge_sampling_args or check model responsiveness. Error: {str(e)}"
)
raise RuntimeError(
f"Judge model timeout. Increase timeout in judge_sampling_args or check model responsiveness. "
f"Model: {self.judge_model}, Error: {str(e)}"
) from e
except APIError as e:
self.logger.warning(
f"API error when calling judge model '{self.judge_model}'. "
f"Check model availability and API key. Error: {str(e)}"
)
raise RuntimeError(
f"Judge model API error. Check model availability and API key. "
f"Model: {self.judge_model}, Error: {str(e)}"
) from e
except Exception as e:
self.logger.warning(f"Unexpected error when calling judge model '{self.judge_model}'. Error: {str(e)}")
raise RuntimeError(
f"Unexpected error when calling judge model '{self.judge_model}'. Error: {str(e)}"
) from e

if state:
if not isinstance(cached, dict):
cached = {}
cached[judge_prompt] = judge_response
state["judge_response"] = cached
return judge_response


def load_environment(
dataset_name: str = "iot23-train-dev-test-v1.jsonl",
dataset_source: DatasetSource = "auto",
Expand Down Expand Up @@ -254,7 +382,7 @@ def _create_synthetic_dataset():

# Use JudgeRubric instead of executable Rubric
# The judge_reward function uses the judge callable injected by JudgeRubric
rubric = vf.JudgeRubric(
rubric = StructuredResponseJudgeRubric(
parser=parser,
judge_model=judge_model,
judge_prompt=JUDGE_PROMPT,
Expand Down
Loading