Skip to content

Commit 1bd6a0b

Browse files
committed
dynamic imports for rest of predictors
1 parent 68ea577 commit 1bd6a0b

File tree

6 files changed

+62
-36
lines changed

6 files changed

+62
-36
lines changed
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from .base_predictor import BasePredictor
2-
from .chemlog_predictor import ChemlogPeptidesPredictor, ChemlogExtraPredictor
2+
from .c3p_predictor import C3PPredictor
3+
from .chebi_lookup import ChEBILookupPredictor
4+
from .chemlog_predictor import ChemlogExtraPredictor, ChemlogPeptidesPredictor
35
from .electra_predictor import ElectraPredictor
46
from .gnn_predictor import ResGatedPredictor
5-
from .chebi_lookup import ChEBILookupPredictor
67

78
__all__ = [
89
"BasePredictor",
@@ -11,4 +12,5 @@
1112
"ResGatedPredictor",
1213
"ChEBILookupPredictor",
1314
"ChemlogExtraPredictor",
15+
"C3PPredictor",
1416
]

chebifier/prediction_models/c3p_predictor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from functools import lru_cache
2-
from typing import Optional, List
32
from pathlib import Path
4-
5-
from c3p import classifier as c3p_classifier
3+
from typing import List, Optional
64

75
from chebifier.prediction_models import BasePredictor
86

