|
1 | 1 | from abc import ABC |
| 2 | +from typing import TYPE_CHECKING |
2 | 3 |
|
3 | 4 | from chebai.result.prediction import Predictor |
4 | 5 |
|
5 | 6 | from chebifier import modelwise_smiles_lru_cache |
6 | 7 |
|
7 | 8 | from .base_predictor import BasePredictor |
8 | 9 |
|
| 10 | +if TYPE_CHECKING: |
| 11 | + from torch import Tensor |
| 12 | + |
9 | 13 |
|
10 | 14 | class NNPredictor(BasePredictor, ABC): |
11 | 15 | def __init__( |
12 | 16 | self, |
13 | 17 | model_name: str, |
14 | 18 | ckpt_path: str, |
| 19 | + target_labels_path: str, |
15 | 20 | **kwargs, |
16 | 21 | ): |
| 22 | + super().__init__(model_name, **kwargs) |
17 | 23 | self.batch_size = kwargs.get("batch_size", None) |
18 | 24 | # If batch_size is not provided, it will be set to default batch size used during training in Predictor |
19 | 25 | self._predictor: Predictor = Predictor(ckpt_path, self.batch_size) |
| 26 | + self.target_labels = [ |
| 27 | + line.strip() for line in open(target_labels_path, encoding="utf-8") |
| 28 | + ] |
20 | 29 |
|
21 | | - super().__init__(model_name, **kwargs) |
| 30 | + # Sanity check - ensure that the number of classes predicted by the model matches the number of target labels |
| 31 | + # TODO: In future, we can include the target labels in the model metadata and avoid this. |
| 32 | + raw_preds = self._predictor.predict_smiles(["CO"]) |
| 33 | + assert len(raw_preds[0]) == len( |
| 34 | + self.target_labels |
| 35 | + ), "Number of predicted classes does not match number of target labels." |
22 | 36 |
|
23 | 37 | @modelwise_smiles_lru_cache.batch_decorator |
24 | 38 | def predict_smiles_list(self, smiles_list: list[str]) -> list: |
25 | 39 | """ |
26 | 40 | Returns a list with the length of smiles_list, each element is |
27 | 41 | either None (=failure) or a dictionary of classes and predicted values. |
28 | 42 | """ |
29 | | - return self._predictor.predict_smiles(smiles_list) |
| 43 | + raw_preds: Tensor = self._predictor.predict_smiles(smiles_list) |
| 44 | + if raw_preds is not None: |
| 45 | + preds = [ |
| 46 | + ( |
| 47 | + { |
| 48 | + label: pred |
| 49 | + for label, pred in zip( |
| 50 | + self.target_labels, raw_preds[i].tolist() |
| 51 | + ) |
| 52 | + } |
| 53 | + ) |
| 54 | + for i in range(len(smiles_list)) |
| 55 | + ] |
| 56 | + return preds |
| 57 | + else: |
| 58 | + return [None for _ in smiles_list] |
0 commit comments