Skip to content

Commit 1f6a5ff

Browse files
author
sfluegel
committed
improve semantic loss eval
1 parent 3dd1d39 commit 1f6a5ff

File tree

3 files changed

+178
-39
lines changed

3 files changed

+178
-39
lines changed

chebai/result/analyse_sem.py

Lines changed: 165 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pandas as pd
22
import sys
3-
3+
import traceback
44
from datetime import datetime
55
from chebai.loss.semantic import DisjointLoss
66
from chebai.preprocessing.datasets.chebi import ChEBIOver100
@@ -59,15 +59,15 @@ def _sort_results_by_label(n_labels, results, filter):
5959
def get_best_epoch(run):
6060
files = run.files()
6161
best_ep = None
62-
best_val_loss = 0
62+
best_micro_f1 = 0
6363
for file in files:
6464
if file.name.startswith("checkpoints/best_epoch"):
65-
val_loss = float(file.name.split("=")[2].split("_")[0])
66-
if val_loss < best_val_loss or best_ep is None:
65+
micro_f1 = float(file.name.split("=")[-1][:-5])
66+
if micro_f1 > best_micro_f1 or best_ep is None:
6767
best_ep = int(file.name.split("=")[1].split("_")[0])
68-
best_val_loss = val_loss
68+
best_micro_f1 = micro_f1
6969
if best_ep is None:
70-
raise Exception("Could not find any 'best' checkpoint")
70+
raise Exception(f"Could not find any 'best' checkpoint for run {run.name}")
7171
else:
7272
print(f"Best epoch for run {run.name}: {best_ep}")
7373
return best_ep
@@ -88,7 +88,42 @@ def load_preds_labels_from_wandb(
8888
f"{data_module.__class__.__name__}_{kind}",
8989
)
9090

