-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathevaluate.py
More file actions
97 lines (82 loc) · 3.13 KB
/
evaluate.py
File metadata and controls
97 lines (82 loc) · 3.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import argparse
from metrics import metric_max_over_ground_truths, exact_match_score, match, rouge, f1
import json
from tqdm import tqdm
from rouge import Rouge
def load_file(file_path):
"""
Load data from a JSON or JSONL file.
Args:
file_path (str): Path to the file to load.
Returns:
list: List of dictionaries loaded from the file.
Raises:
ValueError: If the file format is not supported.
"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
if file_path.endswith('.json'):
data = json.load(f)
if isinstance(data, dict):
return [data]
elif isinstance(data, list):
return data
else:
raise ValueError("Unsupported JSON structure. Expecting list or dict.")
elif file_path.endswith('.jsonl'):
data = [json.loads(line.strip()) for line in f if line.strip()]
return data
else:
raise ValueError(f"Unsupported file format: {file_path}")
except Exception as e:
print(f"Error loading file {file_path}: {e}")
raise
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--results_file", type=str, help="File containing results with both generated responses and ground truth answers")
parser.add_argument("--metric", type=str, choices=["em", "accuracy", "match", "rouge", "f1"], help="Metric to use for evaluation")
return parser.parse_args()
def main():
args = get_args()
results = load_file(args.results_file)
scores = []
cnt=0
rouge_score = Rouge()
for item in tqdm(results):
response = item['response'][0]
question = item['question']
if 'asqa' in args.results_file:
answers= []
for ans in item['qa_pairs']:
answers.extend(ans['short_answers'])
else:
answers = item['answers']
if not answers:
print(f"Warning: No answers provided for ID {item.get('id', 'unknown')}")
continue
if args.metric == "em":
metric_result = metric_max_over_ground_truths(
exact_match_score, response, answers
)
elif args.metric == "accuracy":
response = response.replace('\n','').strip()[0]
answer = answers[0][0]
if response==answer:
metric_result = 1.0
else:
metric_result = 0.0
elif args.metric == "match":
metric_result = match(question, response, answers)
elif args.metric == "rouge":
metric_result = rouge(rouge_score, response, answers)
elif args.metric == "f1":
metric_result = f1(response, answers)
else:
raise NotImplementedError(f"Metric {args.metric} is not implemented.")
scores.append(metric_result)
if scores:
print(f'Overall result: {sum(scores) / len(scores)}')
else:
print("No scores were calculated. Please check your input file.")
if __name__ == "__main__":
main()