1- from typing import TYPE_CHECKING
1+ from typing import TYPE_CHECKING , Optional
2+
3+ import torch
24
35from .nn_predictor import NNPredictor
46
57if TYPE_CHECKING :
6- from chebai_graph .models .graph import ResGatedGraphConvNetGraphPred
8+ from chebai_graph .models .gat import GATGraphPred
9+ from chebai_graph .models .resgated import ResGatedGraphPred
710
811
912class ResGatedPredictor (NNPredictor ):
10- def __init__ (self , model_name : str , ckpt_path : str , molecular_properties , ** kwargs ):
13+ def __init__ (
14+ self ,
15+ model_name : str ,
16+ ckpt_path : str ,
17+ molecular_properties ,
18+ dataset_cls : Optional [str ] = None ,
19+ ** kwargs ,
20+ ):
21+ from chebai_graph .preprocessing .datasets .chebi import (
22+ ChEBI50GraphProperties ,
23+ GraphPropertiesMixIn ,
24+ )
1125 from chebai_graph .preprocessing .properties import MolecularProperty
12- from chebai_graph .preprocessing .reader import GraphPropertyReader
1326
14- super ().__init__ (
15- model_name , ckpt_path , reader_cls = GraphPropertyReader , ** kwargs
16- )
1727 # molecular_properties is a list of class paths
1828 if molecular_properties is not None :
1929 properties = [self .load_class (prop )() for prop in molecular_properties ]
@@ -22,22 +32,38 @@ def __init__(self, model_name: str, ckpt_path: str, molecular_properties, **kwar
2232 )
2333 else :
2434 properties = []
35+ for property in properties :
36+ property .encoder .eval = True
2537 self .molecular_properties = properties
2638 assert isinstance (self .molecular_properties , list ) and all (
2739 isinstance (prop , MolecularProperty ) for prop in self .molecular_properties
2840 )
41+ # TODO it should not be necessary to refer to the whole dataset class, disentangle dataset and molecule reading
42+ self .dataset_cls = (
43+ self .load_class (dataset_cls )
44+ if dataset_cls is not None
45+ else ChEBI50GraphProperties
46+ )
47+ self .dataset : Optional [GraphPropertiesMixIn ] = self .dataset_cls (
48+ properties = molecular_properties
49+ )
50+
51+ super ().__init__ (
52+ model_name , ckpt_path , reader_cls = self .dataset .READER , ** kwargs
53+ )
54+
2955 print (f"Initialised GNN model { self .model_name } (device: { self .device } )" )
3056
3157 def load_class (self , class_path : str ):
3258 module_path , class_name = class_path .rsplit ("." , 1 )
3359 module = __import__ (module_path , fromlist = [class_name ])
3460 return getattr (module , class_name )
3561
36- def init_model (self , ckpt_path : str , ** kwargs ) -> "ResGatedGraphConvNetGraphPred " :
62+ def init_model (self , ckpt_path : str , ** kwargs ) -> "ResGatedGraphPred " :
3763 import torch
38- from chebai_graph .models .graph import ResGatedGraphConvNetGraphPred
64+ from chebai_graph .models .resgated import ResGatedGraphPred
3965
40- model = ResGatedGraphConvNetGraphPred .load_from_checkpoint (
66+ model = ResGatedGraphPred .load_from_checkpoint (
4167 ckpt_path ,
4268 map_location = torch .device (self .device ),
4369 criterion = None ,
@@ -49,67 +75,40 @@ def init_model(self, ckpt_path: str, **kwargs) -> "ResGatedGraphConvNetGraphPred
4975 return model
5076
5177 def read_smiles (self , smiles ):
52- import torch
53- from chebai_graph .preprocessing .properties import AtomProperty , BondProperty
54- from chebai_graph .preprocessing .property_encoder import (
55- IndexEncoder ,
56- OneHotEncoder ,
57- )
58- from torch_geometric .data .data import Data as GeomData
59-
60- reader = self .reader_cls ()
61- d = reader .to_data (dict (features = smiles , labels = None ))
62- geom_data = d ["features" ]
63- edge_attr = geom_data .edge_attr
64- x = geom_data .x
65- molecule_attr = torch .empty ((1 , 0 ))
66- for prop in self .molecular_properties :
67- property_values = reader .read_property (smiles , prop )
68- encoded_values = []
69- for value in property_values :
70- # cant use standard encode for index encoder because model has been trained on a certain range of values
71- # use default value if we meet an unseen value
72- if isinstance (prop .encoder , IndexEncoder ):
73- if str (value ) in prop .encoder .cache :
74- index = prop .encoder .cache [str (value )] + prop .encoder .offset
75- else :
76- index = 0
77- print (
78- f"Unknown property value { value } for property { prop } at smiles { smiles } "
79- )
80- if isinstance (prop .encoder , OneHotEncoder ):
81- encoded_values .append (
82- torch .nn .functional .one_hot (
83- torch .tensor (index ),
84- num_classes = prop .encoder .get_encoding_length (),
85- )
86- )
87- else :
88- encoded_values .append (torch .tensor ([index ]))
89-
90- else :
91- encoded_values .append (prop .encoder .encode (value ))
92- if len (encoded_values ) > 0 :
93- encoded_values = torch .stack (encoded_values )
94-
95- if isinstance (encoded_values , torch .Tensor ):
96- if len (encoded_values .size ()) == 0 :
97- encoded_values = encoded_values .unsqueeze (0 )
98- if len (encoded_values .size ()) == 1 :
99- encoded_values = encoded_values .unsqueeze (1 )
100- else :
101- encoded_values = torch .zeros ((0 , prop .encoder .get_encoding_length ()))
102- if isinstance (prop , AtomProperty ):
103- x = torch .cat ([x , encoded_values ], dim = 1 )
104- elif isinstance (prop , BondProperty ):
105- edge_attr = torch .cat ([edge_attr , encoded_values ], dim = 1 )
78+ d = self .dataset .READER ().to_data (dict (features = smiles , labels = None ))
79+ property_data = d
80+ # TODO merge props into base should not be a method of a dataset (or at least static)
81+ for property in self .dataset .properties :
82+ property .encoder .eval = True
83+ property_value = self .reader .read_property (smiles , property )
84+ if property_value is None or len (property_value ) == 0 :
85+ encoded_value = None
10686 else :
107- molecule_attr = torch .cat ([molecule_attr , encoded_values [0 ]], dim = 1 )
108-
109- d ["features" ] = GeomData (
110- x = x ,
111- edge_index = geom_data .edge_index ,
112- edge_attr = edge_attr ,
113- molecule_attr = molecule_attr ,
87+ encoded_value = torch .stack (
88+ [property .encoder .encode (v ) for v in property_value ]
89+ )
90+ if len (encoded_value .shape ) == 3 :
91+ encoded_value = encoded_value .squeeze (0 )
92+ property_data [property .name ] = encoded_value
93+ d ["features" ] = self .dataset ._merge_props_into_base (
94+ property_data , max_len_node_properties = self .model .gnn .in_channels
11495 )
11596 return d
97+
98+
99+ class GATPredictor (ResGatedPredictor ):
100+
101+ def init_model (self , ckpt_path : str , ** kwargs ) -> "GATGraphPred" :
102+ import torch
103+ from chebai_graph .models .gat import GATGraphPred
104+
105+ model = GATGraphPred .load_from_checkpoint (
106+ ckpt_path ,
107+ map_location = torch .device (self .device ),
108+ criterion = None ,
109+ strict = False ,
110+ metrics = dict (train = dict (), test = dict (), validation = dict ()),
111+ pretrained_checkpoint = None ,
112+ )
113+ model .eval ()
114+ return model
0 commit comments