@@ -36,6 +36,20 @@ def get_checkpoint_from_wandb(
3636 return None
3737
3838
39+ def _run_batch (batch , model , collate ):
40+ collated = collate (batch )
41+ collated .x = collated .to_x (model .device )
42+ if collated .y is not None :
43+ collated .y = collated .to_y (model .device )
44+ processable_data = model ._process_batch (collated , 0 )
45+ del processable_data ["loss_kwargs" ]
46+ model_output = model (processable_data , ** processable_data ["model_kwargs" ])
47+ preds , labels = model ._get_prediction_and_labels (
48+ processable_data , processable_data ["labels" ], model_output
49+ )
50+ return preds , labels
51+
52+
3953def evaluate_model (
4054 model : ChebaiBaseNet ,
4155 data_module : XYBaseDataModule ,
@@ -57,7 +71,7 @@ def evaluate_model(
5771 if buffer_dir is not None :
5872 os .makedirs (buffer_dir , exist_ok = True )
5973 save_ind = 0
60- save_batch_size = 4
74+ save_batch_size = 128
6175 n_saved = 1
6276
6377 print (f"" )
@@ -66,32 +80,24 @@ def evaluate_model(
6680 skip_existing_preds
6781 and os .path .isfile (os .path .join (buffer_dir , f"preds{ save_ind :03d} .pt" ))
6882 ):
69- collated = collate (data_list [i : min (i + batch_size , len (data_list ) - 1 )])
70- collated .x = collated .to_x (model .device )
71- if collated .y is not None :
72- collated .y = collated .to_y (model .device )
73- processable_data = model ._process_batch (collated , 0 )
74- del processable_data ["loss_kwargs" ]
75- model_output = model (processable_data , ** processable_data ["model_kwargs" ])
76- preds , labels = model ._get_prediction_and_labels (
77- processable_data , processable_data ["labels" ], model_output
78- )
83+ preds , labels = _run_batch (data_list [i : i + batch_size ], model , collate )
7984 preds_list .append (preds )
8085 labels_list .append (labels )
86+
8187 if buffer_dir is not None :
82- if n_saved >= save_batch_size :
88+ if n_saved * batch_size >= save_batch_size :
8389 torch .save (
8490 torch .cat (preds_list ),
8591 os .path .join (buffer_dir , f"preds{ save_ind :03d} .pt" ),
8692 )
87- if collated . y is not None :
93+ if labels_list [ 0 ] is not None :
8894 torch .save (
8995 torch .cat (labels_list ),
9096 os .path .join (buffer_dir , f"labels{ save_ind :03d} .pt" ),
9197 )
9298 preds_list = []
9399 labels_list = []
94- if n_saved >= save_batch_size :
100+ if n_saved * batch_size >= save_batch_size :
95101 save_ind += 1
96102 n_saved = 0
97103 n_saved += 1
@@ -103,6 +109,16 @@ def evaluate_model(
103109
104110 return test_preds , test_labels
105111 return test_preds , None
112+ else :
113+ torch .save (
114+ torch .cat (preds_list ),
115+ os .path .join (buffer_dir , f"preds{ save_ind :03d} .pt" ),
116+ )
117+ if labels_list [0 ] is not None :
118+ torch .save (
119+ torch .cat (labels_list ),
120+ os .path .join (buffer_dir , f"labels{ save_ind :03d} .pt" ),
121+ )
106122
107123
108124def load_results_from_buffer (buffer_dir , device ):
@@ -144,3 +160,16 @@ def load_results_from_buffer(buffer_dir, device):
144160 test_labels = None
145161
146162 return test_preds , test_labels
163+
164+
165+ if __name__ == "__main__" :
166+ import sys
167+
168+ buffer_dir = os .path .join ("results_buffer" , sys .argv [1 ], "ChEBIOver100_train" )
169+ buffer_dir_concat = os .path .join (
170+ "results_buffer" , "concatenated" , sys .argv [1 ], "ChEBIOver100_train"
171+ )
172+ os .makedirs (buffer_dir_concat , exist_ok = True )
173+ preds , labels = load_results_from_buffer (buffer_dir , "cpu" )
174+ torch .save (preds , os .path .join (buffer_dir_concat , f"preds000.pt" ))
175+ torch .save (labels , os .path .join (buffer_dir_concat , f"labels000.pt" ))
0 commit comments