Skip to content

Commit 802ab0c

Browse files
committed
add support for augmented gnns
1 parent 7aaf46f commit 802ab0c

File tree

3 files changed

+73
-74
lines changed

3 files changed

+73
-74
lines changed

chebifier/ensemble/base_ensemble.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383

8484
def gather_predictions(self, smiles_list):
8585
# get predictions from all models for the SMILES list
86-
# order them by alphabetically by label class
86+
# order them alphabetically by label class
8787
model_predictions = []
8888
predicted_classes = set()
8989
for model in self.models:
Lines changed: 70 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,29 @@
1-
from typing import TYPE_CHECKING
1+
from typing import TYPE_CHECKING, Optional
2+
3+
import torch
24

35
from .nn_predictor import NNPredictor
46

57
if TYPE_CHECKING:
6-
from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred
8+
from chebai_graph.models.gat import GATGraphPred
9+
from chebai_graph.models.resgated import ResGatedGraphPred
710

811

912
class ResGatedPredictor(NNPredictor):
10-
def __init__(self, model_name: str, ckpt_path: str, molecular_properties, **kwargs):
13+
def __init__(
14+
self,
15+
model_name: str,
16+
ckpt_path: str,
17+
molecular_properties,
18+
dataset_cls: Optional[str] = None,
19+
**kwargs,
20+
):
21+
from chebai_graph.preprocessing.datasets.chebi import (
22+
ChEBI50GraphProperties,
23+
GraphPropertiesMixIn,
24+
)
1125
from chebai_graph.preprocessing.properties import MolecularProperty
12-
from chebai_graph.preprocessing.reader import GraphPropertyReader
1326

14-
super().__init__(
15-
model_name, ckpt_path, reader_cls=GraphPropertyReader, **kwargs
16-
)
1727
# molecular_properties is a list of class paths
1828
if molecular_properties is not None:
1929
properties = [self.load_class(prop)() for prop in molecular_properties]
@@ -22,22 +32,38 @@ def __init__(self, model_name: str, ckpt_path: str, molecular_properties, **kwar
2232
)
2333
else:
2434
properties = []
35+
for property in properties:
36+
property.encoder.eval = True
2537
self.molecular_properties = properties
2638
assert isinstance(self.molecular_properties, list) and all(
2739
isinstance(prop, MolecularProperty) for prop in self.molecular_properties
2840
)
41+
# TODO it should not be necessary to refer to the whole dataset class, disentangle dataset and molecule reading
42+
self.dataset_cls = (
43+
self.load_class(dataset_cls)
44+
if dataset_cls is not None
45+
else ChEBI50GraphProperties
46+
)
47+
self.dataset: Optional[GraphPropertiesMixIn] = self.dataset_cls(
48+
properties=molecular_properties
49+
)
50+
51+
super().__init__(
52+
model_name, ckpt_path, reader_cls=self.dataset.READER, **kwargs
53+
)
54+
2955
print(f"Initialised GNN model {self.model_name} (device: {self.device})")
3056

3157
def load_class(self, class_path: str):
3258
module_path, class_name = class_path.rsplit(".", 1)
3359
module = __import__(module_path, fromlist=[class_name])
3460
return getattr(module, class_name)
3561

36-
def init_model(self, ckpt_path: str, **kwargs) -> "ResGatedGraphConvNetGraphPred":
62+
def init_model(self, ckpt_path: str, **kwargs) -> "ResGatedGraphPred":
3763
import torch
38-
from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred
64+
from chebai_graph.models.resgated import ResGatedGraphPred
3965

