55"""
66
77# Standard
8- import os
8+ import os , shutil
9+ import yaml
10+ from uuid import uuid4
11+
12+ # Third Party
13+ from lm_eval .tasks .unitxt import task
914
1015# First Party
1116from instructlab .eval .mmlu import MMLUBranchEvaluator
1621logger = setup_logger (__name__ )
1722
1823class UnitxtEvaluator (MMLUBranchEvaluator ):
24+ """
25+ An evaluator class, running Unitxt evaluation
26+
27+ Attributes:
28+ model_path absolute path to or name of a huggingface model
29+ unitxt_recipe unitxt recipe (see unitxt.ai for more information)
30+ A Recipe holds a complete specification of a unitxt pipeline
31+ Example: card=cards.wnli,template=templates.classification.multi_class.relation.default,max_train_instances=5,loader_limit=20,num_demos=3,demos_pool_size=10
32+
33+ """
1934 name = "unitxt"
2035 def __init__ (
2136 self ,
22- model_path ,
23- tasks_dir : str ,
24- tasks : list [str ],
25- # unitxt_recipe: str,
37+ model_path ,
38+ unitxt_recipe : str ,
2639 ):
27- # tasks,tasks_dir = self.prepare_files (unitxt_recipe)
40+ tasks ,tasks_dir = self .prepare_unitxt_files (unitxt_recipe )
2841 super ().__init__ (
2942 model_path = model_path ,
3043 tasks_dir = tasks_dir ,
3144 tasks = tasks ,
3245 few_shots = 0
3346 )
3447
35- def prepare_files (self , unitxt_recipe )-> tuple :
36- tasks = ''
37- tasks_dir = ''
38- return tasks ,tasks_dir
48+ def prepare_unitxt_files (self , unitxt_recipe )-> tuple :
49+ temp_task = str (uuid4 ())
50+ temp_tasks_dir = f'unitxt_temp_{ temp_task } '
51+ yaml_file = os .path .join (temp_tasks_dir ,f"{ temp_task } .yaml" )
52+ create_unitxt_pointer (temp_tasks_dir )
53+ create_unitxt_yaml (yaml_file = yaml_file , unitxt_recipe = unitxt_recipe , task_name = temp_task )
54+ return temp_task ,temp_tasks_dir
55+
56+ def remove_temp_files (self ):
57+ if self .tasks_dir .startswith ('temp_' ): #to avoid unintended deletion if this class is inherited
58+ shutil .rmtree (self .tasks_dir )
59+ else :
60+ logger .warning ("unitxt tasks dir did not start with 'temp_' and therefor was not deleted" )
3961
4062 def run (self ,server_url : str | None = None ) -> tuple :
4163 """
@@ -47,19 +69,44 @@ def run(self,server_url: str | None = None) -> tuple:
4769 """
4870 logger .debug (locals ())
4971 os .environ ["TOKENIZERS_PARALLELISM" ] = "true"
50- results = self ._run_mmlu (server_url = server_url )
72+ results = self ._run_mmlu (server_url = server_url , return_all_results = True )
5173 with open ('my_tasks/output.txt' , 'w' ) as f :
5274 print (results , file = f )
5375 taskname = self .tasks [0 ]
54- global_scores = results [taskname ]
76+ global_scores = results ['results' ][ taskname ]
5577 global_scores .pop ('alias' )
56- instance_scores = None
57- # instances = results['samples'][taskname]
58- # instance_scores = {}
59- # metrics = [metric.replace('metrics.','') for metric in instances[0]['doc']['metrics']]
60- # for i,instance in enumerate(instances):
61- # scores = {}
62- # for metric in metrics:
63- # scores[metric] = instance[metric][0]
64- # instance_scores[i] = scores
78+ try :
79+ instances = results ['samples' ][taskname ]
80+ instance_scores = {}
81+ metrics = [metric .replace ('metrics.' ,'' ) for metric in instances [0 ]['doc' ]['metrics' ]]
82+ for i ,instance in enumerate (instances ):
83+ scores = {}
84+ for metric in metrics :
85+ scores [metric ] = instance [metric ][0 ]
86+ instance_scores [i ] = scores
87+ except Exception as e :
88+ logger .error ("Error in extracting single instance scores" )
89+ logger .error (e )
90+ logger .error (e .__traceback__ )
91+ instance_scores = None
92+ self .remove_temp_files ()
6593 return global_scores ,instance_scores
94+
95+
96+ def create_unitxt_yaml (yaml_file ,unitxt_recipe , task_name ):
97+ data = {
98+ 'task' : f'{ task_name } ' ,
99+ 'include' : 'unitxt' ,
100+ 'recipe' : f'{ unitxt_recipe } '
101+ }
102+ with open (yaml_file , 'w' ) as file :
103+ yaml .dump (data , file , default_flow_style = False )
104+ logger .info (f"task { task } unitxt recipe written to { yaml_file } " )
105+
106+ def create_unitxt_pointer (tasks_dir ):
107+ class_line = "class: !function " + task .__file__ .replace ("task.py" , "task.Unitxt" )
108+ output_file = os .path .join (tasks_dir ,'unitxt' )
109+ os .makedirs (os .path .dirname (output_file ), exist_ok = True )
110+ with open (output_file , 'w' ) as f :
111+ f .write (class_line )
112+ logger .info (f"Unitxt task pointer written to { output_file } " )
0 commit comments