File tree Expand file tree Collapse file tree 4 files changed +16
-12
lines changed
Expand file tree Collapse file tree 4 files changed +16
-12
lines changed Original file line number Diff line number Diff line change 22import math
33import os
44import pickle
5- from typing import Literal
65
76import torch
87
8+ from typing import Literal , Union
9+
910from chebai .loss .bce_weighted import BCEWeighted
1011from chebai .preprocessing .datasets .chebi import ChEBIOver100 , _ChEBIDataExtractor
1112from chebai .preprocessing .datasets .pubchem import LabeledUnlabeledMixed
1415class ImplicationLoss (torch .nn .Module ):
1516 def __init__ (
1617 self ,
17- data_extractor : _ChEBIDataExtractor | LabeledUnlabeledMixed ,
18+ data_extractor : Union [ _ChEBIDataExtractor , LabeledUnlabeledMixed ] ,
1819 base_loss : torch .nn .Module = None ,
1920 tnorm : Literal ["product" , "lukasiewicz" , "xu19" ] = "product" ,
2021 impl_loss_weight = 0.1 , # weight of implication loss in relation to base_loss
@@ -114,7 +115,7 @@ class DisjointLoss(ImplicationLoss):
114115 def __init__ (
115116 self ,
116117 path_to_disjointness ,
117- data_extractor : _ChEBIDataExtractor | LabeledUnlabeledMixed ,
118+ data_extractor : Union [ _ChEBIDataExtractor , LabeledUnlabeledMixed ] ,
118119 base_loss : torch .nn .Module = None ,
119120 disjoint_loss_weight = 100 ,
120121 ** kwargs ,
Original file line number Diff line number Diff line change 77import pandas as pd
88import torch
99import wandb
10+
1011from torchmetrics .functional .classification import multilabel_auroc , multilabel_f1_score
12+ import gc
13+ from typing import List , Union
14+
1115from utils import *
1216
1317from chebai .loss .semantic import DisjointLoss
@@ -245,7 +249,7 @@ def analyse_run(
245249 labeled_data_cls = ChEBIOver100 , # use labels from this dataset for violations
246250 chebi_version = 231 ,
247251 results_path = os .path .join ("_semantic" , "eval_results.csv" ),
248- violation_metrics : [str | list [callable ]] = "all" ,
252+ violation_metrics : Union [str , List [callable ]] = "all" ,
249253 verbose_violation_output = False ,
250254):
251255 """Calculates all semantic metrics for given predictions (and supervised metrics if labels are provided),
Original file line number Diff line number Diff line change @@ -64,7 +64,7 @@ def predict_from_file(
6464 smiles_strings = [inp .strip () for inp in input .readlines ()]
6565 loaded_model .eval ()
6666 predictions = self ._predict_smiles (loaded_model , smiles_strings )
67- predictions_df = pd .DataFrame (predictions .detach ().numpy ())
67+ predictions_df = pd .DataFrame (predictions .detach ().cpu (). numpy ())
6868 if classes_path is not None :
6969 with open (classes_path , "r" ) as f :
7070 predictions_df .columns = [cls .strip () for cls in f .readlines ()]
@@ -74,7 +74,10 @@ def predict_from_file(
7474 def _predict_smiles (self , model : LightningModule , smiles : List [str ]):
7575 reader = ChemDataReader ()
7676 parsed_smiles = [reader ._read_data (s ) for s in smiles ]
77- x = pad_sequence ([torch .tensor (a ) for a in parsed_smiles ], batch_first = True )
77+ x = pad_sequence (
78+ [torch .tensor (a , device = model .device ) for a in parsed_smiles ],
79+ batch_first = True ,
80+ )
7881 cls_tokens = (
7982 torch .ones (x .shape [0 ], dtype = torch .int , device = model .device ).unsqueeze (- 1 )
8083 * CLS_TOKEN
Original file line number Diff line number Diff line change 1414 author_email = "martin.glauer@ovgu.de" ,
1515 description = "" ,
1616 zip_safe = False ,
17- # `|` operator used for type hint in chebai/loss/semantic.py is only supported for python_version >= 3.10.
18- # https://stackoverflow.com/questions/76712720/typeerror-unsupported-operand-types-for-type-and-nonetype
19- # python_requires="<=3.11.8",
20- python_requires = ">=3.10.0, <3.11.8" ,
17+ python_requires = "<=3.11.8" ,
2118 install_requires = [
2219 "certifi" ,
2320 "idna" ,
5148 "iterative-stratification" ,
5249 "wandb" ,
5350 "chardet" ,
54- # --- commented below due to strange dependency error while setting up new env
55- # "yaml",`
51+ "pyyaml" ,
5652 "torchmetrics" ,
5753 ],
5854 extras_require = {"dev" : ["black" , "isort" , "pre-commit" ]},
You can’t perform that action at this time.
0 commit comments