Skip to content

Commit b3a758c

Browse files
committed
fix pred error
1 parent 1ef0848 commit b3a758c

File tree

3 files changed

+37
-4
lines changed

3 files changed

+37
-4
lines changed

chebifier/prediction_models/electra_predictor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0):
3636
class ElectraPredictor(NNPredictor):
3737
def __init__(self, model_name: str, ckpt_path: str, **kwargs):
3838
super().__init__(model_name, ckpt_path, **kwargs)
39-
print(f"Initialised Electra model {self.model_name}")
39+
print(
40+
f"Initialised Electra model {self.model_name} (device: {self._predictor.device})"
41+
)
4042

4143
def explain_smiles(self, smiles) -> dict:
4244
from chebai.preprocessing.reader import EMBEDDING_OFFSET

chebifier/prediction_models/gnn_predictor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,6 @@ def __init__(
99
**kwargs,
1010
):
1111
super().__init__(model_name, ckpt_path, **kwargs)
12-
print(f"Initialised GNN model {self.model_name}")
12+
print(
13+
f"Initialised GNN model {self.model_name} (device: {self._predictor.device})"
14+
)
Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,58 @@
11
from abc import ABC
2+
from typing import TYPE_CHECKING
23

34
from chebai.result.prediction import Predictor
45

56
from chebifier import modelwise_smiles_lru_cache
67

78
from .base_predictor import BasePredictor
89

10+
if TYPE_CHECKING:
11+
from torch import Tensor
12+
913

1014
class NNPredictor(BasePredictor, ABC):
1115
def __init__(
1216
self,
1317
model_name: str,
1418
ckpt_path: str,
19+
target_labels_path: str,
1520
**kwargs,
1621
):
22+
super().__init__(model_name, **kwargs)
1723
self.batch_size = kwargs.get("batch_size", None)
1824
# If batch_size is not provided, it will be set to default batch size used during training in Predictor
1925
self._predictor: Predictor = Predictor(ckpt_path, self.batch_size)
26+
self.target_labels = [
27+
line.strip() for line in open(target_labels_path, encoding="utf-8")
28+
]
2029

21-
super().__init__(model_name, **kwargs)
30+
# Sanity check - ensure that the number of classes predicted by the model matches the number of target labels
31+
# TODO: In future, we can include the target labels in the model metadata and avoid this.
32+
raw_preds = self._predictor.predict_smiles(["CO"])
33+
assert len(raw_preds[0]) == len(
34+
self.target_labels
35+
), "Number of predicted classes does not match number of target labels."
2236

2337
@modelwise_smiles_lru_cache.batch_decorator
2438
def predict_smiles_list(self, smiles_list: list[str]) -> list:
2539
"""
2640
Returns a list with the length of smiles_list, each element is
2741
either None (=failure) or a dictionary of classes and predicted values.
2842
"""
29-
return self._predictor.predict_smiles(smiles_list)
43+
raw_preds: Tensor = self._predictor.predict_smiles(smiles_list)
44+
if raw_preds is not None:
45+
preds = [
46+
(
47+
{
48+
label: pred
49+
for label, pred in zip(
50+
self.target_labels, raw_preds[i].tolist()
51+
)
52+
}
53+
)
54+
for i in range(len(smiles_list))
55+
]
56+
return preds
57+
else:
58+
return [None for _ in smiles_list]

0 commit comments

Comments
 (0)