Skip to content

Commit 97a134f

Browse files
author
sfluegel
committed
Merge remote-tracking branch 'origin/dev' into tutorial
2 parents d8c8edb + c06b058 commit 97a134f

25 files changed

+1129
-207
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from lightning.pytorch.callbacks import BasePredictionWriter
2+
import torch
3+
import os
4+
import pickle
5+
6+
7+
class PredictionWriter(BasePredictionWriter):
8+
def __init__(self, output_dir, write_interval):
9+
super().__init__(write_interval)
10+
self.output_dir = output_dir
11+
self.prediction_file_name = "predictions.pkl"
12+
13+
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
14+
results = [
15+
dict(
16+
ident=row["data"]["idents"][0],
17+
predictions=torch.sigmoid(row["output"]["logits"]).numpy(),
18+
labels=row["labels"][0].numpy() if row["labels"] is not None else None,
19+
)
20+
for row in predictions
21+
]
22+
with open(
23+
os.path.join(self.output_dir, self.prediction_file_name), "wb"
24+
) as fout:
25+
pickle.dump(results, fout)

chebai/loss/bce_weighted.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor
3+
import pandas as pd
4+
import os
5+
import pickle
6+
7+
8+
class BCEWeighted(torch.nn.BCEWithLogitsLoss):
9+
"""BCEWithLogitsLoss with weights automatically computed according to beta parameter (formula from
10+
https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf)
11+
"""
12+
13+
def __init__(self, beta: float = None, data_extractor: _ChEBIDataExtractor = None):
14+
self.beta = beta
15+
self.data_extractor = data_extractor
16+
super().__init__()
17+
18+
def set_pos_weight(self, input):
19+
if (
20+
self.beta is not None
21+
and self.data_extractor is not None
22+
and all(
23+
os.path.exists(os.path.join(self.data_extractor.raw_dir, raw_file))
24+
for raw_file in self.data_extractor.raw_file_names
25+
)
26+
and self.pos_weight is None
27+
):
28+
complete_data = pd.concat(
29+
[
30+
pickle.load(
31+
open(
32+
os.path.join(
33+
self.data_extractor.raw_dir,
34+
self.data_extractor.raw_file_names_dict[set],
35+
),
36+
"rb",
37+
)
38+
)
39+
for set in ["train", "validation", "test"]
40+
]
41+
)
42+
value_counts = []
43+
for c in complete_data.columns[3:]:
44+
value_counts.append(len([v for v in complete_data[c] if v]))
45+
weights = [
46+
(1 - self.beta) / (1 - pow(self.beta, value)) for value in value_counts
47+
]
48+
mean = sum(weights) / len(weights)
49+
self.pos_weight = torch.tensor(
50+
[w / mean for w in weights], device=input.device
51+
)
52+
53+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
54+
self.set_pos_weight(input)
55+
return super().forward(input, target)

chebai/loss/semantic.py

Lines changed: 88 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,40 @@
22
import os
33
import pickle
44

5+
import math
56
import torch
7+
from typing import Literal
68

7-
from chebai.models.electra import extract_class_hierarchy
8-
9-
IMPLICATION_CACHE_FILE = "chebi.cache"
9+
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor, ChEBIOver100
1010

1111

1212
class ImplicationLoss(torch.nn.Module):
1313
def __init__(
14-
self, path_to_chebi, path_to_label_names, base_loss: torch.nn.Module = None
14+
self,
15+
data_extractor: _ChEBIDataExtractor,
16+
base_loss: torch.nn.Module = None,
17+
tnorm: Literal["product", "lukasiewicz", "xu19"] = "product",
18+
impl_loss_weight=0.1, # weight of implication loss in relation to base_loss
19+
pos_scalar=1,
20+
pos_epsilon=0.01,
1521
):
1622
super().__init__()
23+
self.data_extractor = data_extractor
1724
self.base_loss = base_loss
18-
label_names = _load_label_names(path_to_label_names)
19-
hierarchy = _load_implications(path_to_chebi)
20-
implication_filter = _build_implication_filter(label_names, hierarchy)
25+
self.implication_cache_file = f"implications_{self.data_extractor.name}.cache"
26+
self.label_names = _load_label_names(
27+
os.path.join(data_extractor.raw_dir, "classes.txt")
28+
)
29+
self.hierarchy = self._load_implications(
30+
os.path.join(data_extractor.raw_dir, "chebi.obo")
31+
)
32+
implication_filter = _build_implication_filter(self.label_names, self.hierarchy)
2133
self.implication_filter_l = implication_filter[:, 0]
2234
self.implication_filter_r = implication_filter[:, 1]
35+
self.tnorm = tnorm
36+
self.impl_weight = impl_loss_weight
37+
self.pos_scalar = pos_scalar
38+
self.eps = pos_epsilon
2339

