Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 151 additions & 5 deletions ais_bench/benchmark/cli/workers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import os.path as osp
import copy
import shutil
from abc import ABC, abstractmethod
from collections import defaultdict

Expand All @@ -8,12 +10,15 @@
from ais_bench.benchmark.registry import PARTITIONERS, RUNNERS, build_from_cfg
from ais_bench.benchmark.utils.config.run import get_config_type
from ais_bench.benchmark.utils.logging.logger import AISLogger
from ais_bench.benchmark.utils.logging.exceptions import PredictionInvalidException
from ais_bench.benchmark.utils.logging.error_codes import TMAN_CODES
from ais_bench.benchmark.partitioners import NaivePartitioner
from ais_bench.benchmark.runners import LocalRunner
from ais_bench.benchmark.tasks import OpenICLEvalTask, OpenICLApiInferTask, OpenICLInferTask
from ais_bench.benchmark.summarizers import DefaultSummarizer, DefaultPerfSummarizer
from ais_bench.benchmark.calculators import DefaultPerfMetricCalculator
from ais_bench.benchmark.cli.utils import fill_model_path_if_datasets_need
from ais_bench.benchmark.utils.file.file import load_jsonl, dump_jsonl

logger = AISLogger()

Expand Down Expand Up @@ -108,6 +113,117 @@ def _update_tasks_cfg(self, tasks, cfg: ConfigDict):
task.attack = cfg.attack


class JudgeInfer(BaseWorker):
def update_cfg(self, cfg: ConfigDict) -> None:
def get_task_type() -> str:
if cfg["models"][0]["attr"] == "service":
return get_config_type(OpenICLApiInferTask)
else:
return get_config_type(OpenICLInferTask)

new_cfg = dict(
judge_infer=dict(
partitioner=dict(type=get_config_type(NaivePartitioner)),
runner=dict(
max_num_workers=self.args.max_num_workers,
max_workers_per_gpu=self.args.max_workers_per_gpu,
debug=self.args.debug,
task=dict(type=get_task_type()),
type=get_config_type(LocalRunner),
),
),
)

cfg.merge_from_dict(new_cfg)
if cfg.cli_args.debug:
cfg.judge_infer.runner.debug = True
cfg.judge_infer.partitioner["out_dir"] = osp.join(cfg["work_dir"], "predictions/")
return cfg

def do_work(self, cfg: ConfigDict):
partitioner = PARTITIONERS.build(cfg.judge_infer.partitioner)
logger.info("Starting inference tasks...")
tasks = partitioner(cfg)

# delete the tasks without judge_infer_cfg
new_tasks = []
for task in tasks:
if task["datasets"][0][0].get("judge_infer_cfg"):
new_tasks.append(task)
tasks = new_tasks
if len(tasks) == 0:
return

# update tasks cfg before run
self._update_tasks_cfg(tasks, cfg)

if (
cfg.get("cli_args", {}).get("merge_ds", False)
or cfg.get("cli_args", {}).get("mode") == "perf" # performance mode will enable merge datasets by default
):
logger.info("Merging datasets with the same model and inferencer...")
tasks = self._merge_datasets(tasks)

runner = RUNNERS.build(cfg.judge_infer.runner)
runner(tasks)
self._result_post_process(tasks, cfg)
logger.info("Inference tasks completed.")

def _merge_datasets(self, tasks):
# merge datasets with the same model, dataset type and inferencer
task_groups = defaultdict(list)
for task in tasks:
key = (
task["models"][0]["abbr"] # same model
+ "_"
+ str(task['datasets'][0][0]['type']) # same dataset type
+ "_"
+ str(task["datasets"][0][0]["infer_cfg"]["inferencer"]) # same inferencer with the same args
)
Comment on lines +176 to +182

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current method of generating a key for grouping tasks by concatenating strings with _ is not robust. If any of the components (like task["models"][0]["abbr"]) contain an underscore, it could lead to incorrect grouping. Using a tuple as the key would be a safer and more reliable approach.

            key = (
                task["models"][0]["abbr"],  # same model
                str(task['datasets'][0][0]['type']),  # same dataset type
                str(task["datasets"][0][0]["infer_cfg"]["inferencer"]),  # same inferencer with the same args
            )

