Skip to content

Commit 2288b83

Browse files
author
sfluegel
committed
minor changes / fixes for prediction generation, semantic loss
1 parent 392ea40 commit 2288b83

File tree

2 files changed

+52
-20
lines changed

2 files changed

+52
-20
lines changed

chebai/result/analyse_sem.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def load_preds_labels_from_wandb(
9797
buffer_dir=buffer_dir,
9898
filename=f"{kind}.pt",
9999
skip_existing_preds=True,
100+
batch_size=1,
100101
)
101102
preds, labels = load_results_from_buffer(buffer_dir, device=DEVICE)
102103
del model
@@ -394,11 +395,13 @@ def run_all(
394395
"_semloss_eval",
395396
f"semloss_results_pc-dis-200k_{timestamp}{'_violations_removed' if remove_violations else ''}.csv",
396397
)
397-
label_names = get_label_names(ChEBIOver100(chebi_version=chebi_version))
398-
chebi_graph = get_chebi_graph(
399-
ChEBIOver100(chebi_version=chebi_version), label_names
400-
)
401-
disjoint_groups = get_disjoint_groups()
398+
399+
if remove_violations:
400+
label_names = get_label_names(ChEBIOver100(chebi_version=chebi_version))
401+
chebi_graph = get_chebi_graph(
402+
ChEBIOver100(chebi_version=chebi_version), label_names
403+
)
404+
disjoint_groups = get_disjoint_groups()
402405

403406
api = wandb.Api()
404407
for run_id in run_ids:
@@ -472,7 +475,7 @@ def run_all(
472475
)
473476
except Exception as e:
474477
print(f"Failed for run {run_id}: {e}")
475-
print(traceback.format_exc())
478+
# print(traceback.format_exc())
476479

477480
if nonwandb_runs:
478481
for run_name, epoch in nonwandb_runs:

chebai/result/utils.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,20 @@ def get_checkpoint_from_wandb(
3636
return None
3737

3838

39+
def _run_batch(batch, model, collate):
40+
collated = collate(batch)
41+
collated.x = collated.to_x(model.device)
42+
if collated.y is not None:
43+
collated.y = collated.to_y(model.device)
44+
processable_data = model._process_batch(collated, 0)
45+
del processable_data["loss_kwargs"]
46+
model_output = model(processable_data, **processable_data["model_kwargs"])
47+
preds, labels = model._get_prediction_and_labels(
48+
processable_data, processable_data["labels"], model_output
49+
)
50+
return preds, labels
51+
52+
3953
def evaluate_model(
4054
model: ChebaiBaseNet,
4155
data_module: XYBaseDataModule,
@@ -57,7 +71,7 @@ def evaluate_model(
5771
if buffer_dir is not None:
5872
os.makedirs(buffer_dir, exist_ok=True)
5973
save_ind = 0
60-
save_batch_size = 4
74+
save_batch_size = 128
6175
n_saved = 1
6276

6377
print(f"")
@@ -66,32 +80,24 @@ def evaluate_model(
6680
skip_existing_preds
6781
and os.path.isfile(os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"))
6882
):
69-
collated = collate(data_list[i : min(i + batch_size, len(data_list) - 1)])
70-
collated.x = collated.to_x(model.device)
71-
if collated.y is not None:
72-
collated.y = collated.to_y(model.device)
73-
processable_data = model._process_batch(collated, 0)
74-
del processable_data["loss_kwargs"]
75-
model_output = model(processable_data, **processable_data["model_kwargs"])
76-
preds, labels = model._get_prediction_and_labels(
77-
processable_data, processable_data["labels"], model_output
78-
)
83+
preds, labels = _run_batch(data_list[i : i + batch_size], model, collate)
7984
preds_list.append(preds)
8085
labels_list.append(labels)
86+
8187
if buffer_dir is not None:
82-
if n_saved >= save_batch_size:
88+
if n_saved * batch_size >= save_batch_size:
8389
torch.save(
8490
torch.cat(preds_list),
8591
os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"),
8692
)
87-
if collated.y is not None:
93+
if labels_list[0] is not None:
8894
torch.save(
8995
torch.cat(labels_list),
9096
os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"),
9197
)
9298
preds_list = []
9399
labels_list = []
94-
if n_saved >= save_batch_size:
100+
if n_saved * batch_size >= save_batch_size:
95101
save_ind += 1
96102
n_saved = 0
97103
n_saved += 1
@@ -103,6 +109,16 @@ def evaluate_model(
103109

104110
return test_preds, test_labels
105111
return test_preds, None
112+
else:
113+
torch.save(
114+
torch.cat(preds_list),
115+
os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"),
116+
)
117+
if labels_list[0] is not None:
118+
torch.save(
119+
torch.cat(labels_list),
120+
os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"),
121+
)
106122

107123

108124
def load_results_from_buffer(buffer_dir, device):
@@ -144,3 +160,16 @@ def load_results_from_buffer(buffer_dir, device):
144160
test_labels = None
145161

146162
return test_preds, test_labels
163+
164+
165+
if __name__ == "__main__":
166+
import sys
167+
168+
buffer_dir = os.path.join("results_buffer", sys.argv[1], "ChEBIOver100_train")
169+
buffer_dir_concat = os.path.join(
170+
"results_buffer", "concatenated", sys.argv[1], "ChEBIOver100_train"
171+
)
172+
os.makedirs(buffer_dir_concat, exist_ok=True)
173+
preds, labels = load_results_from_buffer(buffer_dir, "cpu")
174+
torch.save(preds, os.path.join(buffer_dir_concat, f"preds000.pt"))
175+
torch.save(labels, os.path.join(buffer_dir_concat, f"labels000.pt"))

0 commit comments

Comments
 (0)