Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions chebifier/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
ChEBILookupPredictor,
ChemlogPeptidesPredictor,
ElectraPredictor,
ResGatedPredictor,
GNNPredictor,
)
from chebifier.prediction_models.c3p_predictor import C3PPredictor
from chebifier.prediction_models.chemlog_predictor import (
ChemlogAllPredictor,
ChemlogOrganoXCompoundPredictor,
ChemlogXMolecularEntityPredictor,
)
from chebifier.prediction_models.gnn_predictor import GATPredictor

ENSEMBLES = {
"mv": BaseEnsemble,
Expand All @@ -26,8 +25,8 @@

MODEL_TYPES = {
"electra": ElectraPredictor,
"resgated": ResGatedPredictor,
"gat": GATPredictor,
"resgated": GNNPredictor,
"gat": GNNPredictor,
"chemlog": ChemlogAllPredictor,
"chemlog_peptides": ChemlogPeptidesPredictor,
"chebi_lookup": ChEBILookupPredictor,
Expand Down
67 changes: 0 additions & 67 deletions chebifier/model_registry.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,6 @@ gat_chebi50_v244:
ckpt_path: gat_chebi50_v244_0nfi19qt_epoch=198.ckpt
target_labels_path: classes.txt
classwise_weights_path: gat_chebi50_v244_0nfi19qt_epoch=198_trust_3star.json
dataset_cls: chebai_graph.preprocessing.datasets.ChEBI50GraphProperties
molecular_properties:
- chebai_graph.preprocessing.properties.AtomType
- chebai_graph.preprocessing.properties.NumAtomBonds
- chebai_graph.preprocessing.properties.AtomCharge
- chebai_graph.preprocessing.properties.AtomAromaticity
- chebai_graph.preprocessing.properties.AtomHybridization
- chebai_graph.preprocessing.properties.AtomNumHs
- chebai_graph.preprocessing.properties.BondType
- chebai_graph.preprocessing.properties.BondInRing
- chebai_graph.preprocessing.properties.BondAromaticity
- chebai_graph.preprocessing.properties.RDKit2DNormalized
gat-aug_chebi50_v244:
type: gat
hugging_face:
Expand All @@ -34,28 +22,6 @@ gat-aug_chebi50_v244:
ckpt_path: gat-aug_chebi50_v244_8fky8tru_epoch=192.ckpt
target_labels_path: classes.txt
classwise_weights_path: gat-aug_chebi50_v244_8fky8tru_epoch=192_trust_3star.json
dataset_cls: chebai_graph.preprocessing.datasets.ChEBI50_WFGE_WGN_AsPerNodeType
molecular_properties:
- chebai_graph.preprocessing.properties.AtomNodeLevel
# Atom Node type properties
- chebai_graph.preprocessing.properties.AugAtomAromaticity
- chebai_graph.preprocessing.properties.AugAtomCharge
- chebai_graph.preprocessing.properties.AugAtomHybridization
- chebai_graph.preprocessing.properties.AugAtomNumHs
- chebai_graph.preprocessing.properties.AugAtomType
- chebai_graph.preprocessing.properties.AugNumAtomBonds
# FG Node type properties
- chebai_graph.preprocessing.properties.AtomFunctionalGroup
- chebai_graph.preprocessing.properties.IsHydrogenBondDonorFG
- chebai_graph.preprocessing.properties.IsHydrogenBondAcceptorFG
- chebai_graph.preprocessing.properties.IsFGAlkyl
# Graph Node type properties
- chebai_graph.preprocessing.properties.AugRDKit2DNormalized
# Bond properties
- chebai_graph.preprocessing.properties.BondLevel
- chebai_graph.preprocessing.properties.AugBondAromaticity
- chebai_graph.preprocessing.properties.AugBondInRing
- chebai_graph.preprocessing.properties.AugBondType
resgated-aug_chebi50-3star_v244:
type: resgated
hugging_face:
Expand All @@ -64,28 +30,6 @@ resgated-aug_chebi50-3star_v244:
ckpt_path: resgated-aug_chebi50-3star_v244_w0rhmajx_epoch=190.ckpt
target_labels_path: classes.txt
classwise_weights_path: resgated-aug_chebi50-3star_v244_w0rhmajx_epoch=190_trust_3star.json
dataset_cls: chebai_graph.preprocessing.datasets.ChEBI50_WFGE_WGN_AsPerNodeType
molecular_properties:
- chebai_graph.preprocessing.properties.AtomNodeLevel
# Atom Node type properties
- chebai_graph.preprocessing.properties.AugAtomAromaticity
- chebai_graph.preprocessing.properties.AugAtomCharge
- chebai_graph.preprocessing.properties.AugAtomHybridization
- chebai_graph.preprocessing.properties.AugAtomNumHs
- chebai_graph.preprocessing.properties.AugAtomType
- chebai_graph.preprocessing.properties.AugNumAtomBonds
# FG Node type properties
- chebai_graph.preprocessing.properties.AtomFunctionalGroup
- chebai_graph.preprocessing.properties.IsHydrogenBondDonorFG
- chebai_graph.preprocessing.properties.IsHydrogenBondAcceptorFG
- chebai_graph.preprocessing.properties.IsFGAlkyl
# Graph Node type properties
- chebai_graph.preprocessing.properties.AugRDKit2DNormalized
# Bond properties
- chebai_graph.preprocessing.properties.BondLevel
- chebai_graph.preprocessing.properties.AugBondAromaticity
- chebai_graph.preprocessing.properties.AugBondInRing
- chebai_graph.preprocessing.properties.AugBondType
electra_chebi50_v241:
type: electra
hugging_face:
Expand All @@ -102,17 +46,6 @@ resgated_chebi50_v241:
ckpt_path: 0ps1g189_epoch=122.ckpt
target_labels_path: classes.txt
classwise_weights_path: metrics_0ps1g189_80-10-10_short.json
molecular_properties:
- chebai_graph.preprocessing.properties.AtomType
- chebai_graph.preprocessing.properties.NumAtomBonds
- chebai_graph.preprocessing.properties.AtomCharge
- chebai_graph.preprocessing.properties.AtomAromaticity
- chebai_graph.preprocessing.properties.AtomHybridization
- chebai_graph.preprocessing.properties.AtomNumHs
- chebai_graph.preprocessing.properties.BondType
- chebai_graph.preprocessing.properties.BondInRing
- chebai_graph.preprocessing.properties.BondAromaticity
- chebai_graph.preprocessing.properties.RDKit2DNormalized
c3p_with_weights:
type: c3p
hugging_face:
Expand Down
4 changes: 2 additions & 2 deletions chebifier/prediction_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from .chebi_lookup import ChEBILookupPredictor
from .chemlog_predictor import ChemlogExtraPredictor, ChemlogPeptidesPredictor
from .electra_predictor import ElectraPredictor
from .gnn_predictor import ResGatedPredictor
from .gnn_predictor import GNNPredictor

__all__ = [
"BasePredictor",
"ChemlogPeptidesPredictor",
"ElectraPredictor",
"ResGatedPredictor",
"GNNPredictor",
"ChEBILookupPredictor",
"ChemlogExtraPredictor",
"C3PPredictor",
Expand Down
32 changes: 7 additions & 25 deletions chebifier/prediction_models/electra_predictor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
from typing import TYPE_CHECKING

import numpy as np

from .nn_predictor import NNPredictor

if TYPE_CHECKING:
from chebai.models.electra import Electra


def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0):
n_nodes = len(node_labels)
Expand Down Expand Up @@ -40,37 +35,24 @@ def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0):

class ElectraPredictor(NNPredictor):
def __init__(self, model_name: str, ckpt_path: str, **kwargs):
from chebai.preprocessing.reader import ChemDataReader

super().__init__(model_name, ckpt_path, reader_cls=ChemDataReader, **kwargs)
print(f"Initialised Electra model {self.model_name} (device: {self.device})")

def init_model(self, ckpt_path: str, **kwargs) -> "Electra":
from chebai.models.electra import Electra

model = Electra.load_from_checkpoint(
ckpt_path,
map_location=self.device,
criterion=None,
strict=False,
metrics=dict(train=dict(), test=dict(), validation=dict()),
pretrained_checkpoint=None,
super().__init__(model_name, ckpt_path, **kwargs)
print(
f"Initialised Electra model {self.model_name} (device: {self._predictor.device})"
)
model.eval()
return model

def explain_smiles(self, smiles) -> dict:
from chebai.preprocessing.reader import EMBEDDING_OFFSET

reader = self.reader_cls()
token_dict = reader.to_data(dict(features=smiles, labels=None))
token_dict = self._predictor._dm.reader.to_data(
dict(features=smiles, labels=None)
)
tokens = np.array(token_dict["features"]).astype(int).tolist()
result = self.calculate_results([token_dict])

token_labels = (
["[CLR]"]
+ [None for _ in range(EMBEDDING_OFFSET - 1)]
+ list(reader.cache.keys())
+ list(self._predictor._dm.reader.cache.keys())
)

graphs = [
Expand Down
114 changes: 4 additions & 110 deletions chebifier/prediction_models/gnn_predictor.py
Original file line number Diff line number Diff line change
@@ -1,120 +1,14 @@
from typing import TYPE_CHECKING, Optional

import torch

from .nn_predictor import NNPredictor

if TYPE_CHECKING:
from chebai_graph.models.gat import GATGraphPred
from chebai_graph.models.resgated import ResGatedGraphPred


class ResGatedPredictor(NNPredictor):
class GNNPredictor(NNPredictor):
def __init__(
self,
model_name: str,
ckpt_path: str,
molecular_properties,
dataset_cls: Optional[str] = None,
**kwargs,
):
from chebai_graph.preprocessing.datasets.chebi import (
ChEBI50GraphProperties,
GraphPropertiesMixIn,
)
from chebai_graph.preprocessing.properties import MolecularProperty

# molecular_properties is a list of class paths
if molecular_properties is not None:
properties = [self.load_class(prop)() for prop in molecular_properties]
properties = sorted(
properties, key=lambda prop: f"{prop.name}_{prop.encoder.name}"
)
else:
properties = []
for property in properties:
property.encoder.eval = True
self.molecular_properties = properties
assert isinstance(self.molecular_properties, list) and all(
isinstance(prop, MolecularProperty) for prop in self.molecular_properties
)
# TODO it should not be necessary to refer to the whole dataset class, disentangle dataset and molecule reading
self.dataset_cls = (
self.load_class(dataset_cls)
if dataset_cls is not None
else ChEBI50GraphProperties
)
self.dataset: Optional[GraphPropertiesMixIn] = self.dataset_cls(
properties=molecular_properties
)

super().__init__(
model_name, ckpt_path, reader_cls=self.dataset.READER, **kwargs
)

print(f"Initialised GNN model {self.model_name} (device: {self.device})")

def load_class(self, class_path: str):
module_path, class_name = class_path.rsplit(".", 1)
module = __import__(module_path, fromlist=[class_name])
return getattr(module, class_name)

def init_model(self, ckpt_path: str, **kwargs) -> "ResGatedGraphPred":
import torch
from chebai_graph.models.resgated import ResGatedGraphPred

model = ResGatedGraphPred.load_from_checkpoint(
ckpt_path,
map_location=torch.device(self.device),
criterion=None,
strict=False,
metrics=dict(train=dict(), test=dict(), validation=dict()),
pretrained_checkpoint=None,
)
model.eval()
return model

def read_smiles(self, smiles):
from chebai_graph.preprocessing.datasets.chebi import GraphPropAsPerNodeType

d = self.dataset.READER().to_data(dict(features=smiles, labels=None))
property_data = d
# TODO merge props into base should not be a method of a dataset (or at least static)
for property in self.dataset.properties:
property.encoder.eval = True
property_value = self.reader.read_property(smiles, property)
if property_value is None or len(property_value) == 0:
encoded_value = None
else:
encoded_value = torch.stack(
[property.encoder.encode(v) for v in property_value]
)
if len(encoded_value.shape) == 3:
encoded_value = encoded_value.squeeze(0)
property_data[property.name] = encoded_value
# Augmented graphs need an additional argument
if isinstance(self.dataset, GraphPropAsPerNodeType):
d["features"] = self.dataset._merge_props_into_base(
property_data, max_len_node_properties=self.model.gnn.in_channels
)
else:
d["features"] = self.dataset._merge_props_into_base(property_data)
return d


class GATPredictor(ResGatedPredictor):

def init_model(self, ckpt_path: str, **kwargs) -> "GATGraphPred":
import torch
from chebai_graph.models.gat import GATGraphPred

model = GATGraphPred.load_from_checkpoint(
ckpt_path,
map_location=torch.device(self.device),
criterion=None,
strict=False,
metrics=dict(train=dict(), test=dict(), validation=dict()),
pretrained_checkpoint=None,
super().__init__(model_name, ckpt_path, **kwargs)
print(
f"Initialised GNN model {self.model_name} (device: {self._predictor.device})"
)
model.eval()
return model
Loading