task_groups[key].append(task)
new_tasks = []
for key, task_group in task_groups.items():
new_task = copy.deepcopy(task_group[0])
if len(task_group) > 1:
for t in task_group[1:]:
new_task["datasets"][0].extend(t["datasets"][0])
new_tasks.append(new_task)
return new_tasks

def _update_tasks_cfg(self, tasks, cfg: ConfigDict):
# update parameters to correct sub cfg
if hasattr(cfg, "attack"):
for task in tasks:
cfg.attack.dataset = task.datasets[0][0].abbr
task.attack = cfg.attack

# update judge cfgs to model cfgs and data
for task in tasks:
task["datasets"][0][0]["predictions_path"] = osp.join(cfg.judge_infer.partitioner.out_dir, task["models"][0]["abbr"], f'{task["datasets"][0][0]["abbr"]}.jsonl')
if not osp.exists(task["datasets"][0][0]["predictions_path"]):
raise PredictionInvalidException(TMAN_CODES.UNKNOWN_ERROR, f"Predictions path {task['datasets'][0][0]['predictions_path']} does not exist.")
task["datasets"][0][0]["abbr"] = f'{task["datasets"][0][0]["abbr"]}-{task["datasets"][0][0]["judge_infer_cfg"]["judge_model"]["abbr"]}'
model_abbr = task["models"][0]["abbr"]
task["models"][0] = task["datasets"][0][0]["judge_infer_cfg"].pop("judge_model")
task["models"][0]["abbr"] = model_abbr
task["datasets"][0][0]["type"] = task["datasets"][0][0]["judge_infer_cfg"].pop("judge_dataset_type")
task["datasets"][0][0]["reader_cfg"] = task["datasets"][0][0]["judge_infer_cfg"].pop("judge_reader_cfg")
task["datasets"][0][0]["infer_cfg"] = task["datasets"][0][0].pop("judge_infer_cfg")

def _result_post_process(self, tasks, cfg: ConfigDict):
# Reconstruct the judge infer predictions to normal predictions format
for task in tasks:
model_org_prediction_path = task["datasets"][0][0]["predictions_path"]
model_preds: dict = {item["uuid"]: item for item in load_jsonl(model_org_prediction_path)}
judge_org_prediction_path = osp.join(cfg.judge_infer.partitioner.out_dir, task["models"][0]["abbr"], f'{task["datasets"][0][0]["abbr"]}.jsonl')
judge_preds: list = load_jsonl(judge_org_prediction_path)
for i, pred in enumerate(judge_preds):
uuid = pred["gold"]
judge_preds[i]["id"] = model_preds[uuid]["id"]
os.remove(judge_org_prediction_path)
dump_jsonl(judge_preds, judge_org_prediction_path)
Comment on lines +223 to +224

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of updating the prediction file by removing it and then writing to the same path is unsafe. If the dump_jsonl operation fails for any reason (e.g., disk full, permission error), the original prediction file will be lost. A safer pattern is to write the new content to a temporary file and then atomically rename it to the final destination.

            temp_judge_org_prediction_path = judge_org_prediction_path + ".tmp"
            dump_jsonl(judge_preds, temp_judge_org_prediction_path)
            os.replace(temp_judge_org_prediction_path, judge_org_prediction_path)



