Skip to content

Commit 836af9e

Browse files
authored
Merge pull request #130 from schnamo/dev
New regression and classification datasets for ontology pre-training
2 parents 8c5ebcd + 01f9f5f commit 836af9e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+3057
-74
lines changed

README.md

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,30 @@
33
ChEBai is a deep learning library designed for the integration of deep learning methods with chemical ontologies, particularly ChEBI.
44
The library emphasizes the incorporation of the semantic qualities of the ontology into the learning process.
55

6-
## Installation
6+
## News
7+
8+
We now support regression tasks!
9+
10+
## Note for developers
711

8-
You can install ChEBai via pip:
12+
If you have used ChEBai before PR #39, the file structure in which your ChEBI-data is saved has changed. This means that
13+
datasets will be freshly generated. The data however is the same. If you want to keep the old data (including the old
14+
splits), you can use a migration script. It copies the old data to the new location for a specific ChEBI class
15+
(including chebi version and other parameters). The script can be called by specifying the data module from a config
916
```
10-
pip install chebai
17+
python chebai/preprocessing/migration/chebi_data_migration.py migrate --datamodule=[path-to-data-config]
18+
```
19+
or by specifying the class name (e.g. `ChEBIOver50`) and arguments separately
1120
```
21+
python chebai/preprocessing/migration/chebi_data_migration.py migrate --class_name=[data-class] [--chebi_version=[version]]
22+
```
23+
The new dataset will by default generate random data splits (with a given seed).
24+
To reuse a fixed data split, you have to provide the path of the csv file generated during the migration:
25+
`--data.init_args.splits_file_path=[path-to-processed_data]/splits.csv`
1226

13-
Alternatively, you can get the latest development version directly from GitHub:
27+
## Installation
28+
29+
To install ChEBai, follow these steps:
1430

1531
1. Clone the repository:
1632
```
@@ -63,11 +79,16 @@ A command with additional options may look like this:
6379
python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi50.yml --model.criterion=configs/loss/bce.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_unweighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000
6480
```
6581

66-
### Fine-tuning for Toxicity prediction
82+
### Fine-tuning for classification tasks, e.g. Toxicity prediction
6783
```
6884
python -m chebai fit --config=[path-to-your-tox21-config] --trainer.callbacks=configs/training/default_callbacks.yml --model.pretrained_checkpoint=[path-to-pretrained-model]
6985
```
7086

