Skip to content

Commit a8cc1dc

Browse files
author
sfluegel
committed
Merge remote-tracking branch 'origin/dev' into code_documentation
# Conflicts: # chebai/models/electra.py
2 parents 7f22f83 + 290348d commit a8cc1dc

24 files changed

+784
-206
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from lightning.pytorch.callbacks import BasePredictionWriter
2+
import torch
3+
import os
4+
import pickle
5+
6+
7+
class PredictionWriter(BasePredictionWriter):
8+
def __init__(self, output_dir, write_interval):
9+
super().__init__(write_interval)
10+
self.output_dir = output_dir
11+
self.prediction_file_name = "predictions.pkl"
12+
13+
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
14+
results = [
15+
dict(
16+
ident=row["data"]["idents"][0],
17+
predictions=torch.sigmoid(row["output"]["logits"]).numpy(),
18+
labels=row["labels"][0].numpy() if row["labels"] is not None else None,
19+
)
20+
for row in predictions
21+
]
22+
with open(
23+
os.path.join(self.output_dir, self.prediction_file_name), "wb"
24+
) as fout:
25+
pickle.dump(results, fout)

chebai/loss/bce_weighted.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor
3+
import pandas as pd
4+
import os
5+
import pickle
6+
7+
8+
class BCEWeighted(torch.nn.BCEWithLogitsLoss):
9+
"""BCEWithLogitsLoss with weights automatically computed according to beta parameter (formula from
10+
https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf)
11+
"""
12+
13+
def __init__(self, beta: float = None, data_extractor: _ChEBIDataExtractor = None):
14+
self.beta = beta
15+
self.data_extractor = data_extractor
16+
super().__init__()
17+
18+
def set_pos_weight(self, input):
19+
if (
20+
self.beta is not None
21+
and self.data_extractor is not None
22+
and all(
23+
os.path.exists(os.path.join(self.data_extractor.raw_dir, raw_file))
24+
for raw_file in self.data_extractor.raw_file_names
25+
)
26+
and self.pos_weight is None
27+
):
28+
complete_data = pd.concat(
29+
[
30+
pickle.load(
31+
open(
32+
os.path.join(
33+
self.data_extractor.raw_dir,
34+
self.data_extractor.raw_file_names_dict[set],
35+
),
36+
"rb",
37+
)
38+
)
39+
for set in ["train", "validation", "test"]
40+
]
41+
)
42+
value_counts = []
43+
for c in complete_data.columns[3:]:
44+
value_counts.append(len([v for v in complete_data[c] if v]))
45+
weights = [
46+
(1 - self.beta) / (1 - pow(self.beta, value)) for value in value_counts
47+
]
48+
mean = sum(weights) / len(weights)
49+
self.pos_weight = torch.tensor(
50+
[w / mean for w in weights], device=input.device
51+
)
52+
53+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
54+
self.set_pos_weight(input)
55+
return super().forward(input, target)

chebai/loss/semantic.py

Lines changed: 88 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,40 @@
22
import os
33
import pickle
44

5+
import math
56
import torch
7+
from typing import Literal
68

7-
from chebai.models.electra import extract_class_hierarchy
8-
9-
IMPLICATION_CACHE_FILE = "chebi.cache"
9+
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor, ChEBIOver100
1010

1111

