From 465b651fbfcf346569674002799879436c3a176a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 16 Nov 2025 23:29:17 +0100 Subject: [PATCH 01/18] use instantiate model to load data and model from ckpt --- chebai/trainer/CustomTrainer.py | 69 +++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 20 deletions(-) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index f7fbce26..e84aad4c 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -1,4 +1,5 @@ import logging +import os from typing import Any, List, Optional, Tuple import pandas as pd @@ -6,13 +7,17 @@ from lightning import LightningModule, Trainer from lightning.fabric.utilities.data import _set_sampler_epoch from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.cli import instantiate_module from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.loops.fit_loop import _FitLoop from lightning.pytorch.trainer import call from torch.nn.utils.rnn import pad_sequence +from build.lib.chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.loggers.custom import CustomLogger -from chebai.preprocessing.reader import CLS_TOKEN, ChemDataReader +from chebai.models.base import ChebaiBaseNet +from chebai.preprocessing.datasets.base import XYBaseDataModule +from chebai.preprocessing.reader import CLS_TOKEN log = logging.getLogger(__name__) @@ -44,6 +49,7 @@ def __init__(self, *args, **kwargs): # use custom fit loop (https://lightning.ai/docs/pytorch/LTS/extensions/loops.html#overriding-the-default-loops) self.fit_loop = LoadDataLaterFitLoop(self, self.min_epochs, self.max_epochs) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: """ @@ -76,12 +82,10 @@ def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: def predict_from_file( self, - model: LightningModule, checkpoint_path: _PATH, input_path: _PATH, save_to: _PATH = "predictions.csv", classes_path: Optional[_PATH] = None, - **kwargs, ) -> None: """ Loads a model from a checkpoint and makes predictions on input data from a file. @@ -93,20 +97,21 @@ def predict_from_file( save_to: Path to save the predictions CSV file. classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered). """ - loaded_model = model.__class__.load_from_checkpoint(checkpoint_path) with open(input_path, "r") as input: smiles_strings = [inp.strip() for inp in input.readlines()] - loaded_model.eval() - predictions = self._predict_smiles(loaded_model, smiles_strings) - predictions_df = pd.DataFrame(predictions.detach().cpu().numpy()) - if classes_path is not None: - with open(classes_path, "r") as f: - predictions_df.columns = [cls.strip() for cls in f.readlines()] - predictions_df.index = smiles_strings - predictions_df.to_csv(save_to) + self._predict_smiles( + checkpoint_path, + smiles=smiles_strings, + classes_path=classes_path, + save_to=save_to, + ) def _predict_smiles( - self, model: LightningModule, smiles: List[str] + self, + checkpoint_path: _PATH, + smiles: List[str], + classes_path: Optional[_PATH] = None, + save_to: _PATH = "predictions.csv", ) -> torch.Tensor: """ Predicts the output for a list of SMILES strings using the model. @@ -118,22 +123,47 @@ def _predict_smiles( Returns: A tensor containing the predictions. """ - reader = ChemDataReader() - parsed_smiles = [reader._read_data(s) for s in smiles] + ckpt_file = torch.load( + checkpoint_path, map_location=self.device, weights_only=False + ) + + ckpt_file["datamodule_hyper_parameters"].pop("splits_file_path") + dm: XYBaseDataModule = instantiate_module( + XYBaseDataModule, ckpt_file["datamodule_hyper_parameters"] + ) + + model: ChebaiBaseNet = instantiate_module( + ChebaiBaseNet, ckpt_file["hyper_parameters"] + ) + model.to(self.device) + model.eval() + + parsed_smiles = [dm.reader._read_data(s) for s in smiles] x = pad_sequence( - [torch.tensor(a, device=model.device) for a in parsed_smiles], + [torch.tensor(a, device=self.device) for a in parsed_smiles], batch_first=True, ) cls_tokens = ( - torch.ones(x.shape[0], dtype=torch.int, device=model.device).unsqueeze(-1) + torch.ones(x.shape[0], dtype=torch.int, device=self.device).unsqueeze(-1) * CLS_TOKEN ) features = torch.cat((cls_tokens, x), dim=1) model_output = model({"features": features}) preds = torch.sigmoid(model_output["logits"]) - print(preds.shape) - return preds + predictions_df = pd.DataFrame(preds.detach().cpu().numpy()) + + def _add_class_columns(class_file_path: _PATH): + with open(class_file_path, "r") as f: + predictions_df.columns = [cls.strip() for cls in f.readlines()] + + if classes_path is not None: + _add_class_columns(classes_path) + elif os.path.exists(dm.classes_txt_file_path): + _add_class_columns(dm.classes_txt_file_path) + + predictions_df.index = smiles + predictions_df.to_csv(save_to) @property def log_dir(self) -> Optional[str]: @@ -157,7 +187,6 @@ def log_dir(self) -> Optional[str]: class LoadDataLaterFitLoop(_FitLoop): - def on_advance_start(self) -> None: """Calls the hook ``on_train_epoch_start`` **before** the dataloaders are setup. This is necessary so that the dataloaders can get information from the model. For example: The on_train_epoch_start From 2acd166b87f03df14839fa9c5c381b1669bc6e84 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 16 Nov 2025 23:29:38 +0100 Subject: [PATCH 02/18] update readme --- README.md | 2 +- chebai/preprocessing/datasets/base.py | 38 ++++++++++++++++++--------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index eeecd714..2555c0a6 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ python -m chebai fit --config=[path-to-your-tox21-config] --trainer.callbacks=co ### Predicting classes given SMILES strings ``` -python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] +python3 -m chebai predict_from_file --checkpoint_path=[path-to-model] --input_path=[path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] ``` The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the one row for each SMILES string and one column for each class. diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 68254007..f1357c88 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -96,9 +96,9 @@ def __init__( self.prediction_kind = prediction_kind self.data_limit = data_limit self.label_filter = label_filter - assert (balance_after_filter is not None) or ( - self.label_filter is None - ), "Filter balancing requires a filter" + assert (balance_after_filter is not None) or (self.label_filter is None), ( + "Filter balancing requires a filter" + ) self.balance_after_filter = balance_after_filter self.num_workers = num_workers self.persistent_workers: bool = bool(persistent_workers) @@ -108,13 +108,13 @@ def __init__( self.use_inner_cross_validation = ( inner_k_folds > 1 ) # only use cv if there are at least 2 folds - assert ( - fold_index is None or self.use_inner_cross_validation is not None - ), "fold_index can only be set if cross validation is used" + assert fold_index is None or self.use_inner_cross_validation is not None, ( + "fold_index can only be set if cross validation is used" + ) if fold_index is not None and self.inner_k_folds is not None: - assert ( - fold_index < self.inner_k_folds - ), "fold_index can't be larger than the total number of folds" + assert fold_index < self.inner_k_folds, ( + "fold_index can't be larger than the total number of folds" + ) self.fold_index = fold_index self._base_dir = base_dir self.n_token_limit = n_token_limit @@ -137,9 +137,9 @@ def num_of_labels(self): @property def feature_vector_size(self): - assert ( - self._feature_vector_size is not None - ), "size of feature vector must be set" + assert self._feature_vector_size is not None, ( + "size of feature vector must be set" + ) return self._feature_vector_size @property @@ -1190,7 +1190,8 @@ def _retrieve_splits_from_csv(self) -> None: print(f"Applying label filter from {self.apply_label_filter}...") with open(self.apply_label_filter, "r") as f: label_filter = [line.strip() for line in f] - with open(os.path.join(self.processed_dir_main, "classes.txt"), "r") as cf: + + with open(self.classes_txt_file_path, "r") as cf: classes = [line.strip() for line in cf] # reorder labels old_labels = np.stack(df_data["labels"]) @@ -1315,3 +1316,14 @@ def processed_file_names_dict(self) -> dict: if self.n_token_limit is not None: return {"data": f"data_maxlen{self.n_token_limit}.pt"} return {"data": "data.pt"} + + @property + def classes_txt_file_path(self) -> str: + """ + Returns the filename for the classes text file. + + Returns: + str: The filename for the classes text file. + """ + # This property also used in custom trainer `chebai/trainer/CustomTrainer.py` + return os.path.join(self.processed_dir_main, "classes.txt") From cfbf392f13773c0ce92719b24a83ef231cea2e32 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 17 Nov 2025 14:50:16 +0100 Subject: [PATCH 03/18] set no grad for predict --- chebai/trainer/CustomTrainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index e84aad4c..6f1a542c 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -80,6 +80,7 @@ def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: else: return key, value + @torch.no_grad() def predict_from_file( self, checkpoint_path: _PATH, @@ -106,6 +107,7 @@ def predict_from_file( save_to=save_to, ) + @torch.no_grad() def _predict_smiles( self, checkpoint_path: _PATH, From 82b365ca31698da387f4077c13ea345d9572aa04 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 26 Nov 2025 16:36:15 +0100 Subject: [PATCH 04/18] predict pipeline in dm and lm --- chebai/models/base.py | 8 ++++++- chebai/preprocessing/datasets/base.py | 31 ++++++++++++++++++++++----- chebai/trainer/CustomTrainer.py | 4 +++- 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index 7653f13c..808ea59e 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -232,7 +232,13 @@ def predict_step( Returns: Dict[str, Union[torch.Tensor, Any]]: The result of the prediction step. """ - return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False) + assert isinstance(batch, XYData) + batch = batch.to(self.device) + data = self._process_batch(batch, batch_idx) + labels = data["labels"] + model_output = self(data, **data.get("model_kwargs", dict())) + pr, _ = self._get_prediction_and_labels(data, labels, model_output) + return pr def _execute( self, diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index f1357c88..e2df794d 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -339,8 +339,14 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: for d in tqdm.tqdm(self._load_dict(path), total=lines) if d["features"] is not None ] + + return self._filter_to_token_limit(data) + + def _filter_to_token_limit( + self, data: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: # filter for missing features in resulting data, keep features length below token limit - data = [ + return [ val for val in data if val["features"] is not None @@ -349,8 +355,6 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: ) ] - return data - def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: """ Returns the train DataLoader. @@ -400,10 +404,13 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader] Returns: Union[DataLoader, List[DataLoader]]: A DataLoader object for test data. """ + return self.dataloader("test", shuffle=False, **kwargs) def predict_dataloader( - self, *args, **kwargs + self, + smiles_list: List[str], + **kwargs, ) -> Union[DataLoader, List[DataLoader]]: """ Returns the predict DataLoader. @@ -415,7 +422,21 @@ def predict_dataloader( Returns: Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data. """ - return self.dataloader(self.prediction_kind, shuffle=False, **kwargs) + + data = [ + self.reader.to_data( + {"id": f"smiles_{idx}", "features": smiles, "labels": None} + ) + for idx, smiles in enumerate(smiles_list) + ] + data = self._filter_to_token_limit(data) + + return DataLoader( + data, + collate_fn=self.reader.collator, + batch_size=self.batch_size, + **kwargs, + ) def prepare_data(self, *args, **kwargs) -> None: if self._prepare_data_flag != 1: diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index 6f1a542c..acd468f2 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -4,7 +4,7 @@ import pandas as pd import torch -from lightning import LightningModule, Trainer +from lightning import Trainer from lightning.fabric.utilities.data import _set_sampler_epoch from lightning.fabric.utilities.types import _PATH from lightning.pytorch.cli import instantiate_module @@ -87,6 +87,7 @@ def predict_from_file( input_path: _PATH, save_to: _PATH = "predictions.csv", classes_path: Optional[_PATH] = None, + **kwargs, ) -> None: """ Loads a model from a checkpoint and makes predictions on input data from a file. @@ -114,6 +115,7 @@ def _predict_smiles( smiles: List[str], classes_path: Optional[_PATH] = None, save_to: _PATH = "predictions.csv", + **kwargs, ) -> torch.Tensor: """ Predicts the output for a list of SMILES strings using the model. From fa6f1b521d05d38261973f602ff53232ec5eabb3 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 26 Nov 2025 18:16:36 +0100 Subject: [PATCH 05/18] there is no need that predict func must depend on trainer --- chebai/result/prediction.py | 111 ++++++++++++++++++++++++++++++++ chebai/trainer/CustomTrainer.py | 107 ++++-------------------------- 2 files changed, 122 insertions(+), 96 deletions(-) create mode 100644 chebai/result/prediction.py diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py new file mode 100644 index 00000000..cb3e1415 --- /dev/null +++ b/chebai/result/prediction.py @@ -0,0 +1,111 @@ +import os +from typing import List, Optional + +import pandas as pd +import torch +from jsonargparse import CLI +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.cli import instantiate_module +from torch.utils.data import DataLoader + +from chebai.models.base import ChebaiBaseNet +from chebai.preprocessing.datasets.base import XYBaseDataModule + + +class Predictor: + def __init__(self): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + super().__init__() + + def predict_from_file( + self, + checkpoint_path: _PATH, + input_path: _PATH, + save_to: _PATH = "predictions.csv", + classes_path: Optional[_PATH] = None, + ) -> None: + """ + Loads a model from a checkpoint and makes predictions on input data from a file. + + Args: + model: The model to use for predictions. + checkpoint_path: Path to the model checkpoint. + input_path: Path to the input file containing SMILES strings. + save_to: Path to save the predictions CSV file. + classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered). + """ + with open(input_path, "r") as input: + smiles_strings = [inp.strip() for inp in input.readlines()] + + self.predict_smiles( + checkpoint_path, + smiles=smiles_strings, + classes_path=classes_path, + save_to=save_to, + ) + + @torch.inference_mode() + def predict_smiles( + self, + checkpoint_path: _PATH, + smiles: List[str], + classes_path: Optional[_PATH] = None, + save_to: Optional[_PATH] = None, + **kwargs, + ) -> torch.Tensor | None: + """ + Predicts the output for a list of SMILES strings using the model. + + Args: + model: The model to use for predictions. + smiles: A list of SMILES strings. + + Returns: + A tensor containing the predictions. + """ + ckpt_file = torch.load( + checkpoint_path, map_location=self.device, weights_only=False + ) + + ckpt_file["datamodule_hyper_parameters"].pop("splits_file_path") + dm: XYBaseDataModule = instantiate_module( + XYBaseDataModule, ckpt_file["datamodule_hyper_parameters"] + ) + print(f"Loaded datamodule class: {dm.__class__.__name__}") + + pred_dl: DataLoader = dm.predict_dataloader(smiles_list=smiles) + + model: ChebaiBaseNet = instantiate_module( + ChebaiBaseNet, ckpt_file["hyper_parameters"] + ) + model.eval() + # model = torch.compile(model) + + print(f"Loaded model class: {model.__class__.__name__}") + + preds = [] + for batch_idx, batch in enumerate(pred_dl): + preds.append(model.predict_step(batch, batch_idx)) + + if not save_to: + # If no save path is provided, return the predictions tensor + return torch.cat(preds) + + predictions_df = pd.DataFrame(torch.cat(preds).detach().cpu().numpy()) + + def _add_class_columns(class_file_path: _PATH): + with open(class_file_path, "r") as f: + predictions_df.columns = [cls.strip() for cls in f.readlines()] + + if classes_path is not None: + _add_class_columns(classes_path) + elif os.path.exists(dm.classes_txt_file_path): + _add_class_columns(dm.classes_txt_file_path) + + predictions_df.index = smiles + predictions_df.to_csv(save_to) + + +if __name__ == "__main__": + # python chebai/result/prediction.py predict_from_file --help + CLI(Predictor, as_positional=False) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index acd468f2..11ade921 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -1,23 +1,14 @@ import logging -import os -from typing import Any, List, Optional, Tuple +from typing import Any, Optional, Tuple -import pandas as pd import torch from lightning import Trainer from lightning.fabric.utilities.data import _set_sampler_epoch -from lightning.fabric.utilities.types import _PATH -from lightning.pytorch.cli import instantiate_module from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.loops.fit_loop import _FitLoop from lightning.pytorch.trainer import call -from torch.nn.utils.rnn import pad_sequence -from build.lib.chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.loggers.custom import CustomLogger -from chebai.models.base import ChebaiBaseNet -from chebai.preprocessing.datasets.base import XYBaseDataModule -from chebai.preprocessing.reader import CLS_TOKEN log = logging.getLogger(__name__) @@ -80,94 +71,18 @@ def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: else: return key, value - @torch.no_grad() - def predict_from_file( + def predict( self, - checkpoint_path: _PATH, - input_path: _PATH, - save_to: _PATH = "predictions.csv", - classes_path: Optional[_PATH] = None, - **kwargs, - ) -> None: - """ - Loads a model from a checkpoint and makes predictions on input data from a file. - - Args: - model: The model to use for predictions. - checkpoint_path: Path to the model checkpoint. - input_path: Path to the input file containing SMILES strings. - save_to: Path to save the predictions CSV file. - classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered). - """ - with open(input_path, "r") as input: - smiles_strings = [inp.strip() for inp in input.readlines()] - self._predict_smiles( - checkpoint_path, - smiles=smiles_strings, - classes_path=classes_path, - save_to=save_to, - ) - - @torch.no_grad() - def _predict_smiles( - self, - checkpoint_path: _PATH, - smiles: List[str], - classes_path: Optional[_PATH] = None, - save_to: _PATH = "predictions.csv", - **kwargs, - ) -> torch.Tensor: - """ - Predicts the output for a list of SMILES strings using the model. - - Args: - model: The model to use for predictions. - smiles: A list of SMILES strings. - - Returns: - A tensor containing the predictions. - """ - ckpt_file = torch.load( - checkpoint_path, map_location=self.device, weights_only=False - ) - - ckpt_file["datamodule_hyper_parameters"].pop("splits_file_path") - dm: XYBaseDataModule = instantiate_module( - XYBaseDataModule, ckpt_file["datamodule_hyper_parameters"] - ) - - model: ChebaiBaseNet = instantiate_module( - ChebaiBaseNet, ckpt_file["hyper_parameters"] - ) - model.to(self.device) - model.eval() - - parsed_smiles = [dm.reader._read_data(s) for s in smiles] - x = pad_sequence( - [torch.tensor(a, device=self.device) for a in parsed_smiles], - batch_first=True, - ) - cls_tokens = ( - torch.ones(x.shape[0], dtype=torch.int, device=self.device).unsqueeze(-1) - * CLS_TOKEN + model=None, + dataloaders=None, + datamodule=None, + return_predictions=None, + ckpt_path=None, + ): + raise NotImplementedError( + "CustomTrainer.predict is not implemented." + "Use `Prediction.predict_from_file` or `Prediction.predict_smiles` from `chebai/result/prediction.py` instead." ) - features = torch.cat((cls_tokens, x), dim=1) - model_output = model({"features": features}) - preds = torch.sigmoid(model_output["logits"]) - - predictions_df = pd.DataFrame(preds.detach().cpu().numpy()) - - def _add_class_columns(class_file_path: _PATH): - with open(class_file_path, "r") as f: - predictions_df.columns = [cls.strip() for cls in f.readlines()] - - if classes_path is not None: - _add_class_columns(classes_path) - elif os.path.exists(dm.classes_txt_file_path): - _add_class_columns(dm.classes_txt_file_path) - - predictions_df.index = smiles - predictions_df.to_csv(save_to) @property def log_dir(self) -> Optional[str]: From ae47608eb26918d497ef2cd60e750c4c2d585782 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 13:14:18 +0100 Subject: [PATCH 06/18] model hparams for data predict pipeline and vice versa --- README.md | 2 +- chebai/preprocessing/datasets/base.py | 53 +++++++++++++++++---------- chebai/result/prediction.py | 23 ++++++------ 3 files changed, 47 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 2555c0a6..713a9d42 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ The `classes_path` is the path to the dataset's `raw/classes.txt` file that cont You can evaluate a model trained on the ontology extension task in one of two ways: ### 1. Using the Jupyter Notebook -An example notebook is provided at `tutorials/eval_model_basic.ipynb`. +An example notebook is provided at `tutorials/eval_model_basic.ipynb`. - Load your finetuned model and run the evaluation cells to compute metrics on the test set. ### 2. Using the Lightning CLI diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index e2df794d..5e3064b3 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -96,9 +96,9 @@ def __init__( self.prediction_kind = prediction_kind self.data_limit = data_limit self.label_filter = label_filter - assert (balance_after_filter is not None) or (self.label_filter is None), ( - "Filter balancing requires a filter" - ) + assert (balance_after_filter is not None) or ( + self.label_filter is None + ), "Filter balancing requires a filter" self.balance_after_filter = balance_after_filter self.num_workers = num_workers self.persistent_workers: bool = bool(persistent_workers) @@ -108,13 +108,13 @@ def __init__( self.use_inner_cross_validation = ( inner_k_folds > 1 ) # only use cv if there are at least 2 folds - assert fold_index is None or self.use_inner_cross_validation is not None, ( - "fold_index can only be set if cross validation is used" - ) + assert ( + fold_index is None or self.use_inner_cross_validation is not None + ), "fold_index can only be set if cross validation is used" if fold_index is not None and self.inner_k_folds is not None: - assert fold_index < self.inner_k_folds, ( - "fold_index can't be larger than the total number of folds" - ) + assert ( + fold_index < self.inner_k_folds + ), "fold_index can't be larger than the total number of folds" self.fold_index = fold_index self._base_dir = base_dir self.n_token_limit = n_token_limit @@ -137,9 +137,9 @@ def num_of_labels(self): @property def feature_vector_size(self): - assert self._feature_vector_size is not None, ( - "size of feature vector must be set" - ) + assert ( + self._feature_vector_size is not None + ), "size of feature vector must be set" return self._feature_vector_size @property @@ -410,6 +410,7 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader] def predict_dataloader( self, smiles_list: List[str], + model_hparams: Optional[dict] = None, **kwargs, ) -> Union[DataLoader, List[DataLoader]]: """ @@ -423,6 +424,26 @@ def predict_dataloader( Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data. """ + data = self._process_input_for_prediction(smiles_list, model_hparams) + return DataLoader( + data, + collate_fn=self.reader.collator, + batch_size=self.batch_size, + **kwargs, + ) + + def _process_input_for_prediction( + self, smiles_list: list[str], model_hparams: Optional[dict] = None + ) -> list: + """ + Process input data for prediction. + + Args: + smiles_list (List[str]): List of SMILES strings. + + Returns: + List[Dict[str, Any]]: Processed input data. + """ data = [ self.reader.to_data( {"id": f"smiles_{idx}", "features": smiles, "labels": None} @@ -430,13 +451,7 @@ def predict_dataloader( for idx, smiles in enumerate(smiles_list) ] data = self._filter_to_token_limit(data) - - return DataLoader( - data, - collate_fn=self.reader.collator, - batch_size=self.batch_size, - **kwargs, - ) + return data def prepare_data(self, *args, **kwargs) -> None: if self._prepare_data_flag != 1: diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index cb3e1415..250d68b6 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -67,25 +67,26 @@ def predict_smiles( checkpoint_path, map_location=self.device, weights_only=False ) - ckpt_file["datamodule_hyper_parameters"].pop("splits_file_path") - dm: XYBaseDataModule = instantiate_module( - XYBaseDataModule, ckpt_file["datamodule_hyper_parameters"] - ) + dm_hparams = ckpt_file["datamodule_hyper_parameters"] + dm_hparams.pop("splits_file_path") + dm: XYBaseDataModule = instantiate_module(XYBaseDataModule, dm_hparams) print(f"Loaded datamodule class: {dm.__class__.__name__}") - pred_dl: DataLoader = dm.predict_dataloader(smiles_list=smiles) - - model: ChebaiBaseNet = instantiate_module( - ChebaiBaseNet, ckpt_file["hyper_parameters"] - ) + model_hparams = ckpt_file["hyper_parameters"] + model: ChebaiBaseNet = instantiate_module(ChebaiBaseNet, model_hparams) model.eval() # model = torch.compile(model) - print(f"Loaded model class: {model.__class__.__name__}") + # For certain data prediction piplines, we may need model hyperparameters + pred_dl: DataLoader = dm.predict_dataloader( + smiles_list=smiles, model_hparams=model_hparams + ) + preds = [] for batch_idx, batch in enumerate(pred_dl): - preds.append(model.predict_step(batch, batch_idx)) + # For certain model prediction pipelines, we may need data module hyperparameters + preds.append(model.predict_step(batch, batch_idx, dm_hparams=dm_hparams)) if not save_to: # If no save path is provided, return the predictions tensor From 517a5a2061927f95a6e2ceb4661d1948e3d1defd Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 20:19:55 +0100 Subject: [PATCH 07/18] fix reader ident error --- chebai/preprocessing/reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 22b91a0e..7e41510c 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -51,7 +51,7 @@ def _get_raw_label(self, row: Dict[str, Any]) -> Any: def _get_raw_id(self, row: Dict[str, Any]) -> Any: """Get raw ID from the row.""" - return row.get("ident", row["features"]) + return row.get("ident", row["id"]) def _get_raw_group(self, row: Dict[str, Any]) -> Any: """Get raw group from the row.""" From 40491e55f97d94d75e2f987ffd99f4f68e02f616 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 20:23:37 +0100 Subject: [PATCH 08/18] fix label None error --- chebai/models/base.py | 4 +++- chebai/result/prediction.py | 32 ++++++++++++++++++++++++-------- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index 808ea59e..c6c347a4 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -235,8 +235,10 @@ def predict_step( assert isinstance(batch, XYData) batch = batch.to(self.device) data = self._process_batch(batch, batch_idx) - labels = data["labels"] model_output = self(data, **data.get("model_kwargs", dict())) + + # Dummy labels to avoid errors in _get_prediction_and_labels + labels = torch.zeros((len(batch), self.out_dim)).to(self.device) pr, _ = self._get_prediction_and_labels(data, labels, model_output) return pr diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index 250d68b6..6f8e41c1 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -20,21 +20,25 @@ def __init__(self): def predict_from_file( self, checkpoint_path: _PATH, - input_path: _PATH, + smiles_file_path: _PATH, save_to: _PATH = "predictions.csv", classes_path: Optional[_PATH] = None, + batch_size: Optional[int] = None, ) -> None: """ Loads a model from a checkpoint and makes predictions on input data from a file. Args: - model: The model to use for predictions. checkpoint_path: Path to the model checkpoint. - input_path: Path to the input file containing SMILES strings. + smiles_file_path: Path to the input file containing SMILES strings. save_to: Path to save the predictions CSV file. - classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered). + classes_path: Optional path to a file containing class names: + if no class names are provided, code will try to get the class path + from the datamodule, else the columns will be numbered. + batch_size: Optional batch size for the DataLoader. If not provided, + the default from the datamodule will be used. """ - with open(input_path, "r") as input: + with open(smiles_file_path, "r") as input: smiles_strings = [inp.strip() for inp in input.readlines()] self.predict_smiles( @@ -42,6 +46,7 @@ def predict_from_file( smiles=smiles_strings, classes_path=classes_path, save_to=save_to, + batch_size=batch_size, ) @torch.inference_mode() @@ -51,16 +56,24 @@ def predict_smiles( smiles: List[str], classes_path: Optional[_PATH] = None, save_to: Optional[_PATH] = None, + batch_size: Optional[int] = None, **kwargs, ) -> torch.Tensor | None: """ Predicts the output for a list of SMILES strings using the model. Args: - model: The model to use for predictions. + checkpoint_path: Path to the model checkpoint. smiles: A list of SMILES strings. - - Returns: + classes_path: Optional path to a file containing class names. If no class + names are provided, code will try to get the class path from the datamodule, + else the columns will be numbered. + save_to: Optional path to save the predictions CSV file. If not provided, + predictions will be returned as a tensor. + batch_size: Optional batch size for the DataLoader. If not provided, the default + from the datamodule will be used. + + Returns: (if save_to is None) A tensor containing the predictions. """ ckpt_file = torch.load( @@ -71,10 +84,13 @@ def predict_smiles( dm_hparams.pop("splits_file_path") dm: XYBaseDataModule = instantiate_module(XYBaseDataModule, dm_hparams) print(f"Loaded datamodule class: {dm.__class__.__name__}") + if batch_size is not None: + dm.batch_size = batch_size model_hparams = ckpt_file["hyper_parameters"] model: ChebaiBaseNet = instantiate_module(ChebaiBaseNet, model_hparams) model.eval() + # TODO: Enable torch.compile when supported # model = torch.compile(model) print(f"Loaded model class: {model.__class__.__name__}") From 7b7e48f7bb581b87f89c162b0b3d4931a134872b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 20:50:16 +0100 Subject: [PATCH 09/18] fix cli predict_from_file error --- chebai/cli.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chebai/cli.py b/chebai/cli.py index 96262447..de48f615 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -83,7 +83,6 @@ def subcommands() -> Dict[str, Set[str]]: "validate": {"model", "dataloaders", "datamodule"}, "test": {"model", "dataloaders", "datamodule"}, "predict": {"model", "dataloaders", "datamodule"}, - "predict_from_file": {"model"}, } From 6a383177eff88ba768076bd094c660aeeba27102 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 20:50:32 +0100 Subject: [PATCH 10/18] update readme for new prediction method --- README.md | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 713a9d42..7af8aad4 100644 --- a/README.md +++ b/README.md @@ -70,11 +70,19 @@ python -m chebai fit --config=[path-to-your-tox21-config] --trainer.callbacks=co ### Predicting classes given SMILES strings ``` -python3 -m chebai predict_from_file --checkpoint_path=[path-to-model] --input_path=[path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] +python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-to-model] ----smiles_file_path=[path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] ``` -The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the -one row for each SMILES string and one column for each class. -The `classes_path` is the path to the dataset's `raw/classes.txt` file that contains the relationship between model output and ChEBI-IDs. + +* **`--checkpoint_path`**: Path to the Lightning checkpoint file (must end with `.ckpt`). + +* **`--smiles_file_path`**: Path to a text file containing one SMILES string per line. + +* **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class. + +* **`--classes_path`** *(optional)*: Path to the dataset’s `raw/classes.txt` file, which maps model output indices to ChEBI IDs. + + * If provided, the CSV columns will be named using the ChEBI IDs. + * If omitted, then script will located the file automatically. If unable to locate then the columns will be numbered sequentially. ## Evaluation From 31b12dbf2ce85b685d2ee8bd721bf4f97e995db6 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 23:26:10 +0100 Subject: [PATCH 11/18] Revert "fix reader ident error" This reverts commit 517a5a2061927f95a6e2ceb4661d1948e3d1defd. --- chebai/preprocessing/reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 7e41510c..22b91a0e 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -51,7 +51,7 @@ def _get_raw_label(self, row: Dict[str, Any]) -> Any: def _get_raw_id(self, row: Dict[str, Any]) -> Any: """Get raw ID from the row.""" - return row.get("ident", row["id"]) + return row.get("ident", row["features"]) def _get_raw_group(self, row: Dict[str, Any]) -> Any: """Get raw group from the row.""" From c6e8b6137c6bc0ca0e97481587361a57d3ccbcda Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 28 Nov 2025 15:42:48 +0100 Subject: [PATCH 12/18] modify pred logic to store model and dm as instance var --- chebai/result/prediction.py | 140 ++++++++++++++++++++---------------- 1 file changed, 78 insertions(+), 62 deletions(-) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index 6f8e41c1..b84e59ad 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -13,116 +13,132 @@ class Predictor: - def __init__(self): + def __init__(self, checkpoint_path: _PATH, batch_size: Optional[int] = None): + """Initializes the Predictor with a model loaded from the checkpoint. + + Args: + checkpoint_path: Path to the model checkpoint. + batch_size: Optional batch size for the DataLoader. If not provided, + the default from the datamodule will be used. + """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - super().__init__() + ckpt_file = torch.load( + checkpoint_path, map_location=self.device, weights_only=False + ) + + self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] + self._dm_hparams.pop("splits_file_path") + self._dm: XYBaseDataModule = instantiate_module( + XYBaseDataModule, self._dm_hparams + ) + print(f"Loaded datamodule class: {self._dm.__class__.__name__}") + if batch_size is not None and int(batch_size) > 0: + self._dm.batch_size = int(batch_size) + + self._model_hparams = ckpt_file["hyper_parameters"] + self._model: ChebaiBaseNet = instantiate_module( + ChebaiBaseNet, self._model_hparams + ) + self._model.eval() + # TODO: Enable torch.compile when supported + # model = torch.compile(model) + print(f"Loaded model class: {self._model.__class__.__name__}") def predict_from_file( self, - checkpoint_path: _PATH, smiles_file_path: _PATH, save_to: _PATH = "predictions.csv", classes_path: Optional[_PATH] = None, - batch_size: Optional[int] = None, ) -> None: """ Loads a model from a checkpoint and makes predictions on input data from a file. Args: - checkpoint_path: Path to the model checkpoint. smiles_file_path: Path to the input file containing SMILES strings. save_to: Path to save the predictions CSV file. classes_path: Optional path to a file containing class names: if no class names are provided, code will try to get the class path from the datamodule, else the columns will be numbered. - batch_size: Optional batch size for the DataLoader. If not provided, - the default from the datamodule will be used. """ with open(smiles_file_path, "r") as input: smiles_strings = [inp.strip() for inp in input.readlines()] - self.predict_smiles( - checkpoint_path, + preds: torch.Tensor = self.predict_smiles( smiles=smiles_strings, classes_path=classes_path, save_to=save_to, - batch_size=batch_size, ) + predictions_df = pd.DataFrame(torch.cat(preds).detach().cpu().numpy()) + + def _add_class_columns(class_file_path: _PATH): + with open(class_file_path, "r") as f: + predictions_df.columns = [cls.strip() for cls in f.readlines()] + + if classes_path is not None: + _add_class_columns(classes_path) + elif os.path.exists(self._dm.classes_txt_file_path): + _add_class_columns(self._dm.classes_txt_file_path) + + predictions_df.index = smiles_strings + predictions_df.to_csv(save_to) + @torch.inference_mode() def predict_smiles( self, - checkpoint_path: _PATH, smiles: List[str], - classes_path: Optional[_PATH] = None, - save_to: Optional[_PATH] = None, - batch_size: Optional[int] = None, - **kwargs, - ) -> torch.Tensor | None: + ) -> torch.Tensor: """ Predicts the output for a list of SMILES strings using the model. Args: - checkpoint_path: Path to the model checkpoint. smiles: A list of SMILES strings. - classes_path: Optional path to a file containing class names. If no class - names are provided, code will try to get the class path from the datamodule, - else the columns will be numbered. - save_to: Optional path to save the predictions CSV file. If not provided, - predictions will be returned as a tensor. - batch_size: Optional batch size for the DataLoader. If not provided, the default - from the datamodule will be used. - - Returns: (if save_to is None) + + Returns: A tensor containing the predictions. """ - ckpt_file = torch.load( - checkpoint_path, map_location=self.device, weights_only=False - ) - - dm_hparams = ckpt_file["datamodule_hyper_parameters"] - dm_hparams.pop("splits_file_path") - dm: XYBaseDataModule = instantiate_module(XYBaseDataModule, dm_hparams) - print(f"Loaded datamodule class: {dm.__class__.__name__}") - if batch_size is not None: - dm.batch_size = batch_size - - model_hparams = ckpt_file["hyper_parameters"] - model: ChebaiBaseNet = instantiate_module(ChebaiBaseNet, model_hparams) - model.eval() - # TODO: Enable torch.compile when supported - # model = torch.compile(model) - print(f"Loaded model class: {model.__class__.__name__}") - # For certain data prediction piplines, we may need model hyperparameters - pred_dl: DataLoader = dm.predict_dataloader( - smiles_list=smiles, model_hparams=model_hparams + pred_dl: DataLoader = self._dm.predict_dataloader( + smiles_list=smiles, model_hparams=self._model_hparams ) preds = [] for batch_idx, batch in enumerate(pred_dl): # For certain model prediction pipelines, we may need data module hyperparameters - preds.append(model.predict_step(batch, batch_idx, dm_hparams=dm_hparams)) + preds.append( + self._model.predict_step(batch, batch_idx, dm_hparams=self._dm_hparams) + ) - if not save_to: - # If no save path is provided, return the predictions tensor - return torch.cat(preds) + return torch.cat(preds) - predictions_df = pd.DataFrame(torch.cat(preds).detach().cpu().numpy()) - - def _add_class_columns(class_file_path: _PATH): - with open(class_file_path, "r") as f: - predictions_df.columns = [cls.strip() for cls in f.readlines()] - if classes_path is not None: - _add_class_columns(classes_path) - elif os.path.exists(dm.classes_txt_file_path): - _add_class_columns(dm.classes_txt_file_path) +class MainPredictor: + @staticmethod + def predict_from_file( + checkpoint_path: _PATH, + smiles_file_path: _PATH, + save_to: _PATH = "predictions.csv", + classes_path: Optional[_PATH] = None, + batch_size: Optional[int] = None, + ) -> None: + predictor = Predictor(checkpoint_path, batch_size) + predictor.predict_from_file( + smiles_file_path, + save_to, + classes_path, + ) - predictions_df.index = smiles - predictions_df.to_csv(save_to) + @staticmethod + def predict_smiles( + checkpoint_path: _PATH, + smiles: List[str], + batch_size: Optional[int] = None, + ) -> torch.Tensor: + predictor = Predictor(checkpoint_path, batch_size) + return predictor.predict_smiles(smiles) if __name__ == "__main__": # python chebai/result/prediction.py predict_from_file --help - CLI(Predictor, as_positional=False) + # python chebai/result/prediction.py predict_smiles --help + CLI(MainPredictor, as_positional=False) From 63670ddada766d84c706752319841e2b7d5d4c60 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 6 Dec 2025 19:32:14 +0100 Subject: [PATCH 13/18] fix for unwanted args to predict_smiles --- chebai/result/prediction.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index b84e59ad..c558d9c7 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -27,7 +27,7 @@ def __init__(self, checkpoint_path: _PATH, batch_size: Optional[int] = None): ) self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] - self._dm_hparams.pop("splits_file_path") + # self._dm_hparams.pop("splits_file_path") self._dm: XYBaseDataModule = instantiate_module( XYBaseDataModule, self._dm_hparams ) @@ -63,11 +63,7 @@ def predict_from_file( with open(smiles_file_path, "r") as input: smiles_strings = [inp.strip() for inp in input.readlines()] - preds: torch.Tensor = self.predict_smiles( - smiles=smiles_strings, - classes_path=classes_path, - save_to=save_to, - ) + preds: torch.Tensor = self.predict_smiles(smiles=smiles_strings) predictions_df = pd.DataFrame(torch.cat(preds).detach().cpu().numpy()) From d906ad404ecac72fe1236193de462391f9da607b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 6 Dec 2025 20:22:18 +0100 Subject: [PATCH 14/18] avoid non_null_labels key in loss kwargs --- chebai/preprocessing/datasets/base.py | 5 ++++- chebai/result/prediction.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 5e3064b3..ee60de99 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -444,9 +444,12 @@ def _process_input_for_prediction( Returns: List[Dict[str, Any]]: Processed input data. """ + # Add dummy labels because the collate function requires them. + # Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`, + # which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty. data = [ self.reader.to_data( - {"id": f"smiles_{idx}", "features": smiles, "labels": None} + {"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]} ) for idx, smiles in enumerate(smiles_list) ] diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index c558d9c7..a0a01050 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -65,7 +65,7 @@ def predict_from_file( preds: torch.Tensor = self.predict_smiles(smiles=smiles_strings) - predictions_df = pd.DataFrame(torch.cat(preds).detach().cpu().numpy()) + predictions_df = pd.DataFrame(preds.detach().cpu().numpy()) def _add_class_columns(class_file_path: _PATH): with open(class_file_path, "r") as f: From ba5884a996ed663ff2784d7dd3779c697df1540f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 6 Dec 2025 20:37:58 +0100 Subject: [PATCH 15/18] compile model --- chebai/result/prediction.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index a0a01050..212a95e9 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -13,36 +13,47 @@ class Predictor: - def __init__(self, checkpoint_path: _PATH, batch_size: Optional[int] = None): + def __init__( + self, + checkpoint_path: _PATH, + batch_size: Optional[int] = None, + compile_model: bool = True, + ): """Initializes the Predictor with a model loaded from the checkpoint. Args: checkpoint_path: Path to the model checkpoint. batch_size: Optional batch size for the DataLoader. If not provided, the default from the datamodule will be used. + compile_model: Whether to compile the model using torch.compile. Default is True. """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ckpt_file = torch.load( checkpoint_path, map_location=self.device, weights_only=False ) + print("-" * 50) + print(f"For Loaded checkpoint from: {checkpoint_path}") + print("Below are the modules loaded from the checkpoint:") self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] # self._dm_hparams.pop("splits_file_path") self._dm: XYBaseDataModule = instantiate_module( XYBaseDataModule, self._dm_hparams ) - print(f"Loaded datamodule class: {self._dm.__class__.__name__}") if batch_size is not None and int(batch_size) > 0: self._dm.batch_size = int(batch_size) + print("*" * 10, f"Loaded datamodule class: {self._dm.__class__.__name__}") self._model_hparams = ckpt_file["hyper_parameters"] self._model: ChebaiBaseNet = instantiate_module( ChebaiBaseNet, self._model_hparams ) + print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}") + + if compile_model: + self._model = torch.compile(self._model) self._model.eval() - # TODO: Enable torch.compile when supported - # model = torch.compile(model) - print(f"Loaded model class: {self._model.__class__.__name__}") + print("-" * 50) def predict_from_file( self, From 5a17f722df61ce6e51b24c475a002e4cfbc75ffc Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 6 Dec 2025 20:43:48 +0100 Subject: [PATCH 16/18] revert the comment line for splits file path --- chebai/result/prediction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index 212a95e9..e904c57d 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -36,7 +36,7 @@ def __init__( print("Below are the modules loaded from the checkpoint:") self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] - # self._dm_hparams.pop("splits_file_path") + self._dm_hparams.pop("splits_file_path") self._dm: XYBaseDataModule = instantiate_module( XYBaseDataModule, self._dm_hparams ) From bd59d5924f1de6fae1d4a2d54e2cf6a4bb3a67a1 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 13 Dec 2025 00:16:42 +0100 Subject: [PATCH 17/18] handle augment electra and old ckpt files --- chebai/result/prediction.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index e904c57d..a3b82364 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -37,6 +37,13 @@ def __init__( self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] self._dm_hparams.pop("splits_file_path") + self._dm_hparams.pop("augment_smiles", None) + self._dm_hparams.pop("aug_smiles_variations", None) + assert "_class_path" in self._dm_hparams, ( + "Datamodule hyperparameters must include a '_class_path' key.\n" + "Hence, either the checkpoint is corrupted or " + "it was not saved properly with latest lightning version" + ) self._dm: XYBaseDataModule = instantiate_module( XYBaseDataModule, self._dm_hparams ) From 6b8ae0992e96a4ae77756521b9d5b1c54877d093 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 13 Dec 2025 19:58:48 +0100 Subject: [PATCH 18/18] remove unnec device --- chebai/trainer/CustomTrainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index 11ade921..b9e4b0f3 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -1,7 +1,6 @@ import logging from typing import Any, Optional, Tuple -import torch from lightning import Trainer from lightning.fabric.utilities.data import _set_sampler_epoch from lightning.pytorch.loggers import WandbLogger @@ -40,7 +39,6 @@ def __init__(self, *args, **kwargs): # use custom fit loop (https://lightning.ai/docs/pytorch/LTS/extensions/loops.html#overriding-the-default-loops) self.fit_loop = LoadDataLaterFitLoop(self, self.min_epochs, self.max_epochs) - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: """