87+
### Fine-tuning for regression tasks, e.g. solubility prediction
88+
```
89+
python -m chebai fit --config=[path-to-your-esol-config] --trainer.callbacks=configs/training/solCur_callbacks.yml --model.pretrained_checkpoint=[path-to-pretrained-model]
90+
```
91+
7192
### Predicting classes given SMILES strings
7293
```
7394
python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]]

chebai/cli.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,40 @@ def call_data_methods(data: Type[XYBaseDataModule]):
6060
)
6161

6262
for kind in ("train", "val", "test"):
63-
for average in ("micro-f1", "macro-f1", "balanced-accuracy"):
63+
for average in (
64+
"micro-f1",
65+
"macro-f1",
66+
"balanced-accuracy",
67+
"roc-auc",
68+
"f1",
69+
"mse",
70+
"rmse",
71+
"r2",
72+
):
73+
# When using lightning > 2.5.1 then need to uncomment all metrics that are not used
74+
# for average in ("mse", "rmse","r2"): # for regression
75+
# for average in ("f1", "roc-auc"): # for binary classification
76+
# for average in ("micro-f1", "macro-f1", "roc-auc"): # for multilabel classification
77+
# for average in ("micro-f1", "macro-f1", "balanced-accuracy", "roc-auc"): # for multilabel classification using balanced-accuracy
6478
parser.link_arguments(
6579
"data.num_of_labels",
6680
f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels",
6781
apply_on="instantiate",
6882
)
83+
6984
parser.link_arguments(
7085
"data.num_of_labels", "trainer.callbacks.init_args.num_labels"
7186
)
87+
# parser.link_arguments(
88+
# "model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
89+
# )
90+
# parser.link_arguments(
91+
# "data", "model.init_args.criterion.init_args.data_extractor"
92+
# )
93+
# parser.link_arguments(
94+
# "data.init_args.chebi_version",
95+
# "model.init_args.criterion.init_args.data_extractor.init_args.chebi_version",
96+
# )
7297

7398
parser.link_arguments(
7499
"data", "model.init_args.criterion.init_args.data_extractor"

chebai/loss/focal_loss.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
# from https://github.com/itakurah/Focal-loss-PyTorch
7+
8+
9+
class FocalLoss(nn.Module):
10+
def __init__(
11+
self,
12+
gamma=2,
13+
alpha=None,
14+
reduction="mean",
15+
task_type="binary",
16+
num_classes=None,
17+
):
18+
"""
19+
Unified Focal Loss class for binary, multi-class, and multi-label classification tasks.
20+
:param gamma: Focusing parameter, controls the strength of the modulating factor (1 - p_t)^gamma
21+
:param alpha: Balancing factor, can be a scalar or a tensor for class-wise weights. If None, no class balancing is used.
22+
:param reduction: Specifies the reduction method: 'none' | 'mean' | 'sum'
23+
:param task_type: Specifies the type of task: 'binary', 'multi-class', or 'multi-label'
24+
:param num_classes: Number of classes (only required for multi-class classification)
25+
"""
26+
super(FocalLoss, self).__init__()
27+
self.gamma = gamma
28+
self.alpha = alpha
29+
self.reduction = reduction
30+
self.task_type = task_type
31+
self.num_classes = num_classes
32+
33+
# Handle alpha for class balancing in multi-class tasks
34+
if (
35+
task_type == "multi-class"
36+
and alpha is not None
37+
and isinstance(alpha, (list, torch.Tensor))
38+
):
39+
assert (
40+
num_classes is not None
41+
), "num_classes must be specified for multi-class classification"
42+
if isinstance(alpha, list):
43+
self.alpha = torch.Tensor(alpha)
44+
else:
45+
self.alpha = alpha
46+
47+
def forward(self, inputs, targets):
48+
"""
49+
Forward pass to compute the Focal Loss based on the specified task type.
50+
:param inputs: Predictions (logits) from the model.
51+
Shape:
52+
- binary/multi-label: (batch_size, num_classes)
53+
- multi-class: (batch_size, num_classes)
54+
:param targets: Ground truth labels.
55+
Shape:
56+
- binary: (batch_size,)
57+
- multi-label: (batch_size, num_classes)
58+
- multi-class: (batch_size,)
59+
"""
60+
if self.task_type == "binary":
61+
return self.binary_focal_loss(inputs, targets)
62+
elif self.task_type == "multi-class":
63+
return self.multi_class_focal_loss(inputs, targets)
64+
elif self.task_type == "multi-label":
65+
return self.multi_label_focal_loss(inputs, targets)
66+
else:
67+
raise ValueError(
68+
f"Unsupported task_type '{self.task_type}'. Use 'binary', 'multi-class', or 'multi-label'."
69+
)
70+
71+
def binary_focal_loss(self, inputs, targets):
72+
"""Focal loss for binary classification."""
73+
probs = torch.sigmoid(inputs)
74+
targets = targets.float()
75+
76+
# Compute binary cross entropy
77+
bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
78+
79+
# Compute focal weight
80+
p_t = probs * targets + (1 - probs) * (1 - targets)
81+
focal_weight = (1 - p_t) ** self.gamma
82+
83+
# Apply alpha if provided
84+
if self.alpha is not None:
85+
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
86+
bce_loss = alpha_t * bce_loss
87+
88+
# Apply focal loss weighting
89+
loss = focal_weight * bce_loss
90+
91+
if self.reduction == "mean":
92+
return loss.mean()
93+
elif self.reduction == "sum":
94+
return loss.sum()
95+
return loss
96+
97+
def multi_class_focal_loss(self, inputs, targets):
98+
"""Focal loss for multi-class classification."""
99+
if self.alpha is not None:
100+
alpha = self.alpha.to(inputs.device)
101+
102+
# Convert logits to probabilities with softmax
103+
probs = F.softmax(inputs, dim=1)
104+
105+
# One-hot encode the targets
106+
targets_one_hot = F.one_hot(targets, num_classes=self.num_classes).float()
107+
108+
# Compute cross-entropy for each class
109+
ce_loss = -targets_one_hot * torch.log(probs)
110+
111+
# Compute focal weight
112+
p_t = torch.sum(probs * targets_one_hot, dim=1) # p_t for each sample
113+
focal_weight = (1 - p_t) ** self.gamma
114+
115+
# Apply alpha if provided (per-class weighting)
116+
if self.alpha is not None:
117+
alpha_t = alpha.gather(0, targets)
118+
ce_loss = alpha_t.unsqueeze(1) * ce_loss
119+
120+
# Apply focal loss weight
121+
loss = focal_weight.unsqueeze(1) * ce_loss
122+
123+
if self.reduction == "mean":
124+
return loss.mean()
125+
elif self.reduction == "sum":
126+
return loss.sum()
127+
return loss
128+
129+
def multi_label_focal_loss(self, inputs, targets):
130+
"""Focal loss for multi-label classification."""
131+
probs = torch.sigmoid(inputs)
132+
133+
# Compute binary cross entropy
134+
bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
135+
136+
# Compute focal weight
137+
p_t = probs * targets + (1 - probs) * (1 - targets)
138+
focal_weight = (1 - p_t) ** self.gamma
139+
140+
# Apply alpha if provided
141+
if self.alpha is not None:
142+
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
143+
bce_loss = alpha_t * bce_loss
144+
145+
# Apply focal loss weight
146+
loss = focal_weight * bce_loss
147+
148+
if self.reduction == "mean":
149+
return loss.mean()
150+
elif self.reduction == "sum":
151+
return loss.sum()
152+
return loss

chebai/loss/semantic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import math
33
import os
44
import pickle
5-
from typing import TYPE_CHECKING, List, Literal, Union
5+
from typing import TYPE_CHECKING, List, Literal, Union, Tuple
66

77
import torch
88

@@ -62,7 +62,7 @@ def __init__(
6262
pos_epsilon: float = 0.01,
6363
multiply_by_softmax: bool = False,
6464
use_sigmoidal_implication: bool = False,
65-
weight_epoch_dependent: Union[bool | tuple[int, int]] = False,
65+
weight_epoch_dependent: Union[bool, Tuple[int, int]] = False,
6666
start_at_epoch: int = 0,
6767
violations_per_cls_aggregator: Literal[
6868
"sum", "max", "mean", "log-sum", "log-max", "log-mean"

chebai/models/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def __init__(
4242
exclude_hyperparameter_logging: Optional[Iterable[str]] = None,
4343
**kwargs,
4444
):
45-
super().__init__()
45+
super().__init__(**kwargs)
46+
# super().__init__()
4647
if exclude_hyperparameter_logging is None:
4748
exclude_hyperparameter_logging = tuple()
4849
self.criterion = criterion
@@ -277,7 +278,6 @@ def _execute(
277278
loss_kwargs = dict()
278279
if self.pass_loss_kwargs:
279280
loss_kwargs = loss_kwargs_candidates
280-
loss_kwargs["current_epoch"] = self.trainer.current_epoch
281281
loss = self.criterion(loss_data, loss_labels, **loss_kwargs)
282282
if isinstance(loss, tuple):
283283
unnamed_loss_index = 1

chebai/models/electra.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
2121

22+
2223
from chebai.loss.semantic import DisjointLoss as ElectraChEBIDisjointLoss # noqa
2324

2425

@@ -40,6 +41,7 @@ class ElectraPre(ChebaiBaseNet):
4041

4142
def __init__(self, config: Dict[str, Any] = None, **kwargs: Any):
4243
super().__init__(config=config, **kwargs)
44+
4345
self.generator_config = ElectraConfig(**config["generator"])
4446
self.generator = ElectraForMaskedLM(self.generator_config)
4547
self.discriminator_config = ElectraConfig(**config["discriminator"])
@@ -224,6 +226,7 @@ def __init__(
224226
config: Optional[Dict[str, Any]] = None,
225227
pretrained_checkpoint: Optional[str] = None,
226228
load_prefix: Optional[str] = None,
229+
model_type="classification",
227230
freeze_electra: bool = False,
228231
**kwargs: Any,
229232
):
@@ -237,6 +240,8 @@ def __init__(
237240
config["num_labels"] = self.out_dim
238241
self.config = ElectraConfig(**config, output_attentions=True)
239242
self.word_dropout = nn.Dropout(config.get("word_dropout", 0))
243+
self.model_type = model_type
244+
self.pass_loss_kwargs = True
240245

241246
in_d = self.config.hidden_size
242247
self.output = nn.Sequential(
@@ -285,9 +290,16 @@ def _process_for_loss(
285290
tuple: A tuple containing the processed model output, labels, and loss arguments.
286291
"""
287292
kwargs_copy = dict(loss_kwargs)
293+
output = model_output["logits"]
288294
if labels is not None:
289295
labels = labels.float()
290-
return model_output["logits"], labels, kwargs_copy
296+
if "missing_labels" in kwargs_copy:
297+
missing_labels = kwargs_copy.pop("missing_labels")
298+
output = output * (~missing_labels).int() - 10000 * missing_labels.int()
299+
labels = labels * (~missing_labels).int()
300+
if self.model_type == "classification":
301+
assert ((labels <= torch.tensor(1.0)) & (labels >= torch.tensor(0.0))).all()
302+
return output, labels, kwargs_copy
291303

292304
def _get_prediction_and_labels(
293305
self, data: Dict[str, Any], labels: Tensor, model_output: Dict[str, Tensor]
@@ -308,7 +320,25 @@ def _get_prediction_and_labels(
308320
if "non_null_labels" in loss_kwargs:
309321
n = loss_kwargs["non_null_labels"]
310322
d = d[n]
311-
return torch.sigmoid(d), labels.int() if labels is not None else None
323+
if self.model_type == "classification":
324+
# print(self.model_type, ' in electra 324')
325+
# for mulitclass here softmax instead of sigmoid
326+
d = torch.sigmoid(
327+
d
328+
) # changing this made a difference for the roc-auc but not the f1, why?
329+
if "missing_labels" in loss_kwargs:
330+
missing_labels = loss_kwargs["missing_labels"]
331+
d = d * (~missing_labels).int().to(
332+
device=d.device
333+
) # we set the prob of missing labels to 0
334+
labels = labels * (~missing_labels).int().to(
335+
device=d.device
336+
) # we set the labels of missing labels to 0
337+
return d, labels.int() if labels is not None else None
338+
elif self.model_type == "regression":
339+
return d, labels
340+
else:
341+
raise ValueError("Please specify a valid model type in your model config.")
312342

313343
def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
314344
"""

chebai/preprocessing/bin/smiles_token/tokens.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4371,3 +4371,5 @@ b
43714371
[90Sr]
43724372
[32PH2]
43734373
[CaH2]
4374+
[NH3]
4375+
[OH2]

0 commit comments

Comments
 (0)