2440
def forward(self, input, target, **kwargs):
2541
nnl = kwargs.pop("non_null_labels", None)
@@ -36,40 +52,77 @@ def forward(self, input, target, **kwargs):
3652
r = pred[:, self.implication_filter_r]
3753
# implication_loss = torch.sqrt(torch.mean(torch.sum(l*(1-r), dim=-1), dim=0))
3854
implication_loss = self._calculate_implication_loss(l, r)
39-
return base_loss + implication_loss
55+
56+
return (
57+
base_loss + self.impl_weight * implication_loss,
58+
base_loss,
59+
implication_loss,
60+
)
4061

4162
def _calculate_implication_loss(self, l, r):
42-
capped_difference = torch.relu(l - r)
63+
assert not l.isnan().any()
64+
assert not r.isnan().any()
65+
if self.pos_scalar != 1:
66+
l = (
67+
torch.pow(l + self.eps, 1 / self.pos_scalar)
68+
- math.pow(self.eps, 1 / self.pos_scalar)
69+
) / (
70+
math.pow(1 + self.eps, 1 / self.pos_scalar)
71+
- math.pow(self.eps, 1 / self.pos_scalar)
72+
)
73+
r = torch.pow(r, self.pos_scalar)
74+
if self.tnorm == "product":
75+
individual_loss = l * (1 - r)
76+
elif self.tnorm == "xu19":
77+
individual_loss = -torch.log(1 - l * (1 - r))
78+
elif self.tnorm == "lukasiewicz":
79+
individual_loss = torch.relu(l - r)
80+
else:
81+
raise NotImplementedError(f"Unknown tnorm {self.tnorm}")
82+
4383
return torch.mean(
44-
torch.sum(
45-
(torch.softmax(capped_difference, dim=-1) * capped_difference), dim=-1
46-
),
84+
torch.sum(individual_loss, dim=-1),
4785
dim=0,
4886
)
4987

88+
def _load_implications(self, path_to_chebi):
89+
if os.path.isfile(self.implication_cache_file):
90+
with open(self.implication_cache_file, "rb") as fin:
91+
hierarchy = pickle.load(fin)
92+
else:
93+
hierarchy = self.data_extractor.extract_class_hierarchy(path_to_chebi)
94+
with open(self.implication_cache_file, "wb") as fout:
95+
pickle.dump(hierarchy, fout)
96+
return hierarchy
97+
5098

5199
class DisjointLoss(ImplicationLoss):
52100
def __init__(
53101
self,
54-
path_to_chebi,
55-
path_to_label_names,
56-
path_to_disjointedness,
102+
path_to_disjointness,
103+
data_extractor: _ChEBIDataExtractor,
57104
base_loss: torch.nn.Module = None,
105+
disjoint_loss_weight=100,
106+
**kwargs,
58107
):
59-
super().__init__(path_to_chebi, path_to_label_names, base_loss)
60-
label_names = _load_label_names(path_to_label_names)
61-
hierarchy = _load_implications(path_to_chebi)
108+
super().__init__(data_extractor, base_loss, **kwargs)
62109
self.disjoint_filter_l, self.disjoint_filter_r = _build_disjointness_filter(
63-
path_to_disjointedness, label_names, hierarchy
110+
path_to_disjointness, self.label_names, self.hierarchy
64111
)
112+
self.disjoint_weight = disjoint_loss_weight
65113

66114
def forward(self, input, target, **kwargs):
67-
loss = super().forward(input, target, **kwargs)
115+
loss, base_loss, impl_loss = super().forward(input, target, **kwargs)
68116
pred = torch.sigmoid(input)
69117
l = pred[:, self.disjoint_filter_l]
70118
r = pred[:, self.disjoint_filter_r]
71119
disjointness_loss = self._calculate_implication_loss(l, 1 - r)
72-
return loss + disjointness_loss
120+
return (
121+
loss + self.disjoint_weight * disjointness_loss,
122+
base_loss,
123+
impl_loss,
124+
disjointness_loss,
125+
)
73126

74127

75128
def _load_label_names(path_to_label_names):
@@ -78,17 +131,6 @@ def _load_label_names(path_to_label_names):
78131
return label_names
79132

80133