1212
class ImplicationLoss(torch.nn.Module):
1313
def __init__(
14-
self, path_to_chebi, path_to_label_names, base_loss: torch.nn.Module = None
14+
self,
15+
data_extractor: _ChEBIDataExtractor,
16+
base_loss: torch.nn.Module = None,
17+
tnorm: Literal["product", "lukasiewicz", "xu19"] = "product",
18+
impl_loss_weight=0.1, # weight of implication loss in relation to base_loss
19+
pos_scalar=1,
20+
pos_epsilon=0.01,
1521
):
1622
super().__init__()
23+
self.data_extractor = data_extractor
1724
self.base_loss = base_loss
18-
label_names = _load_label_names(path_to_label_names)
19-
hierarchy = _load_implications(path_to_chebi)
20-
implication_filter = _build_implication_filter(label_names, hierarchy)
25+
self.implication_cache_file = f"implications_{self.data_extractor.name}.cache"
26+
self.label_names = _load_label_names(
27+
os.path.join(data_extractor.raw_dir, "classes.txt")
28+
)
29+
self.hierarchy = self._load_implications(
30+
os.path.join(data_extractor.raw_dir, "chebi.obo")
31+
)
32+
implication_filter = _build_implication_filter(self.label_names, self.hierarchy)
2133
self.implication_filter_l = implication_filter[:, 0]
2234
self.implication_filter_r = implication_filter[:, 1]
35+
self.tnorm = tnorm
36+
self.impl_weight = impl_loss_weight
37+
self.pos_scalar = pos_scalar
38+
self.eps = pos_epsilon
2339

2440
def forward(self, input, target, **kwargs):
2541
nnl = kwargs.pop("non_null_labels", None)
@@ -36,40 +52,77 @@ def forward(self, input, target, **kwargs):
3652
r = pred[:, self.implication_filter_r]
3753
# implication_loss = torch.sqrt(torch.mean(torch.sum(l*(1-r), dim=-1), dim=0))
3854
implication_loss = self._calculate_implication_loss(l, r)
39-
return base_loss + implication_loss
55+
56+
return (
57+
base_loss + self.impl_weight * implication_loss,
58+
base_loss,
59+
implication_loss,
60+
)
4061

4162
def _calculate_implication_loss(self, l, r):
42-
capped_difference = torch.relu(l - r)
63+
assert not l.isnan().any()
64+
assert not r.isnan().any()
65+
if self.pos_scalar != 1:
66+
l = (
67+
torch.pow(l + self.eps, 1 / self.pos_scalar)
68+
- math.pow(self.eps, 1 / self.pos_scalar)
69+
) / (
70+
math.pow(1 + self.eps, 1 / self.pos_scalar)
71+
- math.pow(self.eps, 1 / self.pos_scalar)
72+
)
73+
r = torch.pow(r, self.pos_scalar)
74+
if self.tnorm == "product":
75+
individual_loss = l * (1 - r)
76+
elif self.tnorm == "xu19":
77+
individual_loss = -torch.log(1 - l * (1 - r))
78+
elif self.tnorm == "lukasiewicz":
79+
individual_loss = torch.relu(l - r)
80+
else:
81+
raise NotImplementedError(f"Unknown tnorm {self.tnorm}")
82+
4383
return torch.mean(
44-
torch.sum(
45-
(torch.softmax(capped_difference, dim=-1) * capped_difference), dim=-1
46-
),
84+
torch.sum(individual_loss, dim=-1),
4785
dim=0,
4886
)
4987

88+
def _load_implications(self, path_to_chebi):
89+
if os.path.isfile(self.implication_cache_file):
90+
with open(self.implication_cache_file, "rb") as fin:
91+
hierarchy = pickle.load(fin)
92+
else:
93+
hierarchy = self.data_extractor.extract_class_hierarchy(path_to_chebi)
94+
with open(self.implication_cache_file, "wb") as fout:
95+
pickle.dump(hierarchy, fout)
96+
return hierarchy
97+
5098

5199
class DisjointLoss(ImplicationLoss):
52100
def __init__(
53101
self,
54-
path_to_chebi,
55-
path_to_label_names,
56-
path_to_disjointedness,
102+
path_to_disjointness,
103+
data_extractor: _ChEBIDataExtractor,
57104
base_loss: torch.nn.Module = None,
105+
disjoint_loss_weight=100,
106+
**kwargs,
58107
):
59-
super().__init__(path_to_chebi, path_to_label_names, base_loss)
60-
label_names = _load_label_names(path_to_label_names)
61-
hierarchy = _load_implications(path_to_chebi)
108+
super().__init__(data_extractor, base_loss, **kwargs)
62109
self.disjoint_filter_l, self.disjoint_filter_r = _build_disjointness_filter(
63-
path_to_disjointedness, label_names, hierarchy
110+
path_to_disjointness, self.label_names, self.hierarchy
64111
)
112+
self.disjoint_weight = disjoint_loss_weight
65113