40-
model = ResGatedGraphConvNetGraphPred.load_from_checkpoint(
66+
model = ResGatedGraphPred.load_from_checkpoint(
4167
ckpt_path,
4268
map_location=torch.device(self.device),
4369
criterion=None,
@@ -49,67 +75,40 @@ def init_model(self, ckpt_path: str, **kwargs) -> "ResGatedGraphConvNetGraphPred
4975
return model
5076

5177
def read_smiles(self, smiles):
52-
import torch
53-
from chebai_graph.preprocessing.properties import AtomProperty, BondProperty
54-
from chebai_graph.preprocessing.property_encoder import (
55-
IndexEncoder,
56-
OneHotEncoder,
57-
)
58-
from torch_geometric.data.data import Data as GeomData
59-
60-
reader = self.reader_cls()
61-
d = reader.to_data(dict(features=smiles, labels=None))
62-
geom_data = d["features"]
63-
edge_attr = geom_data.edge_attr
64-
x = geom_data.x
65-
molecule_attr = torch.empty((1, 0))
66-
for prop in self.molecular_properties:
67-
property_values = reader.read_property(smiles, prop)
68-
encoded_values = []
69-
for value in property_values:
70-
# cant use standard encode for index encoder because model has been trained on a certain range of values
71-
# use default value if we meet an unseen value
72-
if isinstance(prop.encoder, IndexEncoder):
73-
if str(value) in prop.encoder.cache:
74-
index = prop.encoder.cache[str(value)] + prop.encoder.offset
75-
else:
76-
index = 0
77-
print(
78-
f"Unknown property value {value} for property {prop} at smiles {smiles}"
79-
)
80-
if isinstance(prop.encoder, OneHotEncoder):
81-
encoded_values.append(
82-
torch.nn.functional.one_hot(
83-
torch.tensor(index),
84-
num_classes=prop.encoder.get_encoding_length(),
85-
)
86-
)
87-
else:
88-
encoded_values.append(torch.tensor([index]))
89-
90-
else:
91-
encoded_values.append(prop.encoder.encode(value))
92-
if len(encoded_values) > 0:
93-
encoded_values = torch.stack(encoded_values)
94-
95-
if isinstance(encoded_values, torch.Tensor):
96-
if len(encoded_values.size()) == 0:
97-
encoded_values = encoded_values.unsqueeze(0)
98-
if len(encoded_values.size()) == 1:
99-
encoded_values = encoded_values.unsqueeze(1)
100-
else:
101-
encoded_values = torch.zeros((0, prop.encoder.get_encoding_length()))
102-
if isinstance(prop, AtomProperty):
103-
x = torch.cat([x, encoded_values], dim=1)
104-
elif isinstance(prop, BondProperty):
105-
edge_attr = torch.cat([edge_attr, encoded_values], dim=1)
78+
d = self.dataset.READER().to_data(dict(features=smiles, labels=None))
79+
property_data = d
80+
# TODO merge props into base should not be a method of a dataset (or at least static)
81+
for property in self.dataset.properties:
82+
property.encoder.eval = True
83+
property_value = self.reader.read_property(smiles, property)
84+
if property_value is None or len(property_value) == 0:
85+
encoded_value = None
10686
else:
107-
molecule_attr = torch.cat([molecule_attr, encoded_values[0]], dim=1)
108-
109-
d["features"] = GeomData(
110-
x=x,
111-
edge_index=geom_data.edge_index,
112-
edge_attr=edge_attr,
113-
molecule_attr=molecule_attr,
87+
encoded_value = torch.stack(
88+
[property.encoder.encode(v) for v in property_value]
89+
)
90+
if len(encoded_value.shape) == 3:
91+
encoded_value = encoded_value.squeeze(0)
92+
property_data[property.name] = encoded_value
93+
d["features"] = self.dataset._merge_props_into_base(
94+
property_data, max_len_node_properties=self.model.gnn.in_channels
11495
)
11596
return d
97+
98+
99+
class GATPredictor(ResGatedPredictor):
100+
101+
def init_model(self, ckpt_path: str, **kwargs) -> "GATGraphPred":
102+
import torch
103+
from chebai_graph.models.gat import GATGraphPred
104+
105+
model = GATGraphPred.load_from_checkpoint(
106+
ckpt_path,
107+
map_location=torch.device(self.device),
108+
criterion=None,
109+
strict=False,
110+
metrics=dict(train=dict(), test=dict(), validation=dict()),
111+
pretrained_checkpoint=None,
112+
)
113+
model.eval()
114+
return model

chebifier/prediction_models/nn_predictor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(
2020

2121
super().__init__(model_name, **kwargs)
2222
self.reader_cls = reader_cls
23+
self.reader = reader_cls()
2324

2425
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2526
self.model = self.init_model(ckpt_path=ckpt_path)
@@ -49,8 +50,7 @@ def batchify(self, batch):
4950
yield cache
5051

5152
def read_smiles(self, smiles):
52-
reader = self.reader_cls()
53-
d = reader.to_data(dict(features=smiles, labels=None))
53+
d = self.reader.to_data(dict(features=smiles, labels=None))
5454
return d
5555

5656
@modelwise_smiles_lru_cache.batch_decorator

0 commit comments

Comments
 (0)