diff --git a/README.md b/README.md index eeecd714..7af8aad4 100644 --- a/README.md +++ b/README.md @@ -70,18 +70,26 @@ python -m chebai fit --config=[path-to-your-tox21-config] --trainer.callbacks=co ### Predicting classes given SMILES strings ``` -python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] +python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-to-model] ----smiles_file_path=[path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] ``` -The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the -one row for each SMILES string and one column for each class. -The `classes_path` is the path to the dataset's `raw/classes.txt` file that contains the relationship between model output and ChEBI-IDs. + +* **`--checkpoint_path`**: Path to the Lightning checkpoint file (must end with `.ckpt`). + +* **`--smiles_file_path`**: Path to a text file containing one SMILES string per line. + +* **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class. + +* **`--classes_path`** *(optional)*: Path to the dataset’s `raw/classes.txt` file, which maps model output indices to ChEBI IDs. + + * If provided, the CSV columns will be named using the ChEBI IDs. + * If omitted, then script will located the file automatically. If unable to locate then the columns will be numbered sequentially. ## Evaluation You can evaluate a model trained on the ontology extension task in one of two ways: ### 1. Using the Jupyter Notebook -An example notebook is provided at `tutorials/eval_model_basic.ipynb`. +An example notebook is provided at `tutorials/eval_model_basic.ipynb`. - Load your finetuned model and run the evaluation cells to compute metrics on the test set. ### 2. Using the Lightning CLI diff --git a/chebai/cli.py b/chebai/cli.py index 96262447..de48f615 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -83,7 +83,6 @@ def subcommands() -> Dict[str, Set[str]]: "validate": {"model", "dataloaders", "datamodule"}, "test": {"model", "dataloaders", "datamodule"}, "predict": {"model", "dataloaders", "datamodule"}, - "predict_from_file": {"model"}, } diff --git a/chebai/models/base.py b/chebai/models/base.py index 7653f13c..c6c347a4 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -232,7 +232,15 @@ def predict_step( Returns: Dict[str, Union[torch.Tensor, Any]]: The result of the prediction step. """ - return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False) + assert isinstance(batch, XYData) + batch = batch.to(self.device) + data = self._process_batch(batch, batch_idx) + model_output = self(data, **data.get("model_kwargs", dict())) + + # Dummy labels to avoid errors in _get_prediction_and_labels + labels = torch.zeros((len(batch), self.out_dim)).to(self.device) + pr, _ = self._get_prediction_and_labels(data, labels, model_output) + return pr def _execute( self, diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 68254007..ee60de99 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -339,8 +339,14 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: for d in tqdm.tqdm(self._load_dict(path), total=lines) if d["features"] is not None ] + + return self._filter_to_token_limit(data) + + def _filter_to_token_limit( + self, data: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: # filter for missing features in resulting data, keep features length below token limit - data = [ + return [ val for val in data if val["features"] is not None @@ -349,8 +355,6 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: ) ] - return data - def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: """ Returns the train DataLoader. @@ -400,10 +404,14 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader] Returns: Union[DataLoader, List[DataLoader]]: A DataLoader object for test data. """ + return self.dataloader("test", shuffle=False, **kwargs) def predict_dataloader( - self, *args, **kwargs + self, + smiles_list: List[str], + model_hparams: Optional[dict] = None, + **kwargs, ) -> Union[DataLoader, List[DataLoader]]: """ Returns the predict DataLoader. @@ -415,7 +423,38 @@ def predict_dataloader( Returns: Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data. """ - return self.dataloader(self.prediction_kind, shuffle=False, **kwargs) + + data = self._process_input_for_prediction(smiles_list, model_hparams) + return DataLoader( + data, + collate_fn=self.reader.collator, + batch_size=self.batch_size, + **kwargs, + ) + + def _process_input_for_prediction( + self, smiles_list: list[str], model_hparams: Optional[dict] = None + ) -> list: + """ + Process input data for prediction. + + Args: + smiles_list (List[str]): List of SMILES strings. + + Returns: + List[Dict[str, Any]]: Processed input data. + """ + # Add dummy labels because the collate function requires them. + # Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`, + # which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty. + data = [ + self.reader.to_data( + {"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]} + ) + for idx, smiles in enumerate(smiles_list) + ] + data = self._filter_to_token_limit(data) + return data def prepare_data(self, *args, **kwargs) -> None: if self._prepare_data_flag != 1: @@ -1190,7 +1229,8 @@ def _retrieve_splits_from_csv(self) -> None: print(f"Applying label filter from {self.apply_label_filter}...") with open(self.apply_label_filter, "r") as f: label_filter = [line.strip() for line in f] - with open(os.path.join(self.processed_dir_main, "classes.txt"), "r") as cf: + + with open(self.classes_txt_file_path, "r") as cf: classes = [line.strip() for line in cf] # reorder labels old_labels = np.stack(df_data["labels"]) @@ -1315,3 +1355,14 @@ def processed_file_names_dict(self) -> dict: if self.n_token_limit is not None: return {"data": f"data_maxlen{self.n_token_limit}.pt"} return {"data": "data.pt"} + + @property + def classes_txt_file_path(self) -> str: + """ + Returns the filename for the classes text file. + + Returns: + str: The filename for the classes text file. + """ + # This property also used in custom trainer `chebai/trainer/CustomTrainer.py` + return os.path.join(self.processed_dir_main, "classes.txt") diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py new file mode 100644 index 00000000..a3b82364 --- /dev/null +++ b/chebai/result/prediction.py @@ -0,0 +1,158 @@ +import os +from typing import List, Optional + +import pandas as pd +import torch +from jsonargparse import CLI +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.cli import instantiate_module +from torch.utils.data import DataLoader + +from chebai.models.base import ChebaiBaseNet +from chebai.preprocessing.datasets.base import XYBaseDataModule + + +class Predictor: + def __init__( + self, + checkpoint_path: _PATH, + batch_size: Optional[int] = None, + compile_model: bool = True, + ): + """Initializes the Predictor with a model loaded from the checkpoint. + + Args: + checkpoint_path: Path to the model checkpoint. + batch_size: Optional batch size for the DataLoader. If not provided, + the default from the datamodule will be used. + compile_model: Whether to compile the model using torch.compile. Default is True. + """ + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ckpt_file = torch.load( + checkpoint_path, map_location=self.device, weights_only=False + ) + print("-" * 50) + print(f"For Loaded checkpoint from: {checkpoint_path}") + print("Below are the modules loaded from the checkpoint:") + + self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] + self._dm_hparams.pop("splits_file_path") + self._dm_hparams.pop("augment_smiles", None) + self._dm_hparams.pop("aug_smiles_variations", None) + assert "_class_path" in self._dm_hparams, ( + "Datamodule hyperparameters must include a '_class_path' key.\n" + "Hence, either the checkpoint is corrupted or " + "it was not saved properly with latest lightning version" + ) + self._dm: XYBaseDataModule = instantiate_module( + XYBaseDataModule, self._dm_hparams + ) + if batch_size is not None and int(batch_size) > 0: + self._dm.batch_size = int(batch_size) + print("*" * 10, f"Loaded datamodule class: {self._dm.__class__.__name__}") + + self._model_hparams = ckpt_file["hyper_parameters"] + self._model: ChebaiBaseNet = instantiate_module( + ChebaiBaseNet, self._model_hparams + ) + print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}") + + if compile_model: + self._model = torch.compile(self._model) + self._model.eval() + print("-" * 50) + + def predict_from_file( + self, + smiles_file_path: _PATH, + save_to: _PATH = "predictions.csv", + classes_path: Optional[_PATH] = None, + ) -> None: + """ + Loads a model from a checkpoint and makes predictions on input data from a file. + + Args: + smiles_file_path: Path to the input file containing SMILES strings. + save_to: Path to save the predictions CSV file. + classes_path: Optional path to a file containing class names: + if no class names are provided, code will try to get the class path + from the datamodule, else the columns will be numbered. + """ + with open(smiles_file_path, "r") as input: + smiles_strings = [inp.strip() for inp in input.readlines()] + + preds: torch.Tensor = self.predict_smiles(smiles=smiles_strings) + + predictions_df = pd.DataFrame(preds.detach().cpu().numpy()) + + def _add_class_columns(class_file_path: _PATH): + with open(class_file_path, "r") as f: + predictions_df.columns = [cls.strip() for cls in f.readlines()] + + if classes_path is not None: + _add_class_columns(classes_path) + elif os.path.exists(self._dm.classes_txt_file_path): + _add_class_columns(self._dm.classes_txt_file_path) + + predictions_df.index = smiles_strings + predictions_df.to_csv(save_to) + + @torch.inference_mode() + def predict_smiles( + self, + smiles: List[str], + ) -> torch.Tensor: + """ + Predicts the output for a list of SMILES strings using the model. + + Args: + smiles: A list of SMILES strings. + + Returns: + A tensor containing the predictions. + """ + # For certain data prediction piplines, we may need model hyperparameters + pred_dl: DataLoader = self._dm.predict_dataloader( + smiles_list=smiles, model_hparams=self._model_hparams + ) + + preds = [] + for batch_idx, batch in enumerate(pred_dl): + # For certain model prediction pipelines, we may need data module hyperparameters + preds.append( + self._model.predict_step(batch, batch_idx, dm_hparams=self._dm_hparams) + ) + + return torch.cat(preds) + + +class MainPredictor: + @staticmethod + def predict_from_file( + checkpoint_path: _PATH, + smiles_file_path: _PATH, + save_to: _PATH = "predictions.csv", + classes_path: Optional[_PATH] = None, + batch_size: Optional[int] = None, + ) -> None: + predictor = Predictor(checkpoint_path, batch_size) + predictor.predict_from_file( + smiles_file_path, + save_to, + classes_path, + ) + + @staticmethod + def predict_smiles( + checkpoint_path: _PATH, + smiles: List[str], + batch_size: Optional[int] = None, + ) -> torch.Tensor: + predictor = Predictor(checkpoint_path, batch_size) + return predictor.predict_smiles(smiles) + + +if __name__ == "__main__": + # python chebai/result/prediction.py predict_from_file --help + # python chebai/result/prediction.py predict_smiles --help + CLI(MainPredictor, as_positional=False) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index f7fbce26..b9e4b0f3 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -1,18 +1,13 @@ import logging -from typing import Any, List, Optional, Tuple +from typing import Any, Optional, Tuple -import pandas as pd -import torch -from lightning import LightningModule, Trainer +from lightning import Trainer from lightning.fabric.utilities.data import _set_sampler_epoch -from lightning.fabric.utilities.types import _PATH from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.loops.fit_loop import _FitLoop from lightning.pytorch.trainer import call -from torch.nn.utils.rnn import pad_sequence from chebai.loggers.custom import CustomLogger -from chebai.preprocessing.reader import CLS_TOKEN, ChemDataReader log = logging.getLogger(__name__) @@ -74,66 +69,18 @@ def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: else: return key, value - def predict_from_file( + def predict( self, - model: LightningModule, - checkpoint_path: _PATH, - input_path: _PATH, - save_to: _PATH = "predictions.csv", - classes_path: Optional[_PATH] = None, - **kwargs, - ) -> None: - """ - Loads a model from a checkpoint and makes predictions on input data from a file. - - Args: - model: The model to use for predictions. - checkpoint_path: Path to the model checkpoint. - input_path: Path to the input file containing SMILES strings. - save_to: Path to save the predictions CSV file. - classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered). - """ - loaded_model = model.__class__.load_from_checkpoint(checkpoint_path) - with open(input_path, "r") as input: - smiles_strings = [inp.strip() for inp in input.readlines()] - loaded_model.eval() - predictions = self._predict_smiles(loaded_model, smiles_strings) - predictions_df = pd.DataFrame(predictions.detach().cpu().numpy()) - if classes_path is not None: - with open(classes_path, "r") as f: - predictions_df.columns = [cls.strip() for cls in f.readlines()] - predictions_df.index = smiles_strings - predictions_df.to_csv(save_to) - - def _predict_smiles( - self, model: LightningModule, smiles: List[str] - ) -> torch.Tensor: - """ - Predicts the output for a list of SMILES strings using the model. - - Args: - model: The model to use for predictions. - smiles: A list of SMILES strings. - - Returns: - A tensor containing the predictions. - """ - reader = ChemDataReader() - parsed_smiles = [reader._read_data(s) for s in smiles] - x = pad_sequence( - [torch.tensor(a, device=model.device) for a in parsed_smiles], - batch_first=True, + model=None, + dataloaders=None, + datamodule=None, + return_predictions=None, + ckpt_path=None, + ): + raise NotImplementedError( + "CustomTrainer.predict is not implemented." + "Use `Prediction.predict_from_file` or `Prediction.predict_smiles` from `chebai/result/prediction.py` instead." ) - cls_tokens = ( - torch.ones(x.shape[0], dtype=torch.int, device=model.device).unsqueeze(-1) - * CLS_TOKEN - ) - features = torch.cat((cls_tokens, x), dim=1) - model_output = model({"features": features}) - preds = torch.sigmoid(model_output["logits"]) - - print(preds.shape) - return preds @property def log_dir(self) -> Optional[str]: @@ -157,7 +104,6 @@ def log_dir(self) -> Optional[str]: class LoadDataLaterFitLoop(_FitLoop): - def on_advance_start(self) -> None: """Calls the hook ``on_train_epoch_start`` **before** the dataloaders are setup. This is necessary so that the dataloaders can get information from the model. For example: The on_train_epoch_start