66114
def forward(self, input, target, **kwargs):
67-
loss = super().forward(input, target, **kwargs)
115+
loss, base_loss, impl_loss = super().forward(input, target, **kwargs)
68116
pred = torch.sigmoid(input)
69117
l = pred[:, self.disjoint_filter_l]
70118
r = pred[:, self.disjoint_filter_r]
71119
disjointness_loss = self._calculate_implication_loss(l, 1 - r)
72-
return loss + disjointness_loss
120+
return (
121+
loss + self.disjoint_weight * disjointness_loss,
122+
base_loss,
123+
impl_loss,
124+
disjointness_loss,
125+
)
73126

74127

75128
def _load_label_names(path_to_label_names):
@@ -78,17 +131,6 @@ def _load_label_names(path_to_label_names):
78131
return label_names
79132

80133

81-
def _load_implications(path_to_chebi, implication_cache=IMPLICATION_CACHE_FILE):
82-
if os.path.isfile(implication_cache):
83-
with open(implication_cache, "rb") as fin:
84-
hierarchy = pickle.load(fin)
85-
else:
86-
hierarchy = extract_class_hierarchy(path_to_chebi)
87-
with open(implication_cache, "wb") as fout:
88-
pickle.dump(hierarchy, fout)
89-
return hierarchy
90-
91-
92134
def _build_implication_filter(label_names, hierarchy):
93135
return torch.tensor(
94136
[
@@ -100,24 +142,33 @@ def _build_implication_filter(label_names, hierarchy):
100142
)
101143

102144

103-
def _build_disjointness_filter(path_to_disjointedness, label_names, hierarchy):
145+
def _build_disjointness_filter(path_to_disjointness, label_names, hierarchy):
104146
disjoints = set()
105147
label_dict = dict(map(reversed, enumerate(label_names)))
106148

107-
with open(path_to_disjointedness, "rt") as fin:
149+
with open(path_to_disjointness, "rt") as fin:
108150
reader = csv.reader(fin)
109151
for l1_raw, r1_raw in reader:
110152
l1 = int(l1_raw)
111153
r1 = int(r1_raw)
154+
if l1 == 36233 and r1 == 63353:
155+
# ignore disaccharide-disaccharide derivative disjointness axiom
156+
continue
112157
disjoints.update(
113158
{
114159
(label_dict[l2], label_dict[r2])
115-
for r2 in hierarchy.succ[r1]
160+
for r2 in list(hierarchy.succ[r1]) + [r1]
116161
if r2 in label_names
117-
for l2 in hierarchy.succ[l1]
118-
if l2 in label_names and l2 < r2
162+
for l2 in list(hierarchy.succ[l1]) + [l1]
163+
if l2 in label_names
119164
}
120165
)
121166

122167
dis_filter = torch.tensor(list(disjoints))
123168
return dis_filter[:, 0], dis_filter[:, 1]
169+
170+
171+
if __name__ == "__main__":
172+
loss = DisjointLoss(
173+
os.path.join("data", "disjoint.csv"), ChEBIOver100(chebi_version=227)
174+
)

chebai/models/base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,21 @@ def _execute(self, batch, batch_idx, metrics, prefix="", log=True, sync_dist=Fal
133133
if self.pass_loss_kwargs:
134134
loss_kwargs = loss_kwargs_candidates
135135
loss = self.criterion(loss_data, loss_labels, **loss_kwargs)
136+
if isinstance(loss, tuple):
137+
loss_additional = loss[1:]
138+
for i, loss_add in enumerate(loss_additional):
139+
self.log(
140+
f"{prefix}loss_{i}",
141+
loss_add if isinstance(loss_add, int) else loss_add.item(),
142+
batch_size=len(batch),
143+
on_step=True,
144+
on_epoch=False,
145+
prog_bar=False,
146+
logger=True,
147+
sync_dist=sync_dist,
148+
)
149+
loss = loss[0]
150+
136151
d["loss"] = loss
137152
self.log(
138153
f"{prefix}loss",

chebai/models/electra.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
2020

21+
from chebai.loss.semantic import DisjointLoss as ElectraChEBIDisjointLoss # noqa
22+
2123

2224
class ElectraPre(ChebaiBaseNet):
2325
"""
@@ -245,14 +247,9 @@ def _process_for_loss(self, model_output, labels, loss_kwargs):
245247
246248
"""
247249
kwargs_copy = dict(loss_kwargs)
248-
mask = kwargs_copy.pop("target_mask", None)
249-
if mask is not None:
250-
d = model_output["logits"] * mask - 100 * ~mask
251-
else:
252-
d = model_output["logits"]
253250
if labels is not None:
254251
labels = labels.float()
255-
return d, labels, kwargs_copy
252+
return model_output["logits"], labels, kwargs_copy
256253

257254
def _get_prediction_and_labels(self, data, labels, model_output):
258255
"""
@@ -267,16 +264,12 @@ def _get_prediction_and_labels(self, data, labels, model_output):
267264
tuple: A tuple containing the predictions and labels.
268265
269266
"""
270-
mask = model_output.get("target_mask")
271-
if mask is not None:
272-
d = model_output["logits"] * mask - 100 * ~mask
273-
else:
274-
d = model_output["logits"]
267+
d = model_output["logits"]
275268
loss_kwargs = data.get("loss_kwargs", dict())
276269
if "non_null_labels" in loss_kwargs:
277270
n = loss_kwargs["non_null_labels"]
278271
d = d[n]
279-
return torch.sigmoid(d), labels.int()
272+
return torch.sigmoid(d), labels.int() if labels is not None else None
280273

281274
def forward(self, data, **kwargs):
282275
"""
@@ -303,7 +296,6 @@ def forward(self, data, **kwargs):
303296
return dict(
304297
logits=self.output(d),
305298
attentions=electra.attentions,
306-
target_mask=data.get("target_mask"),
307299
)
308300

309301

@@ -359,7 +351,6 @@ def _process_batch(self, batch, batch_idx):
359351
features=torch.cat((cls_tokens, batch.x), dim=1),
360352
labels=batch.y,
361353
model_kwargs=dict(attention_mask=mask),
362-
target_mask=batch.target_mask,
363354
)
364355

365356
@property
@@ -418,7 +409,6 @@ def __init__(self, cone_dimensions=20, **kwargs):
418409
)
419410

420411
def _get_data_for_loss(self, model_output, labels):
421-
mask = model_output.get("target_mask")
422412
d = model_output["predicted_vectors"]
423413
return dict(
424414
input=dict(
@@ -428,7 +418,6 @@ def _get_data_for_loss(self, model_output, labels):
428418
)
429419

430420
def _get_prediction_and_labels(self, data, labels, model_output):
431-
mask = model_output.get("target_mask")
432421
d = model_output["predicted_vectors"].unsqueeze(1)
433422

434423
d = in_cone_parts(d, self.cone_axes, self.cone_arcs)
@@ -444,7 +433,6 @@ def forward(self, data, **kwargs):
444433
return dict(
445434
predicted_vectors=self.line_embedding(d),
446435
attentions=electra.attentions,
447-
target_mask=data.get("target_mask"),
448436
)
449437

450438

chebai/preprocessing/bin/smiles_token/tokens.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,3 +767,5 @@ p
767767
[Nd]
768768
[Ti+3]
769769
[14CH3]
770+
[HH]
771+
[CH3-]

0 commit comments

Comments
 (0)