81-
def _load_implications(path_to_chebi, implication_cache=IMPLICATION_CACHE_FILE):
82-
if os.path.isfile(implication_cache):
83-
with open(implication_cache, "rb") as fin:
84-
hierarchy = pickle.load(fin)
85-
else:
86-
hierarchy = extract_class_hierarchy(path_to_chebi)
87-
with open(implication_cache, "wb") as fout:
88-
pickle.dump(hierarchy, fout)
89-
return hierarchy
90-
91-
92134
def _build_implication_filter(label_names, hierarchy):
93135
return torch.tensor(
94136
[
@@ -100,24 +142,33 @@ def _build_implication_filter(label_names, hierarchy):
100142
)
101143

102144

103-
def _build_disjointness_filter(path_to_disjointedness, label_names, hierarchy):
145+
def _build_disjointness_filter(path_to_disjointness, label_names, hierarchy):
104146
disjoints = set()
105147
label_dict = dict(map(reversed, enumerate(label_names)))
106148

107-
with open(path_to_disjointedness, "rt") as fin:
149+
with open(path_to_disjointness, "rt") as fin:
108150
reader = csv.reader(fin)
109151
for l1_raw, r1_raw in reader:
110152
l1 = int(l1_raw)
111153
r1 = int(r1_raw)
154+
if l1 == 36233 and r1 == 63353:
155+
# ignore disaccharide-disaccharide derivative disjointness axiom
156+
continue
112157
disjoints.update(
113158
{
114159
(label_dict[l2], label_dict[r2])
115-
for r2 in hierarchy.succ[r1]
160+
for r2 in list(hierarchy.succ[r1]) + [r1]
116161
if r2 in label_names
117-
for l2 in hierarchy.succ[l1]
118-
if l2 in label_names and l2 < r2
162+
for l2 in list(hierarchy.succ[l1]) + [l1]
163+
if l2 in label_names
119164
}
120165
)
121166

122167
dis_filter = torch.tensor(list(disjoints))
123168
return dis_filter[:, 0], dis_filter[:, 1]
169+
170+
171+
if __name__ == "__main__":
172+
loss = DisjointLoss(
173+
os.path.join("data", "disjoint.csv"), ChEBIOver100(chebi_version=227)
174+
)

chebai/models/base.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,24 @@
1313

1414

1515
class ChebaiBaseNet(LightningModule):
16+
"""
17+
Base class for Chebai neural network models inheriting from PyTorch Lightning's LightningModule.
18+
19+
Args:
20+
criterion (torch.nn.Module, optional): The loss criterion for the model. Defaults to None.
21+
out_dim (int, optional): The output dimension of the model. Defaults to None.
22+
train_metrics (torch.nn.Module, optional): The metrics to be used during training. Defaults to None.
23+
val_metrics (torch.nn.Module, optional): The metrics to be used during validation. Defaults to None.
24+
test_metrics (torch.nn.Module, optional): The metrics to be used during testing. Defaults to None.
25+
pass_loss_kwargs (bool, optional): Whether to pass loss kwargs to the criterion. Defaults to True.
26+
optimizer_kwargs (typing.Dict, optional): Additional keyword arguments for the optimizer. Defaults to None.
27+
**kwargs: Additional keyword arguments.
28+
29+
Attributes:
30+
NAME (str): The name of the model.
31+
LOSS (torch.nn.Module): The loss function used by the model.
32+
"""
33+
1634
NAME = None
1735
LOSS = torch.nn.BCEWithLogitsLoss
1836

@@ -85,6 +103,20 @@ def predict_step(self, batch, batch_idx, **kwargs):
85103
return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False)
86104

87105
def _execute(self, batch, batch_idx, metrics, prefix="", log=True, sync_dist=False):
106+
"""
107+
Executes the model on a batch of data and returns the model output and predictions.
108+
109+
Args:
110+
batch (XYData): The input batch of data.
111+
batch_idx (int): The index of the current batch.
112+
metrics (dict): A dictionary of metrics to track.
113+
prefix (str, optional): A prefix to add to the metric names. Defaults to "".
114+
log (bool, optional): Whether to log the metrics. Defaults to True.
115+
sync_dist (bool, optional): Whether to synchronize distributed training. Defaults to False.
116+
117+
Returns:
118+
dict: A dictionary containing the processed data, labels, model_output, predictions, and loss (if applicable).
119+
"""
88120
assert isinstance(batch, XYData)
89121
batch = batch.to(self.device)
90122
data = self._process_batch(batch, batch_idx)
@@ -101,6 +133,21 @@ def _execute(self, batch, batch_idx, metrics, prefix="", log=True, sync_dist=Fal
101133
if self.pass_loss_kwargs:
102134
loss_kwargs = loss_kwargs_candidates
103135
loss = self.criterion(loss_data, loss_labels, **loss_kwargs)
136+
if isinstance(loss, tuple):
137+
loss_additional = loss[1:]
138+
for i, loss_add in enumerate(loss_additional):
139+
self.log(
140+
f"{prefix}loss_{i}",
141+
loss_add if isinstance(loss_add, int) else loss_add.item(),
142+
batch_size=len(batch),
143+
on_step=True,
144+
on_epoch=False,
145+
prog_bar=False,
146+
logger=True,
147+
sync_dist=sync_dist,
148+
)
149+
loss = loss[0]
150+
104151
d["loss"] = loss
105152
self.log(
106153
f"{prefix}loss",
@@ -119,6 +166,17 @@ def _execute(self, batch, batch_idx, metrics, prefix="", log=True, sync_dist=Fal
119166
return d
120167

121168
def _log_metrics(self, prefix, metrics, batch_size):
169+
"""
170+
Logs the metrics for the given prefix.
171+
172+
Args:
173+
prefix (str): The prefix to be added to the metric names.
174+
metrics (dict): A dictionary containing the metrics to be logged.
175+
batch_size (int): The batch size used for logging.
176+
177+
Returns:
178+
None
179+
"""
122180
# don't use sync_dist=True if the metric is a torchmetrics-metric
123181
# (see https://github.com/Lightning-AI/pytorch-lightning/discussions/6501#discussioncomment-569757)
124182
for metric_name, metric in metrics.items():

0 commit comments

Comments
 (0)