Skip to content

Commit dcb0ab5

Browse files
authored
Create custom task (#621)
* init * commit * Apply suggestions from code review * commit * commit * commit
1 parent 9c05a83 commit dcb0ab5

File tree

3 files changed

+338
-2
lines changed

3 files changed

+338
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
# MIT License
2+
3+
# Copyright (c) 2024 The HuggingFace Team
4+
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
22+
23+
24+
import logging
25+
import re
26+
27+
import numpy as np
28+
from aenum import extend_enum
29+
30+
from lighteval.metrics.metrics import Metrics
31+
from lighteval.metrics.metrics_sample import JudgeLLM
32+
from lighteval.metrics.utils.metric_utils import (
33+
CorpusLevelMetricGrouping,
34+
MetricCategory,
35+
MetricUseCase,
36+
)
37+
from lighteval.tasks.lighteval_task import LightevalTaskConfig
38+
from lighteval.tasks.requests import Doc
39+
40+
41+
logger = logging.getLogger(__name__)
42+
43+
JUDGE_ANSWER_SYSTEM_PROMPT = """You will be provided with the summary of a document, a piece of text, a question generated from that text, and the correct or "gold" answer to the question. Additionally, you will receive a model answer. Your task is to determine wether the model answer is correct using the provided "gold" answer as a reference.
44+
45+
# Steps
46+
47+
1. **Document Understanding**:
48+
- Analyze the provided document summary to grasp the context and main themes.
49+
50+
2. **Chunk Understanding**:
51+
- Examine the provided text (chunk) to understand its content.
52+
53+
3. **Question Understanding**:
54+
- Interpret the given question to fully comprehend what is being asked.
55+
56+
4. **Ground Truth Answer Understanding**:
57+
- Understand the provided ground truth answer, identifying its key points.
58+
59+
6. **Answer Understanding**:
60+
- Examine the Model Answer, identifying key points and assessing accuracy and factuality.
61+
62+
7. **Final Answer**:
63+
- 0 or 1 (0 if the model answer is incorrect, 1 if it is correct).
64+
65+
# Output Format
66+
67+
- Provide your final evaluation of whether the answer is correct within `<final_answer>` XML tags.
68+
- Include a detailed analysis for each part within the designated XML tags: `<document_understanding>`, `<chunk_understanding>`, `<question_understanding>`, `<ground_truth_answer_understanding>`, `<model_answer_understanding>`, and `<final_answer>`.
69+
70+
# Examples
71+
72+
**Input**:
73+
```xml
74+
<document_summary>
75+
[Summary]
76+
</document_summary>
77+
78+
<piece_of_text>
79+
[Text]
80+
</piece_of_text>
81+
82+
<question>
83+
[Question]
84+
</question>
85+
86+
<gold_answer>
87+
[Gold Answer]
88+
</gold_answer>
89+
90+
<model_answer>
91+
[Model Answer]
92+
</model_answer>
93+
```
94+
**Output**:
95+
```xml
96+
97+
<document_understanding>
98+
Understanding of the summary including key themes
99+
</document_understanding>
100+
101+
<chunk_understanding>
102+
Analysis of the piece of text
103+
</chunk_understanding>
104+
105+
<question_understanding>
106+
Comprehension of the question being asked
107+
</question_understanding>
108+
109+
<ground_truth_answer_understanding>
110+
Key points from the gold answer
111+
</ground_truth_answer_understanding>
112+
113+
<model_answer_understanding>
114+
Key points and accuracy of Answer A
115+
</model_answer_understanding>
116+
117+
<final_answer>
118+
1 or 0 (1 if the model answer is correct, 0 if it is incorrect)
119+
</final_answer>
120+
```
121+
122+
# Notes
123+
124+
- Always focus on key points and factual correctness as per the ground truth.
125+
- Avoid any biases and rely solely on the evidence presented.
126+
- Enclose all evaluations and analyses in the specified XML tags for clarity and structure."""
127+
128+
129+
JUDGE_ANSWER_USER_PROMPT = """<document_summary>
130+
{summary}
131+
</document_summary>
132+
133+
<piece_of_text>
134+
{chunk}
135+
</piece_of_text>
136+
137+
<question>
138+
{question}
139+
</question>
140+
141+
<gold_answer>
142+
{oracle_answer}
143+
</gold_answer>
144+
145+
<model_answer>
146+
{model_answer}
147+
</model_answer>"""
148+
149+
150+
def get_judge_prompt(question: str, answer: str, gold: str, **kwargs):
151+
chunk = kwargs.get("chunks", "")
152+
summary = kwargs.get("documents", "")
153+
154+
return [
155+
{"role": "system", "content": JUDGE_ANSWER_SYSTEM_PROMPT},
156+
{
157+
"role": "user",
158+
"content": JUDGE_ANSWER_USER_PROMPT.format(
159+
summary=summary, chunk=chunk, question=question, oracle_answer=gold, model_answer=answer
160+
),
161+
},
162+
]
163+
164+
165+
def process_judge_response_yourbench(response):
166+
# extract the final answer using regex from the response xml
167+
try:
168+
answer = re.search(r"<final_answer>(.*?)</final_answer>", response, re.DOTALL).group(1)
169+
return int(answer)
170+
except Exception as e:
171+
logger.error(f"Error processing judge response: {e}")
172+
return 0
173+
174+
175+
class JudgeLLMYourBench(JudgeLLM):
176+
def __init__(self):
177+
super().__init__(
178+
judge_model_name="gpt-4o-2024-08-06",
179+
template=get_judge_prompt,
180+
process_judge_response=process_judge_response_yourbench,
181+
judge_backend="openai",
182+
short_judge_name="yourbench_judge",
183+
)
184+
185+
def compute(self, sample_ids: list[str], responses: list, formatted_docs: list[Doc]) -> list[dict[str, float]]:
186+
# If we are evaluating a multiturn task, we need to have specific field in the formatted doc
187+
questions = [formatted_doc.specific["question"] for formatted_doc in formatted_docs]
188+
golds = [formatted_doc.get_golds()[0] for formatted_doc in formatted_docs]
189+
predictions = [response[0].result[0] for response in responses]
190+
options = [None] * len(questions)
191+
chunks = [formatted_doc.specific["chunks"][0] for formatted_doc in formatted_docs]
192+
documents = [formatted_doc.specific["document"] for formatted_doc in formatted_docs]
193+
194+
score, _, _ = self.judge.evaluate_answer_batch(
195+
questions, predictions, options, golds, chunks=chunks, documents=documents
196+
)
197+
198+
metrics = []
199+
for i in range(len(sample_ids)):
200+
metrics.append(
201+
{
202+
"accuracy": score[i],
203+
}
204+
)
205+
206+
return metrics
207+
208+
209+
ZEROSHOT_QA_USER_PROMPT = """Answer the following question:
210+
211+
<question>
212+
{question}
213+
</question>
214+
215+
Enclose your full answer in <answer> XML tags. For example:
216+
217+
<answer>
218+
[your answer here]
219+
</answer>"""
220+
221+
222+
def yourbench_prompt(line, task_name: str = ""):
223+
return Doc(
224+
task_name=task_name,
225+
query=ZEROSHOT_QA_USER_PROMPT.format(question=line["question"]),
226+
choices=[line["ground_truth_answer"]],
227+
gold_index=0,
228+
specific={
229+
"question_category": line["question_category"],
230+
"kind": line["kind"],
231+
"estimated_difficulty": line["estimated_difficulty"],
232+
"document_id": line["document_id"],
233+
"question_generating_model": line["question_generating_model"],
234+
"chunks": line["chunks"],
235+
"question": line["question"],
236+
"document": line["document"],
237+
},
238+
)
239+
240+
241+
yourbench_metrics = CorpusLevelMetricGrouping(
242+
metric_name=["accuracy"],
243+
higher_is_better={"accuracy": True},
244+
category=MetricCategory.LLM_AS_JUDGE,
245+
use_case=MetricUseCase.ACCURACY,
246+
sample_level_fn=JudgeLLMYourBench().compute,
247+
corpus_level_fn={"accuracy": np.mean},
248+
)
249+
extend_enum(Metrics, "yourbench_metrics", yourbench_metrics)
250+
251+
yourbench = LightevalTaskConfig(
252+
name=HF_TASK_NAME, # noqa: F821
253+
suite=["custom"],
254+
prompt_function=yourbench_prompt,
255+
hf_repo=HF_DATASET_NAME, # noqa: F821
256+
hf_subset="lighteval_single_shot_questions",
257+
hf_avail_splits=["train"],
258+
evaluation_splits=["train"],
259+
few_shots_split=None,
260+
few_shots_select=None,
261+
generation_size=8192,
262+
metric=[Metrics.yourbench_metrics],
263+
stop_sequence=[],
264+
trust_dataset=True,
265+
version=0,
266+
)
267+
268+
269+
TASKS_TABLE = [yourbench]

src/lighteval/main_tasks.py

+22
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
22+
import logging
2223
import os
2324
from typing import Optional
2425

@@ -75,3 +76,24 @@ def list(custom_tasks: Annotated[Optional[str], Option(help="Path to a file with
7576

7677
registry = Registry(cache_dir=CACHE_DIR, custom_tasks=custom_tasks)
7778
registry.print_all_tasks()
79+
80+
81+
@app.command()
82+
def create(template: str, task_name: str, dataset_name: str):
83+
"""
84+
Create a new task
85+
"""
86+
logger = logging.getLogger(__name__)
87+
88+
logger.info(f"Creating task for dataset {dataset_name}")
89+
90+
with open(template, "r") as f:
91+
content = f.read()
92+
93+
content = content.replace("HF_TASK_NAME", task_name)
94+
content = content.replace("HF_DATASET_NAME", dataset_name)
95+
96+
with open(f"custom_{task_name}_task.py", "w+") as f:
97+
f.write(content)
98+
99+
logger.info(f"Task created in custom_{task_name}_task.py")

src/lighteval/metrics/llm_as_judge.py

+47-2
Original file line numberDiff line numberDiff line change
@@ -142,19 +142,64 @@ def __lazy_load_client(self):
142142
case _:
143143
return lambda x: x
144144

145+
def dict_of_lists_to_list_of_dicts(self, dict_of_lists):
146+
"""
147+
Transform a dictionary of lists into a list of dictionaries.
148+
149+
Each dictionary in the output list will contain one element from each list in the input dictionary,
150+
with the same keys as the input dictionary.
151+
152+
Args:
153+
dict_of_lists: A dictionary where each value is a list.
154+
All lists are expected to have the same length.
155+
156+
Returns:
157+
A list of dictionaries.
158+
159+
Example:
160+
>>> dict_of_lists_to_list_of_dicts({'k': [1, 2, 3], 'k2': ['a', 'b', 'c']})
161+
[{'k': 1, 'k2': 'a'}, {'k': 2, 'k2': 'b'}, {'k': 3, 'k2': 'c'}]
162+
"""
163+
# Check if input is empty
164+
if not dict_of_lists:
165+
return None
166+
167+
# Get all list lengths to ensure they match
168+
list_lengths = [len(values) for values in dict_of_lists.values()]
169+
170+
# Ensure all lists have the same length
171+
if len(set(list_lengths)) > 1:
172+
raise ValueError("All lists in the input dictionary must have the same length")
173+
174+
# Get the length of the lists
175+
n = list_lengths[0] if list_lengths else 0
176+
177+
# Create list of dictionaries
178+
result = []
179+
for i in range(n):
180+
new_dict = {key: values[i] for key, values in dict_of_lists.items()}
181+
result.append(new_dict)
182+
183+
return result
184+
145185
def evaluate_answer_batch(
146186
self,
147187
questions: list[str],
148188
answers: list[str],
149189
options: list[list[str]] | list[None],
150190
golds: list[str] | list[None],
191+
**kwargs,
151192
):
152193
judge_function = self.__lazy_load_client()
153194

195+
kwargss = self.dict_of_lists_to_list_of_dicts(kwargs)
196+
if kwargss is None:
197+
kwargss = [{} for _ in range(len(questions))]
198+
154199
# enumerate over questions answers options and golds to make the
155200
prompts = [
156-
self.template(question=q, answer=a, options=o, gold=g)
157-
for q, a, o, g in zip(questions, answers, options, golds)
201+
self.template(question=q, answer=a, options=o, gold=g, **k)
202+
for q, a, o, g, k in zip(questions, answers, options, golds, kwargss)
158203
]
159204
responses = judge_function(prompts)
160205
scores = [self.process_judge_response(response) for response in responses]

0 commit comments

Comments
 (0)