Skip to content

Commit d58b097

Browse files
committed
precalculate graph and successor relationship for more efficient call-time performance
1 parent bda3ed7 commit d58b097

File tree

2 files changed

+34
-33
lines changed

2 files changed

+34
-33
lines changed

chebai/result/analyse_sem.py

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ def get_chebi_graph(data_module, label_names):
135135
chebi_graph = data_module._extract_class_hierarchy(
136136
os.path.join(data_module.raw_dir, "chebi.obo")
137137
)
138+
if label_names is None:
139+
return chebi_graph
138140
return chebi_graph.subgraph([int(n) for n in label_names])
139141
print(
140142
f"Failed to retrieve ChEBI graph, {os.path.join(data_module.raw_dir, 'chebi.obo')} not found"
@@ -190,32 +192,38 @@ class PredictionSmoother:
190192
"""Removes implication and disjointness violations from predictions"""
191193

192194
def __init__(self, dataset, label_names=None, disjoint_files=None):
193-
if label_names:
194-
self.label_names = label_names
195-
else:
196-
self.label_names = get_label_names(dataset)
197-
self.chebi_graph = get_chebi_graph(dataset, self.label_names)
195+
self.chebi_graph = get_chebi_graph(dataset, None)
196+
self.set_label_names(label_names)
198197
self.disjoint_groups = get_disjoint_groups(disjoint_files)
199198

199+
def set_label_names(self, label_names):
200+
if label_names is not None:
201+
self.label_names = [int(label) for label in label_names]
202+
chebi_subgraph = self.chebi_graph.subgraph(self.label_names)
203+
self.label_successors = torch.zeros(
204+
(len(self.label_names), len(self.label_names)), dtype=torch.bool
205+
)
206+
for i, label in enumerate(self.label_names):
207+
self.label_successors[i, i] = 1
208+
for p in chebi_subgraph.successors(label):
209+
if p in self.label_names:
210+
self.label_successors[i, self.label_names.index(p)] = 1
211+
self.label_successors = self.label_successors.unsqueeze(0)
212+
200213
def __call__(self, preds):
201214
preds_sum_orig = torch.sum(preds)
202-
for i, label in enumerate(self.label_names):
203-
succs = [
204-
self.label_names.index(str(p))
205-
for p in self.chebi_graph.successors(int(label))
206-
] + [i]
207-
if len(succs) > 0:
208-
preds[:, i] = torch.max(preds[:, succs], dim=1).values
215+
# step 1: apply implications: for each class, set prediction to max of itself and all successors
216+
preds = preds.unsqueeze(1)
217+
preds_masked_succ = torch.where(self.label_successors, preds, 0)
218+
preds = preds_masked_succ.max(dim=2).values
209219
if torch.sum(preds) != preds_sum_orig:
210220
print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}")
211221
preds_sum_orig = torch.sum(preds)
212222
# step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
213223
preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49)
214224
for disj_group in self.disjoint_groups:
215225
disj_group = [
216-
self.label_names.index(str(g))
217-
for g in disj_group
218-
if g in self.label_names
226+
self.label_names.index(g) for g in disj_group if g in self.label_names
219227
]
220228
if len(disj_group) > 1:
221229
old_preds = preds[:, disj_group]
@@ -232,26 +240,17 @@ def __call__(self, preds):
232240
print(
233241
f"disjointness group {[self.label_names[d] for d in disj_group]} changed {samples_changed} samples"
234242
)
243+
if torch.sum(preds) != preds_sum_orig:
244+
print(f"Preds change (step 2): {torch.sum(preds) - preds_sum_orig}")
235245
preds_sum_orig = torch.sum(preds)
236246
# step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors
237-
for i, label in enumerate(self.label_names):
238-
predecessors = [i] + [
239-
self.label_names.index(str(p))
240-
for p in self.chebi_graph.predecessors(int(label))
241-
]
242-
lowest_predecessors = torch.min(preds[:, predecessors], dim=1)
243-
preds[:, i] = lowest_predecessors.values
244-
for idx_idx, idx in enumerate(lowest_predecessors.indices):
245-
if idx > 0:
246-
print(
247-
f"class {label}: changed prediction of sample {idx_idx} to value of class "
248-
f"{self.label_names[predecessors[idx]]} ({preds[idx_idx, i].item():.2f})"
249-
)
250-
if torch.sum(preds) != preds_sum_orig:
251-
print(
252-
f"Preds change (step 3) for {label}: {torch.sum(preds) - preds_sum_orig}"
253-
)
254-
preds_sum_orig = torch.sum(preds)
247+
preds = preds.unsqueeze(1)
248+
preds_masked_predec = torch.where(
249+
torch.transpose(self.label_successors, 1, 2), preds, 1
250+
)
251+
preds = preds_masked_predec.min(dim=2).values
252+
if torch.sum(preds) != preds_sum_orig:
253+
print(f"Preds change (step 3): {torch.sum(preds) - preds_sum_orig}")
255254
return preds
256255

257256

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
license="",
1313
author="MGlauer",
1414
author_email="martin.glauer@ovgu.de",
15+
maintainer="sfluegel05",
16+
maintainer_email="simon.fluegel@uni-osnabrueck.de",
1517
description="",
1618
zip_safe=False,
1719
python_requires=">=3.9, <3.13",

0 commit comments

Comments
 (0)