class Eval(BaseWorker):
def update_cfg(self, cfg: ConfigDict) -> None:
new_cfg = dict(
Expand Down Expand Up @@ -138,7 +254,7 @@ def do_work(self, cfg: ConfigDict):
logger.info("Starting evaluation tasks...")
tasks = partitioner(cfg)

# update tasks cfg before run
# Update tasks cfg before run
self._update_tasks_cfg(tasks, cfg)

runner = RUNNERS.build(cfg.eval.runner)
Expand All @@ -148,11 +264,33 @@ def do_work(self, cfg: ConfigDict):
runner(task_part)
else:
runner(tasks)
self._result_post_process(tasks, cfg)
logger.info("Evaluation tasks completed.")

def _update_tasks_cfg(self, tasks, cfg: ConfigDict):
# update parameters to correct sub cfg
pass
# Replace default model config to judge model config
self.judge_result_paths = {}
for task in tasks:
if task["datasets"][0][0].get("judge_infer_cfg"):
new_dataset_abbr = f'{task["datasets"][0][0]["abbr"]}-{task["datasets"][0][0]["judge_infer_cfg"]["judge_model"]["abbr"]}'
org_dataset_abbr = task["datasets"][0][0]["abbr"]
self.judge_result_paths[new_dataset_abbr] = org_dataset_abbr
task["datasets"][0][0]["abbr"] = new_dataset_abbr
task["datasets"][0][0].pop("judge_infer_cfg")

def _result_post_process(self, tasks, cfg: ConfigDict):
# Copy judge infer result to normal name

for task in tasks:
if task["datasets"][0][0]["abbr"] in self.judge_result_paths.keys():
cur_results_path = osp.join(cfg.eval.partitioner.out_dir, task["models"][0]["abbr"], f'{task["datasets"][0][0]["abbr"]}.jsonl')
final_org_results_path = osp.join(cfg.eval.partitioner.out_dir, task["models"][0]["abbr"], f'{self.judge_result_paths[task["datasets"][0][0]["abbr"]]}.jsonl')
if os.path.exists(final_org_results_path):
os.remove(final_org_results_path)

if os.path.exists(cur_results_path):
# 基于cur_results_path的文件复制一份final_org_results_path

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment is in Chinese. To maintain consistency and readability for all contributors, please write comments in English.

                    # Copy the current results file to the final original results path

shutil.copy(cur_results_path, final_org_results_path)


class AccViz(BaseWorker):
Expand All @@ -171,6 +309,7 @@ def update_cfg(self, cfg: ConfigDict) -> None:
def do_work(self, cfg: ConfigDict) -> int:
logger.info("Summarizing evaluation results...")
summarizer_cfg = cfg.get("summarizer", {})
cfg = self._cfg_pre_process(cfg)

# For subjective summarizer
if summarizer_cfg.get("function", None):
Expand Down Expand Up @@ -203,6 +342,13 @@ def do_work(self, cfg: ConfigDict) -> int:
summarizer = build_from_cfg(summarizer_cfg)
summarizer.summarize(time_str=self.args.cfg_time_str)

def _cfg_pre_process(self, cfg: ConfigDict) -> None:
for i, dataset in enumerate(cfg.datasets):
if dataset.get("judge_infer_cfg"):
cfg.datasets[i]["abbr"] = f'{cfg.datasets[i]["abbr"]}-{cfg.datasets[i]["judge_infer_cfg"]["judge_model"]["abbr"]}'
cfg.datasets[i].pop("judge_infer_cfg")
return cfg


class PerfViz(BaseWorker):
def update_cfg(self, cfg: ConfigDict) -> None:
Expand Down Expand Up @@ -233,9 +379,9 @@ def do_work(self, cfg: ConfigDict) -> int:


WORK_FLOW = dict(
all=[Infer, Eval, AccViz],
all=[Infer, JudgeInfer, Eval, AccViz],
infer=[Infer],
eval=[Eval, AccViz],
eval=[JudgeInfer, Eval, AccViz],
viz=[AccViz],
perf=[Infer, PerfViz],
perf_viz=[PerfViz],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from ais_bench.benchmark.openicl.icl_prompt_template import PromptTemplate
from ais_bench.benchmark.openicl.icl_retriever import ZeroRetriever
from ais_bench.benchmark.openicl.icl_inferencer import GenInferencer
from ais_bench.benchmark.models import VLLMCustomAPIChat
from ais_bench.benchmark.utils.postprocess.model_postprocessors import extract_non_reasoning_content
from ais_bench.benchmark.datasets import (
Aime2025Dataset,
Aime2025JDGDataset,
)
from ais_bench.benchmark.datasets.utils.llm_judge import get_a_or_b, LLMJudgeCorrectEvaluator


aime2025_reader_cfg = dict(input_columns=["question"], output_column="answer")


aime2025_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role="HUMAN",
prompt="{question}\nRemember to put your final answer within \\boxed{}.",
),
],
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)

