Skip to content

Commit bc12d81

Browse files
committed
create unitxt files on the fly
Signed-off-by: Roni Friedman-Melamed <Roni.friedman-melamed@il.ibm.com>
1 parent eb2b20b commit bc12d81

File tree

3 files changed

+77
-27
lines changed

3 files changed

+77
-27
lines changed

src/instructlab/eval/mmlu.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def run(self, server_url: str | None = None) -> tuple:
153153

154154
return overall_score, individual_scores
155155

156-
def _run_mmlu(self, server_url: str | None = None) -> dict:
156+
def _run_mmlu(self, server_url: str | None = None, return_all_results:bool = False) -> dict:
157157
if server_url is not None:
158158
# Requires lm_eval >= 0.4.4
159159
model_args = f"base_url={server_url}/completions,model={self.model_path},tokenizer_backend=huggingface"
@@ -177,7 +177,10 @@ def _run_mmlu(self, server_url: str | None = None) -> dict:
177177
device=self.device,
178178
task_manager=tm,
179179
)
180-
results = mmlu_output["results"]
180+
if return_all_results:
181+
results = mmlu_output
182+
else:
183+
results = mmlu_output["results"]
181184
return results
182185

183186
# This method converts general errors from simple_evaluate

src/instructlab/eval/unitxt.py

Lines changed: 68 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
"""
66

77
# Standard
8-
import os
8+
import os, shutil
9+
import yaml
10+
from uuid import uuid4
11+
12+
# Third Party
13+
from lm_eval.tasks.unitxt import task
914

1015
# First Party
1116
from instructlab.eval.mmlu import MMLUBranchEvaluator
@@ -16,26 +21,43 @@
1621
logger = setup_logger(__name__)
1722

1823
class UnitxtEvaluator(MMLUBranchEvaluator):
24+
"""
25+
An evaluator class, running Unitxt evaluation
26+
27+
Attributes:
28+
model_path absolute path to or name of a huggingface model
29+
unitxt_recipe unitxt recipe (see unitxt.ai for more information)
30+
A Recipe holds a complete specification of a unitxt pipeline
31+
Example: card=cards.wnli,template=templates.classification.multi_class.relation.default,max_train_instances=5,loader_limit=20,num_demos=3,demos_pool_size=10
32+
33+
"""
1934
name = "unitxt"
2035
def __init__(
2136
self,
22-
model_path,
23-
tasks_dir: str,
24-
tasks: list[str],
25-
# unitxt_recipe: str,
37+
model_path,
38+
unitxt_recipe: str,
2639
):
27-
# tasks,tasks_dir = self.prepare_files(unitxt_recipe)
40+
tasks,tasks_dir = self.prepare_unitxt_files(unitxt_recipe)
2841
super().__init__(
2942
model_path = model_path,
3043
tasks_dir = tasks_dir,
3144
tasks = tasks,
3245
few_shots = 0
3346
)
3447

35-
def prepare_files(self, unitxt_recipe)->tuple:
36-
tasks = ''
37-
tasks_dir = ''
38-
return tasks,tasks_dir
48+
def prepare_unitxt_files(self, unitxt_recipe)->tuple:
49+
temp_task = str(uuid4())
50+
temp_tasks_dir = f'unitxt_temp_{temp_task}'
51+
yaml_file = os.path.join(temp_tasks_dir,f"{temp_task}.yaml")
52+
create_unitxt_pointer(temp_tasks_dir)
53+
create_unitxt_yaml(yaml_file=yaml_file, unitxt_recipe=unitxt_recipe, task_name=temp_task)
54+
return temp_task,temp_tasks_dir
55+
56+
def remove_temp_files(self):
57+
if self.tasks_dir.startswith('temp_'): #to avoid unintended deletion if this class is inherited
58+
shutil.rmtree(self.tasks_dir)
59+
else:
60+
logger.warning("unitxt tasks dir did not start with 'temp_' and therefor was not deleted")
3961

4062
def run(self,server_url: str | None = None) -> tuple:
4163
"""
@@ -47,19 +69,44 @@ def run(self,server_url: str | None = None) -> tuple:
4769
"""
4870
logger.debug(locals())
4971
os.environ["TOKENIZERS_PARALLELISM"] = "true"
50-
results = self._run_mmlu(server_url=server_url)
72+
results = self._run_mmlu(server_url=server_url, return_all_results=True)
5173
with open('my_tasks/output.txt', 'w') as f:
5274
print(results, file=f)
5375
taskname = self.tasks[0]
54-
global_scores = results[taskname]
76+
global_scores = results['results'][taskname]
5577
global_scores.pop('alias')
56-
instance_scores = None
57-
# instances = results['samples'][taskname]
58-
# instance_scores = {}
59-
# metrics = [metric.replace('metrics.','') for metric in instances[0]['doc']['metrics']]
60-
# for i,instance in enumerate(instances):
61-
# scores = {}
62-
# for metric in metrics:
63-
# scores[metric] = instance[metric][0]
64-
# instance_scores[i] = scores
78+
try:
79+
instances = results['samples'][taskname]
80+
instance_scores = {}
81+
metrics = [metric.replace('metrics.','') for metric in instances[0]['doc']['metrics']]
82+
for i,instance in enumerate(instances):
83+
scores = {}
84+
for metric in metrics:
85+
scores[metric] = instance[metric][0]
86+
instance_scores[i] = scores
87+
except Exception as e:
88+
logger.error("Error in extracting single instance scores")
89+
logger.error(e)
90+
logger.error(e.__traceback__)
91+
instance_scores = None
92+
self.remove_temp_files()
6593
return global_scores,instance_scores
94+
95+
96+
def create_unitxt_yaml(yaml_file,unitxt_recipe, task_name):
97+
data = {
98+
'task': f'{task_name}',
99+
'include': 'unitxt',
100+
'recipe': f'{unitxt_recipe}'
101+
}
102+
with open(yaml_file, 'w') as file:
103+
yaml.dump(data, file, default_flow_style=False)
104+
logger.info(f"task {task} unitxt recipe written to {yaml_file}")
105+
106+
def create_unitxt_pointer(tasks_dir):
107+
class_line = "class: !function " + task.__file__.replace("task.py", "task.Unitxt")
108+
output_file = os.path.join(tasks_dir,'unitxt')
109+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
110+
with open(output_file, 'w') as f:
111+
f.write(class_line)
112+
logger.info(f"Unitxt task pointer written to {output_file}")

tests/test_unitxt.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# First Party
2-
from instructlab.eval.unitxt import UnitxtEvaluator
2+
from instruclab.eval.unitxt import UnitxtEvaluator
33

44

55
def test_unitxt():
66
print("===> Executing 'test_unitxt'...")
77
try:
88
model_path = "instructlab/granite-7b-lab"
9-
tasks = ["my_task"]
9+
unitxt_recipe = "card=cards.wnli,template=templates.classification.multi_class.relation.default,max_train_instances=5,loader_limit=20,num_demos=3,demos_pool_size=10"
1010
unitxt = UnitxtEvaluator(
11-
model_path=model_path, tasks_dir='./my_tasks/', tasks=tasks
11+
model_path=model_path, unitxt_recipe=unitxt_recipe
1212
)
13-
overall_score, _ = unitxt.run()
13+
overall_score, single_scores = unitxt.run()
1414
print(overall_score)
1515
except Exception as exc:
1616
print(f"'test_unitxt_branch' failed: {exc}")

0 commit comments

Comments
 (0)