@@ -26,6 +24,8 @@ def __init__(
2624

2725
@lru_cache(maxsize=100)
2826
def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
27+
from c3p import classifier as c3p_classifier
28+
2929
result_list = c3p_classifier.classify(
3030
list(smiles_list),
3131
self.program_directory,
@@ -50,6 +50,8 @@ def explain_smiles(self, smiles):
5050
C3P provides natural language explanations for each prediction (positive or negative). Since there are more
5151
than 300 classes, only take the positive ones.
5252
"""
53+
from c3p import classifier as c3p_classifier
54+
5355
highlights = []
5456
result_list = c3p_classifier.classify(
5557
[smiles], self.program_directory, self.chemical_classes, strict=False

chebifier/prediction_models/chebi_lookup.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1+
import json
2+
import os
13
from functools import lru_cache
24
from typing import Optional
35

4-
from chebifier.prediction_models import BasePredictor
5-
import os
6-
import networkx as nx
76
from rdkit import Chem
8-
import json
97

8+
from chebifier.prediction_models import BasePredictor
109

11-
class ChEBILookupPredictor(BasePredictor):
1210

11+
class ChEBILookupPredictor(BasePredictor):
1312
def __init__(
1413
self,
1514
model_name: str,
@@ -49,6 +48,8 @@ def get_smiles_lookup(self):
4948
return smiles_lookup
5049

5150
def build_smiles_lookup(self):
51+
import networkx as nx
52+
5253
smiles_lookup = dict()
5354
for chebi_id, smiles in nx.get_node_attributes(
5455
self.chebi_graph, "smiles"
@@ -152,7 +153,8 @@ def explain_smiles(self, smiles: str) -> dict:
152153
# Example usage
153154
smiles_list = [
154155
"CCO",
155-
"C1=CC=CC=C1" "*C(=O)OC[C@H](COP(=O)([O-])OCC[N+](C)(C)C)OC(*)=O",
156+
"C1=CC=CC=C1",
157+
"*C(=O)OC[C@H](COP(=O)([O-])OCC[N+](C)(C)C)OC(*)=O",
156158
] # SMILES with 251 matches in ChEBI
157159
predictions = predictor.predict_smiles_list(smiles_list)
158160
print(predictions)

chebifier/prediction_models/chemlog_predictor.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,7 @@
1+
from functools import lru_cache
12
from typing import Optional
23

34
import tqdm
4-
from chemlog.alg_classification.charge_classifier import get_charge_category
5-
from chemlog.alg_classification.peptide_size_classifier import get_n_amino_acid_residues
6-
from chemlog.alg_classification.proteinogenics_classifier import (
7-
get_proteinogenic_amino_acids,
8-
)
9-
from chemlog.alg_classification.substructure_classifier import (
10-
is_diketopiperazine,
11-
is_emericellamide,
12-
)
13-
from chemlog.cli import CLASSIFIERS, _smiles_to_mol, strategy_call
14-
from chemlog_extra.alg_classification.by_element_classification import (
15-
XMolecularEntityClassifier,
16-
OrganoXCompoundClassifier,
17-
)
18-
from functools import lru_cache
195

206
from .base_predictor import BasePredictor
217

@@ -47,15 +33,14 @@
4733

4834

4935
class ChemlogExtraPredictor(BasePredictor):
50-
51-
CHEMLOG_CLASSIFIER = None
52-
5336
def __init__(self, model_name: str, **kwargs):
5437
super().__init__(model_name, **kwargs)
5538
self.chebi_graph = kwargs.get("chebi_graph", None)
56-
self.classifier = self.CHEMLOG_CLASSIFIER()
39+
self.classifier = None
5740

5841
def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
42+
from chemlog.cli import _smiles_to_mol
43+
5944
mol_list = [_smiles_to_mol(smiles) for smiles in smiles_list]
6045
res = self.classifier.classify(mol_list)
6146
if self.chebi_graph is not None:
@@ -72,17 +57,29 @@ def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
7257

7358

7459
class ChemlogXMolecularEntityPredictor(ChemlogExtraPredictor):
60+
def __init__(self, model_name: str, **kwargs):
61+
from chemlog_extra.alg_classification.by_element_classification import (
62+
XMolecularEntityClassifier,
63+
)
7564

76-
CHEMLOG_CLASSIFIER = XMolecularEntityClassifier
65+
super().__init__(model_name, **kwargs)
66+
self.classifier = XMolecularEntityClassifier()
7767

7868

7969
class ChemlogOrganoXCompoundPredictor(ChemlogExtraPredictor):
70+
def __init__(self, model_name: str, **kwargs):
71+
from chemlog_extra.alg_classification.by_element_classification import (
72+
OrganoXCompoundClassifier,
73+
)
8074

81-
CHEMLOG_CLASSIFIER = OrganoXCompoundClassifier
75+
super().__init__(model_name, **kwargs)
76+
self.classifier = OrganoXCompoundClassifier()
8277

8378

8479
class ChemlogPeptidesPredictor(BasePredictor):
8580
def __init__(self, model_name: str, **kwargs):
81+
from chemlog.cli import CLASSIFIERS
82+
8683
super().__init__(model_name, **kwargs)
8784
self.strategy = "algo"
8885
self.chebi_graph = kwargs.get("chebi_graph", None)
@@ -99,6 +96,8 @@ def __init__(self, model_name: str, **kwargs):
9996

10097
@lru_cache(maxsize=100)
10198
def predict_smiles(self, smiles: str) -> Optional[dict]:
99+
from chemlog.cli import _smiles_to_mol, strategy_call
100+
102101
mol = _smiles_to_mol(smiles)
103102
if mol is None:
104103
return None
@@ -134,6 +133,19 @@ def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
134133

135134
def get_chemlog_result_info(self, smiles):
136135
"""Get classification for single molecule with additional information."""
136+
from chemlog.alg_classification.charge_classifier import get_charge_category
137+
from chemlog.alg_classification.peptide_size_classifier import (
138+
get_n_amino_acid_residues,
139+
)
140+
from chemlog.alg_classification.proteinogenics_classifier import (
141+
get_proteinogenic_amino_acids,
142+
)
143+
from chemlog.alg_classification.substructure_classifier import (
144+
is_diketopiperazine,
145+
is_emericellamide,
146+
)
147+
from chemlog.cli import _smiles_to_mol
148+
137149
mol = _smiles_to_mol(smiles)
138150
if mol is None or not smiles:
139151
return {"error": "Failed to parse SMILES"}

chebifier/prediction_models/electra_predictor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
from typing import TYPE_CHECKING
2+
13
import numpy as np
24

35
from .nn_predictor import NNPredictor
46

7+
if TYPE_CHECKING:
8+
from chebai.models.electra import Electra
9+
510

611
def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0):
712
n_nodes = len(node_labels)
@@ -40,7 +45,7 @@ def __init__(self, model_name: str, ckpt_path: str, **kwargs):
4045
super().__init__(model_name, ckpt_path, reader_cls=ChemDataReader, **kwargs)
4146
print(f"Initialised Electra model {self.model_name} (device: {self.device})")
4247

43-
def init_model(self, ckpt_path: str, **kwargs) -> "Electra": # noqa: F821
48+
def init_model(self, ckpt_path: str, **kwargs) -> Electra:
4449
from chebai.models.electra import Electra
4550

4651
model = Electra.load_from_checkpoint(

chebifier/prediction_models/gnn_predictor.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
from typing import TYPE_CHECKING
2+
13
from .nn_predictor import NNPredictor
24

5+
if TYPE_CHECKING:
6+
from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred
7+
38

49
class ResGatedPredictor(NNPredictor):
510
def __init__(self, model_name: str, ckpt_path: str, molecular_properties, **kwargs):
@@ -28,9 +33,7 @@ def load_class(self, class_path: str):
2833
module = __import__(module_path, fromlist=[class_name])
2934
return getattr(module, class_name)
3035

31-
def init_model(
32-
self, ckpt_path: str, **kwargs
33-
) -> "ResGatedGraphConvNetGraphPred": # noqa: F821
36+
def init_model(self, ckpt_path: str, **kwargs) -> ResGatedGraphConvNetGraphPred:
3437
import torch
3538
from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred
3639

0 commit comments

Comments
 (0)