GRADER_TEMPLATE = """
Please as a grading expert, judge whether the final answers given by the candidates below are consistent with the standard answers, that is, whether the candidates answered correctly.

Here are some evaluation criteria:
1. Please refer to the given standard answer. You don't need to re-generate the answer to the question because the standard answer has been given. You only need to judge whether the candidate's answer is consistent with the standard answer according to the form of the question. Don't try to answer the original question. You can assume that the standard answer is definitely correct.
2. Because the candidate's answer may be different from the standard answer in the form of expression, before making a judgment, please understand the question and the standard answer first, and then judge whether the candidate's answer is correct, but be careful not to try to answer the original question.
3. Some answers may contain multiple items, such as multiple-choice questions, multiple-select questions, fill-in-the-blank questions, etc. As long as the answer is the same as the standard answer, it is enough. For multiple-select questions and multiple-blank fill-in-the-blank questions, the candidate needs to answer all the corresponding options or blanks correctly to be considered correct.
4. Some answers may be expressed in different ways, such as some answers may be a mathematical expression, some answers may be a textual description, as long as the meaning expressed is the same. And some formulas are expressed in different ways, but they are equivalent and correct.
5. If the prediction is given with \\boxed{}, please ignore the \\boxed{} and only judge whether the candidate's answer is consistent with the standard answer.
6. If the candidate's answer is semantically incomplete at the end, please judge it as inconsistent.

Please judge whether the following answers are consistent with the standard answer based on the above criteria. Grade the predicted answer of this new question as one of:
A: Means the answer is consistent with the standard answer.
B: Means the answer is inconsistent with the standard answer.
Just return the letters "A" or "B", with no text around it.

Here is your task. Simply reply with either CORRECT, INCORRECT. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer.


<Original Question Begin>: \n{question}\n<Original Question End>\n\n
<Gold Target Begin>: \n{answer}\n<Gold Target End>\n\n
<Predicted Answer Begin>: \n{model_answer}\n<Predicted End>\n\n

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

This GRADER_TEMPLATE is susceptible to prompt injection due to the direct embedding of the untrusted {model_answer}. A malicious model could exploit this to manipulate evaluation results. It is critical to use clear delimiters and sanitize the model output to prevent such attacks. Furthermore, the template contains conflicting instructions, asking for "A" or "B" in multiple places but also "CORRECT, INCORRECT" on line 48. This inconsistency can confuse the language model and lead to unreliable evaluations. The instructions should be aligned to consistently expect "A" or "B" to match the post-processing logic.


Judging the correctness of candidates' answers, please return the the letters "A" or "B" first before your thinking:
""".strip()

aime2025_judge_infer_cfg = dict(
judge_reader_cfg = dict(input_columns=["question", "answer", "model_answer"], output_column="model_pred_uuid"),
judge_model=dict(
attr="service",
type=VLLMCustomAPIChat,
abbr="judge", # Be added after dataset abbr
path="",
model="",
stream=True,
request_rate=0,
use_timestamp=False,
retry=2,
api_key="",
host_ip="localhost",
host_port=8080,
url="",
max_out_len=512,
batch_size=1,
trust_remote_code=False,
generation_kwargs=dict(
temperature=0.01,
ignore_eos=False,
),
pred_postprocessor=dict(type=extract_non_reasoning_content),
),
judge_dataset_type=Aime2025JDGDataset,
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.",
)
],
round=[
dict(role='HUMAN', prompt=GRADER_TEMPLATE),
],
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)

aime2025_eval_cfg = dict(
evaluator=dict(type=LLMJudgeCorrectEvaluator),
pred_postprocessor=dict(type=get_a_or_b),
)

aime2025_datasets = [
dict(
abbr="aime2025",
type=Aime2025Dataset,
path="ais_bench/datasets/aime2025/aime2025.jsonl",
reader_cfg=aime2025_reader_cfg,
infer_cfg=aime2025_infer_cfg,
judge_infer_cfg=aime2025_judge_infer_cfg,
eval_cfg=aime2025_eval_cfg,
)
]
13 changes: 9 additions & 4 deletions ais_bench/benchmark/datasets/aime2025.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import json

import json
from datasets import Dataset

from ais_bench.benchmark.registry import LOAD_DATASET
from ais_bench.benchmark.datasets.utils.datasets import get_data_path
from ais_bench.benchmark.datasets.utils.llm_judge import LLMJudgeDataset

from .base import BaseDataset

from ais_bench.benchmark.datasets.base import BaseDataset

@LOAD_DATASET.register_module()
class Aime2025Dataset(BaseDataset):

@staticmethod
def load(path, **kwargs):
path = get_data_path(path)
Expand All @@ -20,3 +19,9 @@ def load(path, **kwargs):
line = json.loads(line.strip())
dataset.append(line)
return Dataset.from_list(dataset)


@LOAD_DATASET.register_module()
class Aime2025JDGDataset(LLMJudgeDataset):
def _get_dataset_class(self):
return Aime2025Dataset
Loading
Loading