91-
model = get_checkpoint_from_wandb(epoch, run)
91+
model = get_checkpoint_from_wandb(epoch, run, map_device_to="cuda:0")
92+
print(f"Calculating predictions...")
93+
evaluate_model(
94+
model,
95+
data_module,
96+
buffer_dir=buffer_dir,
97+
filename=f"{kind}.pt",
98+
skip_existing_preds=True,
99+
)
100+
preds, labels = load_results_from_buffer(buffer_dir, device=DEVICE)
101+
del model
102+
gc.collect()
103+
104+
return preds, labels
105+
106+
107+
def load_preds_labels_from_nonwandb(
108+
name, epoch, chebi_version, test_on_data_cls=ChEBIOver100, kind="test"
109+
):
110+
data_module = test_on_data_cls(chebi_version=chebi_version)
111+
112+
buffer_dir = os.path.join(
113+
"results_buffer",
114+
f"{name}_ep{epoch}",
115+
f"{data_module.__class__.__name__}_{kind}",
116+
)
117+
ckpt_path = None
118+
for file in os.listdir(os.path.join("logs", "downloaded_ckpts", name)):
119+
if file.startswith(f"best_epoch={epoch}"):
120+
ckpt_path = os.path.join(
121+
os.path.join("logs", "downloaded_ckpts", name, file)
122+
)
123+
assert (
124+
ckpt_path is not None
125+
), f"Could not find ckpt for epoch {epoch} in directory {os.path.join('logs', 'downloaded_ckpts', name)}"
126+
model = Electra.load_from_checkpoint(ckpt_path, map_location="cuda:0", strict=False)
92127
print(f"Calculating predictions...")
93128
evaluate_model(
94129
model,
@@ -130,7 +165,6 @@ def analyse_run(
130165
(dl.implication_filter_l, dl.implication_filter_r, "impl"),
131166
(dl.disjoint_filter_l, dl.disjoint_filter_r, "disj"),
132167
]:
133-
print(f"Calculating on {filter_type} loss")
134168
# prepare predictions
135169
n_loss_terms = dl_filter_l.shape[0]
136170
preds_exp = preds.unsqueeze(2).expand((-1, -1, n_loss_terms)).swapaxes(1, 2)
@@ -218,34 +252,135 @@ def analyse_run(
218252
gc.collect()
219253

220254

221-
def run_all(run_ids, datasets=None, chebi_version=231):
255+
def run_all(
256+
run_ids,
257+
datasets=None,
258+
chebi_version=231,
259+
skip_analyse=False,
260+
skip_preds=False,
261+
nonwandb_runs=None,
262+
):
222263
# evaluate a list of runs on Hazardous and ChEBIOver100 datasets
223264
if datasets is None:
224265
datasets = [(Hazardous, "all"), (ChEBIOver100, "test")]
225266
timestamp = datetime.now().strftime("%y%m%d-%H%M")
226267
results_path = os.path.join(
227268
"_semloss_eval", f"semloss_results_pc-dis-200k_{timestamp}.csv"
228269
)
229-
270+
api = wandb.Api()
230271
for run_id in run_ids:
272+
try:
273+
run = api.run(f"chebai/chebai/{run_id}")
274+
epoch = get_best_epoch(run)
275+
for test_on, kind in datasets:
276+
df = {
277+
"run-id": run_id,
278+
"epoch": int(epoch),
279+
"kind": kind,
280+
"data_module": test_on.__name__,
281+
"chebi_version": chebi_version,
282+
}
283+
if not skip_preds:
284+
preds, labels = load_preds_labels_from_wandb(
285+
run, epoch, chebi_version, test_on, kind
286+
)
287+
else:
288+
buffer_dir = os.path.join(
289+
"results_buffer",
290+
f"{run.name}_ep{epoch}",
291+
f"{test_on.__name__}_{kind}",
292+
)
293+
preds, labels = load_results_from_buffer(buffer_dir, device=DEVICE)
294+
if not skip_analyse:
295+
print(
296+
f"Calculating metrics for run {run.name} on {test_on.__name__} ({kind})"
297+
)
298+
analyse_run(
299+
preds,
300+
labels,
301+
df_hyperparams=df,
302+
chebi_version=chebi_version,
303+
results_path=results_path,
304+
)
305+
except Exception as e:
306+
print(f"Failed for run {run_id}: {e}")
307+
print(traceback.format_exc())
308+
309+
if nonwandb_runs:
310+
for run_name, epoch in nonwandb_runs:
311+
try:
312+
for test_on, kind in datasets:
313+
df = {
314+
"run-id": run_name,
315+
"epoch": int(epoch),
316+
"kind": kind,
317+
"data_module": test_on.__name__,
318+
"chebi_version": chebi_version,
319+
}
320+
if not skip_preds:
321+
preds, labels = load_preds_labels_from_nonwandb(
322+
run_name, epoch, chebi_version, test_on, kind
323+
)
324+
else:
325+
buffer_dir = os.path.join(
326+
"results_buffer",
327+
f"{run_name}_ep{epoch}",
328+
f"{test_on.__name__}_{kind}",
329+
)
330+
preds, labels = load_results_from_buffer(
331+
buffer_dir, device=DEVICE
332+
)
333+
if not skip_analyse:
334+
print(
335+
f"Calculating metrics for run {run_name} on {test_on.__name__} ({kind})"
336+
)
337+
analyse_run(
338+
preds,
339+
labels,
340+
df_hyperparams=df,
341+
chebi_version=chebi_version,
342+
results_path=results_path,
343+
)
344+
except Exception as e:
345+
print(f"Failed for run {run_name}: {e}")
346+
print(traceback.format_exc())
347+
348+
349+
def run_semloss_eval(mode="eval"):
350+
non_wandb_runs = (
351+
[]
352+
) # ("chebi100_semprodk2_weighted_v231_pc_200k_dis_24042-2000", 195)]
353+
if mode == "preds":
231354
api = wandb.Api()
232-
run = api.run(f"chebai/chebai/{run_id}")
233-
epoch = get_best_epoch(run)
234-
for test_on, kind in datasets:
235-
df = {
236-
"run-id": run_id,
237-
"epoch": int(epoch),
238-
"kind": kind,
239-
"data_module": test_on.__class__.__name__,
240-
"chebi_version": chebi_version,
241-
}
242-
preds, labels = load_preds_labels_from_wandb(
243-
run, epoch, chebi_version, test_on, kind
244-
)
245-
analyse_run(
246-
preds,
247-
labels,
248-
df_hyperparams=df,
249-
chebi_version=chebi_version,
250-
results_path=results_path,
251-
)
355+
runs = api.runs("chebai/chebai", filters={"tags": "eval_semloss_paper"})
356+
print(f"Found {len(runs)} tagged wandb runs")
357+
ids = [run.id for run in runs]
358+
run_all(ids, skip_analyse=True, nonwandb_runs=non_wandb_runs)
359+
360+
if mode == "eval":
361+
new_14 = [
362+
"e4ba0ff8",
363+
"5ko8knb4",
364+
"hk8555ff",
365+
"r50ioujs",
366+
"w0h3zr5s",
367+
"e0lxw8py",
368+
"0c0s48nh",
369+
"lfg384bp",
370+
"75o8bc3h",
371+
"lig23cmg",
372+
"qeghvubh",
373+
"uke62a8m",
374+
"061fd85t",
375+
"tk15yznc",
376+
]
377+
baseline = ["i4wtz1k4", "zd020wkv", "rc1q3t49"]
378+
ids = baseline
379+
run_all(ids, skip_preds=True, nonwandb_runs=non_wandb_runs)
380+
381+
382+
if __name__ == "__main__":
383+
if len(sys.argv) > 1:
384+
run_semloss_eval(sys.argv[1])
385+
else:
386+
run_semloss_eval()

chebai/result/utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99

1010

1111
def get_checkpoint_from_wandb(
12-
epoch, run, root=os.path.join("logs", "downloaded_ckpts"), model_class=None
12+
epoch,
13+
run,
14+
root=os.path.join("logs", "downloaded_ckpts"),
15+
model_class=None,
16+
map_device_to=None,
1317
):
1418
"""Gets wandb checkpoint based on run and epoch, downloads it if necessary"""
1519
api = wandb.Api()
@@ -26,7 +30,7 @@ def get_checkpoint_from_wandb(
2630
print(f"Downloading checkpoint to {dest_path}")
2731
wandb_util.download_file_from_url(dest_path, file.url, api.api_key)
2832
return model_class.load_from_checkpoint(
29-
dest_path, strict=False, map_location="cuda:0"
33+
dest_path, strict=False, map_location=map_device_to
3034
)
3135
print(f"No model found for epoch {epoch}")
3236
return None
@@ -54,8 +58,9 @@ def evaluate_model(
5458
os.makedirs(buffer_dir, exist_ok=True)
5559
save_ind = 0
5660
save_batch_size = 4
57-
n_saved = 0
61+
n_saved = 1
5862

63+
print(f"")
5964
for i in tqdm.tqdm(range(0, len(data_list), batch_size)):
6065
if not (
6166
skip_existing_preds
@@ -74,7 +79,6 @@ def evaluate_model(
7479
preds_list.append(preds)
7580
labels_list.append(labels)
7681
if buffer_dir is not None:
77-
n_saved += 1
7882
if n_saved >= save_batch_size:
7983
torch.save(
8084
torch.cat(preds_list),
@@ -87,8 +91,10 @@ def evaluate_model(
8791
)
8892
preds_list = []
8993
labels_list = []
90-
save_ind += 1
91-
n_saved = 0
94+
if n_saved >= save_batch_size:
95+
save_ind += 1
96+
n_saved = 0
97+
n_saved += 1
9298

9399
if buffer_dir is None:
94100
test_preds = torch.cat(preds_list)

configs/data/chebi100.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1 @@
1-
class_path: chebai.preprocessing.datasets.chebi.ChEBIOver100
2-
init_args:
3-
chebi_version: 231
1+
class_path: chebai.preprocessing.datasets.chebi.ChEBIOver100

0 commit comments

Comments
 (0)