11import pandas as pd
22import sys
3-
3+ import traceback
44from datetime import datetime
55from chebai .loss .semantic import DisjointLoss
66from chebai .preprocessing .datasets .chebi import ChEBIOver100
@@ -59,15 +59,15 @@ def _sort_results_by_label(n_labels, results, filter):
5959def get_best_epoch (run ):
6060 files = run .files ()
6161 best_ep = None
62- best_val_loss = 0
62+ best_micro_f1 = 0
6363 for file in files :
6464 if file .name .startswith ("checkpoints/best_epoch" ):
65- val_loss = float (file .name .split ("=" )[2 ]. split ( "_" )[ 0 ])
66- if val_loss < best_val_loss or best_ep is None :
65+ micro_f1 = float (file .name .split ("=" )[- 1 ][: - 5 ])
66+ if micro_f1 > best_micro_f1 or best_ep is None :
6767 best_ep = int (file .name .split ("=" )[1 ].split ("_" )[0 ])
68- best_val_loss = val_loss
68+ best_micro_f1 = micro_f1
6969 if best_ep is None :
70- raise Exception ("Could not find any 'best' checkpoint" )
70+ raise Exception (f "Could not find any 'best' checkpoint for run { run . name } " )
7171 else :
7272 print (f"Best epoch for run { run .name } : { best_ep } " )
7373 return best_ep
@@ -88,7 +88,42 @@ def load_preds_labels_from_wandb(
8888 f"{ data_module .__class__ .__name__ } _{ kind } " ,
8989 )
9090
91- model = get_checkpoint_from_wandb (epoch , run )
91+ model = get_checkpoint_from_wandb (epoch , run , map_device_to = "cuda:0" )
92+ print (f"Calculating predictions..." )
93+ evaluate_model (
94+ model ,
95+ data_module ,
96+ buffer_dir = buffer_dir ,
97+ filename = f"{ kind } .pt" ,
98+ skip_existing_preds = True ,
99+ )
100+ preds , labels = load_results_from_buffer (buffer_dir , device = DEVICE )
101+ del model
102+ gc .collect ()
103+
104+ return preds , labels
105+
106+
107+ def load_preds_labels_from_nonwandb (
108+ name , epoch , chebi_version , test_on_data_cls = ChEBIOver100 , kind = "test"
109+ ):
110+ data_module = test_on_data_cls (chebi_version = chebi_version )
111+
112+ buffer_dir = os .path .join (
113+ "results_buffer" ,
114+ f"{ name } _ep{ epoch } " ,
115+ f"{ data_module .__class__ .__name__ } _{ kind } " ,
116+ )
117+ ckpt_path = None
118+ for file in os .listdir (os .path .join ("logs" , "downloaded_ckpts" , name )):
119+ if file .startswith (f"best_epoch={ epoch } " ):
120+ ckpt_path = os .path .join (
121+ os .path .join ("logs" , "downloaded_ckpts" , name , file )
122+ )
123+ assert (
124+ ckpt_path is not None
125+ ), f"Could not find ckpt for epoch { epoch } in directory { os .path .join ('logs' , 'downloaded_ckpts' , name )} "
126+ model = Electra .load_from_checkpoint (ckpt_path , map_location = "cuda:0" , strict = False )
92127 print (f"Calculating predictions..." )
93128 evaluate_model (
94129 model ,
@@ -130,7 +165,6 @@ def analyse_run(
130165 (dl .implication_filter_l , dl .implication_filter_r , "impl" ),
131166 (dl .disjoint_filter_l , dl .disjoint_filter_r , "disj" ),
132167 ]:
133- print (f"Calculating on { filter_type } loss" )
134168 # prepare predictions
135169 n_loss_terms = dl_filter_l .shape [0 ]
136170 preds_exp = preds .unsqueeze (2 ).expand ((- 1 , - 1 , n_loss_terms )).swapaxes (1 , 2 )
@@ -218,34 +252,135 @@ def analyse_run(
218252 gc .collect ()
219253
220254
221- def run_all (run_ids , datasets = None , chebi_version = 231 ):
255+ def run_all (
256+ run_ids ,
257+ datasets = None ,
258+ chebi_version = 231 ,
259+ skip_analyse = False ,
260+ skip_preds = False ,
261+ nonwandb_runs = None ,
262+ ):
222263 # evaluate a list of runs on Hazardous and ChEBIOver100 datasets
223264 if datasets is None :
224265 datasets = [(Hazardous , "all" ), (ChEBIOver100 , "test" )]
225266 timestamp = datetime .now ().strftime ("%y%m%d-%H%M" )
226267 results_path = os .path .join (
227268 "_semloss_eval" , f"semloss_results_pc-dis-200k_{ timestamp } .csv"
228269 )
229-
270+ api = wandb . Api ()
230271 for run_id in run_ids :
272+ try :
273+ run = api .run (f"chebai/chebai/{ run_id } " )
274+ epoch = get_best_epoch (run )
275+ for test_on , kind in datasets :
276+ df = {
277+ "run-id" : run_id ,
278+ "epoch" : int (epoch ),
279+ "kind" : kind ,
280+ "data_module" : test_on .__name__ ,
281+ "chebi_version" : chebi_version ,
282+ }
283+ if not skip_preds :
284+ preds , labels = load_preds_labels_from_wandb (
285+ run , epoch , chebi_version , test_on , kind
286+ )
287+ else :
288+ buffer_dir = os .path .join (
289+ "results_buffer" ,
290+ f"{ run .name } _ep{ epoch } " ,
291+ f"{ test_on .__name__ } _{ kind } " ,
292+ )
293+ preds , labels = load_results_from_buffer (buffer_dir , device = DEVICE )
294+ if not skip_analyse :
295+ print (
296+ f"Calculating metrics for run { run .name } on { test_on .__name__ } ({ kind } )"
297+ )
298+ analyse_run (
299+ preds ,
300+ labels ,
301+ df_hyperparams = df ,
302+ chebi_version = chebi_version ,
303+ results_path = results_path ,
304+ )
305+ except Exception as e :
306+ print (f"Failed for run { run_id } : { e } " )
307+ print (traceback .format_exc ())
308+
309+ if nonwandb_runs :
310+ for run_name , epoch in nonwandb_runs :
311+ try :
312+ for test_on , kind in datasets :
313+ df = {
314+ "run-id" : run_name ,
315+ "epoch" : int (epoch ),
316+ "kind" : kind ,
317+ "data_module" : test_on .__name__ ,
318+ "chebi_version" : chebi_version ,
319+ }
320+ if not skip_preds :
321+ preds , labels = load_preds_labels_from_nonwandb (
322+ run_name , epoch , chebi_version , test_on , kind
323+ )
324+ else :
325+ buffer_dir = os .path .join (
326+ "results_buffer" ,
327+ f"{ run_name } _ep{ epoch } " ,
328+ f"{ test_on .__name__ } _{ kind } " ,
329+ )
330+ preds , labels = load_results_from_buffer (
331+ buffer_dir , device = DEVICE
332+ )
333+ if not skip_analyse :
334+ print (
335+ f"Calculating metrics for run { run_name } on { test_on .__name__ } ({ kind } )"
336+ )
337+ analyse_run (
338+ preds ,
339+ labels ,
340+ df_hyperparams = df ,
341+ chebi_version = chebi_version ,
342+ results_path = results_path ,
343+ )
344+ except Exception as e :
345+ print (f"Failed for run { run_name } : { e } " )
346+ print (traceback .format_exc ())
347+
348+
349+ def run_semloss_eval (mode = "eval" ):
350+ non_wandb_runs = (
351+ []
352+ ) # ("chebi100_semprodk2_weighted_v231_pc_200k_dis_24042-2000", 195)]
353+ if mode == "preds" :
231354 api = wandb .Api ()
232- run = api .run (f"chebai/chebai/{ run_id } " )
233- epoch = get_best_epoch (run )
234- for test_on , kind in datasets :
235- df = {
236- "run-id" : run_id ,
237- "epoch" : int (epoch ),
238- "kind" : kind ,
239- "data_module" : test_on .__class__ .__name__ ,
240- "chebi_version" : chebi_version ,
241- }
242- preds , labels = load_preds_labels_from_wandb (
243- run , epoch , chebi_version , test_on , kind
244- )
245- analyse_run (
246- preds ,
247- labels ,
248- df_hyperparams = df ,
249- chebi_version = chebi_version ,
250- results_path = results_path ,
251- )
355+ runs = api .runs ("chebai/chebai" , filters = {"tags" : "eval_semloss_paper" })
356+ print (f"Found { len (runs )} tagged wandb runs" )
357+ ids = [run .id for run in runs ]
358+ run_all (ids , skip_analyse = True , nonwandb_runs = non_wandb_runs )
359+
360+ if mode == "eval" :
361+ new_14 = [
362+ "e4ba0ff8" ,
363+ "5ko8knb4" ,
364+ "hk8555ff" ,
365+ "r50ioujs" ,
366+ "w0h3zr5s" ,
367+ "e0lxw8py" ,
368+ "0c0s48nh" ,
369+ "lfg384bp" ,
370+ "75o8bc3h" ,
371+ "lig23cmg" ,
372+ "qeghvubh" ,
373+ "uke62a8m" ,
374+ "061fd85t" ,
375+ "tk15yznc" ,
376+ ]
377+ baseline = ["i4wtz1k4" , "zd020wkv" , "rc1q3t49" ]
378+ ids = baseline
379+ run_all (ids , skip_preds = True , nonwandb_runs = non_wandb_runs )
380+
381+
382+ if __name__ == "__main__" :
383+ if len (sys .argv ) > 1 :
384+ run_semloss_eval (sys .argv [1 ])
385+ else :
386+ run_semloss_eval ()
0 commit comments