@@ -39,23 +39,29 @@ def __init__(
3939 model_path ,
4040 unitxt_recipe : str ,
4141 ):
42- task ,tasks_dir = self .prepare_unitxt_files (unitxt_recipe )
42+ task = self .assign_task_name ()
43+ tasks_dir = self .assign_tasks_dir (task )
4344 super ().__init__ (
4445 model_path = model_path ,
4546 tasks_dir = tasks_dir ,
4647 tasks = [task ],
4748 few_shots = 0
4849 )
50+ self .unitxt_recipe = unitxt_recipe
4951
50- def prepare_unitxt_files (self , unitxt_recipe )-> tuple :
51- temp_task = str (uuid4 ())
52- temp_tasks_dir = f'{ TEMP_DIR_PREFIX } _{ temp_task } '
53- yaml_file = os .path .join (temp_tasks_dir ,f"{ temp_task } .yaml" )
54- create_unitxt_pointer (temp_tasks_dir )
55- create_unitxt_yaml (yaml_file = yaml_file , unitxt_recipe = unitxt_recipe , task_name = temp_task )
56- return temp_task ,temp_tasks_dir
52+ def assign_tasks_dir (self , task ):
53+ return f'{ TEMP_DIR_PREFIX } _{ task } '
5754
58- def remove_temp_files (self ):
55+ def assign_task_name (self ):
56+ return str (uuid4 ())
57+
58+ def prepare_unitxt_files (self )-> tuple :
59+ task = self .tasks [0 ]
60+ yaml_file = os .path .join (self .tasks_dir ,f"{ task } .yaml" )
61+ create_unitxt_pointer (self .tasks_dir )
62+ create_unitxt_yaml (yaml_file = yaml_file , unitxt_recipe = self .unitxt_recipe , task_name = task )
63+
64+ def remove_unitxt_files (self ):
5965 if self .tasks_dir .startswith (TEMP_DIR_PREFIX ): #to avoid unintended deletion if this class is inherited
6066 shutil .rmtree (self .tasks_dir )
6167 else :
@@ -69,6 +75,7 @@ def run(self,server_url: str | None = None) -> tuple:
6975 overall_scores Average scores for the task group
7076 individual_scores Individual scores for each task in the task group
7177 """
78+ self .prepare_unitxt_files ()
7279 logger .debug (locals ())
7380 os .environ ["TOKENIZERS_PARALLELISM" ] = "true"
7481 results = self ._run_mmlu (server_url = server_url , return_all_results = True )
@@ -89,7 +96,7 @@ def run(self,server_url: str | None = None) -> tuple:
8996 logger .error (e )
9097 logger .error (e .__traceback__ )
9198 instance_scores = None
92- self .remove_temp_files ()
99+ self .remove_unitxt_files ()
93100 return global_scores ,instance_scores
94101
95102
@@ -101,12 +108,12 @@ def create_unitxt_yaml(yaml_file,unitxt_recipe, task_name):
101108 }
102109 with open (yaml_file , 'w' ) as file :
103110 yaml .dump (data , file , default_flow_style = False )
104- logger .info (f"task { task } unitxt recipe written to { yaml_file } " )
111+ logger .debug (f"task { task } unitxt recipe written to { yaml_file } " )
105112
106113def create_unitxt_pointer (tasks_dir ):
107114 class_line = "class: !function " + task .__file__ .replace ("task.py" , "task.Unitxt" )
108115 output_file = os .path .join (tasks_dir ,'unitxt' )
109116 os .makedirs (os .path .dirname (output_file ), exist_ok = True )
110117 with open (output_file , 'w' ) as f :
111118 f .write (class_line )
112- logger .info (f"Unitxt task pointer written to { output_file } " )
119+ logger .debug (f"Unitxt task pointer written to { output_file } " )
0 commit comments