Skip to content

Commit 0584345

Browse files
authored
Merge branch 'dev' into feature/testing_framework
2 parents cd03023 + 7e5f801 commit 0584345

File tree

4 files changed

+16
-12
lines changed

4 files changed

+16
-12
lines changed

chebai/loss/semantic.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import math
33
import os
44
import pickle
5-
from typing import Literal
65

76
import torch
87

8+
from typing import Literal, Union
9+
910
from chebai.loss.bce_weighted import BCEWeighted
1011
from chebai.preprocessing.datasets.chebi import ChEBIOver100, _ChEBIDataExtractor
1112
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
@@ -14,7 +15,7 @@
1415
class ImplicationLoss(torch.nn.Module):
1516
def __init__(
1617
self,
17-
data_extractor: _ChEBIDataExtractor | LabeledUnlabeledMixed,
18+
data_extractor: Union[_ChEBIDataExtractor, LabeledUnlabeledMixed],
1819
base_loss: torch.nn.Module = None,
1920
tnorm: Literal["product", "lukasiewicz", "xu19"] = "product",
2021
impl_loss_weight=0.1, # weight of implication loss in relation to base_loss
@@ -114,7 +115,7 @@ class DisjointLoss(ImplicationLoss):
114115
def __init__(
115116
self,
116117
path_to_disjointness,
117-
data_extractor: _ChEBIDataExtractor | LabeledUnlabeledMixed,
118+
data_extractor: Union[_ChEBIDataExtractor, LabeledUnlabeledMixed],
118119
base_loss: torch.nn.Module = None,
119120
disjoint_loss_weight=100,
120121
**kwargs,

chebai/result/analyse_sem.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
import pandas as pd
88
import torch
99
import wandb
10+
1011
from torchmetrics.functional.classification import multilabel_auroc, multilabel_f1_score
12+
import gc
13+
from typing import List, Union
14+
1115
from utils import *
1216

1317
from chebai.loss.semantic import DisjointLoss
@@ -245,7 +249,7 @@ def analyse_run(
245249
labeled_data_cls=ChEBIOver100, # use labels from this dataset for violations
246250
chebi_version=231,
247251
results_path=os.path.join("_semantic", "eval_results.csv"),
248-
violation_metrics: [str | list[callable]] = "all",
252+
violation_metrics: Union[str, List[callable]] = "all",
249253
verbose_violation_output=False,
250254
):
251255
"""Calculates all semantic metrics for given predictions (and supervised metrics if labels are provided),

chebai/trainer/CustomTrainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def predict_from_file(
6464
smiles_strings = [inp.strip() for inp in input.readlines()]
6565
loaded_model.eval()
6666
predictions = self._predict_smiles(loaded_model, smiles_strings)
67-
predictions_df = pd.DataFrame(predictions.detach().numpy())
67+
predictions_df = pd.DataFrame(predictions.detach().cpu().numpy())
6868
if classes_path is not None:
6969
with open(classes_path, "r") as f:
7070
predictions_df.columns = [cls.strip() for cls in f.readlines()]
@@ -74,7 +74,10 @@ def predict_from_file(
7474
def _predict_smiles(self, model: LightningModule, smiles: List[str]):
7575
reader = ChemDataReader()
7676
parsed_smiles = [reader._read_data(s) for s in smiles]
77-
x = pad_sequence([torch.tensor(a) for a in parsed_smiles], batch_first=True)
77+
x = pad_sequence(
78+
[torch.tensor(a, device=model.device) for a in parsed_smiles],
79+
batch_first=True,
80+
)
7881
cls_tokens = (
7982
torch.ones(x.shape[0], dtype=torch.int, device=model.device).unsqueeze(-1)
8083
* CLS_TOKEN

setup.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414
author_email="martin.glauer@ovgu.de",
1515
description="",
1616
zip_safe=False,
17-
# `|` operator used for type hint in chebai/loss/semantic.py is only supported for python_version >= 3.10.
18-
# https://stackoverflow.com/questions/76712720/typeerror-unsupported-operand-types-for-type-and-nonetype
19-
# python_requires="<=3.11.8",
20-
python_requires=">=3.10.0, <3.11.8",
17+
python_requires="<=3.11.8",
2118
install_requires=[
2219
"certifi",
2320
"idna",
@@ -51,8 +48,7 @@
5148
"iterative-stratification",
5249
"wandb",
5350
"chardet",
54-
# --- commented below due to strange dependency error while setting up new env
55-
# "yaml",`
51+
"pyyaml",
5652
"torchmetrics",
5753
],
5854
extras_require={"dev": ["black", "isort", "pre-commit"]},

0 commit comments

Comments
 (0)