-
Notifications
You must be signed in to change notification settings - Fork 124
Expand file tree
/
Copy pathutils.py
More file actions
293 lines (247 loc) · 10.4 KB
/
utils.py
File metadata and controls
293 lines (247 loc) · 10.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
#!/usr/bin/env python3
import os
import re
import json
import openai
import tiktoken
from dotenv import load_dotenv
from typing import List, Dict, Any, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
# Load environment variables from .env file
load_dotenv()
def initialize_clients(api_provider):
"""Initialize separate clients for generator, reflector, and curator"""
if api_provider == "sambanova":
# Use SambaNova API
base_url = "https://api.sambanova.ai/v1"
api_key = os.getenv('SAMBANOVA_API_KEY', '')
if not api_key:
raise ValueError("SambaNova api key not found in environment variables")
elif api_provider == "together":
# Use Together API
base_url = "https://api.together.xyz/v1"
api_key = os.getenv('TOGETHER_API_KEY', '')
if not api_key:
raise ValueError("Together api key not found in environment variables")
elif api_provider == "openai":
# Use OpenAI API
base_url = "https://api.openai.com/v1"
api_key = os.getenv('OPENAI_API_KEY', '')
if not api_key:
raise ValueError("OpenAI api key not found in environment variables")
elif api_provider == "commonstack":
# Use Commonstack API
base_url = "https://api.commonstack.ai/v1"
api_key = os.getenv('COMMONSTACK_API_KEY', '')
if not api_key:
raise ValueError("Commonstack api key not found in environment variables")
else:
raise ValueError(
f"Invalid api_provider name: {api_provider}. Must be 'sambanova', 'together', 'openai', or 'commonstack'"
)
generator_client = openai.OpenAI(api_key=api_key, base_url=base_url)
reflector_client = openai.OpenAI(api_key=api_key, base_url=base_url)
curator_client = openai.OpenAI(api_key=api_key, base_url=base_url)
print(f"Using {api_provider} API for all models")
return generator_client, reflector_client, curator_client
def get_section_slug(section_name):
"""Convert section name to slug format (3-5 chars)"""
# Common section mappings - updated to match original sections
slug_map = {
"financial_strategies_and_insights": "fin",
"formulas_and_calculations": "calc",
"code_snippets_and_templates": "code",
"common_mistakes_to_avoid": "err",
"problem_solving_heuristics": "prob",
"context_clues_and_indicators": "ctx",
"others": "misc",
"meta_strategies": "meta"
}
# Clean and convert to snake_case
clean_name = section_name.lower().strip().replace(" ", "_").replace("&", "and")
if clean_name in slug_map:
return slug_map[clean_name]
# Generate slug from first letters
words = clean_name.split("_")
if len(words) == 1:
return words[0][:4]
else:
return "".join(w[0] for w in words[:5])
def extract_boxed_content(text):
"""Helper function to extract content from \\boxed{} format"""
pattern = r'\\boxed\{'
match = re.search(pattern, text)
if not match:
return None
start = match.end() - 1 # Position of opening brace
brace_count = 0
i = start
while i < len(text):
if text[i] == '{':
brace_count += 1
elif text[i] == '}':
brace_count -= 1
if brace_count == 0:
return text[start + 1:i] # Content between braces
i += 1
return None
def extract_answer(response):
"""Extract final answer from model response"""
try:
# First try JSON parsing
parsed = json.loads(response)
answer = str(parsed.get("final_answer", "No final answer found"))
return answer
except (json.JSONDecodeError, KeyError, AttributeError):
# JSON parsing failed, use fallback logic
matches = re.findall(r"Finish\[(.*?)\]", response)
if matches:
answer = matches[-1]
return answer
# Try to get final answer from JSON style response with regex matching
# Try double quotes first
matches = re.findall(r'"final_answer"\s*:\s*"([^"]*)"', response)
if matches:
answer = matches[-1]
return answer
# Try single quotes
matches = re.findall(r"'final_answer'\s*:\s*'([^']*)'", response)
if matches:
answer = matches[-1]
return answer
# Handle JSON format without quotes (for simple expressions)
matches = re.findall(r'[\'"]final_answer[\'"]\s*:\s*([^,}]+)', response)
if matches:
answer = matches[-1].strip()
# Clean up trailing characters
answer = re.sub(r'[,}]*$', '', answer)
return answer
# Fallback for "The final answer is: X" pattern with boxed
final_answer_pattern = r'[Tt]he final answer is:?\s*\$?\\boxed\{'
match = re.search(final_answer_pattern, response)
if match:
# Extract boxed content starting from this match
remaining_text = response[match.start():]
boxed_content = extract_boxed_content(remaining_text)
if boxed_content:
return boxed_content
# More general pattern for "final answer is X"
matches = re.findall(r'[Tt]he final answer is:?\s*([^\n.]+)', response)
if matches:
answer = matches[-1].strip()
# Clean up common formatting
answer = re.sub(r'^\$?\\boxed\{([^}]+)\}\$?$', r'\1', answer)
answer = answer.replace('$', '').strip()
if answer:
return answer
return "No final answer found"
enc = tiktoken.get_encoding("cl100k_base")
def count_tokens(prompt: str) -> int:
return len(enc.encode(prompt))
def evaluate_single_test_sample(args_tuple, data_processor) -> Tuple[Dict, str]:
"""
Evaluate a single test sample - task-agnostic implementation.
Args:
args_tuple: Tuple of (index, task_dict, generator, playbook, max_tokens, log_dir, use_json_mode)
data_processor: DataProcessor instance with answer_is_correct method
"""
(i, task_dict, generator, playbook, max_tokens, log_dir, use_json_mode) = args_tuple
try:
context = task_dict["context"]
question = task_dict["question"]
target = task_dict["target"]
gen_response, bullet_ids, call_info = generator.generate(
question=question,
playbook=playbook,
context=context,
reflection="(empty)",
use_json_mode=use_json_mode,
call_id=f"test_eval_{i}",
log_dir=log_dir
)
final_answer = extract_answer(gen_response)
is_correct = data_processor.answer_is_correct(final_answer, target)
return {
"index": i,
"final_answer": final_answer,
"target": target,
"is_correct": is_correct,
"success": True
}, None
except Exception as e:
return None, f"Error evaluating sample {i}: {type(e).__name__}: {str(e)}"
def evaluate_test_set(data_processor, generator, playbook, test_samples,
max_tokens=4096, log_dir=None, max_workers=20,
use_json_mode=False) -> Tuple[Dict, Dict]:
"""
Parallel evaluation of test set - task-agnostic implementation.
Args:
data_processor: DataProcessor instance with answer_is_correct and evaluate_accuracy methods
generator: Generator instance
playbook: Current playbook string
test_samples: List of test samples
max_tokens: Max tokens for generation
log_dir: Directory for logs
max_workers: Number of parallel workers
use_json_mode: Whether to use JSON mode
Returns:
Tuple of (results_dict, error_logs_dict)
"""
print(f"\n{'='*40}")
print(f"EVALUATING TEST SET - {len(test_samples)} samples, {max_workers} workers")
print(f"{'='*40}")
args_list = [
(i, sample, generator, playbook, max_tokens, log_dir, use_json_mode)
for i, sample in enumerate(test_samples)
]
results = {
"correct": 0, "total": 0, "no_answer": 0,
"answers": [], "targets": [], "errors": []
}
# Use a wrapper to pass data_processor to the evaluation function
def eval_wrapper(args_tuple):
return evaluate_single_test_sample(args_tuple, data_processor)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_args = {
executor.submit(eval_wrapper, args): args
for args in args_list
}
for i, future in enumerate(as_completed(future_to_args), 1):
result, error = future.result()
if error:
print(error)
continue
if result and result["success"]:
results["correct"] += (1 if result["is_correct"] else 0)
results["total"] += 1
results["answers"].append(result["final_answer"])
results["targets"].append(result["target"])
if not result["is_correct"]:
results["errors"].append({
"index": result["index"],
"prediction": result["final_answer"],
"ground_truth": result["target"]
})
if result["final_answer"] == "No final answer found":
results["no_answer"] += 1
if i % 50 == 0:
curr_acc = results["correct"] / results["total"] if results["total"] > 0 else 0
print(f"Progress: {i}/{len(args_list)}, Accuracy: {curr_acc:.3f}")
if results["answers"] and results["targets"]:
accuracy = data_processor.evaluate_accuracy(results["answers"], results["targets"])
final_results = {
"accuracy": accuracy,
"correct": results["correct"],
"total": results["total"],
"no_answer": results["no_answer"]
}
error_logs = {
"accuracy": accuracy,
"errors": results["errors"]
}
print(f"\n📊 Final Accuracy: {accuracy:.3f} ({results['correct']}/{results['total']})")
else:
results = {"accuracy": 0.0, "correct": 0, "total": 0}
error_logs = {}
print(f"\n📊 No valid results!")
return final_results, error_logs