From e3297c64ef0d3013520f8a87712e3a26dbf96a98 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 17 Mar 2025 12:52:25 +0100 Subject: [PATCH 01/78] add dummy dataset class - quick testing purpose --- chebai/preprocessing/datasets/_dummy.py | 97 +++++++++++++++++++++++++ configs/data/_dummy.yml | 3 + 2 files changed, 100 insertions(+) create mode 100644 chebai/preprocessing/datasets/_dummy.py create mode 100644 configs/data/_dummy.yml diff --git a/chebai/preprocessing/datasets/_dummy.py b/chebai/preprocessing/datasets/_dummy.py new file mode 100644 index 00000000..11d34862 --- /dev/null +++ b/chebai/preprocessing/datasets/_dummy.py @@ -0,0 +1,97 @@ +# This file is for developers only + +__all__ = [] # Nothing should be imported from this file + + +import random + +import numpy as np +from torch.utils.data import DataLoader, Dataset + +from chebai.preprocessing.datasets import XYBaseDataModule +from chebai.preprocessing.reader import ChemDataReader + + +class _DummyDataModule(XYBaseDataModule): + + READER = ChemDataReader + + def __init__(self, num_of_labels: int, feature_vector_size: int, *args, **kwargs): + super().__init__(*args, **kwargs) + self._num_of_labels = num_of_labels + self._feature_vector_size = feature_vector_size + assert self._num_of_labels is not None + assert self._feature_vector_size is not None + + def prepare_data(self): + pass + + def setup(self, stage=None): + pass + + @property + def num_of_labels(self): + return self._num_of_labels + + @property + def feature_vector_size(self): + return self._feature_vector_size + + def train_dataloader(self, *args, **kwargs) -> DataLoader: + dataset = _DummyDataset(100, self.num_of_labels, self.feature_vector_size) + return DataLoader( + dataset, + collate_fn=self.reader.collator, + batch_size=self.batch_size, + **kwargs, + ) + + def test_dataloader(self, *args, **kwargs) -> DataLoader: + dataset = _DummyDataset(20, self.num_of_labels, self.feature_vector_size) + return DataLoader( + dataset, + collate_fn=self.reader.collator, + batch_size=self.batch_size, + **kwargs, + ) + + def val_dataloader(self, *args, **kwargs) -> DataLoader: + dataset = _DummyDataset(10, self.num_of_labels, self.feature_vector_size) + return DataLoader( + dataset, + collate_fn=self.reader.collator, + batch_size=self.batch_size, + **kwargs, + ) + + @property + def _name(self) -> str: + return "_DummyDataModule" + + +class _DummyDataset(Dataset): + def __init__(self, num_samples: int, num_labels: int, feature_vector_size: int): + self.num_samples = num_samples + self.num_labels = num_labels + self.feature_vector_size = feature_vector_size + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + "features": np.random.randint( + 10, 100, size=self.feature_vector_size + ), # Random feature vector + "labels": np.random.choice( + [False, True], size=self.num_labels + ), # Random boolean labels + "ident": random.randint(1, 40000), # Random identifier + "group": None, # Default group value + } + + +if __name__ == "__main__": + dataset = _DummyDataset(num_samples=100, num_labels=5, feature_vector_size=20) + for i in range(10): + print(dataset[i]) diff --git a/configs/data/_dummy.yml b/configs/data/_dummy.yml new file mode 100644 index 00000000..180b6860 --- /dev/null +++ b/configs/data/_dummy.yml @@ -0,0 +1,3 @@ +class_path: chebai.preprocessing.datasets._dummy._DummyDataModule +init_args: + feature_vector_size: 20 From f0e4758d55bbddc74b502ec4d6fc3ff7ed81cf62 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 17 Mar 2025 12:53:00 +0100 Subject: [PATCH 02/78] custom typehints --- chebai/custom_typehints/__init__.py | 3 +++ chebai/custom_typehints/model.py | 7 +++++++ 2 files changed, 10 insertions(+) create mode 100644 chebai/custom_typehints/__init__.py create mode 100644 chebai/custom_typehints/model.py diff --git a/chebai/custom_typehints/__init__.py b/chebai/custom_typehints/__init__.py new file mode 100644 index 00000000..72dce9b3 --- /dev/null +++ b/chebai/custom_typehints/__init__.py @@ -0,0 +1,3 @@ +from .model import ModelConfig + +__all__ = ["ModelConfig"] diff --git a/chebai/custom_typehints/model.py b/chebai/custom_typehints/model.py new file mode 100644 index 00000000..a0de7fe1 --- /dev/null +++ b/chebai/custom_typehints/model.py @@ -0,0 +1,7 @@ +from typing import TypedDict + + +class ModelConfig(TypedDict): + path: str + TPV: float + FPV: float From 4fda5653158e74c8cf09b3986a09f4da131dade2 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 17 Mar 2025 12:55:03 +0100 Subject: [PATCH 03/78] model base: make forward method as abstract method rebase ensemble_fr from protein_prediction to dev --- chebai/models/base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index 4ba27bbc..412010ad 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -1,9 +1,10 @@ import logging -from typing import Any, Dict, Optional, Union, Iterable + +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Union import torch from lightning.pytorch.core.module import LightningModule -from torchmetrics import Metric from chebai.preprocessing.structures import XYData @@ -12,7 +13,7 @@ _MODEL_REGISTRY = dict() -class ChebaiBaseNet(LightningModule): +class ChebaiBaseNet(LightningModule, ABC): """ Base class for Chebai neural network models inheriting from PyTorch Lightning's LightningModule. @@ -347,6 +348,7 @@ def _log_metrics(self, prefix: str, metrics: torch.nn.Module, batch_size: int): logger=True, ) + @abstractmethod def forward(self, x: Dict[str, Any]) -> torch.Tensor: """ Defines the forward pass. @@ -357,7 +359,7 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor: Returns: torch.Tensor: The model output. """ - raise NotImplementedError + pass def configure_optimizers(self, **kwargs) -> torch.optim.Optimizer: """ From 7f7c6a0d414e18cc99331c26e49848d5c6338eee Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 17 Mar 2025 12:55:47 +0100 Subject: [PATCH 04/78] ensemble: abstract code --- chebai/models/ensemble.py | 264 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 264 insertions(+) create mode 100644 chebai/models/ensemble.py diff --git a/chebai/models/ensemble.py b/chebai/models/ensemble.py new file mode 100644 index 00000000..e79f6148 --- /dev/null +++ b/chebai/models/ensemble.py @@ -0,0 +1,264 @@ +import os.path +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Union + +import torch +from torch import Tensor + +from chebai.custom_typehints import ModelConfig +from chebai.models import ChebaiBaseNet, Electra +from chebai.preprocessing.structures import XYData + + +class _EnsembleBase(ChebaiBaseNet, ABC): + def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs): + super().__init__(**kwargs) + + self._validate_model_configs(model_configs) + + self.models: Dict[str, ChebaiBaseNet] = {} + self.model_configs: Dict[str, ModelConfig] = model_configs + + for model_name in self.model_configs: + model_path = self.model_configs[model_name]["path"] + if os.path.exists(model_path): + self.models[model_name] = Electra.load_from_checkpoint( + model_path, map_location="cpu" + ) + else: + raise FileNotFoundError( + f"Model {model_name} does not exist in the given path {model_path}" + ) + + for model in self.models.values(): + model.freeze() + + # TODO: Later discuss whether this threshold should be independent of metric threshold or not ? + # if kwargs.get("threshold") is None: + # first_metric_key = next(iter(self.train_metrics)) # Get the first key + # first_metric = self.train_metrics[first_metric_key] # Get the metric object + # self.threshold = int(first_metric.threshold) # Access threshold + # else: + # self.threshold = int(kwargs["threshold"]) + + @classmethod + def _validate_model_configs(cls, model_configs: Dict[str, ModelConfig]): + path_set = set() + required_keys = {"path", "TPV", "FPV"} + + for model_name, config in model_configs.items(): + missing_keys = required_keys - config.keys() + + if missing_keys: + raise AttributeError( + f"Missing keys {missing_keys} in model '{model_name}' configuration." + ) + + model_path = config["path"] + if not os.path.exists(model_path): + raise FileNotFoundError( + f"Model path '{model_path}' for '{model_name}' does not exist." + ) + + # if model_path in path_set: + # raise ValueError( + # f"Duplicate model path detected: '{model_path}'. Each model must have a unique path." + # ) + + path_set.add(model_path) + + # Validate 'tpv' and 'fpv' are either floats or convertible to float + for key in ["TPV", "FPV"]: + try: + value = float(config[key]) + if value < 0: + raise ValueError( + f"'{key}' in model '{model_name}' must be non-negative, but got {value}." + ) + except (TypeError, ValueError): + raise ValueError( + f"'{key}' in model '{model_name}' must be a float or convertible to float, but got {config[key]}." + ) + + @abstractmethod + def _get_prediction_and_labels( + self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor + ) -> (torch.Tensor, torch.Tensor): + pass + + +class ChebiEnsemble(_EnsembleBase): + + NAME = "ChebiEnsemble" + + def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs): + super().__init__(model_configs, **kwargs) + # Add a dummy trainable parameter + self.dummy_param = torch.nn.Parameter(torch.randn(1)) + + def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: + predictions = {} + confidences = {} + total_logits = torch.zeros( + data["labels"].shape[0], data["labels"].shape[1], device=self.device + ).to(self.device) + + print(data["features"].shape) # Debugging + + for name, model in self.models.items(): + output = model(data) + confidences[name] = torch.sigmoid(output["logits"]) + predictions[name] = ( + torch.sigmoid(output["logits"]) > 0.5 + ).long() # Multi-label classification + total_logits += output["logits"] + + return { + "logits": total_logits, + "pred_dict": predictions, + "conf_dict": confidences, + } + + def _get_prediction_and_labels(self, data, labels, model_output): + d = model_output["logits"] + # Aggregate predictions using weighted voting + metrics_preds = self.aggregate_predictions( + model_output["pred_dict"], model_output["conf_dict"] + ) + loss_kwargs = data.get("loss_kwargs", dict()) + if "non_null_labels" in loss_kwargs: + n = loss_kwargs["non_null_labels"] + d = d[n] + metrics_preds = metrics_preds[n] + return ( + torch.sigmoid(d), + labels.int() if labels is not None else None, + metrics_preds, + ) + + def _execute( + self, + batch: XYData, + batch_idx: int, + metrics: Optional[torch.nn.Module] = None, + prefix: Optional[str] = "", + log: Optional[bool] = True, + sync_dist: Optional[bool] = False, + ) -> Dict[str, Union[torch.Tensor, Any]]: + """ + Executes the model on a batch of data and returns the model output and predictions. + + Args: + batch (XYData): The input batch of data. + batch_idx (int): The index of the current batch. + metrics (torch.nn.Module): A dictionary of metrics to track. + prefix (str, optional): A prefix to add to the metric names. Defaults to "". + log (bool, optional): Whether to log the metrics. Defaults to True. + sync_dist (bool, optional): Whether to synchronize distributed training. Defaults to False. + + Returns: + Dict[str, Union[torch.Tensor, Any]]: A dictionary containing the processed data, labels, model output, + predictions, and loss (if applicable). + """ + assert isinstance(batch, XYData) + batch = batch.to(self.device) + data = self._process_batch(batch, batch_idx) + labels = data["labels"] + model_output = self(data, **data.get("model_kwargs", dict())) + pr, tar, metrics_preds = self._get_prediction_and_labels( + data, labels, model_output + ) + d = dict(data=data, labels=labels, output=model_output, preds=pr) + if log: + if self.criterion is not None: + loss_data, loss_labels, loss_kwargs_candidates = self._process_for_loss( + model_output, labels, data.get("loss_kwargs", dict()) + ) + loss_kwargs = dict() + if self.pass_loss_kwargs: + loss_kwargs = loss_kwargs_candidates + loss = self.criterion(loss_data, loss_labels, **loss_kwargs) + if isinstance(loss, tuple): + loss_additional = loss[1:] + for i, loss_add in enumerate(loss_additional): + self.log( + f"{prefix}loss_{i}", + loss_add if isinstance(loss_add, int) else loss_add.item(), + batch_size=len(batch), + on_step=True, + on_epoch=False, + prog_bar=False, + logger=True, + sync_dist=sync_dist, + ) + loss = loss[0] + + d["loss"] = loss + self.log( + f"{prefix}loss", + loss.item(), + batch_size=len(batch), + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=sync_dist, + ) + if metrics and labels is not None: + for metric_name, metric in metrics.items(): + metric.update(metrics_preds, tar) + self._log_metrics(prefix, metrics, len(batch)) + return d + + def aggregate_predictions(self, predictions, confidences): + """Implements weighted voting based on trustworthiness.""" + batch_size, num_classes = list(predictions.values())[0].shape + + true_scores = torch.zeros(batch_size, num_classes, device=self.device) + false_scores = torch.zeros(batch_size, num_classes, device=self.device) + + for model, preds in predictions.items(): + tpv = float(self.model_configs[model]["TPV"]) + npv = float(self.model_configs[model]["FPV"]) + + confidence = confidences[model] + weight = confidence * (tpv * preds + npv * (1 - preds)) + + true_scores += weight * preds + false_scores += weight * (1 - preds) + + return (true_scores > false_scores).long() # Final class decision + + +class ChebiEnsembleLearning(_EnsembleBase): + + NAME = "ChebiEnsembleLearning" + + def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs): + super().__init__(model_configs, **kwargs) + self.ensemble_classifier = torch.nn.Linear( + in_features=len(self.models) * self.out_dim, out_features=self.out_dim + ) + + def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: + predictions = {} + confidences = {} + + for name, model in self.models.items(): + output = model(data["features"]) + confidence = torch.sigmoid(output) # Assuming confidence scores + predictions[name] = output.argmax(dim=1) # Convert logits to class + confidences[name] = confidence.max(dim=1).values # Max confidence + + # Aggregate predictions using weighted voting + final_preds = self.aggregate_predictions(predictions, confidences) + return final_preds + + def _get_prediction_and_labels( + self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor + ) -> (torch.Tensor, torch.Tensor): + pass + + +if __name__ == "__main__": + pass From 4d3f4f6c9dc33160f79b779486a140d6073ac27f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 17 Mar 2025 20:18:37 +0100 Subject: [PATCH 05/78] ignore lightning logs --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index f9cb175a..f616e866 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,4 @@ cython_debug/ /logs /results_buffer electra_pretrained.ckpt +/lightning_logs From 55959de74caf0f5c2fee54121c1c6008778e73b9 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 17 Mar 2025 20:21:26 +0100 Subject: [PATCH 06/78] ensemble: fix for grad runtime error - dummy param should be linked to loss to build gradient graph - Error : element 0 of tensors does not require grad and does not have a grad_fn --- chebai/models/ensemble.py | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/chebai/models/ensemble.py b/chebai/models/ensemble.py index e79f6148..e704ee5b 100644 --- a/chebai/models/ensemble.py +++ b/chebai/models/ensemble.py @@ -1,6 +1,6 @@ import os.path from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import torch from torch import Tensor @@ -94,7 +94,7 @@ class ChebiEnsemble(_EnsembleBase): def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs): super().__init__(model_configs, **kwargs) # Add a dummy trainable parameter - self.dummy_param = torch.nn.Parameter(torch.randn(1)) + self.dummy_param = torch.nn.Parameter(torch.randn(1, requires_grad=True)) def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: predictions = {} @@ -103,8 +103,6 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: data["labels"].shape[0], data["labels"].shape[1], device=self.device ).to(self.device) - print(data["features"].shape) # Debugging - for name, model in self.models.items(): output = model(data) confidences[name] = torch.sigmoid(output["logits"]) @@ -193,7 +191,8 @@ def _execute( ) loss = loss[0] - d["loss"] = loss + d["loss"] = loss + 0 * self.dummy_param.sum() + self.log( f"{prefix}loss", loss.item(), @@ -229,6 +228,28 @@ def aggregate_predictions(self, predictions, confidences): return (true_scores > false_scores).long() # Final class decision + def _process_for_loss( + self, + model_output: Dict[str, Tensor], + labels: Tensor, + loss_kwargs: Dict[str, Any], + ) -> Tuple[Tensor, Tensor, Dict[str, Any]]: + """ + Process the model output for calculating the loss. + + Args: + model_output (Dict[str, Tensor]): The output of the model. + labels (Tensor): The target labels. + loss_kwargs (Dict[str, Any]): Additional loss arguments. + + Returns: + tuple: A tuple containing the processed model output, labels, and loss arguments. + """ + kwargs_copy = dict(loss_kwargs) + if labels is not None: + labels = labels.float() + return model_output["logits"], labels, kwargs_copy + class ChebiEnsembleLearning(_EnsembleBase): From 9513fea12716f67e6555240facffeacb33d46b94 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 17 Mar 2025 21:06:33 +0100 Subject: [PATCH 07/78] ensemble: config for ensemble model --- configs/model/ensemble/chebiEnsemble.yml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 configs/model/ensemble/chebiEnsemble.yml diff --git a/configs/model/ensemble/chebiEnsemble.yml b/configs/model/ensemble/chebiEnsemble.yml new file mode 100644 index 00000000..bb529a5d --- /dev/null +++ b/configs/model/ensemble/chebiEnsemble.yml @@ -0,0 +1,4 @@ +class_path: chebai.models.ensemble.ChebiEnsemble +init_args: + optimizer_kwargs: + lr: 1e-3 From 287538591eee2a1d76f7e229f633d3770dc6e0d1 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 21 Mar 2025 23:55:51 +0100 Subject: [PATCH 08/78] ensemble: add MLP layer on top ensemble models --- chebai/models/ensemble.py | 70 +++++++++++--------- configs/model/ensemble/ensemble_learning.yml | 4 ++ 2 files changed, 41 insertions(+), 33 deletions(-) create mode 100644 configs/model/ensemble/ensemble_learning.yml diff --git a/chebai/models/ensemble.py b/chebai/models/ensemble.py index e704ee5b..e4a69da5 100644 --- a/chebai/models/ensemble.py +++ b/chebai/models/ensemble.py @@ -14,8 +14,6 @@ class _EnsembleBase(ChebaiBaseNet, ABC): def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs): super().__init__(**kwargs) - self._validate_model_configs(model_configs) - self.models: Dict[str, ChebaiBaseNet] = {} self.model_configs: Dict[str, ModelConfig] = model_configs @@ -41,6 +39,23 @@ def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs): # else: # self.threshold = int(kwargs["threshold"]) + @abstractmethod + def _get_prediction_and_labels( + self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor + ) -> (torch.Tensor, torch.Tensor): + pass + + +class ChebiEnsemble(_EnsembleBase): + + NAME = "ChebiEnsemble" + + def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs): + self._validate_model_configs(model_configs) + super().__init__(model_configs, **kwargs) + # Add a dummy trainable parameter + self.dummy_param = torch.nn.Parameter(torch.randn(1, requires_grad=True)) + @classmethod def _validate_model_configs(cls, model_configs: Dict[str, ModelConfig]): path_set = set() @@ -80,22 +95,6 @@ def _validate_model_configs(cls, model_configs: Dict[str, ModelConfig]): f"'{key}' in model '{model_name}' must be a float or convertible to float, but got {config[key]}." ) - @abstractmethod - def _get_prediction_and_labels( - self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor - ) -> (torch.Tensor, torch.Tensor): - pass - - -class ChebiEnsemble(_EnsembleBase): - - NAME = "ChebiEnsemble" - - def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs): - super().__init__(model_configs, **kwargs) - # Add a dummy trainable parameter - self.dummy_param = torch.nn.Parameter(torch.randn(1, requires_grad=True)) - def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: predictions = {} confidences = {} @@ -255,30 +254,35 @@ class ChebiEnsembleLearning(_EnsembleBase): NAME = "ChebiEnsembleLearning" - def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs): + def __init__(self, model_configs: Dict[str, Dict], **kwargs): super().__init__(model_configs, **kwargs) - self.ensemble_classifier = torch.nn.Linear( - in_features=len(self.models) * self.out_dim, out_features=self.out_dim - ) - def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: - predictions = {} - confidences = {} + from chebai.models.ffn import FFN + ffn_kwargs = kwargs.copy() + ffn_kwargs["input_size"] = len(self.model_configs) * int(kwargs["out_dim"]) + self.ffn: FFN = FFN(**ffn_kwargs) + + def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: + logits_list = [] for name, model in self.models.items(): - output = model(data["features"]) - confidence = torch.sigmoid(output) # Assuming confidence scores - predictions[name] = output.argmax(dim=1) # Convert logits to class - confidences[name] = confidence.max(dim=1).values # Max confidence + output = model(data) + logits_list.append(output["logits"]) - # Aggregate predictions using weighted voting - final_preds = self.aggregate_predictions(predictions, confidences) - return final_preds + return self.ffn({"features": torch.cat(logits_list, dim=1)}) def _get_prediction_and_labels( self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor ) -> (torch.Tensor, torch.Tensor): - pass + return self.ffn._get_prediction_and_labels(data, labels, output) + + def _process_for_loss( + self, + model_output: Dict[str, torch.Tensor], + labels: torch.Tensor, + loss_kwargs: Dict[str, Any], + ) -> (torch.Tensor, torch.Tensor, Dict[str, Any]): + return self.ffn._process_for_loss(model_output, labels, loss_kwargs) if __name__ == "__main__": diff --git a/configs/model/ensemble/ensemble_learning.yml b/configs/model/ensemble/ensemble_learning.yml new file mode 100644 index 00000000..73257b49 --- /dev/null +++ b/configs/model/ensemble/ensemble_learning.yml @@ -0,0 +1,4 @@ +class_path: chebai.models.ensemble.ChebiEnsembleLearning +init_args: + optimizer_kwargs: + lr: 1e-3 From 7f892d9b3012885fd57095c1c2d9a92f02a5bf49 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 21 Mar 2025 23:56:37 +0100 Subject: [PATCH 09/78] base: fix import --- chebai/models/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index 412010ad..070d9a21 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -1,7 +1,6 @@ import logging - from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Iterable, Optional, Union import torch from lightning.pytorch.core.module import LightningModule From f60b2d8a3764a04d8270f5f272c2c60c6db9a5ef Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 22 Mar 2025 00:32:23 +0100 Subject: [PATCH 10/78] ensemble: code improvements --- chebai/models/ensemble.py | 42 +++++++++++++++++---------------------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/chebai/models/ensemble.py b/chebai/models/ensemble.py index e4a69da5..33d02f78 100644 --- a/chebai/models/ensemble.py +++ b/chebai/models/ensemble.py @@ -19,13 +19,19 @@ def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs): for model_name in self.model_configs: model_path = self.model_configs[model_name]["path"] - if os.path.exists(model_path): + if not os.path.exists(model_path): + raise FileNotFoundError( + f"Model path '{model_path}' for '{model_name}' does not exist." + ) + + # Attempt to load the model to check validity + try: self.models[model_name] = Electra.load_from_checkpoint( - model_path, map_location="cpu" + model_path, map_location=self.device ) - else: - raise FileNotFoundError( - f"Model {model_name} does not exist in the given path {model_path}" + except Exception as e: + raise RuntimeError( + f"Failed to load model '{model_name}' from {model_path}: {e}" ) for model in self.models.values(): @@ -70,10 +76,6 @@ def _validate_model_configs(cls, model_configs: Dict[str, ModelConfig]): ) model_path = config["path"] - if not os.path.exists(model_path): - raise FileNotFoundError( - f"Model path '{model_path}' for '{model_name}' does not exist." - ) # if model_path in path_set: # raise ValueError( @@ -100,14 +102,13 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: confidences = {} total_logits = torch.zeros( data["labels"].shape[0], data["labels"].shape[1], device=self.device - ).to(self.device) + ) for name, model in self.models.items(): output = model(data) - confidences[name] = torch.sigmoid(output["logits"]) - predictions[name] = ( - torch.sigmoid(output["logits"]) > 0.5 - ).long() # Multi-label classification + sigmoid_logits = torch.sigmoid(output["logits"]) + confidences[name] = sigmoid_logits + predictions[name] = (sigmoid_logits > 0.5).long() total_logits += output["logits"] return { @@ -211,21 +212,18 @@ def _execute( def aggregate_predictions(self, predictions, confidences): """Implements weighted voting based on trustworthiness.""" batch_size, num_classes = list(predictions.values())[0].shape - true_scores = torch.zeros(batch_size, num_classes, device=self.device) false_scores = torch.zeros(batch_size, num_classes, device=self.device) for model, preds in predictions.items(): tpv = float(self.model_configs[model]["TPV"]) npv = float(self.model_configs[model]["FPV"]) - - confidence = confidences[model] - weight = confidence * (tpv * preds + npv * (1 - preds)) + weight = confidences[model] * (tpv * preds + npv * (1 - preds)) true_scores += weight * preds false_scores += weight * (1 - preds) - return (true_scores > false_scores).long() # Final class decision + return (true_scores > false_scores).long() def _process_for_loss( self, @@ -264,11 +262,7 @@ def __init__(self, model_configs: Dict[str, Dict], **kwargs): self.ffn: FFN = FFN(**ffn_kwargs) def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: - logits_list = [] - for name, model in self.models.items(): - output = model(data) - logits_list.append(output["logits"]) - + logits_list = [model(data)["logits"] for model in self.models.values()] return self.ffn({"features": torch.cat(logits_list, dim=1)}) def _get_prediction_and_labels( From 72a6b37c36be7df2cbcbd931098cca277545eee1 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 24 Mar 2025 20:35:50 +0100 Subject: [PATCH 11/78] ensemble: add class path to config and load model via this class --- chebai/custom_typehints/__init__.py | 3 - chebai/custom_typehints/model.py | 7 -- chebai/models/ensemble.py | 127 ++++++++++++++++++---------- 3 files changed, 81 insertions(+), 56 deletions(-) delete mode 100644 chebai/custom_typehints/__init__.py delete mode 100644 chebai/custom_typehints/model.py diff --git a/chebai/custom_typehints/__init__.py b/chebai/custom_typehints/__init__.py deleted file mode 100644 index 72dce9b3..00000000 --- a/chebai/custom_typehints/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .model import ModelConfig - -__all__ = ["ModelConfig"] diff --git a/chebai/custom_typehints/model.py b/chebai/custom_typehints/model.py deleted file mode 100644 index a0de7fe1..00000000 --- a/chebai/custom_typehints/model.py +++ /dev/null @@ -1,7 +0,0 @@ -from typing import TypedDict - - -class ModelConfig(TypedDict): - path: str - TPV: float - FPV: float diff --git a/chebai/models/ensemble.py b/chebai/models/ensemble.py index 33d02f78..856b0197 100644 --- a/chebai/models/ensemble.py +++ b/chebai/models/ensemble.py @@ -1,42 +1,54 @@ +import importlib import os.path from abc import ABC, abstractmethod from typing import Any, Dict, Optional, Tuple, Union import torch +from lightning.pytorch import LightningModule from torch import Tensor -from chebai.custom_typehints import ModelConfig -from chebai.models import ChebaiBaseNet, Electra +from chebai.models import ChebaiBaseNet from chebai.preprocessing.structures import XYData class _EnsembleBase(ChebaiBaseNet, ABC): - def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs): + def __init__(self, model_configs: Dict[str, Dict], **kwargs): super().__init__(**kwargs) + self._validate_model_configs(model_configs) - self.models: Dict[str, ChebaiBaseNet] = {} - self.model_configs: Dict[str, ModelConfig] = model_configs + self.models: Dict[str, LightningModule] = {} + self.model_configs = model_configs for model_name in self.model_configs: - model_path = self.model_configs[model_name]["path"] - if not os.path.exists(model_path): + model_ckpt_path = self.model_configs[model_name]["ckpt_path"] + model_class_path = self.model_configs[model_name]["class_path"] + if not os.path.exists(model_ckpt_path): raise FileNotFoundError( - f"Model path '{model_path}' for '{model_name}' does not exist." + f"Model path '{model_ckpt_path}' for '{model_name}' does not exist." ) - # Attempt to load the model to check validity + class_name = model_class_path.split(".")[-1] + module_path = ".".join(model_class_path.split(".")[:-1]) + try: - self.models[model_name] = Electra.load_from_checkpoint( - model_path, map_location=self.device - ) + module = importlib.import_module(module_path) + lightning_cls: LightningModule = getattr(module, class_name) + + model = lightning_cls.load_from_checkpoint(model_ckpt_path) + model.eval() + model.freeze() + self.models[model_name] = model + + except ModuleNotFoundError: + print(f"Module '{module_path}' not found!") + except AttributeError: + print(f"Class '{class_name}' not found in '{module_path}'!") + except Exception as e: raise RuntimeError( - f"Failed to load model '{model_name}' from {model_path}: {e}" + f"Failed to load model '{model_name}' from {model_ckpt_path}: \n {e}" ) - for model in self.models.values(): - model.freeze() - # TODO: Later discuss whether this threshold should be independent of metric threshold or not ? # if kwargs.get("threshold") is None: # first_metric_key = next(iter(self.train_metrics)) # Get the first key @@ -45,27 +57,12 @@ def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs): # else: # self.threshold = int(kwargs["threshold"]) - @abstractmethod - def _get_prediction_and_labels( - self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor - ) -> (torch.Tensor, torch.Tensor): - pass - - -class ChebiEnsemble(_EnsembleBase): - - NAME = "ChebiEnsemble" - - def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs): - self._validate_model_configs(model_configs) - super().__init__(model_configs, **kwargs) - # Add a dummy trainable parameter - self.dummy_param = torch.nn.Parameter(torch.randn(1, requires_grad=True)) - @classmethod - def _validate_model_configs(cls, model_configs: Dict[str, ModelConfig]): + def _validate_model_configs(cls, model_configs: Dict[str, Dict]): path_set = set() - required_keys = {"path", "TPV", "FPV"} + class_set = set() + + required_keys = {"class_path", "ckpt_path"} for model_name, config in model_configs.items(): missing_keys = required_keys - config.keys() @@ -75,27 +72,65 @@ def _validate_model_configs(cls, model_configs: Dict[str, ModelConfig]): f"Missing keys {missing_keys} in model '{model_name}' configuration." ) - model_path = config["path"] + model_path = config["ckpt_path"] + class_path = config["class_path"] # if model_path in path_set: # raise ValueError( # f"Duplicate model path detected: '{model_path}'. Each model must have a unique path." # ) + # if class_path not in class_set: + # raise ValueError( + # f"Duplicate class path detected: '{class_path}'. Each model must have a unique path." + # ) + path_set.add(model_path) + class_set.add(class_path) - # Validate 'tpv' and 'fpv' are either floats or convertible to float - for key in ["TPV", "FPV"]: - try: - value = float(config[key]) - if value < 0: - raise ValueError( - f"'{key}' in model '{model_name}' must be non-negative, but got {value}." - ) - except (TypeError, ValueError): + cls._extra_validation(model_name, config) + + @classmethod + def _extra_validation(cls, model_name: str, config: Dict[str, Any]): + pass + + @abstractmethod + def _get_prediction_and_labels( + self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor + ) -> (torch.Tensor, torch.Tensor): + pass + + +class ChebiEnsemble(_EnsembleBase): + + NAME = "ChebiEnsemble" + + def __init__(self, model_configs: Dict[str, Dict], **kwargs): + super().__init__(model_configs, **kwargs) + + # Add a dummy trainable parameter + self.dummy_param = torch.nn.Parameter(torch.randn(1, requires_grad=True)) + + @classmethod + def _extra_validation(cls, model_name: str, config: Dict[str, Any]): + + if "TPV" not in config.keys() or "FPV" not in config.keys(): + raise AttributeError( + f"Missing keys 'TPV' and/or 'FPV' in model '{model_name}' configuration." + ) + + # Validate 'tpv' and 'fpv' are either floats or convertible to float + for key in ["TPV", "FPV"]: + try: + value = float(config[key]) + if value < 0: raise ValueError( - f"'{key}' in model '{model_name}' must be a float or convertible to float, but got {config[key]}." + f"'{key}' in model '{model_name}' must be non-negative, but got {value}." ) + except (TypeError, ValueError): + raise ValueError( + f"'{key}' in model '{model_name}' must be a float or convertible to float, but got {config[key]}." + ) def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: predictions = {} From 82a96dca157082e0278d151fb0ea265802f35af5 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 25 Mar 2025 20:07:02 +0100 Subject: [PATCH 12/78] ensemble: changes for out of scope labels for certain models --- chebai/models/ensemble.py | 95 ++++++++++++++++++++++++++++++++++----- 1 file changed, 83 insertions(+), 12 deletions(-) diff --git a/chebai/models/ensemble.py b/chebai/models/ensemble.py index 856b0197..e526820a 100644 --- a/chebai/models/ensemble.py +++ b/chebai/models/ensemble.py @@ -61,7 +61,9 @@ def __init__(self, model_configs: Dict[str, Dict], **kwargs): def _validate_model_configs(cls, model_configs: Dict[str, Dict]): path_set = set() class_set = set() + labels_set = set() + sets_ = {"path": path_set, "class": class_set, "labels": labels_set} required_keys = {"class_path", "ckpt_path"} for model_name, config in model_configs.items(): @@ -88,10 +90,12 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]): path_set.add(model_path) class_set.add(class_path) - cls._extra_validation(model_name, config) + cls._extra_validation(model_name, config, sets_) @classmethod - def _extra_validation(cls, model_name: str, config: Dict[str, Any]): + def _extra_validation( + cls, model_name: str, config: Dict[str, Any], sets_: Dict[str, set] + ): pass @abstractmethod @@ -110,9 +114,23 @@ def __init__(self, model_configs: Dict[str, Dict], **kwargs): # Add a dummy trainable parameter self.dummy_param = torch.nn.Parameter(torch.randn(1, requires_grad=True)) + self._num_models_per_label: Optional[torch.Tensor] = None + self._generate_model_label_mask() @classmethod - def _extra_validation(cls, model_name: str, config: Dict[str, Any]): + def _extra_validation( + cls, model_name: str, config: Dict[str, Any], sets_: Dict[str, set] + ): + + if "labels_path" not in config: + raise AttributeError("Missing 'labels_path' key in config!") + + labels_path = config["labels_path"] + # if labels_path not in sets_["labels"]: + # raise ValueError( + # f"Duplicate labels path detected: '{labels_path}'. Each model must have a unique path." + # ) + sets_["labels"].add(labels_path) if "TPV" not in config.keys() or "FPV" not in config.keys(): raise AttributeError( @@ -132,19 +150,62 @@ def _extra_validation(cls, model_name: str, config: Dict[str, Any]): f"'{key}' in model '{model_name}' must be a float or convertible to float, but got {config[key]}." ) + def _generate_model_label_mask(self): + labels_dict = {} + num_models_per_label = torch.zeros(1, self.out_dim, device=self.device) + for model_name, model_config in self.model_configs.items(): + labels_path = model_config["labels_path"] + if not os.path.exists(labels_path): + raise FileNotFoundError(f"Labels path '{labels_path}' does not exist.") + + with open(labels_path, "r") as f: + labels_list = [int(line.strip()) for line in f] + + model_label_indices = [] + for label in labels_list: + if label not in labels_dict: + labels_dict[label] = len(labels_dict) + + model_label_indices.append(labels_dict[label]) + + # Create masks to apply predictions only to known classes + mask = torch.zeros(self.out_dim, device=self.device, dtype=torch.bool) + mask[ + torch.tensor(model_label_indices, dtype=torch.int, device=self.device) + ] = True + + self.model_configs[model_name]["labels_mask"] = mask + num_models_per_label += mask + + self._num_models_per_label = num_models_per_label + def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: predictions = {} confidences = {} + + assert data["labels"].shape[1] == self.out_dim + + # Initialize total_logits with zeros total_logits = torch.zeros( - data["labels"].shape[0], data["labels"].shape[1], device=self.device + data["labels"].shape[0], self.out_dim, device=self.device ) for name, model in self.models.items(): output = model(data) + mask = self.model_configs[name]["labels_mask"] + + # Consider logits and confidence only for valid classes sigmoid_logits = torch.sigmoid(output["logits"]) - confidences[name] = sigmoid_logits - predictions[name] = (sigmoid_logits > 0.5).long() - total_logits += output["logits"] + prediction = torch.full_like(total_logits, -1, dtype=torch.bool) + confidence = torch.full_like(total_logits, -1, dtype=torch.float) + prediction[:, mask] = sigmoid_logits > 0.5 + confidence[:, mask] = sigmoid_logits + + predictions[name] = prediction + confidences[name] = confidence + total_logits += output[ + "logits" + ] # Don't play a role here, just for lightning flow completeness return { "logits": total_logits, @@ -250,15 +311,25 @@ def aggregate_predictions(self, predictions, confidences): true_scores = torch.zeros(batch_size, num_classes, device=self.device) false_scores = torch.zeros(batch_size, num_classes, device=self.device) - for model, preds in predictions.items(): + for model, conf in confidences.items(): tpv = float(self.model_configs[model]["TPV"]) npv = float(self.model_configs[model]["FPV"]) - weight = confidences[model] * (tpv * preds + npv * (1 - preds)) - true_scores += weight * preds - false_scores += weight * (1 - preds) + # Determine which classes the model provides predictions for + mask = self.model_configs[model]["labels_mask"] + weight = conf * (tpv * conf + npv * (1 - conf)) + + # Apply mask: Only update scores for valid classes + true_scores += weight * conf * mask + false_scores += weight * (1 - conf) * mask + + # Avoid division by zero: Set valid_counts to 1 where it's zero + valid_counts = self._num_models_per_label.clamp(min=1) + + # Normalize by valid contributions to prevent bias, this step can be optional depending upon scenario + final_preds = (true_scores / valid_counts) > (false_scores / valid_counts) - return (true_scores > false_scores).long() + return final_preds def _process_for_loss( self, From 26f5ab43b34d4b62fb6fcb686029a6ba4b1c3354 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 31 Mar 2025 11:46:47 +0200 Subject: [PATCH 13/78] ensemble: correct confidence val calculation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - confidence=2×∣x−0.5∣ --- chebai/models/ensemble.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/models/ensemble.py b/chebai/models/ensemble.py index e526820a..4c3ad348 100644 --- a/chebai/models/ensemble.py +++ b/chebai/models/ensemble.py @@ -199,7 +199,7 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: prediction = torch.full_like(total_logits, -1, dtype=torch.bool) confidence = torch.full_like(total_logits, -1, dtype=torch.float) prediction[:, mask] = sigmoid_logits > 0.5 - confidence[:, mask] = sigmoid_logits + confidence[:, mask] = 2 * torch.abs(sigmoid_logits - 0.5) predictions[name] = prediction confidences[name] = confidence From 9b851c5b37fc03aa450aafadfce100075a77e0ff Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 31 Mar 2025 16:22:31 +0200 Subject: [PATCH 14/78] ensemble: update for tpv/fpv value for each label --- chebai/cli.py | 6 ++ chebai/models/ensemble.py | 149 ++++++++++++++++++++++++-------------- 2 files changed, 102 insertions(+), 53 deletions(-) diff --git a/chebai/cli.py b/chebai/cli.py index b7e78d17..6018699f 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -48,6 +48,12 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser): "model.init_args.out_dim", "trainer.callbacks.init_args.num_labels" ) + parser.link_arguments( + "data.processed_dir_main", + "model.init_args.data_processed_dir_main", + apply_on="instantiate", + ) + @staticmethod def subcommands() -> Dict[str, Set[str]]: """ diff --git a/chebai/models/ensemble.py b/chebai/models/ensemble.py index 4c3ad348..d61c449d 100644 --- a/chebai/models/ensemble.py +++ b/chebai/models/ensemble.py @@ -1,4 +1,5 @@ import importlib +import json import os.path from abc import ABC, abstractmethod from typing import Any, Dict, Optional, Tuple, Union @@ -12,42 +13,19 @@ class _EnsembleBase(ChebaiBaseNet, ABC): - def __init__(self, model_configs: Dict[str, Dict], **kwargs): + def __init__( + self, model_configs: Dict[str, Dict], data_processed_dir_main: str, **kwargs + ): super().__init__(**kwargs) self._validate_model_configs(model_configs) + self.data_processed_dir_main = data_processed_dir_main self.models: Dict[str, LightningModule] = {} self.model_configs = model_configs + self.dm_labels: Dict[str, int] = {} - for model_name in self.model_configs: - model_ckpt_path = self.model_configs[model_name]["ckpt_path"] - model_class_path = self.model_configs[model_name]["class_path"] - if not os.path.exists(model_ckpt_path): - raise FileNotFoundError( - f"Model path '{model_ckpt_path}' for '{model_name}' does not exist." - ) - - class_name = model_class_path.split(".")[-1] - module_path = ".".join(model_class_path.split(".")[:-1]) - - try: - module = importlib.import_module(module_path) - lightning_cls: LightningModule = getattr(module, class_name) - - model = lightning_cls.load_from_checkpoint(model_ckpt_path) - model.eval() - model.freeze() - self.models[model_name] = model - - except ModuleNotFoundError: - print(f"Module '{module_path}' not found!") - except AttributeError: - print(f"Class '{class_name}' not found in '{module_path}'!") - - except Exception as e: - raise RuntimeError( - f"Failed to load model '{model_name}' from {model_ckpt_path}: \n {e}" - ) + self._load_data_module_labels() + self._load_ensemble_models() # TODO: Later discuss whether this threshold should be independent of metric threshold or not ? # if kwargs.get("threshold") is None: @@ -98,6 +76,47 @@ def _extra_validation( ): pass + def _load_ensemble_models(self): + for model_name in self.model_configs: + model_ckpt_path = self.model_configs[model_name]["ckpt_path"] + model_class_path = self.model_configs[model_name]["class_path"] + if not os.path.exists(model_ckpt_path): + raise FileNotFoundError( + f"Model path '{model_ckpt_path}' for '{model_name}' does not exist." + ) + + class_name = model_class_path.split(".")[-1] + module_path = ".".join(model_class_path.split(".")[:-1]) + + try: + module = importlib.import_module(module_path) + lightning_cls: LightningModule = getattr(module, class_name) + + model = lightning_cls.load_from_checkpoint(model_ckpt_path) + model.eval() + model.freeze() + self.models[model_name] = model + + except ModuleNotFoundError: + print(f"Module '{module_path}' not found!") + except AttributeError: + print(f"Class '{class_name}' not found in '{module_path}'!") + + except Exception as e: + raise RuntimeError( + f"Failed to load model '{model_name}' from {model_ckpt_path}: \n {e}" + ) + + def _load_data_module_labels(self): + classes_txt_file = os.path.join(self.data_processed_dir_main, "classes.txt") + if not os.path.exists(classes_txt_file): + raise FileNotFoundError(f"{classes_txt_file} does not exist") + else: + with open(classes_txt_file, "r") as f: + for line in f: + if line.strip() not in self.dm_labels: + self.dm_labels[line.strip()] = len(self.dm_labels) + @abstractmethod def _get_prediction_and_labels( self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor @@ -132,41 +151,49 @@ def _extra_validation( # ) sets_["labels"].add(labels_path) - if "TPV" not in config.keys() or "FPV" not in config.keys(): - raise AttributeError( - f"Missing keys 'TPV' and/or 'FPV' in model '{model_name}' configuration." - ) + with open(labels_path, "r") as f: + model_labels = json.load(f) - # Validate 'tpv' and 'fpv' are either floats or convertible to float - for key in ["TPV", "FPV"]: - try: - value = float(config[key]) - if value < 0: + for label, label_dict in model_labels.items(): + + if "TPV" not in label_dict.keys() or "FPV" not in label_dict.keys(): + raise AttributeError( + f"Missing keys 'TPV' and/or 'FPV' in model '{model_name}' configuration." + ) + + # Validate 'tpv' and 'fpv' are either floats or convertible to float + for key in ["TPV", "FPV"]: + try: + value = float(label_dict[key]) + if value < 0: + raise ValueError( + f"'{key}' in model '{model_name}' and label '{label}' must be non-negative, but got {value}." + ) + except (TypeError, ValueError): raise ValueError( - f"'{key}' in model '{model_name}' must be non-negative, but got {value}." + f"'{key}' in model '{model_name}' and label '{label}' must be a float or convertible to float, but got {label_dict[key]}." ) - except (TypeError, ValueError): - raise ValueError( - f"'{key}' in model '{model_name}' must be a float or convertible to float, but got {config[key]}." - ) def _generate_model_label_mask(self): - labels_dict = {} num_models_per_label = torch.zeros(1, self.out_dim, device=self.device) + for model_name, model_config in self.model_configs.items(): labels_path = model_config["labels_path"] if not os.path.exists(labels_path): raise FileNotFoundError(f"Labels path '{labels_path}' does not exist.") with open(labels_path, "r") as f: - labels_list = [int(line.strip()) for line in f] + labels_dict = json.load(f) - model_label_indices = [] - for label in labels_list: - if label not in labels_dict: - labels_dict[label] = len(labels_dict) + model_label_indices, tpv_label_values, fpv_label_values = [], [], [] + for label in labels_dict.keys(): + if label in self.dm_labels: + model_label_indices.append(self.dm_labels[label]) + tpv_label_values.append(float(labels_dict[label]["TPV"])) + fpv_label_values.append(float(labels_dict[label]["FPV"])) - model_label_indices.append(labels_dict[label]) + if not all([model_label_indices, tpv_label_values, fpv_label_values]): + raise ValueError(f"Values are empty for labels of model {model_name}") # Create masks to apply predictions only to known classes mask = torch.zeros(self.out_dim, device=self.device, dtype=torch.bool) @@ -174,7 +201,23 @@ def _generate_model_label_mask(self): torch.tensor(model_label_indices, dtype=torch.int, device=self.device) ] = True + tpv_tensor = torch.full_like( + mask, -1, dtype=torch.float, device=self.device + ) + fpv_tensor = torch.full_like( + mask, -1, dtype=torch.float, device=self.device + ) + + tpv_tensor[mask] = torch.tensor( + tpv_label_values, dtype=torch.float, device=self.device + ) + fpv_tensor[mask] = torch.tensor( + fpv_label_values, dtype=torch.float, device=self.device + ) + self.model_configs[model_name]["labels_mask"] = mask + self.model_configs[model_name]["tpv_tensor"] = tpv_tensor + self.model_configs[model_name]["fpv_tensor"] = fpv_tensor num_models_per_label += mask self._num_models_per_label = num_models_per_label @@ -312,8 +355,8 @@ def aggregate_predictions(self, predictions, confidences): false_scores = torch.zeros(batch_size, num_classes, device=self.device) for model, conf in confidences.items(): - tpv = float(self.model_configs[model]["TPV"]) - npv = float(self.model_configs[model]["FPV"]) + tpv = self.model_configs[model]["tpv_tensor"] + npv = self.model_configs[model]["fpv_tensor"] # Determine which classes the model provides predictions for mask = self.model_configs[model]["labels_mask"] From 0541ed228b795c42ba5d8a50c23a4371315d43b7 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 1 Apr 2025 13:17:48 +0200 Subject: [PATCH 15/78] ensemble: add docstrings and typehints --- chebai/models/ensemble.py | 246 +++++++++++++++++++++++++++++++++----- 1 file changed, 214 insertions(+), 32 deletions(-) diff --git a/chebai/models/ensemble.py b/chebai/models/ensemble.py index d61c449d..5c1ca230 100644 --- a/chebai/models/ensemble.py +++ b/chebai/models/ensemble.py @@ -9,13 +9,35 @@ from torch import Tensor from chebai.models import ChebaiBaseNet +from chebai.models.ffn import FFN from chebai.preprocessing.structures import XYData class _EnsembleBase(ChebaiBaseNet, ABC): + """ + Base class for ensemble models in the Chebai framework. + + Inherits from ChebaiBaseNet and provides functionality to load multiple models, + validate configuration, and manage predictions. + + Attributes: + data_processed_dir_main (str): Directory where the processed data is stored. + models (Dict[str, LightningModule]): A dictionary of loaded models. + model_configs (Dict[str, Dict]): Configuration dictionary for models in the ensemble. + dm_labels (Dict[str, int]): Mapping of label names to integer indices. + """ + def __init__( self, model_configs: Dict[str, Dict], data_processed_dir_main: str, **kwargs ): + """ + Initializes the ensemble model and loads configuration, models, and labels. + + Args: + model_configs (Dict[str, Dict]): Dictionary of model configurations. + data_processed_dir_main (str): Path to the processed data directory. + **kwargs: Additional arguments for initialization. + """ super().__init__(**kwargs) self._validate_model_configs(model_configs) @@ -27,16 +49,18 @@ def __init__( self._load_data_module_labels() self._load_ensemble_models() - # TODO: Later discuss whether this threshold should be independent of metric threshold or not ? - # if kwargs.get("threshold") is None: - # first_metric_key = next(iter(self.train_metrics)) # Get the first key - # first_metric = self.train_metrics[first_metric_key] # Get the metric object - # self.threshold = int(first_metric.threshold) # Access threshold - # else: - # self.threshold = int(kwargs["threshold"]) - @classmethod def _validate_model_configs(cls, model_configs: Dict[str, Dict]): + """ + Validates the model configurations to ensure required keys are present. + + Args: + model_configs (Dict[str, Dict]): Dictionary of model configurations. + + Raises: + AttributeError: If required keys are missing in the configuration. + ValueError: If there are duplicate model paths or class paths. + """ path_set = set() class_set = set() labels_set = set() @@ -55,15 +79,15 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]): model_path = config["ckpt_path"] class_path = config["class_path"] - # if model_path in path_set: - # raise ValueError( - # f"Duplicate model path detected: '{model_path}'. Each model must have a unique path." - # ) + if model_path in path_set: + raise ValueError( + f"Duplicate model path detected: '{model_path}'. Each model must have a unique path." + ) - # if class_path not in class_set: - # raise ValueError( - # f"Duplicate class path detected: '{class_path}'. Each model must have a unique path." - # ) + if class_path not in class_set: + raise ValueError( + f"Duplicate class path detected: '{class_path}'. Each model must have a unique path." + ) path_set.add(model_path) class_set.add(class_path) @@ -74,9 +98,27 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]): def _extra_validation( cls, model_name: str, config: Dict[str, Any], sets_: Dict[str, set] ): + """ + Perform extra validation on the model configuration, if necessary. + + This method can be extended by subclasses to add additional validation logic. + + Args: + model_name (str): The name of the model. + config (Dict[str, Any]): The configuration dictionary for the model. + sets_ (Dict[str, set]): A dictionary of sets to track model paths, class paths, and labels. + """ pass def _load_ensemble_models(self): + """ + Loads the models specified in the configuration and initializes them. + + Raises: + FileNotFoundError: If the model checkpoint path does not exist. + ModuleNotFoundError: If the module containing the model class is not found. + AttributeError: If the specified class is not found within the module. + """ for model_name in self.model_configs: model_ckpt_path = self.model_configs[model_name]["ckpt_path"] model_class_path = self.model_configs[model_name]["class_path"] @@ -108,6 +150,12 @@ def _load_ensemble_models(self): ) def _load_data_module_labels(self): + """ + Loads the label mapping from the classes.txt file for loaded data. + + Raises: + FileNotFoundError: If the classes.txt file does not exist. + """ classes_txt_file = os.path.join(self.data_processed_dir_main, "classes.txt") if not os.path.exists(classes_txt_file): raise FileNotFoundError(f"{classes_txt_file} does not exist") @@ -120,15 +168,43 @@ def _load_data_module_labels(self): @abstractmethod def _get_prediction_and_labels( self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor - ) -> (torch.Tensor, torch.Tensor): + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Abstract method for obtaining predictions and labels. + + Args: + data (Dict[str, Any]): The input data. + labels (torch.Tensor): The target labels. + output (torch.Tensor): The model output. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The predicted labels and the ground truth labels. + """ pass class ChebiEnsemble(_EnsembleBase): + """ + Ensemble model that aggregates predictions from multiple models for the Chebai task. + + This model combines the outputs of several individual models and aggregates their predictions + using a weighted voting strategy based on trustworthiness (TPV and FPV). This strategy can modified by overriding + `aggregate_predictions` method by subclasses, as per needs. + + There is are relevant trainable parameters for this ensemble model, hence trainer.max_epochs should be set to 1. + `_dummy_param` exists for only lighting module completeness and compatability purpose. + """ NAME = "ChebiEnsemble" def __init__(self, model_configs: Dict[str, Dict], **kwargs): + """ + Initializes the ensemble model and computes the model-label mask. + + Args: + model_configs (Dict[str, Dict]): Dictionary of model configurations. + **kwargs: Additional arguments for initialization. + """ super().__init__(model_configs, **kwargs) # Add a dummy trainable parameter @@ -140,15 +216,27 @@ def __init__(self, model_configs: Dict[str, Dict], **kwargs): def _extra_validation( cls, model_name: str, config: Dict[str, Any], sets_: Dict[str, set] ): + """ + Additional validation for the ensemble model configuration. + Args: + model_name (str): The model name. + config (Dict[str, Any]): The configuration dictionary. + sets_ (Dict[str, set]): The set of paths for labels. + + Raises: + AttributeError: If the 'labels_path' key is missing. + ValueError: If the 'labels_path' contains duplicate entries or certain are not convertible to float. + """ if "labels_path" not in config: raise AttributeError("Missing 'labels_path' key in config!") labels_path = config["labels_path"] - # if labels_path not in sets_["labels"]: - # raise ValueError( - # f"Duplicate labels path detected: '{labels_path}'. Each model must have a unique path." - # ) + if labels_path not in sets_["labels"]: + raise ValueError( + f"Duplicate labels path detected: '{labels_path}'. Each model must have a unique path." + ) + sets_["labels"].add(labels_path) with open(labels_path, "r") as f: @@ -175,6 +263,14 @@ def _extra_validation( ) def _generate_model_label_mask(self): + """ + Generates a mask indicating the labels handled by each model, and retrieves corresponding the TPV and FPV values + as tensors. + + Raises: + FileNotFoundError: If the labels path does not exist. + ValueError: If label values are empty for any model. + """ num_models_per_label = torch.zeros(1, self.out_dim, device=self.device) for model_name, model_config in self.model_configs.items(): @@ -223,6 +319,16 @@ def _generate_model_label_mask(self): self._num_models_per_label = num_models_per_label def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: + """ + Forward pass through the ensemble model, aggregating predictions from all models. + + Args: + data (Dict[str, Tensor]): Input data including features and labels. + **kwargs: Additional arguments for the forward pass. + + Returns: + Dict[str, Any]: The aggregated logits, predictions, and confidences. + """ predictions = {} confidences = {} @@ -257,6 +363,17 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: } def _get_prediction_and_labels(self, data, labels, model_output): + """ + Gets predictions and labels from the model output. + + Args: + data (Dict[str, Any]): The input data. + labels (torch.Tensor): The target labels. + model_output (Dict[str, Tensor]): The model's output. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The predictions and the ground truth labels. + """ d = model_output["logits"] # Aggregate predictions using weighted voting metrics_preds = self.aggregate_predictions( @@ -348,8 +465,30 @@ def _execute( self._log_metrics(prefix, metrics, len(batch)) return d - def aggregate_predictions(self, predictions, confidences): - """Implements weighted voting based on trustworthiness.""" + def aggregate_predictions( + self, predictions: Dict[str, torch.Tensor], confidences: Dict[str, torch.Tensor] + ) -> torch.Tensor: + """ + Implements weighted voting based on trustworthiness. + + This method aggregates predictions from multiple models using a weighted voting mechanism. + The weight of each model's prediction is determined by its True Positive Value (TPV) and + False Positive Value (FPV), scaled by the confidence score. + + Args: + predictions (Dict[str, torch.Tensor]): + A dictionary mapping model names to their respective binary class predictions + (shape: `[batch_size, num_classes]`). + confidences (Dict[str, torch.Tensor]): + A dictionary mapping model names to their respective confidence scores + (shape: `[batch_size, num_classes]`). + + Returns: + torch.Tensor: + A tensor of final aggregated predictions based on weighted voting + (shape: `[batch_size, num_classes]`), where values are `True` for positive class + and `False` otherwise. + """ batch_size, num_classes = list(predictions.values())[0].shape true_scores = torch.zeros(batch_size, num_classes, device=self.device) false_scores = torch.zeros(batch_size, num_classes, device=self.device) @@ -369,7 +508,7 @@ def aggregate_predictions(self, predictions, confidences): # Avoid division by zero: Set valid_counts to 1 where it's zero valid_counts = self._num_models_per_label.clamp(min=1) - # Normalize by valid contributions to prevent bias, this step can be optional depending upon scenario + # Normalize by valid contributions to prevent bias final_preds = (true_scores / valid_counts) > (false_scores / valid_counts) return final_preds @@ -398,33 +537,76 @@ def _process_for_loss( class ChebiEnsembleLearning(_EnsembleBase): + """ + A specialized ensemble learning class for ChEBI classification. + + This ensemble combines multiple models by concatenating their logits and + passing them through a feedforward neural network (FFN) for final predictions. + """ NAME = "ChebiEnsembleLearning" - def __init__(self, model_configs: Dict[str, Dict], **kwargs): - super().__init__(model_configs, **kwargs) + def __init__(self, model_configs: Dict[str, Dict], **kwargs: Any): + """ + Initializes the ChebiEnsembleLearning class. - from chebai.models.ffn import FFN + Args: + model_configs (Dict[str, Dict]): Configuration dictionary for each model. + **kwargs (Any): Additional keyword arguments for configuring the FFN. + """ + super().__init__(model_configs, **kwargs) ffn_kwargs = kwargs.copy() ffn_kwargs["input_size"] = len(self.model_configs) * int(kwargs["out_dim"]) self.ffn: FFN = FFN(**ffn_kwargs) def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: + """ + Performs a forward pass through the ensemble model. + + Args: + data (Dict[str, Tensor]): Input data dictionary for the models. + **kwargs (Any): Additional keyword arguments. + + Returns: + Dict[str, Any]: Output from the FFN model. + """ logits_list = [model(data)["logits"] for model in self.models.values()] return self.ffn({"features": torch.cat(logits_list, dim=1)}) def _get_prediction_and_labels( - self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor - ) -> (torch.Tensor, torch.Tensor): + self, data: Dict[str, Any], labels: Tensor, output: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Extracts predictions and labels for evaluation. + + Args: + data (Dict[str, Any]): Input data dictionary. + labels (Tensor): Ground truth labels. + output (Tensor): Model output. + + Returns: + Tuple[Tensor, Tensor]: Processed predictions and labels. + """ return self.ffn._get_prediction_and_labels(data, labels, output) def _process_for_loss( self, - model_output: Dict[str, torch.Tensor], - labels: torch.Tensor, + model_output: Dict[str, Tensor], + labels: Tensor, loss_kwargs: Dict[str, Any], - ) -> (torch.Tensor, torch.Tensor, Dict[str, Any]): + ) -> Tuple[Tensor, Tensor, Dict[str, Any]]: + """ + Processes model output and labels for computing the loss. + + Args: + model_output (Dict[str, Tensor]): Output dictionary from the model. + labels (Tensor): Ground truth labels. + loss_kwargs (Dict[str, Any]): Additional arguments for loss computation. + + Returns: + Tuple[Tensor, Tensor, Dict[str, Any]]: Loss, processed predictions, and additional info. + """ return self.ffn._process_for_loss(model_output, labels, loss_kwargs) From 3ace30a01c978a56dc8d9e59fd2eb7f4a185dc12 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 4 May 2025 23:21:02 +0200 Subject: [PATCH 16/78] remove optimizer kwargs as not needed --- configs/model/ensemble/chebiEnsemble.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/configs/model/ensemble/chebiEnsemble.yml b/configs/model/ensemble/chebiEnsemble.yml index bb529a5d..2524a8be 100644 --- a/configs/model/ensemble/chebiEnsemble.yml +++ b/configs/model/ensemble/chebiEnsemble.yml @@ -1,4 +1,2 @@ class_path: chebai.models.ensemble.ChebiEnsemble init_args: - optimizer_kwargs: - lr: 1e-3 From ddcdeac45c47baa4db8c4155fc866f67cc07dace Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 4 May 2025 23:26:08 +0200 Subject: [PATCH 17/78] add template to ensemble config --- configs/model/ensemble/chebiEnsemble.yml | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/configs/model/ensemble/chebiEnsemble.yml b/configs/model/ensemble/chebiEnsemble.yml index 2524a8be..bc4547b0 100644 --- a/configs/model/ensemble/chebiEnsemble.yml +++ b/configs/model/ensemble/chebiEnsemble.yml @@ -1,2 +1,15 @@ class_path: chebai.models.ensemble.ChebiEnsemble init_args: + model_configs: { + "model_1_name": { + "ckpt_path": "path/to/your/model_1/checkpoint.ckpt", + "class_path": "path/to/your/model_1/class", + "labels_path": "path/to/your/model_1/classes.json", + }, + + "model_2_name": { + "ckpt_path": "path/to/your/model_2/checkpoint.ckpt", + "class_path": "path/to/your/model_2/class", + "labels_path": "path/to/your/model_2/classes.json", + }, + } From eb1798ccbb73ded8351096c2383480dbe8a7de54 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 4 May 2025 23:42:09 +0200 Subject: [PATCH 18/78] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index f616e866..db40440e 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,4 @@ cython_debug/ /results_buffer electra_pretrained.ckpt /lightning_logs +.isort.cfg From ed92ac56bdf2a5114836e1f39e13175fa00de1a4 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 5 May 2025 19:56:22 +0200 Subject: [PATCH 19/78] each model's each label has TPV, FPV --- chebai/models/ensemble.py | 146 ++++++++++++++------------------------ 1 file changed, 55 insertions(+), 91 deletions(-) diff --git a/chebai/models/ensemble.py b/chebai/models/ensemble.py index 5c1ca230..15ce4e06 100644 --- a/chebai/models/ensemble.py +++ b/chebai/models/ensemble.py @@ -61,12 +61,9 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]): AttributeError: If required keys are missing in the configuration. ValueError: If there are duplicate model paths or class paths. """ - path_set = set() - class_set = set() - labels_set = set() + path_set, class_set, labels_set = set(), set(), set() - sets_ = {"path": path_set, "class": class_set, "labels": labels_set} - required_keys = {"class_path", "ckpt_path"} + required_keys = {"class_path", "ckpt_path", "labels_path"} for model_name, config in model_configs.items(): missing_keys = required_keys - config.keys() @@ -78,37 +75,26 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]): model_path = config["ckpt_path"] class_path = config["class_path"] + labels_path = config["labels_path"] if model_path in path_set: raise ValueError( - f"Duplicate model path detected: '{model_path}'. Each model must have a unique path." + f"Duplicate model path detected: '{model_path}'. Each model must have a unique model-checkpoint path." ) - if class_path not in class_set: + if class_path in class_set: raise ValueError( - f"Duplicate class path detected: '{class_path}'. Each model must have a unique path." + f"Duplicate class path detected: '{class_path}'. Each model must have a unique class path." + ) + + if labels_path in labels_set: + raise ValueError( + f"Duplicate labels path: {labels_path}. Each model must have unique labels path." ) path_set.add(model_path) class_set.add(class_path) - - cls._extra_validation(model_name, config, sets_) - - @classmethod - def _extra_validation( - cls, model_name: str, config: Dict[str, Any], sets_: Dict[str, set] - ): - """ - Perform extra validation on the model configuration, if necessary. - - This method can be extended by subclasses to add additional validation logic. - - Args: - model_name (str): The name of the model. - config (Dict[str, Any]): The configuration dictionary for the model. - sets_ (Dict[str, set]): A dictionary of sets to track model paths, class paths, and labels. - """ - pass + labels_path.add(labels_path) def _load_ensemble_models(self): """ @@ -122,6 +108,7 @@ def _load_ensemble_models(self): for model_name in self.model_configs: model_ckpt_path = self.model_configs[model_name]["ckpt_path"] model_class_path = self.model_configs[model_name]["class_path"] + model_labels_path = self.model_configs[model_name]["labels_path"] if not os.path.exists(model_ckpt_path): raise FileNotFoundError( f"Model path '{model_ckpt_path}' for '{model_name}' does not exist." @@ -134,10 +121,15 @@ def _load_ensemble_models(self): module = importlib.import_module(module_path) lightning_cls: LightningModule = getattr(module, class_name) - model = lightning_cls.load_from_checkpoint(model_ckpt_path) + model = lightning_cls.load_from_checkpoint( + model_ckpt_path, input_dim=self.input_dim + ) model.eval() model.freeze() self.models[model_name] = model + self.models_configs[model_name]["labels"] = self._load_model_labels( + model_labels_path + ) except ModuleNotFoundError: print(f"Module '{module_path}' not found!") @@ -149,21 +141,37 @@ def _load_ensemble_models(self): f"Failed to load model '{model_name}' from {model_ckpt_path}: \n {e}" ) - def _load_data_module_labels(self): - """ - Loads the label mapping from the classes.txt file for loaded data. + @staticmethod + def _load_model_labels(labels_path: str) -> Dict[str, float]: + if not os.path.exists(labels_path): + raise FileNotFoundError(f"{labels_path} does not exist.") - Raises: - FileNotFoundError: If the classes.txt file does not exist. - """ - classes_txt_file = os.path.join(self.data_processed_dir_main, "classes.txt") - if not os.path.exists(classes_txt_file): - raise FileNotFoundError(f"{classes_txt_file} does not exist") - else: - with open(classes_txt_file, "r") as f: - for line in f: - if line.strip() not in self.dm_labels: - self.dm_labels[line.strip()] = len(self.dm_labels) + if not labels_path.endswith(".json"): + raise TypeError(f"{labels_path} is not a JSON file.") + + with open(labels_path, "r") as f: + model_labels = json.load(f) + + labels_dict = {} + for label, label_dict in model_labels.items(): + msg = f"for label {label}" + if "TPV" not in label_dict.keys() or "FPV" not in label_dict.keys(): + raise AttributeError(f"Missing keys 'TPV' and/or 'FPV' {msg}") + + # Validate 'tpv' and 'fpv' are either floats or convertible to float + for key in ["TPV", "FPV"]: + try: + value = float(label_dict[key]) + if value < 0: + raise ValueError( + f"'{key}' must be non-negative but got {value} {msg}" + ) + except (TypeError, ValueError): + raise ValueError( + f"'{key}' must be a float or convertible to float, but got {label_dict[key]} {msg}" + ) + labels_dict[label][key] = value + return labels_dict @abstractmethod def _get_prediction_and_labels( @@ -182,6 +190,12 @@ def _get_prediction_and_labels( """ pass + def controller(self): + pass + + def consolidator(self): + pass + class ChebiEnsemble(_EnsembleBase): """ @@ -212,56 +226,6 @@ def __init__(self, model_configs: Dict[str, Dict], **kwargs): self._num_models_per_label: Optional[torch.Tensor] = None self._generate_model_label_mask() - @classmethod - def _extra_validation( - cls, model_name: str, config: Dict[str, Any], sets_: Dict[str, set] - ): - """ - Additional validation for the ensemble model configuration. - - Args: - model_name (str): The model name. - config (Dict[str, Any]): The configuration dictionary. - sets_ (Dict[str, set]): The set of paths for labels. - - Raises: - AttributeError: If the 'labels_path' key is missing. - ValueError: If the 'labels_path' contains duplicate entries or certain are not convertible to float. - """ - if "labels_path" not in config: - raise AttributeError("Missing 'labels_path' key in config!") - - labels_path = config["labels_path"] - if labels_path not in sets_["labels"]: - raise ValueError( - f"Duplicate labels path detected: '{labels_path}'. Each model must have a unique path." - ) - - sets_["labels"].add(labels_path) - - with open(labels_path, "r") as f: - model_labels = json.load(f) - - for label, label_dict in model_labels.items(): - - if "TPV" not in label_dict.keys() or "FPV" not in label_dict.keys(): - raise AttributeError( - f"Missing keys 'TPV' and/or 'FPV' in model '{model_name}' configuration." - ) - - # Validate 'tpv' and 'fpv' are either floats or convertible to float - for key in ["TPV", "FPV"]: - try: - value = float(label_dict[key]) - if value < 0: - raise ValueError( - f"'{key}' in model '{model_name}' and label '{label}' must be non-negative, but got {value}." - ) - except (TypeError, ValueError): - raise ValueError( - f"'{key}' in model '{model_name}' and label '{label}' must be a float or convertible to float, but got {label_dict[key]}." - ) - def _generate_model_label_mask(self): """ Generates a mask indicating the labels handled by each model, and retrieves corresponding the TPV and FPV values From 0ec03b10b975b1350f21bd6388bed10688131028 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 15 May 2025 15:21:52 +0200 Subject: [PATCH 20/78] remove ensemble learning class --- chebai/models/ensemble.py | 78 --------------------------------------- 1 file changed, 78 deletions(-) diff --git a/chebai/models/ensemble.py b/chebai/models/ensemble.py index 15ce4e06..25f748c5 100644 --- a/chebai/models/ensemble.py +++ b/chebai/models/ensemble.py @@ -498,81 +498,3 @@ def _process_for_loss( if labels is not None: labels = labels.float() return model_output["logits"], labels, kwargs_copy - - -class ChebiEnsembleLearning(_EnsembleBase): - """ - A specialized ensemble learning class for ChEBI classification. - - This ensemble combines multiple models by concatenating their logits and - passing them through a feedforward neural network (FFN) for final predictions. - """ - - NAME = "ChebiEnsembleLearning" - - def __init__(self, model_configs: Dict[str, Dict], **kwargs: Any): - """ - Initializes the ChebiEnsembleLearning class. - - Args: - model_configs (Dict[str, Dict]): Configuration dictionary for each model. - **kwargs (Any): Additional keyword arguments for configuring the FFN. - """ - super().__init__(model_configs, **kwargs) - - ffn_kwargs = kwargs.copy() - ffn_kwargs["input_size"] = len(self.model_configs) * int(kwargs["out_dim"]) - self.ffn: FFN = FFN(**ffn_kwargs) - - def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: - """ - Performs a forward pass through the ensemble model. - - Args: - data (Dict[str, Tensor]): Input data dictionary for the models. - **kwargs (Any): Additional keyword arguments. - - Returns: - Dict[str, Any]: Output from the FFN model. - """ - logits_list = [model(data)["logits"] for model in self.models.values()] - return self.ffn({"features": torch.cat(logits_list, dim=1)}) - - def _get_prediction_and_labels( - self, data: Dict[str, Any], labels: Tensor, output: Tensor - ) -> Tuple[Tensor, Tensor]: - """ - Extracts predictions and labels for evaluation. - - Args: - data (Dict[str, Any]): Input data dictionary. - labels (Tensor): Ground truth labels. - output (Tensor): Model output. - - Returns: - Tuple[Tensor, Tensor]: Processed predictions and labels. - """ - return self.ffn._get_prediction_and_labels(data, labels, output) - - def _process_for_loss( - self, - model_output: Dict[str, Tensor], - labels: Tensor, - loss_kwargs: Dict[str, Any], - ) -> Tuple[Tensor, Tensor, Dict[str, Any]]: - """ - Processes model output and labels for computing the loss. - - Args: - model_output (Dict[str, Tensor]): Output dictionary from the model. - labels (Tensor): Ground truth labels. - loss_kwargs (Dict[str, Any]): Additional arguments for loss computation. - - Returns: - Tuple[Tensor, Tensor, Dict[str, Any]]: Loss, processed predictions, and additional info. - """ - return self.ffn._process_for_loss(model_output, labels, loss_kwargs) - - -if __name__ == "__main__": - pass From dabe5ff3cce57901e2079d549a8f6fd681b5f708 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 15 May 2025 22:51:47 +0200 Subject: [PATCH 21/78] update code change --- chebai/models/ensemble.py | 82 +++++++++++++++++++-------------------- 1 file changed, 40 insertions(+), 42 deletions(-) diff --git a/chebai/models/ensemble.py b/chebai/models/ensemble.py index 25f748c5..eaa22bc7 100644 --- a/chebai/models/ensemble.py +++ b/chebai/models/ensemble.py @@ -9,7 +9,6 @@ from torch import Tensor from chebai.models import ChebaiBaseNet -from chebai.models.ffn import FFN from chebai.preprocessing.structures import XYData @@ -39,7 +38,8 @@ def __init__( **kwargs: Additional arguments for initialization. """ super().__init__(**kwargs) - self._validate_model_configs(model_configs) + if kwargs.get("_validate_configs", True): + self._validate_model_configs(model_configs) self.data_processed_dir_main = data_processed_dir_main self.models: Dict[str, LightningModule] = {} @@ -79,7 +79,8 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]): if model_path in path_set: raise ValueError( - f"Duplicate model path detected: '{model_path}'. Each model must have a unique model-checkpoint path." + f"Duplicate model path detected: '{model_path}'. " + f"Each model must have a unique model-checkpoint path." ) if class_path in class_set: @@ -94,16 +95,11 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]): path_set.add(model_path) class_set.add(class_path) - labels_path.add(labels_path) + labels_set.add(labels_path) def _load_ensemble_models(self): """ Loads the models specified in the configuration and initializes them. - - Raises: - FileNotFoundError: If the model checkpoint path does not exist. - ModuleNotFoundError: If the module containing the model class is not found. - AttributeError: If the specified class is not found within the module. """ for model_name in self.model_configs: model_ckpt_path = self.model_configs[model_name]["ckpt_path"] @@ -116,33 +112,38 @@ def _load_ensemble_models(self): class_name = model_class_path.split(".")[-1] module_path = ".".join(model_class_path.split(".")[:-1]) + module = importlib.import_module(module_path) + lightning_cls: LightningModule = getattr(module, class_name) - try: - module = importlib.import_module(module_path) - lightning_cls: LightningModule = getattr(module, class_name) + model = lightning_cls.load_from_checkpoint( + model_ckpt_path, input_dim=self.input_dim + ) + model.eval() + model.freeze() - model = lightning_cls.load_from_checkpoint( - model_ckpt_path, input_dim=self.input_dim - ) - model.eval() - model.freeze() - self.models[model_name] = model - self.models_configs[model_name]["labels"] = self._load_model_labels( - model_labels_path - ) + self.models[model_name] = model + self.model_configs[model_name]["labels"] = self._load_model_labels( + model_labels_path, model_name + ) - except ModuleNotFoundError: - print(f"Module '{module_path}' not found!") - except AttributeError: - print(f"Class '{class_name}' not found in '{module_path}'!") + def _load_data_module_labels(self): + """ + Loads the label mapping from the classes.txt file for loaded data. - except Exception as e: - raise RuntimeError( - f"Failed to load model '{model_name}' from {model_ckpt_path}: \n {e}" - ) + Raises: + FileNotFoundError: If the classes.txt file does not exist. + """ + classes_txt_file = os.path.join(self.data_processed_dir_main, "classes.txt") + if not os.path.exists(classes_txt_file): + raise FileNotFoundError(f"{classes_txt_file} does not exist") + else: + with open(classes_txt_file, "r") as f: + for line in f: + if line.strip() not in self.dm_labels: + self.dm_labels[line.strip()] = len(self.dm_labels) @staticmethod - def _load_model_labels(labels_path: str) -> Dict[str, float]: + def _load_model_labels(labels_path: str, model_name: str) -> Dict[str, float]: if not os.path.exists(labels_path): raise FileNotFoundError(f"{labels_path} does not exist.") @@ -154,7 +155,7 @@ def _load_model_labels(labels_path: str) -> Dict[str, float]: labels_dict = {} for label, label_dict in model_labels.items(): - msg = f"for label {label}" + msg = f"for model {model_name} for label {label}" if "TPV" not in label_dict.keys() or "FPV" not in label_dict.keys(): raise AttributeError(f"Missing keys 'TPV' and/or 'FPV' {msg}") @@ -170,7 +171,7 @@ def _load_model_labels(labels_path: str) -> Dict[str, float]: raise ValueError( f"'{key}' must be a float or convertible to float, but got {label_dict[key]} {msg}" ) - labels_dict[label][key] = value + labels_dict.setdefault(label, {})[key] = value return labels_dict @abstractmethod @@ -193,7 +194,9 @@ def _get_prediction_and_labels( def controller(self): pass - def consolidator(self): + def consolidator( + self, + ): pass @@ -238,19 +241,14 @@ def _generate_model_label_mask(self): num_models_per_label = torch.zeros(1, self.out_dim, device=self.device) for model_name, model_config in self.model_configs.items(): - labels_path = model_config["labels_path"] - if not os.path.exists(labels_path): - raise FileNotFoundError(f"Labels path '{labels_path}' does not exist.") - - with open(labels_path, "r") as f: - labels_dict = json.load(f) + labels_dict = model_config["labels"] model_label_indices, tpv_label_values, fpv_label_values = [], [], [] for label in labels_dict.keys(): if label in self.dm_labels: model_label_indices.append(self.dm_labels[label]) - tpv_label_values.append(float(labels_dict[label]["TPV"])) - fpv_label_values.append(float(labels_dict[label]["FPV"])) + tpv_label_values.append(labels_dict[label]["TPV"]) + fpv_label_values.append(labels_dict[label]["FPV"]) if not all([model_label_indices, tpv_label_values, fpv_label_values]): raise ValueError(f"Values are empty for labels of model {model_name}") @@ -318,7 +316,7 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: confidences[name] = confidence total_logits += output[ "logits" - ] # Don't play a role here, just for lightning flow completeness + ] # This doesn't play a role here, just for lightning flow completeness return { "logits": total_logits, From 7db384ab64ceaf1063390b4080763064c29f277a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 16 May 2025 21:26:02 +0200 Subject: [PATCH 22/78] add ensemble base to new python dir --- chebai/ensemble/__init__.py | 0 chebai/ensemble/base.py | 260 ++++++++++++++++++++++++++++++++++++ 2 files changed, 260 insertions(+) create mode 100644 chebai/ensemble/__init__.py create mode 100644 chebai/ensemble/base.py diff --git a/chebai/ensemble/__init__.py b/chebai/ensemble/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chebai/ensemble/base.py b/chebai/ensemble/base.py new file mode 100644 index 00000000..531f72f3 --- /dev/null +++ b/chebai/ensemble/base.py @@ -0,0 +1,260 @@ +import importlib +import json +import os +from abc import ABC, abstractmethod +from collections import deque +from typing import Deque, Dict, Optional + +import torch +from lightning import LightningModule + +from chebai.models import ChebaiBaseNet + + +class EnsembleBase(ABC): + """ + Base class for ensemble models in the Chebai framework. + + Inherits from ChebaiBaseNet and provides functionality to load multiple models, + validate configuration, and manage predictions. + + Attributes: + data_processed_dir_main (str): Directory where the processed data is stored. + models (Dict[str, LightningModule]): A dictionary of loaded models. + model_configs (Dict[str, Dict]): Configuration dictionary for models in the ensemble. + dm_labels (Dict[str, int]): Mapping of label names to integer indices. + """ + + def __init__( + self, model_configs: Dict[str, Dict], data_processed_dir_main: str, **kwargs + ): + """ + Initializes the ensemble model and loads configuration, models, and labels. + + Args: + model_configs (Dict[str, Dict]): Dictionary of model configurations. + data_processed_dir_main (str): Path to the processed data directory. + **kwargs: Additional arguments for initialization. + """ + if kwargs.get("_validate_configs", False): + self._validate_model_configs(model_configs) + + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.input_dim = kwargs.get("input_dim", None) + self.num_of_labels: Optional[int] = ( + None # will be set by `_load_data_module_labels` method + ) + self.data_processed_dir_main = data_processed_dir_main + self.models: Dict[str, LightningModule] = {} + self.model_configs = model_configs + self.dm_labels: Dict[str, int] = {} + + self._load_data_module_labels() + self._num_models_per_label: torch.Tensor = torch.zeros( + 1, self.num_of_labels, device=self.device + ) + self._model_queue: Deque = deque() + self._collated_data = None + + @classmethod + def _validate_model_configs(cls, model_configs: Dict[str, Dict]): + """ + Validates the model configurations to ensure required keys are present. + + Args: + model_configs (Dict[str, Dict]): Dictionary of model configurations. + + Raises: + AttributeError: If required keys are missing in the configuration. + ValueError: If there are duplicate model paths or class paths. + """ + path_set, class_set, labels_set = set(), set(), set() + + required_keys = {"class_path", "ckpt_path", "labels_path"} + + for model_name, config in model_configs.items(): + missing_keys = required_keys - config.keys() + + if missing_keys: + raise AttributeError( + f"Missing keys {missing_keys} in model '{model_name}' configuration." + ) + + model_path = config["ckpt_path"] + class_path = config["class_path"] + labels_path = config["labels_path"] + + if model_path in path_set: + raise ValueError( + f"Duplicate model path detected: '{model_path}'. " + f"Each model must have a unique model-checkpoint path." + ) + + if class_path in class_set: + raise ValueError( + f"Duplicate class path detected: '{class_path}'. Each model must have a unique class path." + ) + + if labels_path in labels_set: + raise ValueError( + f"Duplicate labels path: {labels_path}. Each model must have unique labels path." + ) + + path_set.add(model_path) + class_set.add(class_path) + labels_set.add(labels_path) + + def _load_data_module_labels(self): + """ + Loads the label mapping from the classes.txt file for loaded data. + + Raises: + FileNotFoundError: If the classes.txt file does not exist. + """ + classes_txt_file = os.path.join(self.data_processed_dir_main, "classes.txt") + if not os.path.exists(classes_txt_file): + raise FileNotFoundError(f"{classes_txt_file} does not exist") + else: + with open(classes_txt_file, "r") as f: + for line in f: + if line.strip() not in self.dm_labels: + self.dm_labels[line.strip()] = len(self.dm_labels) + self.num_of_labels = len(self.dm_labels) + + def run_ensemble(self): + batch_size = 10 + true_scores = torch.zeros(batch_size, self.num_of_labels, device=self.device) + false_scores = torch.zeros(batch_size, self.num_of_labels, device=self.device) + + while self._model_queue: + model, model_props = self._load_model_and_its_props( + self._model_queue.popleft() + ) + pred_conf_dict = self._controller(model, model_props) + self._consolidator( + pred_conf_dict, + model_props, + true_scores=true_scores, + false_scores=false_scores, + ) + + self._consolidate_on_finish(true_scores=true_scores, false_scores=false_scores) + + def _load_model_and_its_props(self, model_name): + """ + Loads the models specified in the configuration and initializes them. + """ + model_ckpt_path = self.model_configs[model_name]["ckpt_path"] + model_class_path = self.model_configs[model_name]["class_path"] + model_labels_path = self.model_configs[model_name]["labels_path"] + if not os.path.exists(model_ckpt_path): + raise FileNotFoundError( + f"Model path '{model_ckpt_path}' for '{model_name}' does not exist." + ) + + class_name = model_class_path.split(".")[-1] + module_path = ".".join(model_class_path.split(".")[:-1]) + module = importlib.import_module(module_path) + lightning_cls: LightningModule = getattr(module, class_name) + assert isinstance(lightning_cls, type), f"{class_name} is not a class." + assert issubclass( + lightning_cls, ChebaiBaseNet + ), f"{class_name} must inherit from ChebaiBaseNet" + + model = lightning_cls.load_from_checkpoint( + model_ckpt_path, input_dim=self.input_dim + ) + model.eval() + model.freeze() + + model_label_props = self._generate_model_label_props( + model_name, model_labels_path + ) + + return model, model_label_props + + def _generate_model_label_props(self, model_name: str, labels_path: str): + """ + Generates a mask indicating the labels handled by each model, and retrieves corresponding the TPV and FPV values + as tensors. + + Raises: + FileNotFoundError: If the labels path does not exist. + ValueError: If label values are empty for any model. + """ + labels_dict = self._load_model_labels(labels_path) + + model_label_indices, tpv_label_values, fpv_label_values = [], [], [] + for label in labels_dict.keys(): + if label in self.dm_labels: + try: + self._validate_model_labels_json_element(labels_dict[label]) + except Exception as e: + raise Exception(f"Label '{label}' has an unexpected error: {e}") + + model_label_indices.append(self.dm_labels[label]) + tpv_label_values.append(labels_dict[label]["TPV"]) + fpv_label_values.append(labels_dict[label]["FPV"]) + + if not all([model_label_indices, tpv_label_values, fpv_label_values]): + raise ValueError(f"Values are empty for labels of model {model_name}") + + # Create masks to apply predictions only to known classes + mask = torch.zeros(self.num_of_labels, device=self.device, dtype=torch.bool) + mask[torch.tensor(model_label_indices, dtype=torch.int, device=self.device)] = ( + True + ) + + tpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self.device) + fpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self.device) + + tpv_tensor[mask] = torch.tensor( + tpv_label_values, dtype=torch.float, device=self.device + ) + fpv_tensor[mask] = torch.tensor( + fpv_label_values, dtype=torch.float, device=self.device + ) + self._num_models_per_label += mask + return {"mask": mask, "tpv_tensor": tpv_tensor, "fpv_tensor": fpv_tensor} + + @staticmethod + def _load_model_labels(labels_path: str) -> Dict[str, Dict[str, float]]: + if not os.path.exists(labels_path): + raise FileNotFoundError(f"{labels_path} does not exist.") + + if not labels_path.endswith(".json"): + raise TypeError(f"{labels_path} is not a JSON file.") + + with open(labels_path, "r") as f: + model_labels = json.load(f) + return model_labels + + @staticmethod + def _validate_model_labels_json_element(label_dict: Dict[str, float]): + if "TPV" not in label_dict.keys() or "FPV" not in label_dict.keys(): + raise AttributeError(f"Missing keys 'TPV' and/or 'FPV'") + + # Validate 'tpv' and 'fpv' are either floats or convertible to float + for key in ["TPV", "FPV"]: + try: + value = float(label_dict[key]) + if value < 0: + raise ValueError(f"'{key}' must be non-negative but got {value}") + except (TypeError, ValueError): + raise ValueError( + f"'{key}' must be a float or convertible to float, but got {label_dict[key]}" + ) + + @abstractmethod + def _controller(self, model, model_props, **kwargs): + pass + + @abstractmethod + def _consolidator( + self, pred_conf_dict, model_props, *, true_scores, false_scores, **kwargs + ): + pass + + @abstractmethod + def _consolidate_on_finish(self, *, true_scores, false_scores): + pass From 65a51e0af5076f0fb58e1c7618e18a12a5872338 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 16 May 2025 21:26:37 +0200 Subject: [PATCH 23/78] add ensemble controller --- chebai/ensemble/controller.py | 58 +++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 chebai/ensemble/controller.py diff --git a/chebai/ensemble/controller.py b/chebai/ensemble/controller.py new file mode 100644 index 00000000..1e76e8a8 --- /dev/null +++ b/chebai/ensemble/controller.py @@ -0,0 +1,58 @@ +import os.path +from abc import ABC + +import torch + +from chebai.ensemble.base import EnsembleBase +from chebai.models import ChebaiBaseNet +from chebai.preprocessing.collate import RaggedCollator + + +class _Controller(EnsembleBase, ABC): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._collator = RaggedCollator() + + self._collated_data = self._load_and_collate_data() + self.total_data_size: int = len(self._collated_data) + + def _load_and_collate_data(self): + data = torch.load( + os.path.join(self.data_processed_dir_main, "data.pt"), + weights_only=False, + map_location=self.device, + ) + collated_data = self._collator(data) + collated_data.x = collated_data.to_x(self.device) + if collated_data.y is not None: + collated_data.y = collated_data.to_y(self.device) + return collated_data + + def _forward_pass(self, model: ChebaiBaseNet): + processable_data = model._process_batch(self._collated_data, 0) + del processable_data["loss_kwargs"] + model_output = model(processable_data, **processable_data["model_kwargs"]) + return model_output + + def _get_pred_conf_from_model_output(self, model_output, model_label_mask): + # Consider logits and confidence only for valid classes + sigmoid_logits = torch.sigmoid(model_output["logits"]) + prediction = torch.full( + (self.total_data_size, self.num_of_labels), -1, dtype=torch.bool + ) + confidence = torch.full( + (self.total_data_size, self.num_of_labels), -1, dtype=torch.float + ) + prediction[:, model_label_mask] = sigmoid_logits > 0.5 + confidence[:, model_label_mask] = 2 * torch.abs(sigmoid_logits - 0.5) + return {"prediction": prediction, "confidence": confidence} + + +class SimpleController(_Controller): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._model_queue = list(self.model_configs.keys()) + + def _controller(self, model, model_props, **kwargs): + model_output = self._forward_pass(model) + return self._get_pred_conf_from_model_output(model_output, model_props["mask"]) From 37d46f76ae42d4e3fac86b5e01f2552b0d2a53bf Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 18 May 2025 15:08:55 +0200 Subject: [PATCH 24/78] add utils.print_metrics to ensemble --- chebai/ensemble/base.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/chebai/ensemble/base.py b/chebai/ensemble/base.py index 531f72f3..b0acd7a4 100644 --- a/chebai/ensemble/base.py +++ b/chebai/ensemble/base.py @@ -9,6 +9,7 @@ from lightning import LightningModule from chebai.models import ChebaiBaseNet +from chebai.result.classification import print_metrics class EnsembleBase(ABC): @@ -39,7 +40,7 @@ def __init__( if kwargs.get("_validate_configs", False): self._validate_model_configs(model_configs) - self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.input_dim = kwargs.get("input_dim", None) self.num_of_labels: Optional[int] = ( None # will be set by `_load_data_module_labels` method @@ -131,6 +132,8 @@ def run_ensemble(self): self._model_queue.popleft() ) pred_conf_dict = self._controller(model, model_props) + del model # Model can be huge to keep it in memory, delete as no longer needed + self._consolidator( pred_conf_dict, model_props, @@ -138,7 +141,15 @@ def run_ensemble(self): false_scores=false_scores, ) - self._consolidate_on_finish(true_scores=true_scores, false_scores=false_scores) + final_preds = self._consolidate_on_finish( + true_scores=true_scores, false_scores=false_scores + ) + print_metrics( + final_preds, + self._collated_data.y, + self.device, + classes=list(self.dm_labels.keys()), + ) def _load_model_and_its_props(self, model_name): """ From bc6e131316b0183814ebb55e32c997634c0614d7 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 18 May 2025 15:09:35 +0200 Subject: [PATCH 25/78] add consolidator --- chebai/ensemble/consolidator.py | 44 +++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 chebai/ensemble/consolidator.py diff --git a/chebai/ensemble/consolidator.py b/chebai/ensemble/consolidator.py new file mode 100644 index 00000000..0332e17e --- /dev/null +++ b/chebai/ensemble/consolidator.py @@ -0,0 +1,44 @@ +from abc import ABC + +from chebai.ensemble.base import EnsembleBase + + +class WeightedMajorityVoting(EnsembleBase, ABC): + def _consolidator( + self, pred_conf_dict, model_props, *, true_scores, false_scores, **kwargs + ): + tpv = model_props["tpv_tensor"] + npv = model_props["fpv_tensor"] + conf = pred_conf_dict["confidence"] + + # Determine which classes the model provides predictions for + mask = model_props["mask"] + weight = conf * (tpv * conf + npv * (1 - conf)) + + # Apply mask: Only update scores for valid classes + true_scores += weight * conf * mask + false_scores += weight * (1 - conf) * mask + + def _consolidate_on_finish(self, *, true_scores, false_scores): + # Avoid division by zero: Set valid_counts to 1 where it's zero + valid_counts = self._num_models_per_label.clamp(min=1) + + # Normalize by valid contributions to prevent bias + final_preds = (true_scores / valid_counts) > (false_scores / valid_counts) + return final_preds + + +class MajorityVoting(EnsembleBase, ABC): + def _consolidator( + self, pred_conf_dict, model_props, *, true_scores, false_scores, **kwargs + ): + conf = pred_conf_dict["confidence"] + + # Determine which classes the model provides predictions for + mask = model_props["mask"] + # Apply mask: Only update scores for valid classes + true_scores += conf * mask + false_scores += (1 - conf) * mask + + def _consolidate_on_finish(self, *, true_scores, false_scores): + return true_scores > false_scores From b9dbd97cb05b3fe68331a3af2e4f135d8c46e17f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 18 May 2025 15:10:23 +0200 Subject: [PATCH 26/78] add to needed classes to init --- chebai/ensemble/__init__.py | 11 +++++++++++ chebai/ensemble/controller.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/chebai/ensemble/__init__.py b/chebai/ensemble/__init__.py index e69de29b..225b5a9b 100644 --- a/chebai/ensemble/__init__.py +++ b/chebai/ensemble/__init__.py @@ -0,0 +1,11 @@ +from .consolidator import WeightedMajorityVoting +from .controller import NoActivationCondition + + +class FullEnsembleWMV(NoActivationCondition, WeightedMajorityVoting): + """Full Ensemble (no activation condition) with Weighted Majority Voting""" + + pass + + +__all__ = ["FullEnsembleWMV"] diff --git a/chebai/ensemble/controller.py b/chebai/ensemble/controller.py index 1e76e8a8..0e1945e0 100644 --- a/chebai/ensemble/controller.py +++ b/chebai/ensemble/controller.py @@ -48,7 +48,7 @@ def _get_pred_conf_from_model_output(self, model_output, model_label_mask): return {"prediction": prediction, "confidence": confidence} -class SimpleController(_Controller): +class NoActivationCondition(_Controller): def __init__(self, **kwargs): super().__init__(**kwargs) self._model_queue = list(self.model_configs.keys()) From 4d6856d9e236e57a03e766d20dd92d2ceb82a65f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 19 May 2025 01:09:34 +0200 Subject: [PATCH 27/78] add rank_zero_info printing --- chebai/ensemble/base.py | 42 ++++++++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/chebai/ensemble/base.py b/chebai/ensemble/base.py index b0acd7a4..ba6952a8 100644 --- a/chebai/ensemble/base.py +++ b/chebai/ensemble/base.py @@ -7,6 +7,7 @@ import torch from lightning import LightningModule +from lightning_utilities.core.rank_zero import rank_zero_info from chebai.models import ChebaiBaseNet from chebai.result.classification import print_metrics @@ -37,7 +38,7 @@ def __init__( data_processed_dir_main (str): Path to the processed data directory. **kwargs: Additional arguments for initialization. """ - if kwargs.get("_validate_configs", False): + if bool(kwargs.get("_validate_configs", True)): self._validate_model_configs(model_configs) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -113,6 +114,8 @@ def _load_data_module_labels(self): FileNotFoundError: If the classes.txt file does not exist. """ classes_txt_file = os.path.join(self.data_processed_dir_main, "classes.txt") + rank_zero_info(f"Loading {classes_txt_file} ....") + if not os.path.exists(classes_txt_file): raise FileNotFoundError(f"{classes_txt_file} does not exist") else: @@ -128,12 +131,15 @@ def run_ensemble(self): false_scores = torch.zeros(batch_size, self.num_of_labels, device=self.device) while self._model_queue: - model, model_props = self._load_model_and_its_props( - self._model_queue.popleft() - ) + model_name = self._model_queue.popleft() + rank_zero_info(f"Processing model: {model_name}") + model, model_props = self._load_model_and_its_props(model_name) + + rank_zero_info("\t Passing model to controller to generate predictions...") pred_conf_dict = self._controller(model, model_props) del model # Model can be huge to keep it in memory, delete as no longer needed + rank_zero_info("\t Passing predictions to consolidator to aggregation") self._consolidator( pred_conf_dict, model_props, @@ -141,6 +147,9 @@ def run_ensemble(self): false_scores=false_scores, ) + rank_zero_info( + f"Consolidate predictions of the ensemble: {self.__class__.__name__}" + ) final_preds = self._consolidate_on_finish( true_scores=true_scores, false_scores=false_scores ) @@ -172,19 +181,21 @@ def _load_model_and_its_props(self, model_name): lightning_cls, ChebaiBaseNet ), f"{class_name} must inherit from ChebaiBaseNet" - model = lightning_cls.load_from_checkpoint( - model_ckpt_path, input_dim=self.input_dim - ) - model.eval() - model.freeze() - - model_label_props = self._generate_model_label_props( - model_name, model_labels_path - ) + try: + model = lightning_cls.load_from_checkpoint( + model_ckpt_path, input_dim=self.input_dim + ) + model.eval() + model.freeze() + model_label_props = self._generate_model_label_props(model_labels_path) + except Exception as e: + raise RuntimeError( + f"For model {model_name} following exception as occurred \n Error: {e}" + ) return model, model_label_props - def _generate_model_label_props(self, model_name: str, labels_path: str): + def _generate_model_label_props(self, labels_path: str): """ Generates a mask indicating the labels handled by each model, and retrieves corresponding the TPV and FPV values as tensors. @@ -193,6 +204,7 @@ def _generate_model_label_props(self, model_name: str, labels_path: str): FileNotFoundError: If the labels path does not exist. ValueError: If label values are empty for any model. """ + rank_zero_info("\t Generating mask model's labels and other properties") labels_dict = self._load_model_labels(labels_path) model_label_indices, tpv_label_values, fpv_label_values = [], [], [] @@ -208,7 +220,7 @@ def _generate_model_label_props(self, model_name: str, labels_path: str): fpv_label_values.append(labels_dict[label]["FPV"]) if not all([model_label_indices, tpv_label_values, fpv_label_values]): - raise ValueError(f"Values are empty for labels of model {model_name}") + raise ValueError(f"Values are empty for labels of the model") # Create masks to apply predictions only to known classes mask = torch.zeros(self.num_of_labels, device=self.device, dtype=torch.bool) From 825916e7c950689c33e54198155aa62071790ec8 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 19 May 2025 01:10:55 +0200 Subject: [PATCH 28/78] add script for running ensemble --- chebai/ensemble/ensemble_run_script.py | 39 ++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 chebai/ensemble/ensemble_run_script.py diff --git a/chebai/ensemble/ensemble_run_script.py b/chebai/ensemble/ensemble_run_script.py new file mode 100644 index 00000000..127cc687 --- /dev/null +++ b/chebai/ensemble/ensemble_run_script.py @@ -0,0 +1,39 @@ +import importlib + +import yaml +from jsonargparse import ArgumentParser + +from chebai.ensemble.base import EnsembleBase + + +def load_class(class_path: str): + """Dynamically import a class from a full dotted path.""" + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +def load_config_and_instantiate(config_path: str): + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + class_path = config["class_path"] + init_args = config.get("init_args", {}) + + cls = load_class(class_path) + if not issubclass(cls, EnsembleBase): + raise TypeError(f"{cls} must be subclass of EnsembleBase") + return cls(**init_args) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--config", type=str, help="Path to the YAML config file") + + args = parser.parse_args() + ensemble = load_config_and_instantiate(args.config) + + if not isinstance(ensemble, EnsembleBase): + raise TypeError("Object must be an instance of `EnsembleBase`") + + ensemble.run_ensemble() From 69c52634c5460b7e07e2004c569569a73bb58185 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 19 May 2025 01:11:36 +0200 Subject: [PATCH 29/78] ensemble minor changes --- chebai/ensemble/consolidator.py | 2 +- chebai/ensemble/controller.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/chebai/ensemble/consolidator.py b/chebai/ensemble/consolidator.py index 0332e17e..7a042823 100644 --- a/chebai/ensemble/consolidator.py +++ b/chebai/ensemble/consolidator.py @@ -1,6 +1,6 @@ from abc import ABC -from chebai.ensemble.base import EnsembleBase +from .base import EnsembleBase class WeightedMajorityVoting(EnsembleBase, ABC): diff --git a/chebai/ensemble/controller.py b/chebai/ensemble/controller.py index 0e1945e0..b19974cf 100644 --- a/chebai/ensemble/controller.py +++ b/chebai/ensemble/controller.py @@ -1,12 +1,15 @@ import os.path from abc import ABC +from collections import deque +from typing import Deque import torch -from chebai.ensemble.base import EnsembleBase from chebai.models import ChebaiBaseNet from chebai.preprocessing.collate import RaggedCollator +from .base import EnsembleBase + class _Controller(EnsembleBase, ABC): def __init__(self, **kwargs): @@ -14,11 +17,12 @@ def __init__(self, **kwargs): self._collator = RaggedCollator() self._collated_data = self._load_and_collate_data() + self.input_dim = len(self._collated_data.x[0]) self.total_data_size: int = len(self._collated_data) def _load_and_collate_data(self): data = torch.load( - os.path.join(self.data_processed_dir_main, "data.pt"), + os.path.join(self.data_processed_dir_main, "smiles_token", "data.pt"), weights_only=False, map_location=self.device, ) @@ -51,7 +55,7 @@ def _get_pred_conf_from_model_output(self, model_output, model_label_mask): class NoActivationCondition(_Controller): def __init__(self, **kwargs): super().__init__(**kwargs) - self._model_queue = list(self.model_configs.keys()) + self._model_queue: Deque = deque(list(self.model_configs.keys())) def _controller(self, model, model_props, **kwargs): model_output = self._forward_pass(model) From 4bd00ac0c494ef68f8a225dd0f8cc0909139f647 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 19 May 2025 11:43:59 +0200 Subject: [PATCH 30/78] private instance var + reader_dir_name param --- chebai/ensemble/base.py | 66 ++++++++++++++++++++--------------- chebai/ensemble/controller.py | 14 ++++---- 2 files changed, 45 insertions(+), 35 deletions(-) diff --git a/chebai/ensemble/base.py b/chebai/ensemble/base.py index ba6952a8..639c8770 100644 --- a/chebai/ensemble/base.py +++ b/chebai/ensemble/base.py @@ -22,13 +22,17 @@ class EnsembleBase(ABC): Attributes: data_processed_dir_main (str): Directory where the processed data is stored. - models (Dict[str, LightningModule]): A dictionary of loaded models. + _models (Dict[str, LightningModule]): A dictionary of loaded models. model_configs (Dict[str, Dict]): Configuration dictionary for models in the ensemble. - dm_labels (Dict[str, int]): Mapping of label names to integer indices. + _dm_labels (Dict[str, int]): Mapping of label names to integer indices. """ def __init__( - self, model_configs: Dict[str, Dict], data_processed_dir_main: str, **kwargs + self, + model_configs: Dict[str, Dict], + data_processed_dir_main: str, + reader_dir_name: str = "smiles_token", + **kwargs, ): """ Initializes the ensemble model and loads configuration, models, and labels. @@ -41,22 +45,25 @@ def __init__( if bool(kwargs.get("_validate_configs", True)): self._validate_model_configs(model_configs) - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model_configs = model_configs + self.data_processed_dir_main = data_processed_dir_main + self.reader_dir_name = reader_dir_name self.input_dim = kwargs.get("input_dim", None) - self.num_of_labels: Optional[int] = ( + + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self._num_of_labels: Optional[int] = ( None # will be set by `_load_data_module_labels` method ) - self.data_processed_dir_main = data_processed_dir_main - self.models: Dict[str, LightningModule] = {} - self.model_configs = model_configs - self.dm_labels: Dict[str, int] = {} + self._models: Dict[str, LightningModule] = {} + self._dm_labels: Dict[str, int] = {} self._load_data_module_labels() self._num_models_per_label: torch.Tensor = torch.zeros( - 1, self.num_of_labels, device=self.device + 1, self._num_of_labels, device=self._device ) self._model_queue: Deque = deque() self._collated_data = None + self._total_data_size: Optional[int] = None @classmethod def _validate_model_configs(cls, model_configs: Dict[str, Dict]): @@ -121,14 +128,17 @@ def _load_data_module_labels(self): else: with open(classes_txt_file, "r") as f: for line in f: - if line.strip() not in self.dm_labels: - self.dm_labels[line.strip()] = len(self.dm_labels) - self.num_of_labels = len(self.dm_labels) + if line.strip() not in self._dm_labels: + self._dm_labels[line.strip()] = len(self._dm_labels) + self._num_of_labels = len(self._dm_labels) def run_ensemble(self): - batch_size = 10 - true_scores = torch.zeros(batch_size, self.num_of_labels, device=self.device) - false_scores = torch.zeros(batch_size, self.num_of_labels, device=self.device) + true_scores = torch.zeros( + self._total_data_size, self._num_of_labels, device=self._device + ) + false_scores = torch.zeros( + self._total_data_size, self._num_of_labels, device=self._device + ) while self._model_queue: model_name = self._model_queue.popleft() @@ -156,8 +166,8 @@ def run_ensemble(self): print_metrics( final_preds, self._collated_data.y, - self.device, - classes=list(self.dm_labels.keys()), + self._device, + classes=list(self._dm_labels.keys()), ) def _load_model_and_its_props(self, model_name): @@ -209,13 +219,13 @@ def _generate_model_label_props(self, labels_path: str): model_label_indices, tpv_label_values, fpv_label_values = [], [], [] for label in labels_dict.keys(): - if label in self.dm_labels: + if label in self._dm_labels: try: self._validate_model_labels_json_element(labels_dict[label]) except Exception as e: raise Exception(f"Label '{label}' has an unexpected error: {e}") - model_label_indices.append(self.dm_labels[label]) + model_label_indices.append(self._dm_labels[label]) tpv_label_values.append(labels_dict[label]["TPV"]) fpv_label_values.append(labels_dict[label]["FPV"]) @@ -223,19 +233,19 @@ def _generate_model_label_props(self, labels_path: str): raise ValueError(f"Values are empty for labels of the model") # Create masks to apply predictions only to known classes - mask = torch.zeros(self.num_of_labels, device=self.device, dtype=torch.bool) - mask[torch.tensor(model_label_indices, dtype=torch.int, device=self.device)] = ( - True - ) + mask = torch.zeros(self._num_of_labels, device=self._device, dtype=torch.bool) + mask[ + torch.tensor(model_label_indices, dtype=torch.int, device=self._device) + ] = True - tpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self.device) - fpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self.device) + tpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self._device) + fpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self._device) tpv_tensor[mask] = torch.tensor( - tpv_label_values, dtype=torch.float, device=self.device + tpv_label_values, dtype=torch.float, device=self._device ) fpv_tensor[mask] = torch.tensor( - fpv_label_values, dtype=torch.float, device=self.device + fpv_label_values, dtype=torch.float, device=self._device ) self._num_models_per_label += mask return {"mask": mask, "tpv_tensor": tpv_tensor, "fpv_tensor": fpv_tensor} diff --git a/chebai/ensemble/controller.py b/chebai/ensemble/controller.py index b19974cf..d9fccb39 100644 --- a/chebai/ensemble/controller.py +++ b/chebai/ensemble/controller.py @@ -18,18 +18,18 @@ def __init__(self, **kwargs): self._collated_data = self._load_and_collate_data() self.input_dim = len(self._collated_data.x[0]) - self.total_data_size: int = len(self._collated_data) + self._total_data_size: int = len(self._collated_data) def _load_and_collate_data(self): data = torch.load( - os.path.join(self.data_processed_dir_main, "smiles_token", "data.pt"), + os.path.join(self.data_processed_dir_main, self.reader_dir_name, "data.pt"), weights_only=False, - map_location=self.device, + map_location=self._device, ) collated_data = self._collator(data) - collated_data.x = collated_data.to_x(self.device) + collated_data.x = collated_data.to_x(self._device) if collated_data.y is not None: - collated_data.y = collated_data.to_y(self.device) + collated_data.y = collated_data.to_y(self._device) return collated_data def _forward_pass(self, model: ChebaiBaseNet): @@ -42,10 +42,10 @@ def _get_pred_conf_from_model_output(self, model_output, model_label_mask): # Consider logits and confidence only for valid classes sigmoid_logits = torch.sigmoid(model_output["logits"]) prediction = torch.full( - (self.total_data_size, self.num_of_labels), -1, dtype=torch.bool + (self._total_data_size, self._num_of_labels), -1, dtype=torch.bool ) confidence = torch.full( - (self.total_data_size, self.num_of_labels), -1, dtype=torch.float + (self._total_data_size, self._num_of_labels), -1, dtype=torch.float ) prediction[:, model_label_mask] = sigmoid_logits > 0.5 confidence[:, model_label_mask] = 2 * torch.abs(sigmoid_logits - 0.5) From 50057f04442a50ff559e1451a6902d68078f0ea6 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 19 May 2025 11:54:00 +0200 Subject: [PATCH 31/78] config for ensemble --- configs/ensemble/fullEnsembleWithWMV.yaml | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 configs/ensemble/fullEnsembleWithWMV.yaml diff --git a/configs/ensemble/fullEnsembleWithWMV.yaml b/configs/ensemble/fullEnsembleWithWMV.yaml new file mode 100644 index 00000000..626be694 --- /dev/null +++ b/configs/ensemble/fullEnsembleWithWMV.yaml @@ -0,0 +1,18 @@ +class_path: chebai.ensemble.FullEnsembleWMV +init_args: + data_processed_dir_main: "path/to/data/processed/main/directory" +# reader_dir_name: "name_of_reader_dir" # default is `smiles_token` +# _validate_configs: False # To avoid check for using same model with same model configs + model_configs: { + "model_1_name": { + "ckpt_path": "path/to/your/model_1/checkpoint.ckpt", + "class_path": "path/to/your/model_1/class", + "labels_path": "path/to/your/model_1/classes.json", + }, + + "model_2_name": { + "ckpt_path": "path/to/your/model_2/checkpoint.ckpt", + "class_path": "path/to/your/model_2/class", + "labels_path": "path/to/your/model_2/classes.json", + }, + } From ee7a16646a958bf5442553387e2f966df9d5ed99 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 19 May 2025 11:55:22 +0200 Subject: [PATCH 32/78] delete models/ensemble --- chebai/models/ensemble.py | 498 -------------------------------------- 1 file changed, 498 deletions(-) delete mode 100644 chebai/models/ensemble.py diff --git a/chebai/models/ensemble.py b/chebai/models/ensemble.py deleted file mode 100644 index eaa22bc7..00000000 --- a/chebai/models/ensemble.py +++ /dev/null @@ -1,498 +0,0 @@ -import importlib -import json -import os.path -from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, Tuple, Union - -import torch -from lightning.pytorch import LightningModule -from torch import Tensor - -from chebai.models import ChebaiBaseNet -from chebai.preprocessing.structures import XYData - - -class _EnsembleBase(ChebaiBaseNet, ABC): - """ - Base class for ensemble models in the Chebai framework. - - Inherits from ChebaiBaseNet and provides functionality to load multiple models, - validate configuration, and manage predictions. - - Attributes: - data_processed_dir_main (str): Directory where the processed data is stored. - models (Dict[str, LightningModule]): A dictionary of loaded models. - model_configs (Dict[str, Dict]): Configuration dictionary for models in the ensemble. - dm_labels (Dict[str, int]): Mapping of label names to integer indices. - """ - - def __init__( - self, model_configs: Dict[str, Dict], data_processed_dir_main: str, **kwargs - ): - """ - Initializes the ensemble model and loads configuration, models, and labels. - - Args: - model_configs (Dict[str, Dict]): Dictionary of model configurations. - data_processed_dir_main (str): Path to the processed data directory. - **kwargs: Additional arguments for initialization. - """ - super().__init__(**kwargs) - if kwargs.get("_validate_configs", True): - self._validate_model_configs(model_configs) - - self.data_processed_dir_main = data_processed_dir_main - self.models: Dict[str, LightningModule] = {} - self.model_configs = model_configs - self.dm_labels: Dict[str, int] = {} - - self._load_data_module_labels() - self._load_ensemble_models() - - @classmethod - def _validate_model_configs(cls, model_configs: Dict[str, Dict]): - """ - Validates the model configurations to ensure required keys are present. - - Args: - model_configs (Dict[str, Dict]): Dictionary of model configurations. - - Raises: - AttributeError: If required keys are missing in the configuration. - ValueError: If there are duplicate model paths or class paths. - """ - path_set, class_set, labels_set = set(), set(), set() - - required_keys = {"class_path", "ckpt_path", "labels_path"} - - for model_name, config in model_configs.items(): - missing_keys = required_keys - config.keys() - - if missing_keys: - raise AttributeError( - f"Missing keys {missing_keys} in model '{model_name}' configuration." - ) - - model_path = config["ckpt_path"] - class_path = config["class_path"] - labels_path = config["labels_path"] - - if model_path in path_set: - raise ValueError( - f"Duplicate model path detected: '{model_path}'. " - f"Each model must have a unique model-checkpoint path." - ) - - if class_path in class_set: - raise ValueError( - f"Duplicate class path detected: '{class_path}'. Each model must have a unique class path." - ) - - if labels_path in labels_set: - raise ValueError( - f"Duplicate labels path: {labels_path}. Each model must have unique labels path." - ) - - path_set.add(model_path) - class_set.add(class_path) - labels_set.add(labels_path) - - def _load_ensemble_models(self): - """ - Loads the models specified in the configuration and initializes them. - """ - for model_name in self.model_configs: - model_ckpt_path = self.model_configs[model_name]["ckpt_path"] - model_class_path = self.model_configs[model_name]["class_path"] - model_labels_path = self.model_configs[model_name]["labels_path"] - if not os.path.exists(model_ckpt_path): - raise FileNotFoundError( - f"Model path '{model_ckpt_path}' for '{model_name}' does not exist." - ) - - class_name = model_class_path.split(".")[-1] - module_path = ".".join(model_class_path.split(".")[:-1]) - module = importlib.import_module(module_path) - lightning_cls: LightningModule = getattr(module, class_name) - - model = lightning_cls.load_from_checkpoint( - model_ckpt_path, input_dim=self.input_dim - ) - model.eval() - model.freeze() - - self.models[model_name] = model - self.model_configs[model_name]["labels"] = self._load_model_labels( - model_labels_path, model_name - ) - - def _load_data_module_labels(self): - """ - Loads the label mapping from the classes.txt file for loaded data. - - Raises: - FileNotFoundError: If the classes.txt file does not exist. - """ - classes_txt_file = os.path.join(self.data_processed_dir_main, "classes.txt") - if not os.path.exists(classes_txt_file): - raise FileNotFoundError(f"{classes_txt_file} does not exist") - else: - with open(classes_txt_file, "r") as f: - for line in f: - if line.strip() not in self.dm_labels: - self.dm_labels[line.strip()] = len(self.dm_labels) - - @staticmethod - def _load_model_labels(labels_path: str, model_name: str) -> Dict[str, float]: - if not os.path.exists(labels_path): - raise FileNotFoundError(f"{labels_path} does not exist.") - - if not labels_path.endswith(".json"): - raise TypeError(f"{labels_path} is not a JSON file.") - - with open(labels_path, "r") as f: - model_labels = json.load(f) - - labels_dict = {} - for label, label_dict in model_labels.items(): - msg = f"for model {model_name} for label {label}" - if "TPV" not in label_dict.keys() or "FPV" not in label_dict.keys(): - raise AttributeError(f"Missing keys 'TPV' and/or 'FPV' {msg}") - - # Validate 'tpv' and 'fpv' are either floats or convertible to float - for key in ["TPV", "FPV"]: - try: - value = float(label_dict[key]) - if value < 0: - raise ValueError( - f"'{key}' must be non-negative but got {value} {msg}" - ) - except (TypeError, ValueError): - raise ValueError( - f"'{key}' must be a float or convertible to float, but got {label_dict[key]} {msg}" - ) - labels_dict.setdefault(label, {})[key] = value - return labels_dict - - @abstractmethod - def _get_prediction_and_labels( - self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Abstract method for obtaining predictions and labels. - - Args: - data (Dict[str, Any]): The input data. - labels (torch.Tensor): The target labels. - output (torch.Tensor): The model output. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The predicted labels and the ground truth labels. - """ - pass - - def controller(self): - pass - - def consolidator( - self, - ): - pass - - -class ChebiEnsemble(_EnsembleBase): - """ - Ensemble model that aggregates predictions from multiple models for the Chebai task. - - This model combines the outputs of several individual models and aggregates their predictions - using a weighted voting strategy based on trustworthiness (TPV and FPV). This strategy can modified by overriding - `aggregate_predictions` method by subclasses, as per needs. - - There is are relevant trainable parameters for this ensemble model, hence trainer.max_epochs should be set to 1. - `_dummy_param` exists for only lighting module completeness and compatability purpose. - """ - - NAME = "ChebiEnsemble" - - def __init__(self, model_configs: Dict[str, Dict], **kwargs): - """ - Initializes the ensemble model and computes the model-label mask. - - Args: - model_configs (Dict[str, Dict]): Dictionary of model configurations. - **kwargs: Additional arguments for initialization. - """ - super().__init__(model_configs, **kwargs) - - # Add a dummy trainable parameter - self.dummy_param = torch.nn.Parameter(torch.randn(1, requires_grad=True)) - self._num_models_per_label: Optional[torch.Tensor] = None - self._generate_model_label_mask() - - def _generate_model_label_mask(self): - """ - Generates a mask indicating the labels handled by each model, and retrieves corresponding the TPV and FPV values - as tensors. - - Raises: - FileNotFoundError: If the labels path does not exist. - ValueError: If label values are empty for any model. - """ - num_models_per_label = torch.zeros(1, self.out_dim, device=self.device) - - for model_name, model_config in self.model_configs.items(): - labels_dict = model_config["labels"] - - model_label_indices, tpv_label_values, fpv_label_values = [], [], [] - for label in labels_dict.keys(): - if label in self.dm_labels: - model_label_indices.append(self.dm_labels[label]) - tpv_label_values.append(labels_dict[label]["TPV"]) - fpv_label_values.append(labels_dict[label]["FPV"]) - - if not all([model_label_indices, tpv_label_values, fpv_label_values]): - raise ValueError(f"Values are empty for labels of model {model_name}") - - # Create masks to apply predictions only to known classes - mask = torch.zeros(self.out_dim, device=self.device, dtype=torch.bool) - mask[ - torch.tensor(model_label_indices, dtype=torch.int, device=self.device) - ] = True - - tpv_tensor = torch.full_like( - mask, -1, dtype=torch.float, device=self.device - ) - fpv_tensor = torch.full_like( - mask, -1, dtype=torch.float, device=self.device - ) - - tpv_tensor[mask] = torch.tensor( - tpv_label_values, dtype=torch.float, device=self.device - ) - fpv_tensor[mask] = torch.tensor( - fpv_label_values, dtype=torch.float, device=self.device - ) - - self.model_configs[model_name]["labels_mask"] = mask - self.model_configs[model_name]["tpv_tensor"] = tpv_tensor - self.model_configs[model_name]["fpv_tensor"] = fpv_tensor - num_models_per_label += mask - - self._num_models_per_label = num_models_per_label - - def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: - """ - Forward pass through the ensemble model, aggregating predictions from all models. - - Args: - data (Dict[str, Tensor]): Input data including features and labels. - **kwargs: Additional arguments for the forward pass. - - Returns: - Dict[str, Any]: The aggregated logits, predictions, and confidences. - """ - predictions = {} - confidences = {} - - assert data["labels"].shape[1] == self.out_dim - - # Initialize total_logits with zeros - total_logits = torch.zeros( - data["labels"].shape[0], self.out_dim, device=self.device - ) - - for name, model in self.models.items(): - output = model(data) - mask = self.model_configs[name]["labels_mask"] - - # Consider logits and confidence only for valid classes - sigmoid_logits = torch.sigmoid(output["logits"]) - prediction = torch.full_like(total_logits, -1, dtype=torch.bool) - confidence = torch.full_like(total_logits, -1, dtype=torch.float) - prediction[:, mask] = sigmoid_logits > 0.5 - confidence[:, mask] = 2 * torch.abs(sigmoid_logits - 0.5) - - predictions[name] = prediction - confidences[name] = confidence - total_logits += output[ - "logits" - ] # This doesn't play a role here, just for lightning flow completeness - - return { - "logits": total_logits, - "pred_dict": predictions, - "conf_dict": confidences, - } - - def _get_prediction_and_labels(self, data, labels, model_output): - """ - Gets predictions and labels from the model output. - - Args: - data (Dict[str, Any]): The input data. - labels (torch.Tensor): The target labels. - model_output (Dict[str, Tensor]): The model's output. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The predictions and the ground truth labels. - """ - d = model_output["logits"] - # Aggregate predictions using weighted voting - metrics_preds = self.aggregate_predictions( - model_output["pred_dict"], model_output["conf_dict"] - ) - loss_kwargs = data.get("loss_kwargs", dict()) - if "non_null_labels" in loss_kwargs: - n = loss_kwargs["non_null_labels"] - d = d[n] - metrics_preds = metrics_preds[n] - return ( - torch.sigmoid(d), - labels.int() if labels is not None else None, - metrics_preds, - ) - - def _execute( - self, - batch: XYData, - batch_idx: int, - metrics: Optional[torch.nn.Module] = None, - prefix: Optional[str] = "", - log: Optional[bool] = True, - sync_dist: Optional[bool] = False, - ) -> Dict[str, Union[torch.Tensor, Any]]: - """ - Executes the model on a batch of data and returns the model output and predictions. - - Args: - batch (XYData): The input batch of data. - batch_idx (int): The index of the current batch. - metrics (torch.nn.Module): A dictionary of metrics to track. - prefix (str, optional): A prefix to add to the metric names. Defaults to "". - log (bool, optional): Whether to log the metrics. Defaults to True. - sync_dist (bool, optional): Whether to synchronize distributed training. Defaults to False. - - Returns: - Dict[str, Union[torch.Tensor, Any]]: A dictionary containing the processed data, labels, model output, - predictions, and loss (if applicable). - """ - assert isinstance(batch, XYData) - batch = batch.to(self.device) - data = self._process_batch(batch, batch_idx) - labels = data["labels"] - model_output = self(data, **data.get("model_kwargs", dict())) - pr, tar, metrics_preds = self._get_prediction_and_labels( - data, labels, model_output - ) - d = dict(data=data, labels=labels, output=model_output, preds=pr) - if log: - if self.criterion is not None: - loss_data, loss_labels, loss_kwargs_candidates = self._process_for_loss( - model_output, labels, data.get("loss_kwargs", dict()) - ) - loss_kwargs = dict() - if self.pass_loss_kwargs: - loss_kwargs = loss_kwargs_candidates - loss = self.criterion(loss_data, loss_labels, **loss_kwargs) - if isinstance(loss, tuple): - loss_additional = loss[1:] - for i, loss_add in enumerate(loss_additional): - self.log( - f"{prefix}loss_{i}", - loss_add if isinstance(loss_add, int) else loss_add.item(), - batch_size=len(batch), - on_step=True, - on_epoch=False, - prog_bar=False, - logger=True, - sync_dist=sync_dist, - ) - loss = loss[0] - - d["loss"] = loss + 0 * self.dummy_param.sum() - - self.log( - f"{prefix}loss", - loss.item(), - batch_size=len(batch), - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - sync_dist=sync_dist, - ) - if metrics and labels is not None: - for metric_name, metric in metrics.items(): - metric.update(metrics_preds, tar) - self._log_metrics(prefix, metrics, len(batch)) - return d - - def aggregate_predictions( - self, predictions: Dict[str, torch.Tensor], confidences: Dict[str, torch.Tensor] - ) -> torch.Tensor: - """ - Implements weighted voting based on trustworthiness. - - This method aggregates predictions from multiple models using a weighted voting mechanism. - The weight of each model's prediction is determined by its True Positive Value (TPV) and - False Positive Value (FPV), scaled by the confidence score. - - Args: - predictions (Dict[str, torch.Tensor]): - A dictionary mapping model names to their respective binary class predictions - (shape: `[batch_size, num_classes]`). - confidences (Dict[str, torch.Tensor]): - A dictionary mapping model names to their respective confidence scores - (shape: `[batch_size, num_classes]`). - - Returns: - torch.Tensor: - A tensor of final aggregated predictions based on weighted voting - (shape: `[batch_size, num_classes]`), where values are `True` for positive class - and `False` otherwise. - """ - batch_size, num_classes = list(predictions.values())[0].shape - true_scores = torch.zeros(batch_size, num_classes, device=self.device) - false_scores = torch.zeros(batch_size, num_classes, device=self.device) - - for model, conf in confidences.items(): - tpv = self.model_configs[model]["tpv_tensor"] - npv = self.model_configs[model]["fpv_tensor"] - - # Determine which classes the model provides predictions for - mask = self.model_configs[model]["labels_mask"] - weight = conf * (tpv * conf + npv * (1 - conf)) - - # Apply mask: Only update scores for valid classes - true_scores += weight * conf * mask - false_scores += weight * (1 - conf) * mask - - # Avoid division by zero: Set valid_counts to 1 where it's zero - valid_counts = self._num_models_per_label.clamp(min=1) - - # Normalize by valid contributions to prevent bias - final_preds = (true_scores / valid_counts) > (false_scores / valid_counts) - - return final_preds - - def _process_for_loss( - self, - model_output: Dict[str, Tensor], - labels: Tensor, - loss_kwargs: Dict[str, Any], - ) -> Tuple[Tensor, Tensor, Dict[str, Any]]: - """ - Process the model output for calculating the loss. - - Args: - model_output (Dict[str, Tensor]): The output of the model. - labels (Tensor): The target labels. - loss_kwargs (Dict[str, Any]): Additional loss arguments. - - Returns: - tuple: A tuple containing the processed model output, labels, and loss arguments. - """ - kwargs_copy = dict(loss_kwargs) - if labels is not None: - labels = labels.float() - return model_output["logits"], labels, kwargs_copy From e9f1d9567ba3465aac217caeede901c7ac8f3ee8 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 19 May 2025 11:57:10 +0200 Subject: [PATCH 33/78] delete old ensemble config --- configs/model/ensemble/chebiEnsemble.yml | 15 --------------- configs/model/ensemble/ensemble_learning.yml | 4 ---- 2 files changed, 19 deletions(-) delete mode 100644 configs/model/ensemble/chebiEnsemble.yml delete mode 100644 configs/model/ensemble/ensemble_learning.yml diff --git a/configs/model/ensemble/chebiEnsemble.yml b/configs/model/ensemble/chebiEnsemble.yml deleted file mode 100644 index bc4547b0..00000000 --- a/configs/model/ensemble/chebiEnsemble.yml +++ /dev/null @@ -1,15 +0,0 @@ -class_path: chebai.models.ensemble.ChebiEnsemble -init_args: - model_configs: { - "model_1_name": { - "ckpt_path": "path/to/your/model_1/checkpoint.ckpt", - "class_path": "path/to/your/model_1/class", - "labels_path": "path/to/your/model_1/classes.json", - }, - - "model_2_name": { - "ckpt_path": "path/to/your/model_2/checkpoint.ckpt", - "class_path": "path/to/your/model_2/class", - "labels_path": "path/to/your/model_2/classes.json", - }, - } diff --git a/configs/model/ensemble/ensemble_learning.yml b/configs/model/ensemble/ensemble_learning.yml deleted file mode 100644 index 73257b49..00000000 --- a/configs/model/ensemble/ensemble_learning.yml +++ /dev/null @@ -1,4 +0,0 @@ -class_path: chebai.models.ensemble.ChebiEnsembleLearning -init_args: - optimizer_kwargs: - lr: 1e-3 From fca0305cac7424cad0151cb75b4b08152d9c1fd8 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 19 May 2025 12:23:25 +0200 Subject: [PATCH 34/78] add docstrings + typehints --- chebai/ensemble/base.py | 238 +++++++++++++++---------- chebai/ensemble/consolidator.py | 91 ++++++++-- chebai/ensemble/controller.py | 85 ++++++++- chebai/ensemble/ensemble_run_script.py | 47 ++++- 4 files changed, 337 insertions(+), 124 deletions(-) diff --git a/chebai/ensemble/base.py b/chebai/ensemble/base.py index 639c8770..f1cfd309 100644 --- a/chebai/ensemble/base.py +++ b/chebai/ensemble/base.py @@ -3,13 +3,14 @@ import os from abc import ABC, abstractmethod from collections import deque -from typing import Deque, Dict, Optional +from typing import Any, Deque, Dict, Optional, Tuple import torch from lightning import LightningModule from lightning_utilities.core.rank_zero import rank_zero_info from chebai.models import ChebaiBaseNet +from chebai.preprocessing.structures import XYData from chebai.result.classification import print_metrics @@ -17,38 +18,32 @@ class EnsembleBase(ABC): """ Base class for ensemble models in the Chebai framework. - Inherits from ChebaiBaseNet and provides functionality to load multiple models, - validate configuration, and manage predictions. - - Attributes: - data_processed_dir_main (str): Directory where the processed data is stored. - _models (Dict[str, LightningModule]): A dictionary of loaded models. - model_configs (Dict[str, Dict]): Configuration dictionary for models in the ensemble. - _dm_labels (Dict[str, int]): Mapping of label names to integer indices. + Handles loading, validating, and coordinating multiple models for ensemble prediction. """ def __init__( self, - model_configs: Dict[str, Dict], + model_configs: Dict[str, Dict[str, Any]], data_processed_dir_main: str, reader_dir_name: str = "smiles_token", - **kwargs, - ): + **kwargs: Any, + ) -> None: """ - Initializes the ensemble model and loads configuration, models, and labels. + Initializes the ensemble model and loads configurations, labels, and sets up the environment. Args: - model_configs (Dict[str, Dict]): Dictionary of model configurations. + model_configs (Dict[str, Dict[str, Any]]): Dictionary of model configurations. data_processed_dir_main (str): Path to the processed data directory. - **kwargs: Additional arguments for initialization. + reader_dir_name (str): Name of the directory used by the reader. Defaults to 'smiles_token'. + **kwargs (Any): Additional arguments, such as 'input_dim' and '_validate_configs'. """ if bool(kwargs.get("_validate_configs", True)): self._validate_model_configs(model_configs) - self.model_configs = model_configs - self.data_processed_dir_main = data_processed_dir_main - self.reader_dir_name = reader_dir_name - self.input_dim = kwargs.get("input_dim", None) + self.model_configs: Dict[str, Dict[str, Any]] = model_configs + self.data_processed_dir_main: str = data_processed_dir_main + self.reader_dir_name: str = reader_dir_name + self.input_dim: Optional[int] = kwargs.get("input_dim", None) self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._num_of_labels: Optional[int] = ( @@ -61,78 +56,73 @@ def __init__( self._num_models_per_label: torch.Tensor = torch.zeros( 1, self._num_of_labels, device=self._device ) - self._model_queue: Deque = deque() - self._collated_data = None + self._model_queue: Deque[str] = deque() + self._collated_data: Optional[XYData] = None self._total_data_size: Optional[int] = None @classmethod - def _validate_model_configs(cls, model_configs: Dict[str, Dict]): + def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> None: """ - Validates the model configurations to ensure required keys are present. + Validates model configuration dictionary for required keys and uniqueness. Args: - model_configs (Dict[str, Dict]): Dictionary of model configurations. + model_configs (Dict[str, Dict[str, Any]]): Model configuration dictionary. Raises: - AttributeError: If required keys are missing in the configuration. - ValueError: If there are duplicate model paths or class paths. + AttributeError: If any model config is missing required keys. + ValueError: If duplicate paths are found for model checkpoint, class, or labels. """ path_set, class_set, labels_set = set(), set(), set() - required_keys = {"class_path", "ckpt_path", "labels_path"} for model_name, config in model_configs.items(): missing_keys = required_keys - config.keys() - if missing_keys: raise AttributeError( f"Missing keys {missing_keys} in model '{model_name}' configuration." ) - model_path = config["ckpt_path"] - class_path = config["class_path"] - labels_path = config["labels_path"] + model_path, class_path, labels_path = ( + config["ckpt_path"], + config["class_path"], + config["labels_path"], + ) if model_path in path_set: - raise ValueError( - f"Duplicate model path detected: '{model_path}'. " - f"Each model must have a unique model-checkpoint path." - ) - + raise ValueError(f"Duplicate model path detected: '{model_path}'.") if class_path in class_set: - raise ValueError( - f"Duplicate class path detected: '{class_path}'. Each model must have a unique class path." - ) - + raise ValueError(f"Duplicate class path detected: '{class_path}'.") if labels_path in labels_set: - raise ValueError( - f"Duplicate labels path: {labels_path}. Each model must have unique labels path." - ) + raise ValueError(f"Duplicate labels path: {labels_path}.") path_set.add(model_path) class_set.add(class_path) labels_set.add(labels_path) - def _load_data_module_labels(self): + def _load_data_module_labels(self) -> None: """ - Loads the label mapping from the classes.txt file for loaded data. + Loads class labels from the classes.txt file and sets internal label mapping. Raises: - FileNotFoundError: If the classes.txt file does not exist. + FileNotFoundError: If the expected classes.txt file is not found. """ classes_txt_file = os.path.join(self.data_processed_dir_main, "classes.txt") rank_zero_info(f"Loading {classes_txt_file} ....") if not os.path.exists(classes_txt_file): raise FileNotFoundError(f"{classes_txt_file} does not exist") - else: - with open(classes_txt_file, "r") as f: - for line in f: - if line.strip() not in self._dm_labels: - self._dm_labels[line.strip()] = len(self._dm_labels) + + with open(classes_txt_file, "r") as f: + for line in f: + label = line.strip() + if label not in self._dm_labels: + self._dm_labels[label] = len(self._dm_labels) self._num_of_labels = len(self._dm_labels) - def run_ensemble(self): + def run_ensemble(self) -> None: + """ + Executes the full ensemble prediction pipeline, aggregating predictions and printing metrics. + """ true_scores = torch.zeros( self._total_data_size, self._num_of_labels, device=self._device ) @@ -149,7 +139,7 @@ def run_ensemble(self): pred_conf_dict = self._controller(model, model_props) del model # Model can be huge to keep it in memory, delete as no longer needed - rank_zero_info("\t Passing predictions to consolidator to aggregation") + rank_zero_info("\t Passing predictions to consolidator for aggregation...") self._consolidator( pred_conf_dict, model_props, @@ -157,9 +147,7 @@ def run_ensemble(self): false_scores=false_scores, ) - rank_zero_info( - f"Consolidate predictions of the ensemble: {self.__class__.__name__}" - ) + rank_zero_info(f"Consolidating predictions for {self.__class__.__name__}") final_preds = self._consolidate_on_finish( true_scores=true_scores, false_scores=false_scores ) @@ -170,13 +158,23 @@ def run_ensemble(self): classes=list(self._dm_labels.keys()), ) - def _load_model_and_its_props(self, model_name): + def _load_model_and_its_props( + self, model_name: str + ) -> Tuple[LightningModule, Dict[str, torch.Tensor]]: """ - Loads the models specified in the configuration and initializes them. + Loads a model checkpoint and its label-related properties. + + Args: + model_name (str): Name of the model to load. + + Returns: + Tuple[LightningModule, Dict[str, torch.Tensor]]: The model and its label properties. """ - model_ckpt_path = self.model_configs[model_name]["ckpt_path"] - model_class_path = self.model_configs[model_name]["class_path"] - model_labels_path = self.model_configs[model_name]["labels_path"] + config = self.model_configs[model_name] + model_ckpt_path = config["ckpt_path"] + model_class_path = config["class_path"] + model_labels_path = config["labels_path"] + if not os.path.exists(model_ckpt_path): raise FileNotFoundError( f"Model path '{model_ckpt_path}' for '{model_name}' does not exist." @@ -185,7 +183,8 @@ def _load_model_and_its_props(self, model_name): class_name = model_class_path.split(".")[-1] module_path = ".".join(model_class_path.split(".")[:-1]) module = importlib.import_module(module_path) - lightning_cls: LightningModule = getattr(module, class_name) + lightning_cls = getattr(module, class_name) + assert isinstance(lightning_cls, type), f"{class_name} is not a class." assert issubclass( lightning_cls, ChebaiBaseNet @@ -199,44 +198,42 @@ def _load_model_and_its_props(self, model_name): model.freeze() model_label_props = self._generate_model_label_props(model_labels_path) except Exception as e: - raise RuntimeError( - f"For model {model_name} following exception as occurred \n Error: {e}" - ) + raise RuntimeError(f"Error loading model {model_name}") from e return model, model_label_props - def _generate_model_label_props(self, labels_path: str): + def _generate_model_label_props(self, labels_path: str) -> Dict[str, torch.Tensor]: """ - Generates a mask indicating the labels handled by each model, and retrieves corresponding the TPV and FPV values - as tensors. + Generates label mask and confidence tensors (TPV, FPV) for a model. - Raises: - FileNotFoundError: If the labels path does not exist. - ValueError: If label values are empty for any model. + Args: + labels_path (str): Path to the labels JSON file. + + Returns: + Dict[str, torch.Tensor]: Dictionary containing mask, TPV and FPV tensors. """ - rank_zero_info("\t Generating mask model's labels and other properties") + rank_zero_info("\t Generating model label masks and properties") labels_dict = self._load_model_labels(labels_path) model_label_indices, tpv_label_values, fpv_label_values = [], [], [] - for label in labels_dict.keys(): + + for label, props in labels_dict.items(): if label in self._dm_labels: try: self._validate_model_labels_json_element(labels_dict[label]) except Exception as e: - raise Exception(f"Label '{label}' has an unexpected error: {e}") + raise Exception(f"Label '{label}' has an unexpected error") from e model_label_indices.append(self._dm_labels[label]) - tpv_label_values.append(labels_dict[label]["TPV"]) - fpv_label_values.append(labels_dict[label]["FPV"]) + tpv_label_values.append(props["TPV"]) + fpv_label_values.append(props["FPV"]) if not all([model_label_indices, tpv_label_values, fpv_label_values]): - raise ValueError(f"Values are empty for labels of the model") + raise ValueError(f"No valid label values found in {labels_path}.") # Create masks to apply predictions only to known classes - mask = torch.zeros(self._num_of_labels, device=self._device, dtype=torch.bool) - mask[ - torch.tensor(model_label_indices, dtype=torch.int, device=self._device) - ] = True + mask = torch.zeros(self._num_of_labels, dtype=torch.bool, device=self._device) + mask[torch.tensor(model_label_indices, device=self._device)] = True tpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self._device) fpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self._device) @@ -247,47 +244,94 @@ def _generate_model_label_props(self, labels_path: str): fpv_tensor[mask] = torch.tensor( fpv_label_values, dtype=torch.float, device=self._device ) + self._num_models_per_label += mask return {"mask": mask, "tpv_tensor": tpv_tensor, "fpv_tensor": fpv_tensor} @staticmethod def _load_model_labels(labels_path: str) -> Dict[str, Dict[str, float]]: + """ + Loads a JSON label file for a model. + + Args: + labels_path (str): Path to the JSON file. + + Returns: + Dict[str, Dict[str, float]]: Parsed label confidence data. + + Raises: + FileNotFoundError: If the file is missing. + TypeError: If the file is not a JSON. + """ if not os.path.exists(labels_path): raise FileNotFoundError(f"{labels_path} does not exist.") - if not labels_path.endswith(".json"): raise TypeError(f"{labels_path} is not a JSON file.") - with open(labels_path, "r") as f: - model_labels = json.load(f) - return model_labels + return json.load(f) @staticmethod - def _validate_model_labels_json_element(label_dict: Dict[str, float]): - if "TPV" not in label_dict.keys() or "FPV" not in label_dict.keys(): - raise AttributeError(f"Missing keys 'TPV' and/or 'FPV'") + def _validate_model_labels_json_element(label_dict: Dict[str, Any]) -> None: + """ + Validates a label confidence dictionary to ensure required keys and values are valid. + + Args: + label_dict (Dict[str, Any]): Label data with TPV and FPV keys. - # Validate 'tpv' and 'fpv' are either floats or convertible to float + Raises: + AttributeError: If required keys are missing. + ValueError: If values are not valid floats or are negative. + """ for key in ["TPV", "FPV"]: + if key not in label_dict: + raise AttributeError(f"Missing key '{key}' in label dict.") try: value = float(label_dict[key]) if value < 0: raise ValueError(f"'{key}' must be non-negative but got {value}") - except (TypeError, ValueError): - raise ValueError( - f"'{key}' must be a float or convertible to float, but got {label_dict[key]}" - ) + except (TypeError, ValueError) as e: + raise ValueError(f"Invalid value for '{key}': {label_dict[key]}") from e @abstractmethod - def _controller(self, model, model_props, **kwargs): + def _controller( + self, + model: LightningModule, + model_props: Dict[str, torch.Tensor], + **kwargs: Any, + ) -> Dict[str, torch.Tensor]: + """ + Abstract method to define model-specific prediction logic. + + Returns: + Dict[str, torch.Tensor]: Predictions or confidence scores. + """ pass @abstractmethod def _consolidator( - self, pred_conf_dict, model_props, *, true_scores, false_scores, **kwargs - ): + self, + pred_conf_dict: Dict[str, torch.Tensor], + model_props: Dict[str, torch.Tensor], + *, + true_scores: torch.Tensor, + false_scores: torch.Tensor, + **kwargs: Any, + ) -> None: + """ + Abstract method to define aggregation logic. + + Should update the provided `true_scores` and `false_scores`. + """ pass @abstractmethod - def _consolidate_on_finish(self, *, true_scores, false_scores): + def _consolidate_on_finish( + self, *, true_scores: torch.Tensor, false_scores: torch.Tensor + ) -> torch.Tensor: + """ + Abstract method to produce final predictions after all models have been evaluated. + + Returns: + torch.Tensor: Final aggregated predictions. + """ pass diff --git a/chebai/ensemble/consolidator.py b/chebai/ensemble/consolidator.py index 7a042823..c8d79eed 100644 --- a/chebai/ensemble/consolidator.py +++ b/chebai/ensemble/consolidator.py @@ -1,44 +1,111 @@ from abc import ABC +from typing import Any, Dict + +from torch import Tensor from .base import EnsembleBase class WeightedMajorityVoting(EnsembleBase, ABC): + """ + Ensemble consolidator using weighted majority voting. + Each model's contribution is weighted by a function of confidence, + true positive value (TPV), and negative predictive value (NPV). + """ + def _consolidator( - self, pred_conf_dict, model_props, *, true_scores, false_scores, **kwargs - ): + self, + pred_conf_dict: Dict[str, Tensor], + model_props: Dict[str, Tensor], + *, + true_scores: Tensor, + false_scores: Tensor, + **kwargs: Any + ) -> None: + """ + Updates true/false scores based on model predictions using a weighted voting scheme. + + Args: + pred_conf_dict (Dict[str, Tensor]): Contains model predictions and confidence scores. + model_props (Dict[str, Tensor]): Contains mask, TPV and NPV tensors for model. + true_scores (Tensor): Tensor accumulating weighted "true" contributions. + false_scores (Tensor): Tensor accumulating weighted "false" contributions. + **kwargs (Any): Additional unused keyword arguments. + """ tpv = model_props["tpv_tensor"] npv = model_props["fpv_tensor"] conf = pred_conf_dict["confidence"] - - # Determine which classes the model provides predictions for mask = model_props["mask"] + weight = conf * (tpv * conf + npv * (1 - conf)) # Apply mask: Only update scores for valid classes true_scores += weight * conf * mask false_scores += weight * (1 - conf) * mask - def _consolidate_on_finish(self, *, true_scores, false_scores): + def _consolidate_on_finish( + self, *, true_scores: Tensor, false_scores: Tensor + ) -> Tensor: + """ + Finalizes predictions after all models have contributed their scores. + + Args: + true_scores (Tensor): Accumulated weighted true scores per label. + false_scores (Tensor): Accumulated weighted false scores per label. + + Returns: + Tensor: Final binary predictions (True if true_score > false_score). + """ # Avoid division by zero: Set valid_counts to 1 where it's zero valid_counts = self._num_models_per_label.clamp(min=1) - # Normalize by valid contributions to prevent bias final_preds = (true_scores / valid_counts) > (false_scores / valid_counts) return final_preds class MajorityVoting(EnsembleBase, ABC): + """ + Ensemble consolidator using simple majority voting. + Each model contributes equally; confidence is used directly as "vote weight". + """ + def _consolidator( - self, pred_conf_dict, model_props, *, true_scores, false_scores, **kwargs - ): - conf = pred_conf_dict["confidence"] + self, + pred_conf_dict: Dict[str, Tensor], + model_props: Dict[str, Tensor], + *, + true_scores: Tensor, + false_scores: Tensor, + **kwargs: Any + ) -> None: + """ + Updates true/false scores based on model predictions using unweighted voting. - # Determine which classes the model provides predictions for - mask = model_props["mask"] + Args: + pred_conf_dict (Dict[str, Tensor]): Contains model predictions and confidence scores. + model_props (Dict[str, Tensor]): Contains mask tensor for model. + true_scores (Tensor): Tensor accumulating true contributions. + false_scores (Tensor): Tensor accumulating false contributions. + **kwargs (Any): Additional unused keyword arguments. + """ + conf = pred_conf_dict["confidence"] # Apply mask: Only update scores for valid classes + mask = model_props["mask"] + true_scores += conf * mask false_scores += (1 - conf) * mask - def _consolidate_on_finish(self, *, true_scores, false_scores): + def _consolidate_on_finish( + self, *, true_scores: Tensor, false_scores: Tensor + ) -> Tensor: + """ + Finalizes predictions after all models have voted. + + Args: + true_scores (Tensor): Accumulated true votes per label. + false_scores (Tensor): Accumulated false votes per label. + + Returns: + Tensor: Final binary predictions (True if true_score > false_score). + """ return true_scores > false_scores diff --git a/chebai/ensemble/controller.py b/chebai/ensemble/controller.py index d9fccb39..ddffc6d0 100644 --- a/chebai/ensemble/controller.py +++ b/chebai/ensemble/controller.py @@ -1,9 +1,10 @@ import os.path from abc import ABC from collections import deque -from typing import Deque +from typing import Any, Deque, Dict import torch +from torch import Tensor from chebai.models import ChebaiBaseNet from chebai.preprocessing.collate import RaggedCollator @@ -12,7 +13,22 @@ class _Controller(EnsembleBase, ABC): - def __init__(self, **kwargs): + """ + Abstract base controller for ensemble models that handles data loading, collating, + and inference logic over a collection of models. + + Inherits from: + EnsembleBase: The base ensemble class with shared ensemble logic. + ABC: For defining abstract methods. + """ + + def __init__(self, **kwargs: Any): + """ + Initializes the controller with data loader and collator. + + Args: + **kwargs (Any): Keyword arguments passed to the EnsembleBase initializer. + """ super().__init__(**kwargs) self._collator = RaggedCollator() @@ -20,7 +36,13 @@ def __init__(self, **kwargs): self.input_dim = len(self._collated_data.x[0]) self._total_data_size: int = len(self._collated_data) - def _load_and_collate_data(self): + def _load_and_collate_data(self) -> Any: + """ + Loads and collates data using RaggedCollator. + + Returns: + Collated data object with `.x` and `.y` attributes moved to device. + """ data = torch.load( os.path.join(self.data_processed_dir_main, self.reader_dir_name, "data.pt"), weights_only=False, @@ -32,14 +54,35 @@ def _load_and_collate_data(self): collated_data.y = collated_data.to_y(self._device) return collated_data - def _forward_pass(self, model: ChebaiBaseNet): + def _forward_pass(self, model: ChebaiBaseNet) -> Dict[str, Tensor]: + """ + Runs a forward pass of the given model on the collated data. + + Args: + model (ChebaiBaseNet): The model to perform inference with. + + Returns: + Dict[str, Tensor]: Model output dictionary containing logits and other keys. + """ processable_data = model._process_batch(self._collated_data, 0) del processable_data["loss_kwargs"] model_output = model(processable_data, **processable_data["model_kwargs"]) return model_output - def _get_pred_conf_from_model_output(self, model_output, model_label_mask): - # Consider logits and confidence only for valid classes + def _get_pred_conf_from_model_output( + self, model_output: Dict[str, Tensor], model_label_mask: Tensor + ) -> Dict[str, Tensor]: + """ + Processes model output to extract binary predictions and confidence scores. + + Args: + model_output (Dict[str, Tensor]): Dictionary containing logits from the model. + model_label_mask (Tensor): A boolean mask indicating active labels for the model. + + Returns: + Dict[str, Tensor]: Dictionary with keys "prediction" and "confidence" containing + tensors of the same shape as logits, filled only for active labels. + """ sigmoid_logits = torch.sigmoid(model_output["logits"]) prediction = torch.full( (self._total_data_size, self._num_of_labels), -1, dtype=torch.bool @@ -53,10 +96,34 @@ def _get_pred_conf_from_model_output(self, model_output, model_label_mask): class NoActivationCondition(_Controller): - def __init__(self, **kwargs): + """ + A controller that queues and activates all models unconditionally. + + This implementation does not filter or select models dynamically. + """ + + def __init__(self, **kwargs: Any): + """ + Initializes the controller and loads all model names into the processing queue. + + Args: + **kwargs (Any): Keyword arguments passed to the _Controller initializer. + """ super().__init__(**kwargs) - self._model_queue: Deque = deque(list(self.model_configs.keys())) + self._model_queue: Deque[str] = deque(list(self.model_configs.keys())) + + def _controller( + self, model: ChebaiBaseNet, model_props: Dict[str, Tensor], **kwargs: Any + ) -> Dict[str, Tensor]: + """ + Performs inference with the model and extracts predictions and confidence values. + + Args: + model (ChebaiBaseNet): The model to perform inference with. + model_props (Dict[str, Tensor]): Dictionary with label mask and trust scores. - def _controller(self, model, model_props, **kwargs): + Returns: + Dict[str, Tensor]: Dictionary containing predictions and confidence scores. + """ model_output = self._forward_pass(model) return self._get_pred_conf_from_model_output(model_output, model_props["mask"]) diff --git a/chebai/ensemble/ensemble_run_script.py b/chebai/ensemble/ensemble_run_script.py index 127cc687..0862bef9 100644 --- a/chebai/ensemble/ensemble_run_script.py +++ b/chebai/ensemble/ensemble_run_script.py @@ -1,4 +1,5 @@ import importlib +from typing import Any, Dict, Type import yaml from jsonargparse import ArgumentParser @@ -6,34 +7,68 @@ from chebai.ensemble.base import EnsembleBase -def load_class(class_path: str): - """Dynamically import a class from a full dotted path.""" +def load_class(class_path: str) -> Type[EnsembleBase]: + """ + Dynamically imports and returns a class from a full dotted path. + + Args: + class_path (str): Full module path to the class (e.g., 'my_package.module.MyClass'). + + Returns: + Type[EnsembleBase]: The imported class object. + + Raises: + ModuleNotFoundError, AttributeError: If module or class cannot be loaded. + """ module_path, class_name = class_path.rsplit(".", 1) module = importlib.import_module(module_path) return getattr(module, class_name) -def load_config_and_instantiate(config_path: str): +def load_config_and_instantiate(config_path: str) -> EnsembleBase: + """ + Loads a YAML config file, imports the specified class, and instantiates it with the provided arguments. + + Args: + config_path (str): Path to the YAML configuration file. + + Returns: + EnsembleBase: An instantiated object of the specified class. + + Raises: + TypeError: If the loaded class is not a subclass of EnsembleBase. + """ with open(config_path, "r") as f: - config = yaml.safe_load(f) + config: Dict[str, Any] = yaml.safe_load(f) - class_path = config["class_path"] - init_args = config.get("init_args", {}) + class_path: str = config["class_path"] + init_args: Dict[str, Any] = config.get("init_args", {}) cls = load_class(class_path) + if not issubclass(cls, EnsembleBase): raise TypeError(f"{cls} must be subclass of EnsembleBase") + return cls(**init_args) if __name__ == "__main__": + # Example usage: + # python ensemble_run_script.py --config=configs/ensemble/fullEnsembleWithWMV.yaml + + # Set up argument parser to receive config file path from CLI parser = ArgumentParser() parser.add_argument("--config", type=str, help="Path to the YAML config file") + # Parse arguments from the command line args = parser.parse_args() + + # Load and instantiate the ensemble object ensemble = load_config_and_instantiate(args.config) + # Ensure the loaded object is a valid EnsembleBase instance if not isinstance(ensemble, EnsembleBase): raise TypeError("Object must be an instance of `EnsembleBase`") + # Run the ensemble pipeline ensemble.run_ensemble() From de6a70748313b3058a38f88f81469dfd820e8366 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 20 May 2025 14:53:19 +0200 Subject: [PATCH 35/78] delete dummy dataset --- chebai/preprocessing/datasets/_dummy.py | 97 ------------------------- configs/data/_dummy.yml | 3 - 2 files changed, 100 deletions(-) delete mode 100644 chebai/preprocessing/datasets/_dummy.py delete mode 100644 configs/data/_dummy.yml diff --git a/chebai/preprocessing/datasets/_dummy.py b/chebai/preprocessing/datasets/_dummy.py deleted file mode 100644 index 11d34862..00000000 --- a/chebai/preprocessing/datasets/_dummy.py +++ /dev/null @@ -1,97 +0,0 @@ -# This file is for developers only - -__all__ = [] # Nothing should be imported from this file - - -import random - -import numpy as np -from torch.utils.data import DataLoader, Dataset - -from chebai.preprocessing.datasets import XYBaseDataModule -from chebai.preprocessing.reader import ChemDataReader - - -class _DummyDataModule(XYBaseDataModule): - - READER = ChemDataReader - - def __init__(self, num_of_labels: int, feature_vector_size: int, *args, **kwargs): - super().__init__(*args, **kwargs) - self._num_of_labels = num_of_labels - self._feature_vector_size = feature_vector_size - assert self._num_of_labels is not None - assert self._feature_vector_size is not None - - def prepare_data(self): - pass - - def setup(self, stage=None): - pass - - @property - def num_of_labels(self): - return self._num_of_labels - - @property - def feature_vector_size(self): - return self._feature_vector_size - - def train_dataloader(self, *args, **kwargs) -> DataLoader: - dataset = _DummyDataset(100, self.num_of_labels, self.feature_vector_size) - return DataLoader( - dataset, - collate_fn=self.reader.collator, - batch_size=self.batch_size, - **kwargs, - ) - - def test_dataloader(self, *args, **kwargs) -> DataLoader: - dataset = _DummyDataset(20, self.num_of_labels, self.feature_vector_size) - return DataLoader( - dataset, - collate_fn=self.reader.collator, - batch_size=self.batch_size, - **kwargs, - ) - - def val_dataloader(self, *args, **kwargs) -> DataLoader: - dataset = _DummyDataset(10, self.num_of_labels, self.feature_vector_size) - return DataLoader( - dataset, - collate_fn=self.reader.collator, - batch_size=self.batch_size, - **kwargs, - ) - - @property - def _name(self) -> str: - return "_DummyDataModule" - - -class _DummyDataset(Dataset): - def __init__(self, num_samples: int, num_labels: int, feature_vector_size: int): - self.num_samples = num_samples - self.num_labels = num_labels - self.feature_vector_size = feature_vector_size - - def __len__(self): - return self.num_samples - - def __getitem__(self, idx): - return { - "features": np.random.randint( - 10, 100, size=self.feature_vector_size - ), # Random feature vector - "labels": np.random.choice( - [False, True], size=self.num_labels - ), # Random boolean labels - "ident": random.randint(1, 40000), # Random identifier - "group": None, # Default group value - } - - -if __name__ == "__main__": - dataset = _DummyDataset(num_samples=100, num_labels=5, feature_vector_size=20) - for i in range(10): - print(dataset[i]) diff --git a/configs/data/_dummy.yml b/configs/data/_dummy.yml deleted file mode 100644 index 180b6860..00000000 --- a/configs/data/_dummy.yml +++ /dev/null @@ -1,3 +0,0 @@ -class_path: chebai.preprocessing.datasets._dummy._DummyDataModule -init_args: - feature_vector_size: 20 From b471a05afd8d999a8de1ccd038a706287c76c5ee Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 20 May 2025 14:54:44 +0200 Subject: [PATCH 36/78] raname script with _ prefix --- .../ensemble/{ensemble_run_script.py => _ensemble_run_script.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename chebai/ensemble/{ensemble_run_script.py => _ensemble_run_script.py} (100%) diff --git a/chebai/ensemble/ensemble_run_script.py b/chebai/ensemble/_ensemble_run_script.py similarity index 100% rename from chebai/ensemble/ensemble_run_script.py rename to chebai/ensemble/_ensemble_run_script.py From 4c89dd3e978a4ca0b02e1d3a99dfaba8bd3e1670 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 21 May 2025 20:42:10 +0200 Subject: [PATCH 37/78] wrapper base --- chebai/wrappers/__init__.py | 3 +++ chebai/wrappers/_base.py | 41 +++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 chebai/wrappers/__init__.py create mode 100644 chebai/wrappers/_base.py diff --git a/chebai/wrappers/__init__.py b/chebai/wrappers/__init__.py new file mode 100644 index 00000000..62f836cd --- /dev/null +++ b/chebai/wrappers/__init__.py @@ -0,0 +1,3 @@ +from ._neural_network import NNWrapper + +__all__ = ["NNWrapper"] diff --git a/chebai/wrappers/_base.py b/chebai/wrappers/_base.py new file mode 100644 index 00000000..3a49bf59 --- /dev/null +++ b/chebai/wrappers/_base.py @@ -0,0 +1,41 @@ +import importlib +from abc import ABC, abstractmethod +from typing import overload + +import torch + + +class BaseWrapper(ABC): + def __init__(self, **kwargs): + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + @overload + def predict(self, smiles_list: list) -> list: + pass + + @overload + def predict(self, data_file_path: str) -> list: + pass + + def predict(self, x: list | str) -> list: + if isinstance(x, list): + return self._predict_from_list_of_smiles(x) + elif isinstance(x, str): + return self._predict_from_data_file(x) + else: + raise TypeError(f"Type {type(x)} is not supported.") + + @abstractmethod + def _predict_from_list_of_smiles(self, smiles_list: list) -> list: + pass + + @abstractmethod + def _predict_from_data_file(self, data_file_path: str) -> list: + pass + + @staticmethod + def _load_class(class_path): + class_name = class_path.split(".")[-1] + module_path = ".".join(class_path.split(".")[:-1]) + module = importlib.import_module(module_path) + return getattr(module, class_name) From 1563c7624c524315dd2c54edb19bdd1079d0401f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 21 May 2025 20:42:35 +0200 Subject: [PATCH 38/78] nn wrapper --- chebai/wrappers/_neural_network.py | 77 ++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 chebai/wrappers/_neural_network.py diff --git a/chebai/wrappers/_neural_network.py b/chebai/wrappers/_neural_network.py new file mode 100644 index 00000000..5345bdf2 --- /dev/null +++ b/chebai/wrappers/_neural_network.py @@ -0,0 +1,77 @@ +import os +from typing import Optional, Type + +import torch +from rdkit import Chem + +from chebai.models import ChebaiBaseNet +from chebai.preprocessing.reader import DataReader + +from ._base import BaseWrapper + + +class NNWrapper(BaseWrapper): + + def __init__( + self, + model: ChebaiBaseNet, + reader_cls: Type[DataReader], + reader_kwargs: Optional[dict] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.model: ChebaiBaseNet = model + if reader_kwargs is None: + reader_kwargs = dict() + self.reader = reader_cls(**reader_kwargs) + self.collator = reader_cls.COLLATOR() + + def _forward_pass(self, batch): + processable_data = self.model._process_batch( + self.collator(batch).to(self._device), 0 + ) + return self.model(processable_data, **processable_data["model_kwargs"]) + + def _read_smiles(self, smiles): + return self.reader.to_data(dict(features=smiles, labels=None)) + + def _predict_from_list_of_smiles(self, smiles_list) -> list: + token_dicts = [] + could_not_parse = [] + index_map = dict() + for i, smiles in enumerate(smiles_list): + try: + # Try to parse the smiles string + if not smiles: + raise ValueError() + d = self._read_smiles(smiles) + # This is just for sanity checks + rdmol = Chem.MolFromSmiles(smiles, sanitize=False) + except Exception as e: + # Note if it fails + could_not_parse.append(i) + print(f"Failing to parse {smiles} due to {e}") + else: + if rdmol is None: + could_not_parse.append(i) + else: + index_map[i] = len(token_dicts) + token_dicts.append(d) + print(f"Predicting {len(token_dicts), token_dicts} out of {len(smiles_list)}") + if token_dicts: + model_output = self._forward_pass(token_dicts) + if not isinstance(model_output, dict) and not "logits" in model_output: + raise ValueError() + return model_output + else: + raise ValueError() + + def _predict_from_data_file( + self, processed_dir_main: str, data_file_name="data.pt" + ) -> list: + data = torch.load( + os.path.join(processed_dir_main, self.reader.name(), data_file_name), + weights_only=False, + map_location=self._device, + ) + return self._forward_pass(data) From 2fec9ef259cc4c6c71d04c631125030bfb0b0aeb Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 21 May 2025 20:43:59 +0200 Subject: [PATCH 39/78] rename ensemble internal files with _ prefix --- chebai/ensemble/__init__.py | 4 +- chebai/ensemble/{base.py => _base.py} | 49 +++++++++++-------- .../{consolidator.py => _consolidator.py} | 2 +- chebai/ensemble/_constants.py | 7 +++ .../{controller.py => _controller.py} | 36 +------------- .../{ => _scripts}/_ensemble_run_script.py | 2 +- 6 files changed, 41 insertions(+), 59 deletions(-) rename chebai/ensemble/{base.py => _base.py} (91%) rename chebai/ensemble/{consolidator.py => _consolidator.py} (99%) create mode 100644 chebai/ensemble/_constants.py rename chebai/ensemble/{controller.py => _controller.py} (71%) rename chebai/ensemble/{ => _scripts}/_ensemble_run_script.py (97%) diff --git a/chebai/ensemble/__init__.py b/chebai/ensemble/__init__.py index 225b5a9b..67570b10 100644 --- a/chebai/ensemble/__init__.py +++ b/chebai/ensemble/__init__.py @@ -1,5 +1,5 @@ -from .consolidator import WeightedMajorityVoting -from .controller import NoActivationCondition +from ._consolidator import WeightedMajorityVoting +from ._controller import NoActivationCondition class FullEnsembleWMV(NoActivationCondition, WeightedMajorityVoting): diff --git a/chebai/ensemble/base.py b/chebai/ensemble/_base.py similarity index 91% rename from chebai/ensemble/base.py rename to chebai/ensemble/_base.py index f1cfd309..04e07b0a 100644 --- a/chebai/ensemble/base.py +++ b/chebai/ensemble/_base.py @@ -13,6 +13,8 @@ from chebai.preprocessing.structures import XYData from chebai.result.classification import print_metrics +from ._constants import * + class EnsembleBase(ABC): """ @@ -25,7 +27,6 @@ def __init__( self, model_configs: Dict[str, Dict[str, Any]], data_processed_dir_main: str, - reader_dir_name: str = "smiles_token", **kwargs: Any, ) -> None: """ @@ -42,7 +43,6 @@ def __init__( self.model_configs: Dict[str, Dict[str, Any]] = model_configs self.data_processed_dir_main: str = data_processed_dir_main - self.reader_dir_name: str = reader_dir_name self.input_dim: Optional[int] = kwargs.get("input_dim", None) self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -73,7 +73,13 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> No ValueError: If duplicate paths are found for model checkpoint, class, or labels. """ path_set, class_set, labels_set = set(), set(), set() - required_keys = {"class_path", "ckpt_path", "labels_path"} + required_keys = { + MODEL_CKPT_PATH, + MODEL_CLS_PATH, + MODEL_LBL_PATH, + WRAPPER_CLS_PATH, + READER_CLS_PATH, + } for model_name, config in model_configs.items(): missing_keys = required_keys - config.keys() @@ -82,22 +88,24 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> No f"Missing keys {missing_keys} in model '{model_name}' configuration." ) - model_path, class_path, labels_path = ( - config["ckpt_path"], - config["class_path"], - config["labels_path"], + model_ckpt_path, model_class_path, model_labels_path = ( + config[MODEL_CKPT_PATH], + config[MODEL_CLS_PATH], + config[MODEL_LBL_PATH], ) - if model_path in path_set: - raise ValueError(f"Duplicate model path detected: '{model_path}'.") - if class_path in class_set: - raise ValueError(f"Duplicate class path detected: '{class_path}'.") - if labels_path in labels_set: - raise ValueError(f"Duplicate labels path: {labels_path}.") + if model_ckpt_path in path_set: + raise ValueError(f"Duplicate model path detected: '{model_ckpt_path}'.") + if model_class_path in class_set: + raise ValueError( + f"Duplicate class path detected: '{model_class_path}'." + ) + if model_labels_path in labels_set: + raise ValueError(f"Duplicate labels path: {model_labels_path}.") - path_set.add(model_path) - class_set.add(class_path) - labels_set.add(labels_path) + path_set.add(model_ckpt_path) + class_set.add(model_class_path) + labels_set.add(model_labels_path) def _load_data_module_labels(self) -> None: """ @@ -180,10 +188,11 @@ def _load_model_and_its_props( f"Model path '{model_ckpt_path}' for '{model_name}' does not exist." ) - class_name = model_class_path.split(".")[-1] - module_path = ".".join(model_class_path.split(".")[:-1]) - module = importlib.import_module(module_path) - lightning_cls = getattr(module, class_name) + def load_class(class_path): + class_name = class_path.split(".")[-1] + module_path = ".".join(class_path.split(".")[:-1]) + module = importlib.import_module(module_path) + return getattr(module, class_name) assert isinstance(lightning_cls, type), f"{class_name} is not a class." assert issubclass( diff --git a/chebai/ensemble/consolidator.py b/chebai/ensemble/_consolidator.py similarity index 99% rename from chebai/ensemble/consolidator.py rename to chebai/ensemble/_consolidator.py index c8d79eed..f629ef84 100644 --- a/chebai/ensemble/consolidator.py +++ b/chebai/ensemble/_consolidator.py @@ -3,7 +3,7 @@ from torch import Tensor -from .base import EnsembleBase +from ._base import EnsembleBase class WeightedMajorityVoting(EnsembleBase, ABC): diff --git a/chebai/ensemble/_constants.py b/chebai/ensemble/_constants.py new file mode 100644 index 00000000..6253c188 --- /dev/null +++ b/chebai/ensemble/_constants.py @@ -0,0 +1,7 @@ +MODEL_CLS_PATH = "model_class_path" +MODEL_LBL_PATH = "model_labels_path" +MODEL_CKPT_PATH = "model_ckpt_path" + +WRAPPER_CLS_PATH = "wrapper_class_path" + +READER_CLS_PATH = "reader_class_path" diff --git a/chebai/ensemble/controller.py b/chebai/ensemble/_controller.py similarity index 71% rename from chebai/ensemble/controller.py rename to chebai/ensemble/_controller.py index ddffc6d0..2868ace4 100644 --- a/chebai/ensemble/controller.py +++ b/chebai/ensemble/_controller.py @@ -9,7 +9,7 @@ from chebai.models import ChebaiBaseNet from chebai.preprocessing.collate import RaggedCollator -from .base import EnsembleBase +from ._base import EnsembleBase class _Controller(EnsembleBase, ABC): @@ -32,43 +32,9 @@ def __init__(self, **kwargs: Any): super().__init__(**kwargs) self._collator = RaggedCollator() - self._collated_data = self._load_and_collate_data() self.input_dim = len(self._collated_data.x[0]) self._total_data_size: int = len(self._collated_data) - def _load_and_collate_data(self) -> Any: - """ - Loads and collates data using RaggedCollator. - - Returns: - Collated data object with `.x` and `.y` attributes moved to device. - """ - data = torch.load( - os.path.join(self.data_processed_dir_main, self.reader_dir_name, "data.pt"), - weights_only=False, - map_location=self._device, - ) - collated_data = self._collator(data) - collated_data.x = collated_data.to_x(self._device) - if collated_data.y is not None: - collated_data.y = collated_data.to_y(self._device) - return collated_data - - def _forward_pass(self, model: ChebaiBaseNet) -> Dict[str, Tensor]: - """ - Runs a forward pass of the given model on the collated data. - - Args: - model (ChebaiBaseNet): The model to perform inference with. - - Returns: - Dict[str, Tensor]: Model output dictionary containing logits and other keys. - """ - processable_data = model._process_batch(self._collated_data, 0) - del processable_data["loss_kwargs"] - model_output = model(processable_data, **processable_data["model_kwargs"]) - return model_output - def _get_pred_conf_from_model_output( self, model_output: Dict[str, Tensor], model_label_mask: Tensor ) -> Dict[str, Tensor]: diff --git a/chebai/ensemble/_ensemble_run_script.py b/chebai/ensemble/_scripts/_ensemble_run_script.py similarity index 97% rename from chebai/ensemble/_ensemble_run_script.py rename to chebai/ensemble/_scripts/_ensemble_run_script.py index 0862bef9..045c2a53 100644 --- a/chebai/ensemble/_ensemble_run_script.py +++ b/chebai/ensemble/_scripts/_ensemble_run_script.py @@ -4,7 +4,7 @@ import yaml from jsonargparse import ArgumentParser -from chebai.ensemble.base import EnsembleBase +from ._base import EnsembleBase def load_class(class_path: str) -> Type[EnsembleBase]: From 682801f64f993e7dc33b512443ae906a7bcc96db Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 22 May 2025 12:23:17 +0200 Subject: [PATCH 40/78] chemlog wrapper --- chebai/wrappers/_chemlog.py | 103 ++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 chebai/wrappers/_chemlog.py diff --git a/chebai/wrappers/_chemlog.py b/chebai/wrappers/_chemlog.py new file mode 100644 index 00000000..0df47e78 --- /dev/null +++ b/chebai/wrappers/_chemlog.py @@ -0,0 +1,103 @@ +import os +import sys +from typing import Optional + +sys.path.append(os.path.join("..", "..", "..", "PycharmProjects", "chemlog2")) +import chemlog +from chebi_utils import CHEBI_FRAGMENT, get_transitive_predictions +from chemlog.classification.charge_classifier import ( + ChargeCategories, + get_charge_category, +) +from chemlog.classification.peptide_size_classifier import get_n_amino_acid_residues +from chemlog.classification.proteinogenics_classifier import ( + get_proteinogenic_amino_acids, +) +from chemlog.classification.substructure_classifier import ( + is_diketopiperazine, + is_emericellamide, +) +from chemlog.cli import resolve_chebi_classes +from prediction_models.base import PredictionModel +from rdkit import Chem + + +class ChemLog(PredictionModel): + + def __init__( + self, + name: Optional[str] = None, + description: Optional[ + str + ] = "A rule-based model for predicting peptides and peptide-like molecules.", + ): + super().__init__(name, description) + + def get_chemlog_results(self, smiles_list) -> list: + all_preds = [] + for i, smiles in enumerate(smiles_list): + mol = Chem.MolFromSmiles(smiles, sanitize=False) + if mol is None or not smiles: + all_preds.append(None) + continue + mol.UpdatePropertyCache() + charge_category = get_charge_category(mol) + n_amino_acid_residues, add_output = get_n_amino_acid_residues(mol) + r = { + "charge_category": charge_category.name, + "n_amino_acid_residues": n_amino_acid_residues, + } + if n_amino_acid_residues == 5: + r["emericellamide"] = is_emericellamide(mol)[0] + if n_amino_acid_residues == 2: + r["2,5-diketopiperazines"] = is_diketopiperazine(mol)[0] + + chebi_classes = [f"CHEBI:{c}" for c in resolve_chebi_classes(r)] + + all_preds.append(chebi_classes) + return all_preds + + def get_chemlog_result_info(self, smiles): + """Get classification for single molecule with additional information.""" + mol = Chem.MolFromSmiles(smiles, sanitize=False) + if mol is None or not smiles: + return {"error": "Failed to parse SMILES"} + mol.UpdatePropertyCache() + try: + Chem.Kekulize(mol) + except Chem.KekulizeException as e: + pass + + charge_category = get_charge_category(mol) + n_amino_acid_residues, add_output = get_n_amino_acid_residues(mol) + if n_amino_acid_residues > 1: + proteinogenics, proteinogenics_locations, _ = get_proteinogenic_amino_acids( + mol, add_output["amino_residue"], add_output["carboxy_residue"] + ) + else: + proteinogenics, proteinogenics_locations, _ = [], [], [] + results = { + "charge_category": charge_category.name, + "n_amino_acid_residues": n_amino_acid_residues, + "proteinogenics": proteinogenics, + "proteinogenics_locations": proteinogenics_locations, + } + + if n_amino_acid_residues == 5: + emericellamide = is_emericellamide(mol) + results["emericellamide"] = emericellamide[0] + if emericellamide[0]: + results["emericellamide_atoms"] = emericellamide[1] + if n_amino_acid_residues == 2: + diketopiperazine = is_diketopiperazine(mol) + results["2,5-diketopiperazines"] = diketopiperazine[0] + if diketopiperazine[0]: + results["2,5-diketopiperazines_atoms"] = diketopiperazine[1] + + return {**results, **add_output} + + def predict(self, smiles_list): + return [ + get_transitive_predictions([positive_i]) + for positive_i in self.get_chemlog_results(smiles_list) + ] From 2b2d458a647ad91c01b0062deaea82caa693af6a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 22 May 2025 12:23:39 +0200 Subject: [PATCH 41/78] gnn wrapper --- chebai/wrappers/_gnn.py | 115 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 chebai/wrappers/_gnn.py diff --git a/chebai/wrappers/_gnn.py b/chebai/wrappers/_gnn.py new file mode 100644 index 00000000..6035879b --- /dev/null +++ b/chebai/wrappers/_gnn.py @@ -0,0 +1,115 @@ +from typing import Optional, Union + +import chebai_graph.preprocessing.properties as p +import torch +from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred +from chebai_graph.preprocessing.datasets.chebi import ( + ChEBI50GraphProperties, + ChEBI100GraphProperties, + GraphPropertiesMixIn, +) +from chebai_graph.preprocessing.property_encoder import IndexEncoder, OneHotEncoder +from torch_geometric.data.data import Data as GeomData + +from ._neural_network import NNWrapper + +if torch.cuda.is_available(): + device = "cuda" +else: + device = "cpu" + + +class GNNResGated(NNWrapper): + + def __init__( + self, + checkpoint_path: str, + data_class: Union[GraphPropertiesMixIn, str], + prediction_headers_path: str, + batch_size: Optional[int] = 32, + name: Optional[str] = None, + description: Optional[str] = "Residual-gated Graph Convolutional Network for " + "predicting arbitrary ChEBI classes.", + ): + super().__init__(prediction_headers_path, batch_size, name, description) + self.model = ResGatedGraphConvNetGraphPred.load_from_checkpoint( + checkpoint_path, + map_location=torch.device(device), + criterion=None, + strict=False, + metrics=dict(train=dict(), test=dict(), validation=dict()), + pretrained_checkpoint=None, + config={ + "in_length": 256, + "hidden_length": 512, + "dropout_rate": 0.1, + "n_conv_layers": 3, + "n_linear_layers": 3, + "n_atom_properties": 158, + "n_bond_properties": 7, + "n_molecule_properties": 200, + }, + ) + + def _read_smiles(self, smiles): + d = self.reader.to_data(dict(features=smiles, labels=None)) + geom_data = d["features"] + assert isinstance(geom_data, GeomData), "" + edge_attr = geom_data.edge_attr + x = geom_data.x + molecule_attr = torch.empty((1, 0)) + for property in self.data_class.properties: + property_values = reader.read_property(smiles, property) + encoded_values = [] + for value in property_values: + # cant use standard encode for index encoder because model has been trained on a certain range of values + # use default value if we meet an unseen value + if isinstance(property.encoder, IndexEncoder): + if str(value) in property.encoder.cache: + index = ( + property.encoder.cache.index(str(value)) + + property.encoder.offset + ) + else: + index = 0 + print( + f"Unknown property value {value} for property {property} at smiles {smiles}" + ) + if isinstance(property.encoder, OneHotEncoder): + encoded_values.append( + torch.nn.functional.one_hot( + torch.tensor(index), + num_classes=property.encoder.get_encoding_length(), + ) + ) + else: + encoded_values.append(torch.tensor([index])) + + else: + encoded_values.append(property.encoder.encode(value)) + if len(encoded_values) > 0: + encoded_values = torch.stack(encoded_values) + + if isinstance(encoded_values, torch.Tensor): + if len(encoded_values.size()) == 0: + encoded_values = encoded_values.unsqueeze(0) + if len(encoded_values.size()) == 1: + encoded_values = encoded_values.unsqueeze(1) + else: + encoded_values = torch.zeros( + (0, property.encoder.get_encoding_length()) + ) + if isinstance(property, p.AtomProperty): + x = torch.cat([x, encoded_values], dim=1) + elif isinstance(property, p.BondProperty): + edge_attr = torch.cat([edge_attr, encoded_values], dim=1) + else: + molecule_attr = torch.cat([molecule_attr, encoded_values[0]], dim=1) + + d["features"] = GeomData( + x=x, + edge_index=geom_data.edge_index, + edge_attr=edge_attr, + molecule_attr=molecule_attr, + ) + return d From 7cbb73297003595d34c1334bad88fba798d08c68 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 22 May 2025 12:24:21 +0200 Subject: [PATCH 42/78] move related code from ensemble base to nn wrapper --- chebai/ensemble/_base.py | 135 ------------------------- chebai/wrappers/_neural_network.py | 153 +++++++++++++++++++++++++++-- 2 files changed, 143 insertions(+), 145 deletions(-) diff --git a/chebai/ensemble/_base.py b/chebai/ensemble/_base.py index 04e07b0a..25904ef2 100644 --- a/chebai/ensemble/_base.py +++ b/chebai/ensemble/_base.py @@ -166,141 +166,6 @@ def run_ensemble(self) -> None: classes=list(self._dm_labels.keys()), ) - def _load_model_and_its_props( - self, model_name: str - ) -> Tuple[LightningModule, Dict[str, torch.Tensor]]: - """ - Loads a model checkpoint and its label-related properties. - - Args: - model_name (str): Name of the model to load. - - Returns: - Tuple[LightningModule, Dict[str, torch.Tensor]]: The model and its label properties. - """ - config = self.model_configs[model_name] - model_ckpt_path = config["ckpt_path"] - model_class_path = config["class_path"] - model_labels_path = config["labels_path"] - - if not os.path.exists(model_ckpt_path): - raise FileNotFoundError( - f"Model path '{model_ckpt_path}' for '{model_name}' does not exist." - ) - - def load_class(class_path): - class_name = class_path.split(".")[-1] - module_path = ".".join(class_path.split(".")[:-1]) - module = importlib.import_module(module_path) - return getattr(module, class_name) - - assert isinstance(lightning_cls, type), f"{class_name} is not a class." - assert issubclass( - lightning_cls, ChebaiBaseNet - ), f"{class_name} must inherit from ChebaiBaseNet" - - try: - model = lightning_cls.load_from_checkpoint( - model_ckpt_path, input_dim=self.input_dim - ) - model.eval() - model.freeze() - model_label_props = self._generate_model_label_props(model_labels_path) - except Exception as e: - raise RuntimeError(f"Error loading model {model_name}") from e - - return model, model_label_props - - def _generate_model_label_props(self, labels_path: str) -> Dict[str, torch.Tensor]: - """ - Generates label mask and confidence tensors (TPV, FPV) for a model. - - Args: - labels_path (str): Path to the labels JSON file. - - Returns: - Dict[str, torch.Tensor]: Dictionary containing mask, TPV and FPV tensors. - """ - rank_zero_info("\t Generating model label masks and properties") - labels_dict = self._load_model_labels(labels_path) - - model_label_indices, tpv_label_values, fpv_label_values = [], [], [] - - for label, props in labels_dict.items(): - if label in self._dm_labels: - try: - self._validate_model_labels_json_element(labels_dict[label]) - except Exception as e: - raise Exception(f"Label '{label}' has an unexpected error") from e - - model_label_indices.append(self._dm_labels[label]) - tpv_label_values.append(props["TPV"]) - fpv_label_values.append(props["FPV"]) - - if not all([model_label_indices, tpv_label_values, fpv_label_values]): - raise ValueError(f"No valid label values found in {labels_path}.") - - # Create masks to apply predictions only to known classes - mask = torch.zeros(self._num_of_labels, dtype=torch.bool, device=self._device) - mask[torch.tensor(model_label_indices, device=self._device)] = True - - tpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self._device) - fpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self._device) - - tpv_tensor[mask] = torch.tensor( - tpv_label_values, dtype=torch.float, device=self._device - ) - fpv_tensor[mask] = torch.tensor( - fpv_label_values, dtype=torch.float, device=self._device - ) - - self._num_models_per_label += mask - return {"mask": mask, "tpv_tensor": tpv_tensor, "fpv_tensor": fpv_tensor} - - @staticmethod - def _load_model_labels(labels_path: str) -> Dict[str, Dict[str, float]]: - """ - Loads a JSON label file for a model. - - Args: - labels_path (str): Path to the JSON file. - - Returns: - Dict[str, Dict[str, float]]: Parsed label confidence data. - - Raises: - FileNotFoundError: If the file is missing. - TypeError: If the file is not a JSON. - """ - if not os.path.exists(labels_path): - raise FileNotFoundError(f"{labels_path} does not exist.") - if not labels_path.endswith(".json"): - raise TypeError(f"{labels_path} is not a JSON file.") - with open(labels_path, "r") as f: - return json.load(f) - - @staticmethod - def _validate_model_labels_json_element(label_dict: Dict[str, Any]) -> None: - """ - Validates a label confidence dictionary to ensure required keys and values are valid. - - Args: - label_dict (Dict[str, Any]): Label data with TPV and FPV keys. - - Raises: - AttributeError: If required keys are missing. - ValueError: If values are not valid floats or are negative. - """ - for key in ["TPV", "FPV"]: - if key not in label_dict: - raise AttributeError(f"Missing key '{key}' in label dict.") - try: - value = float(label_dict[key]) - if value < 0: - raise ValueError(f"'{key}' must be non-negative but got {value}") - except (TypeError, ValueError) as e: - raise ValueError(f"Invalid value for '{key}': {label_dict[key]}") from e - @abstractmethod def _controller( self, diff --git a/chebai/wrappers/_neural_network.py b/chebai/wrappers/_neural_network.py index 5345bdf2..1b6bdcaa 100644 --- a/chebai/wrappers/_neural_network.py +++ b/chebai/wrappers/_neural_network.py @@ -1,7 +1,10 @@ +import importlib +import json import os from typing import Optional, Type import torch +from lightning import LightningModule from rdkit import Chem from chebai.models import ChebaiBaseNet @@ -14,27 +17,19 @@ class NNWrapper(BaseWrapper): def __init__( self, - model: ChebaiBaseNet, + model_config: dict, reader_cls: Type[DataReader], reader_kwargs: Optional[dict] = None, **kwargs, ): super().__init__(**kwargs) + self._model_class_path = model_config[MODEL_CLS_PATH] self.model: ChebaiBaseNet = model if reader_kwargs is None: reader_kwargs = dict() self.reader = reader_cls(**reader_kwargs) self.collator = reader_cls.COLLATOR() - def _forward_pass(self, batch): - processable_data = self.model._process_batch( - self.collator(batch).to(self._device), 0 - ) - return self.model(processable_data, **processable_data["model_kwargs"]) - - def _read_smiles(self, smiles): - return self.reader.to_data(dict(features=smiles, labels=None)) - def _predict_from_list_of_smiles(self, smiles_list) -> list: token_dicts = [] could_not_parse = [] @@ -66,6 +61,15 @@ def _predict_from_list_of_smiles(self, smiles_list) -> list: else: raise ValueError() + def _read_smiles(self, smiles): + return self.reader.to_data(dict(features=smiles, labels=None)) + + def _forward_pass(self, batch): + processable_data = self.model._process_batch( + self.collator(batch).to(self._device), 0 + ) + return self.model(processable_data, **processable_data["model_kwargs"]) + def _predict_from_data_file( self, processed_dir_main: str, data_file_name="data.pt" ) -> list: @@ -75,3 +79,132 @@ def _predict_from_data_file( map_location=self._device, ) return self._forward_pass(data) + + def _load_model_and_its_props( + self, model_name: str + ) -> tuple[LightningModule, dict[str, torch.Tensor]]: + """ + Loads a model checkpoint and its label-related properties. + + Args: + model_name (str): Name of the model to load. + + Returns: + Tuple[LightningModule, Dict[str, torch.Tensor]]: The model and its label properties. + """ + config = self.model_configs[model_name] + model_ckpt_path = config["ckpt_path"] + model_class_path = config["class_path"] + model_labels_path = config["labels_path"] + + if not os.path.exists(model_ckpt_path): + raise FileNotFoundError( + f"Model path '{model_ckpt_path}' for '{model_name}' does not exist." + ) + + assert isinstance(lightning_cls, type), f"{class_name} is not a class." + assert issubclass( + lightning_cls, ChebaiBaseNet + ), f"{class_name} must inherit from ChebaiBaseNet" + + try: + model = lightning_cls.load_from_checkpoint( + model_ckpt_path, input_dim=self.input_dim + ) + model.eval() + model.freeze() + model_label_props = self._generate_model_label_props(model_labels_path) + except Exception as e: + raise RuntimeError(f"Error loading model {model_name}") from e + + return model, model_label_props + + def _generate_model_label_props(self, labels_path: str) -> dict[str, torch.Tensor]: + """ + Generates label mask and confidence tensors (TPV, FPV) for a model. + + Args: + labels_path (str): Path to the labels JSON file. + + Returns: + Dict[str, torch.Tensor]: Dictionary containing mask, TPV and FPV tensors. + """ + rank_zero_info("\t Generating model label masks and properties") + labels_dict = self._load_model_labels(labels_path) + + model_label_indices, tpv_label_values, fpv_label_values = [], [], [] + + for label, props in labels_dict.items(): + if label in self._dm_labels: + try: + self._validate_model_labels_json_element(labels_dict[label]) + except Exception as e: + raise Exception(f"Label '{label}' has an unexpected error") from e + + model_label_indices.append(self._dm_labels[label]) + tpv_label_values.append(props["TPV"]) + fpv_label_values.append(props["FPV"]) + + if not all([model_label_indices, tpv_label_values, fpv_label_values]): + raise ValueError(f"No valid label values found in {labels_path}.") + + # Create masks to apply predictions only to known classes + mask = torch.zeros(self._num_of_labels, dtype=torch.bool, device=self._device) + mask[torch.tensor(model_label_indices, device=self._device)] = True + + tpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self._device) + fpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self._device) + + tpv_tensor[mask] = torch.tensor( + tpv_label_values, dtype=torch.float, device=self._device + ) + fpv_tensor[mask] = torch.tensor( + fpv_label_values, dtype=torch.float, device=self._device + ) + + self._num_models_per_label += mask + return {"mask": mask, "tpv_tensor": tpv_tensor, "fpv_tensor": fpv_tensor} + + @staticmethod + def _load_model_labels(labels_path: str) -> dict[str, dict[str, float]]: + """ + Loads a JSON label file for a model. + + Args: + labels_path (str): Path to the JSON file. + + Returns: + Dict[str, Dict[str, float]]: Parsed label confidence data. + + Raises: + FileNotFoundError: If the file is missing. + TypeError: If the file is not a JSON. + """ + if not os.path.exists(labels_path): + raise FileNotFoundError(f"{labels_path} does not exist.") + if not labels_path.endswith(".json"): + raise TypeError(f"{labels_path} is not a JSON file.") + with open(labels_path, "r") as f: + return json.load(f) + + @staticmethod + def _validate_model_labels_json_element(label_dict: dict[str, str]) -> None: + """ + Validates a label confidence dictionary to ensure required keys and values are valid. + + Args: + label_dict (Dict[str, Any]): Label data with TPV and FPV keys. + + Raises: + AttributeError: If required keys are missing. + ValueError: If values are not valid floats or are negative. + """ + for key in ["TPV", "FPV"]: + if key not in label_dict: + raise AttributeError(f"Missing key '{key}' in label dict.") + try: + value = float(label_dict[key]) + if value < 0: + raise ValueError(f"'{key}' must be non-negative but got {value}") + except (TypeError, ValueError) as e: + raise ValueError(f"Invalid value for '{key}': {label_dict[key]}") from e From a7df38444f47d5f00d61d97980ca9d533f879c78 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 22 May 2025 12:26:03 +0200 Subject: [PATCH 43/78] move constants to wrappers --- chebai/{ensemble => wrappers}/_constants.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename chebai/{ensemble => wrappers}/_constants.py (100%) diff --git a/chebai/ensemble/_constants.py b/chebai/wrappers/_constants.py similarity index 100% rename from chebai/ensemble/_constants.py rename to chebai/wrappers/_constants.py From ee0aef18330627478e864d3254887c436c53227e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 22 May 2025 20:10:26 +0200 Subject: [PATCH 44/78] move prop loading to base --- chebai/wrappers/_base.py | 120 ++++++++++++++++-- chebai/wrappers/_constants.py | 1 + chebai/wrappers/_neural_network.py | 196 ++++++++--------------------- 3 files changed, 165 insertions(+), 152 deletions(-) diff --git a/chebai/wrappers/_base.py b/chebai/wrappers/_base.py index 3a49bf59..83898b38 100644 --- a/chebai/wrappers/_base.py +++ b/chebai/wrappers/_base.py @@ -1,36 +1,140 @@ import importlib +import json +import os from abc import ABC, abstractmethod from typing import overload import torch +from ._constants import MODEL_CLS_PATH, MODEL_LBL_PATH, READER_CLS_PATH + class BaseWrapper(ABC): - def __init__(self, **kwargs): + def __init__( + self, + model_name: str, + model_config: dict[str, str], + dm_labels: dict[str, int], + **kwargs, + ): self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self._model_config = model_config + self._model_name = model_name + self._model_class_path = self._model_config[MODEL_CLS_PATH] + self._model_labels_path = self._model_config[MODEL_LBL_PATH] + self._dm_labels: dict[str, int] = dm_labels + self._model_props = self._generate_model_label_props() + + def _generate_model_label_props(self) -> dict[str, torch.Tensor]: + """ + Generates label mask and confidence tensors (TPV, FPV) for a model. + + Returns: + Dict[str, torch.Tensor]: Dictionary containing mask, TPV and FPV tensors. + """ + print("\t Generating model label masks and properties") + labels_dict = self._load_model_labels() + + model_label_indices, tpv_label_values, fpv_label_values = [], [], [] + + for label, props in labels_dict.items(): + if label in self._dm_labels: + try: + self._validate_model_labels_json_element(labels_dict[label]) + except Exception as e: + raise Exception(f"Label '{label}' has an unexpected error") from e + + model_label_indices.append(self._dm_labels[label]) + tpv_label_values.append(props["TPV"]) + fpv_label_values.append(props["FPV"]) + + if not all([model_label_indices, tpv_label_values, fpv_label_values]): + raise ValueError( + f"No valid label values found in {self._model_labels_path}." + ) + + # Create masks to apply predictions only to known classes + mask = torch.zeros(len(self._dm_labels), dtype=torch.bool, device=self._device) + mask[torch.tensor(model_label_indices, device=self._device)] = True + + tpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self._device) + fpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self._device) + + tpv_tensor[mask] = torch.tensor( + tpv_label_values, dtype=torch.float, device=self._device + ) + fpv_tensor[mask] = torch.tensor( + fpv_label_values, dtype=torch.float, device=self._device + ) + + return {"mask": mask, "tpv_tensor": tpv_tensor, "fpv_tensor": fpv_tensor} + + def _load_model_labels(self) -> dict[str, dict[str, float]]: + """ + Loads a JSON label file for a model. + + Returns: + Dict[str, Dict[str, float]]: Parsed label confidence data. + + Raises: + FileNotFoundError: If the file is missing. + TypeError: If the file is not a JSON. + """ + if not os.path.exists(self._model_labels_path): + raise FileNotFoundError(f"{self._model_labels_path} does not exist.") + if not self._model_labels_path.endswith(".json"): + raise TypeError(f"{self._model_labels_path} is not a JSON file.") + with open(self._model_labels_path, "r") as f: + return json.load(f) + + @staticmethod + def _validate_model_labels_json_element(label_dict: dict[str, float]) -> None: + """ + Validates a label confidence dictionary to ensure required keys and values are valid. + + Args: + label_dict (Dict[str, Any]): Label data with TPV and FPV keys. + + Raises: + AttributeError: If required keys are missing. + ValueError: If values are not valid floats or are negative. + """ + for key in ["TPV", "FPV"]: + if key not in label_dict: + raise AttributeError(f"Missing key '{key}' in label dict.") + try: + value = float(label_dict[key]) + if value < 0: + raise ValueError(f"'{key}' must be non-negative but got {value}") + except (TypeError, ValueError) as e: + raise ValueError(f"Invalid value for '{key}': {label_dict[key]}") from e + + @property + def name(self): + return f"Wrapper({self.__class__.__name__}) for model: {self._model_name}" @overload - def predict(self, smiles_list: list) -> list: + def predict(self, smiles_list: list) -> tuple[dict, dict]: pass @overload - def predict(self, data_file_path: str) -> list: + def predict(self, data_file_path: str) -> tuple[dict, dict]: pass - def predict(self, x: list | str) -> list: + def predict(self, x: list | str) -> tuple[dict, dict]: if isinstance(x, list): - return self._predict_from_list_of_smiles(x) + return self._predict_from_list_of_smiles(x), self._model_props elif isinstance(x, str): - return self._predict_from_data_file(x) + return self._predict_from_data_file(x), self._model_props else: raise TypeError(f"Type {type(x)} is not supported.") @abstractmethod - def _predict_from_list_of_smiles(self, smiles_list: list) -> list: + def _predict_from_list_of_smiles(self, smiles_list: list) -> dict: pass @abstractmethod - def _predict_from_data_file(self, data_file_path: str) -> list: + def _predict_from_data_file(self, data_file_path: str) -> dict: pass @staticmethod diff --git a/chebai/wrappers/_constants.py b/chebai/wrappers/_constants.py index 6253c188..effc017a 100644 --- a/chebai/wrappers/_constants.py +++ b/chebai/wrappers/_constants.py @@ -5,3 +5,4 @@ WRAPPER_CLS_PATH = "wrapper_class_path" READER_CLS_PATH = "reader_class_path" +READER_KWARGS = "reader_kwargs" diff --git a/chebai/wrappers/_neural_network.py b/chebai/wrappers/_neural_network.py index 1b6bdcaa..052798f8 100644 --- a/chebai/wrappers/_neural_network.py +++ b/chebai/wrappers/_neural_network.py @@ -1,34 +1,71 @@ -import importlib -import json import os from typing import Optional, Type import torch -from lightning import LightningModule from rdkit import Chem from chebai.models import ChebaiBaseNet from chebai.preprocessing.reader import DataReader from ._base import BaseWrapper +from ._constants import MODEL_CKPT_PATH, READER_CLS_PATH, READER_KWARGS class NNWrapper(BaseWrapper): def __init__( self, - model_config: dict, reader_cls: Type[DataReader], reader_kwargs: Optional[dict] = None, **kwargs, ): super().__init__(**kwargs) - self._model_class_path = model_config[MODEL_CLS_PATH] - self.model: ChebaiBaseNet = model - if reader_kwargs is None: - reader_kwargs = dict() - self.reader = reader_cls(**reader_kwargs) - self.collator = reader_cls.COLLATOR() + + self._model_ckpt_path = self._model_config[MODEL_CKPT_PATH] + self._reader_class_path = self._model_config[READER_CLS_PATH] + self._reader_kwargs: dict = ( + self._model_config[READER_KWARGS] + if self._model_config[READER_KWARGS] + else dict() + ) + + self._reader = self._load_class(self._reader_class_path)(**self._reader_kwargs) + self._collator = reader_cls.COLLATOR() + self._model: ChebaiBaseNet = self._load_model_() + + def _load_model_(self) -> ChebaiBaseNet: + """ + Loads a model checkpoint and its label-related properties. + + Args: + model_name (str): Name of the model to load. + + Returns: + Tuple[LightningModule, Dict[str, torch.Tensor]]: The model and its label properties. + """ + + if not os.path.exists(self._model_ckpt_path): + raise FileNotFoundError( + f"Model path '{self._model_ckpt_path}' for '{self._model_name}' does not exist." + ) + + lightning_cls = self._load_class(self._model_class_path) + + assert isinstance(lightning_cls, type), f"{lightning_cls} is not a class." + assert issubclass( + lightning_cls, ChebaiBaseNet + ), f"{lightning_cls} must inherit from ChebaiBaseNet" + + try: + model = lightning_cls.load_from_checkpoint( + self._model_ckpt_path, input_dim=self.input_dim + ) + model.eval() + model.freeze() + except Exception as e: + raise RuntimeError(f"Error loading model {self._model_name}") from e + + return model def _predict_from_list_of_smiles(self, smiles_list) -> list: token_dicts = [] @@ -62,149 +99,20 @@ def _predict_from_list_of_smiles(self, smiles_list) -> list: raise ValueError() def _read_smiles(self, smiles): - return self.reader.to_data(dict(features=smiles, labels=None)) + return self._reader.to_data(dict(features=smiles, labels=None)) def _forward_pass(self, batch): - processable_data = self.model._process_batch( - self.collator(batch).to(self._device), 0 + processable_data = self._model._process_batch( + self._collator(batch).to(self._device), 0 ) - return self.model(processable_data, **processable_data["model_kwargs"]) + return self._model(processable_data, **processable_data["model_kwargs"]) def _predict_from_data_file( self, processed_dir_main: str, data_file_name="data.pt" ) -> list: data = torch.load( - os.path.join(processed_dir_main, self.reader.name(), data_file_name), + os.path.join(processed_dir_main, self._reader.name(), data_file_name), weights_only=False, map_location=self._device, ) return self._forward_pass(data) - - def _load_model_and_its_props( - self, model_name: str - ) -> tuple[LightningModule, dict[str, torch.Tensor]]: - """ - Loads a model checkpoint and its label-related properties. - - Args: - model_name (str): Name of the model to load. - - Returns: - Tuple[LightningModule, Dict[str, torch.Tensor]]: The model and its label properties. - """ - config = self.model_configs[model_name] - model_ckpt_path = config["ckpt_path"] - model_class_path = config["class_path"] - model_labels_path = config["labels_path"] - - if not os.path.exists(model_ckpt_path): - raise FileNotFoundError( - f"Model path '{model_ckpt_path}' for '{model_name}' does not exist." - ) - - assert isinstance(lightning_cls, type), f"{class_name} is not a class." - assert issubclass( - lightning_cls, ChebaiBaseNet - ), f"{class_name} must inherit from ChebaiBaseNet" - - try: - model = lightning_cls.load_from_checkpoint( - model_ckpt_path, input_dim=self.input_dim - ) - model.eval() - model.freeze() - model_label_props = self._generate_model_label_props(model_labels_path) - except Exception as e: - raise RuntimeError(f"Error loading model {model_name}") from e - - return model, model_label_props - - def _generate_model_label_props(self, labels_path: str) -> dict[str, torch.Tensor]: - """ - Generates label mask and confidence tensors (TPV, FPV) for a model. - - Args: - labels_path (str): Path to the labels JSON file. - - Returns: - Dict[str, torch.Tensor]: Dictionary containing mask, TPV and FPV tensors. - """ - rank_zero_info("\t Generating model label masks and properties") - labels_dict = self._load_model_labels(labels_path) - - model_label_indices, tpv_label_values, fpv_label_values = [], [], [] - - for label, props in labels_dict.items(): - if label in self._dm_labels: - try: - self._validate_model_labels_json_element(labels_dict[label]) - except Exception as e: - raise Exception(f"Label '{label}' has an unexpected error") from e - - model_label_indices.append(self._dm_labels[label]) - tpv_label_values.append(props["TPV"]) - fpv_label_values.append(props["FPV"]) - - if not all([model_label_indices, tpv_label_values, fpv_label_values]): - raise ValueError(f"No valid label values found in {labels_path}.") - - # Create masks to apply predictions only to known classes - mask = torch.zeros(self._num_of_labels, dtype=torch.bool, device=self._device) - mask[torch.tensor(model_label_indices, device=self._device)] = True - - tpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self._device) - fpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self._device) - - tpv_tensor[mask] = torch.tensor( - tpv_label_values, dtype=torch.float, device=self._device - ) - fpv_tensor[mask] = torch.tensor( - fpv_label_values, dtype=torch.float, device=self._device - ) - - self._num_models_per_label += mask - return {"mask": mask, "tpv_tensor": tpv_tensor, "fpv_tensor": fpv_tensor} - - @staticmethod - def _load_model_labels(labels_path: str) -> dict[str, dict[str, float]]: - """ - Loads a JSON label file for a model. - - Args: - labels_path (str): Path to the JSON file. - - Returns: - Dict[str, Dict[str, float]]: Parsed label confidence data. - - Raises: - FileNotFoundError: If the file is missing. - TypeError: If the file is not a JSON. - """ - if not os.path.exists(labels_path): - raise FileNotFoundError(f"{labels_path} does not exist.") - if not labels_path.endswith(".json"): - raise TypeError(f"{labels_path} is not a JSON file.") - with open(labels_path, "r") as f: - return json.load(f) - - @staticmethod - def _validate_model_labels_json_element(label_dict: dict[str, str]) -> None: - """ - Validates a label confidence dictionary to ensure required keys and values are valid. - - Args: - label_dict (Dict[str, Any]): Label data with TPV and FPV keys. - - Raises: - AttributeError: If required keys are missing. - ValueError: If values are not valid floats or are negative. - """ - for key in ["TPV", "FPV"]: - if key not in label_dict: - raise AttributeError(f"Missing key '{key}' in label dict.") - try: - value = float(label_dict[key]) - if value < 0: - raise ValueError(f"'{key}' must be non-negative but got {value}") - except (TypeError, ValueError) as e: - raise ValueError(f"Invalid value for '{key}': {label_dict[key]}") from e From 8d8a748425637eba38930325968621acdc2a6e44 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 22 May 2025 20:37:49 +0200 Subject: [PATCH 45/78] move wrappers to ensemble --- chebai/{wrappers => ensemble/_wrappers}/__init__.py | 0 chebai/{wrappers => ensemble/_wrappers}/_base.py | 2 +- chebai/{wrappers => ensemble/_wrappers}/_chemlog.py | 0 chebai/{wrappers => ensemble/_wrappers}/_gnn.py | 2 +- .../{wrappers => ensemble/_wrappers}/_neural_network.py | 2 +- chebai/wrappers/_constants.py | 8 -------- 6 files changed, 3 insertions(+), 11 deletions(-) rename chebai/{wrappers => ensemble/_wrappers}/__init__.py (100%) rename chebai/{wrappers => ensemble/_wrappers}/_base.py (98%) rename chebai/{wrappers => ensemble/_wrappers}/_chemlog.py (100%) rename chebai/{wrappers => ensemble/_wrappers}/_gnn.py (98%) rename chebai/{wrappers => ensemble/_wrappers}/_neural_network.py (98%) delete mode 100644 chebai/wrappers/_constants.py diff --git a/chebai/wrappers/__init__.py b/chebai/ensemble/_wrappers/__init__.py similarity index 100% rename from chebai/wrappers/__init__.py rename to chebai/ensemble/_wrappers/__init__.py diff --git a/chebai/wrappers/_base.py b/chebai/ensemble/_wrappers/_base.py similarity index 98% rename from chebai/wrappers/_base.py rename to chebai/ensemble/_wrappers/_base.py index 83898b38..18c2d70b 100644 --- a/chebai/wrappers/_base.py +++ b/chebai/ensemble/_wrappers/_base.py @@ -6,7 +6,7 @@ import torch -from ._constants import MODEL_CLS_PATH, MODEL_LBL_PATH, READER_CLS_PATH +from .._constants import MODEL_CLS_PATH, MODEL_LBL_PATH class BaseWrapper(ABC): diff --git a/chebai/wrappers/_chemlog.py b/chebai/ensemble/_wrappers/_chemlog.py similarity index 100% rename from chebai/wrappers/_chemlog.py rename to chebai/ensemble/_wrappers/_chemlog.py diff --git a/chebai/wrappers/_gnn.py b/chebai/ensemble/_wrappers/_gnn.py similarity index 98% rename from chebai/wrappers/_gnn.py rename to chebai/ensemble/_wrappers/_gnn.py index 6035879b..3b777409 100644 --- a/chebai/wrappers/_gnn.py +++ b/chebai/ensemble/_wrappers/_gnn.py @@ -52,7 +52,7 @@ def __init__( ) def _read_smiles(self, smiles): - d = self.reader.to_data(dict(features=smiles, labels=None)) + d = self._reader.to_data(dict(features=smiles, labels=None)) geom_data = d["features"] assert isinstance(geom_data, GeomData), "" edge_attr = geom_data.edge_attr diff --git a/chebai/wrappers/_neural_network.py b/chebai/ensemble/_wrappers/_neural_network.py similarity index 98% rename from chebai/wrappers/_neural_network.py rename to chebai/ensemble/_wrappers/_neural_network.py index 052798f8..30246091 100644 --- a/chebai/wrappers/_neural_network.py +++ b/chebai/ensemble/_wrappers/_neural_network.py @@ -7,8 +7,8 @@ from chebai.models import ChebaiBaseNet from chebai.preprocessing.reader import DataReader +from .._constants import MODEL_CKPT_PATH, READER_CLS_PATH, READER_KWARGS from ._base import BaseWrapper -from ._constants import MODEL_CKPT_PATH, READER_CLS_PATH, READER_KWARGS class NNWrapper(BaseWrapper): diff --git a/chebai/wrappers/_constants.py b/chebai/wrappers/_constants.py deleted file mode 100644 index effc017a..00000000 --- a/chebai/wrappers/_constants.py +++ /dev/null @@ -1,8 +0,0 @@ -MODEL_CLS_PATH = "model_class_path" -MODEL_LBL_PATH = "model_labels_path" -MODEL_CKPT_PATH = "model_ckpt_path" - -WRAPPER_CLS_PATH = "wrapper_class_path" - -READER_CLS_PATH = "reader_class_path" -READER_KWARGS = "reader_kwargs" From 00bd478e89abc65e35012a3c8c44ce5372490478 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 22 May 2025 20:53:12 +0200 Subject: [PATCH 46/78] nn validate model config --- chebai/ensemble/_wrappers/_neural_network.py | 32 ++++++++++++++++++-- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/chebai/ensemble/_wrappers/_neural_network.py b/chebai/ensemble/_wrappers/_neural_network.py index 30246091..8ff73e1f 100644 --- a/chebai/ensemble/_wrappers/_neural_network.py +++ b/chebai/ensemble/_wrappers/_neural_network.py @@ -15,10 +15,9 @@ class NNWrapper(BaseWrapper): def __init__( self, - reader_cls: Type[DataReader], - reader_kwargs: Optional[dict] = None, **kwargs, ): + self._validate_model_configs(**kwargs) super().__init__(**kwargs) self._model_ckpt_path = self._model_config[MODEL_CKPT_PATH] @@ -29,10 +28,37 @@ def __init__( else dict() ) - self._reader = self._load_class(self._reader_class_path)(**self._reader_kwargs) + reader_cls: Type[DataReader] = self._load_class(self._reader_class_path) + assert issubclass(reader_cls, DataReader), "" + self._reader = reader_cls(**self._reader_kwargs) self._collator = reader_cls.COLLATOR() self._model: ChebaiBaseNet = self._load_model_() + @classmethod + def _validate_model_configs( + cls, model_config: dict[str, str], model_name: str + ) -> None: + """ + Validates model configuration dictionary for required keys and uniqueness. + + Args: + model_configs (Dict[str, Dict[str, Any]]): Model configuration dictionary. + + Raises: + AttributeError: If any model config is missing required keys. + ValueError: If duplicate paths are found for model checkpoint, class, or labels. + """ + required_keys = { + MODEL_CKPT_PATH, + READER_CLS_PATH, + } + + missing_keys = required_keys - model_config.keys() + if missing_keys: + raise AttributeError( + f"Missing keys {missing_keys} in model '{model_name}' configuration." + ) + def _load_model_(self) -> ChebaiBaseNet: """ Loads a model checkpoint and its label-related properties. From 4f35007e406d89077fd0c66f7637c70f46e640ec Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 22 May 2025 21:01:54 +0200 Subject: [PATCH 47/78] utility for loading class --- .../ensemble/_scripts/_ensemble_run_script.py | 23 +++---------------- chebai/ensemble/_utils.py | 7 ++++++ chebai/ensemble/_wrappers/_base.py | 7 ------ chebai/ensemble/_wrappers/_neural_network.py | 8 +++---- 4 files changed, 13 insertions(+), 32 deletions(-) create mode 100644 chebai/ensemble/_utils.py diff --git a/chebai/ensemble/_scripts/_ensemble_run_script.py b/chebai/ensemble/_scripts/_ensemble_run_script.py index 045c2a53..0ff65787 100644 --- a/chebai/ensemble/_scripts/_ensemble_run_script.py +++ b/chebai/ensemble/_scripts/_ensemble_run_script.py @@ -1,28 +1,11 @@ -import importlib from typing import Any, Dict, Type import yaml from jsonargparse import ArgumentParser -from ._base import EnsembleBase +from chebai.ensemble._utils import _load_class - -def load_class(class_path: str) -> Type[EnsembleBase]: - """ - Dynamically imports and returns a class from a full dotted path. - - Args: - class_path (str): Full module path to the class (e.g., 'my_package.module.MyClass'). - - Returns: - Type[EnsembleBase]: The imported class object. - - Raises: - ModuleNotFoundError, AttributeError: If module or class cannot be loaded. - """ - module_path, class_name = class_path.rsplit(".", 1) - module = importlib.import_module(module_path) - return getattr(module, class_name) +from .._base import EnsembleBase def load_config_and_instantiate(config_path: str) -> EnsembleBase: @@ -44,7 +27,7 @@ def load_config_and_instantiate(config_path: str) -> EnsembleBase: class_path: str = config["class_path"] init_args: Dict[str, Any] = config.get("init_args", {}) - cls = load_class(class_path) + cls = _load_class(class_path) if not issubclass(cls, EnsembleBase): raise TypeError(f"{cls} must be subclass of EnsembleBase") diff --git a/chebai/ensemble/_utils.py b/chebai/ensemble/_utils.py new file mode 100644 index 00000000..c9e273bf --- /dev/null +++ b/chebai/ensemble/_utils.py @@ -0,0 +1,7 @@ +import importlib + + +def _load_class(class_path): + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) diff --git a/chebai/ensemble/_wrappers/_base.py b/chebai/ensemble/_wrappers/_base.py index 18c2d70b..4410ed1c 100644 --- a/chebai/ensemble/_wrappers/_base.py +++ b/chebai/ensemble/_wrappers/_base.py @@ -136,10 +136,3 @@ def _predict_from_list_of_smiles(self, smiles_list: list) -> dict: @abstractmethod def _predict_from_data_file(self, data_file_path: str) -> dict: pass - - @staticmethod - def _load_class(class_path): - class_name = class_path.split(".")[-1] - module_path = ".".join(class_path.split(".")[:-1]) - module = importlib.import_module(module_path) - return getattr(module, class_name) diff --git a/chebai/ensemble/_wrappers/_neural_network.py b/chebai/ensemble/_wrappers/_neural_network.py index 8ff73e1f..2549c6b7 100644 --- a/chebai/ensemble/_wrappers/_neural_network.py +++ b/chebai/ensemble/_wrappers/_neural_network.py @@ -8,15 +8,13 @@ from chebai.preprocessing.reader import DataReader from .._constants import MODEL_CKPT_PATH, READER_CLS_PATH, READER_KWARGS +from .._utils import _load_class from ._base import BaseWrapper class NNWrapper(BaseWrapper): - def __init__( - self, - **kwargs, - ): + def __init__(self, **kwargs): self._validate_model_configs(**kwargs) super().__init__(**kwargs) @@ -28,7 +26,7 @@ def __init__( else dict() ) - reader_cls: Type[DataReader] = self._load_class(self._reader_class_path) + reader_cls: Type[DataReader] = _load_class(self._reader_class_path) assert issubclass(reader_cls, DataReader), "" self._reader = reader_cls(**self._reader_kwargs) self._collator = reader_cls.COLLATOR() From a1a70eb96a135542a62e35387a0b67467e8e075c Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 22 May 2025 21:19:28 +0200 Subject: [PATCH 48/78] Create _constants.py --- chebai/ensemble/_constants.py | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 chebai/ensemble/_constants.py diff --git a/chebai/ensemble/_constants.py b/chebai/ensemble/_constants.py new file mode 100644 index 00000000..effc017a --- /dev/null +++ b/chebai/ensemble/_constants.py @@ -0,0 +1,8 @@ +MODEL_CLS_PATH = "model_class_path" +MODEL_LBL_PATH = "model_labels_path" +MODEL_CKPT_PATH = "model_ckpt_path" + +WRAPPER_CLS_PATH = "wrapper_class_path" + +READER_CLS_PATH = "reader_class_path" +READER_KWARGS = "reader_kwargs" From f812cd7e3b1ee4ca45bc813d74d1f4c915ce12a0 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 23 May 2025 00:16:36 +0200 Subject: [PATCH 49/78] update controller for wrapper --- chebai/ensemble/_controller.py | 31 +++++++++++++++++++-------- chebai/ensemble/_wrappers/__init__.py | 3 ++- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/chebai/ensemble/_controller.py b/chebai/ensemble/_controller.py index 2868ace4..8ae77cee 100644 --- a/chebai/ensemble/_controller.py +++ b/chebai/ensemble/_controller.py @@ -4,12 +4,16 @@ from typing import Any, Deque, Dict import torch +from lightning import LightningModule from torch import Tensor from chebai.models import ChebaiBaseNet from chebai.preprocessing.collate import RaggedCollator from ._base import EnsembleBase +from ._constants import WRAPPER_CLS_PATH +from ._utils import _load_class +from ._wrappers import BaseWrapper class _Controller(EnsembleBase, ABC): @@ -30,10 +34,7 @@ def __init__(self, **kwargs: Any): **kwargs (Any): Keyword arguments passed to the EnsembleBase initializer. """ super().__init__(**kwargs) - self._collator = RaggedCollator() - - self.input_dim = len(self._collated_data.x[0]) - self._total_data_size: int = len(self._collated_data) + self._kwargs = kwargs def _get_pred_conf_from_model_output( self, model_output: Dict[str, Tensor], model_label_mask: Tensor @@ -60,6 +61,20 @@ def _get_pred_conf_from_model_output( confidence[:, model_label_mask] = 2 * torch.abs(sigmoid_logits - 0.5) return {"prediction": prediction, "confidence": confidence} + def _wrap_model(self, model_name: str) -> BaseWrapper: + model_config = self._model_configs[model_name] + wrp_cls = _load_class(model_config[WRAPPER_CLS_PATH]) + assert issubclass(wrp_cls, BaseWrapper), "" + wrapped_model = wrp_cls( + model_name=model_name, + model_config=model_config, + dm_labels=self._dm_labels, + **self._kwargs + ) + assert isinstance(wrapped_model, BaseWrapper), "" + # del wrapped_model # Model can be huge to keep it in memory, delete as no longer needed + return wrapped_model + class NoActivationCondition(_Controller): """ @@ -76,11 +91,9 @@ def __init__(self, **kwargs: Any): **kwargs (Any): Keyword arguments passed to the _Controller initializer. """ super().__init__(**kwargs) - self._model_queue: Deque[str] = deque(list(self.model_configs.keys())) + self._model_queue: Deque[str] = deque(list(self._model_configs.keys())) - def _controller( - self, model: ChebaiBaseNet, model_props: Dict[str, Tensor], **kwargs: Any - ) -> Dict[str, Tensor]: + def _controller(self, model_name, **kwargs: Any) -> Dict[str, Tensor]: """ Performs inference with the model and extracts predictions and confidence values. @@ -91,5 +104,5 @@ def _controller( Returns: Dict[str, Tensor]: Dictionary containing predictions and confidence scores. """ - model_output = self._forward_pass(model) + wrapped_model = self._wrap_model(model_name) return self._get_pred_conf_from_model_output(model_output, model_props["mask"]) diff --git a/chebai/ensemble/_wrappers/__init__.py b/chebai/ensemble/_wrappers/__init__.py index 62f836cd..4c4bac6d 100644 --- a/chebai/ensemble/_wrappers/__init__.py +++ b/chebai/ensemble/_wrappers/__init__.py @@ -1,3 +1,4 @@ +from ._base import BaseWrapper from ._neural_network import NNWrapper -__all__ = ["NNWrapper"] +__all__ = ["NNWrapper", "BaseWrapper"] From c48bfd23631a96f403bfbaf87285842fa000e3f8 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 23 May 2025 00:17:06 +0200 Subject: [PATCH 50/78] update base for wrapper --- chebai/ensemble/_base.py | 61 ++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/chebai/ensemble/_base.py b/chebai/ensemble/_base.py index 25904ef2..7f960dc0 100644 --- a/chebai/ensemble/_base.py +++ b/chebai/ensemble/_base.py @@ -7,13 +7,20 @@ import torch from lightning import LightningModule -from lightning_utilities.core.rank_zero import rank_zero_info from chebai.models import ChebaiBaseNet from chebai.preprocessing.structures import XYData from chebai.result.classification import print_metrics -from ._constants import * +from ._constants import ( + MODEL_CKPT_PATH, + MODEL_CLS_PATH, + MODEL_LBL_PATH, + READER_CLS_PATH, + WRAPPER_CLS_PATH, +) +from ._utils import _load_class +from ._wrappers import BaseWrapper class EnsembleBase(ABC): @@ -41,18 +48,17 @@ def __init__( if bool(kwargs.get("_validate_configs", True)): self._validate_model_configs(model_configs) - self.model_configs: Dict[str, Dict[str, Any]] = model_configs - self.data_processed_dir_main: str = data_processed_dir_main - self.input_dim: Optional[int] = kwargs.get("input_dim", None) + self._model_configs: Dict[str, Dict[str, Any]] = model_configs + self._data_processed_dir_main: str = data_processed_dir_main + self._input_dim: Optional[int] = kwargs.get("input_dim", None) + self._total_data_size: int = len(self._collated_data) self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self._num_of_labels: Optional[int] = ( - None # will be set by `_load_data_module_labels` method - ) + self._models: Dict[str, LightningModule] = {} - self._dm_labels: Dict[str, int] = {} + self._dm_labels: Dict[str, int] = self._load_data_module_labels() + self._num_of_labels: int = len(self._dm_labels) - self._load_data_module_labels() self._num_models_per_label: torch.Tensor = torch.zeros( 1, self._num_of_labels, device=self._device ) @@ -72,13 +78,11 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> No AttributeError: If any model config is missing required keys. ValueError: If duplicate paths are found for model checkpoint, class, or labels. """ - path_set, class_set, labels_set = set(), set(), set() + class_set, labels_set = set(), set() required_keys = { - MODEL_CKPT_PATH, MODEL_CLS_PATH, MODEL_LBL_PATH, WRAPPER_CLS_PATH, - READER_CLS_PATH, } for model_name, config in model_configs.items(): @@ -88,14 +92,11 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> No f"Missing keys {missing_keys} in model '{model_name}' configuration." ) - model_ckpt_path, model_class_path, model_labels_path = ( - config[MODEL_CKPT_PATH], + model_class_path, model_labels_path = ( config[MODEL_CLS_PATH], config[MODEL_LBL_PATH], ) - if model_ckpt_path in path_set: - raise ValueError(f"Duplicate model path detected: '{model_ckpt_path}'.") if model_class_path in class_set: raise ValueError( f"Duplicate class path detected: '{model_class_path}'." @@ -103,29 +104,29 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> No if model_labels_path in labels_set: raise ValueError(f"Duplicate labels path: {model_labels_path}.") - path_set.add(model_ckpt_path) class_set.add(model_class_path) labels_set.add(model_labels_path) - def _load_data_module_labels(self) -> None: + def _load_data_module_labels(self) -> dict[str, int]: """ Loads class labels from the classes.txt file and sets internal label mapping. Raises: FileNotFoundError: If the expected classes.txt file is not found. """ - classes_txt_file = os.path.join(self.data_processed_dir_main, "classes.txt") - rank_zero_info(f"Loading {classes_txt_file} ....") + classes_txt_file = os.path.join(self._data_processed_dir_main, "classes.txt") + print(f"Loading {classes_txt_file} ....") if not os.path.exists(classes_txt_file): raise FileNotFoundError(f"{classes_txt_file} does not exist") + dm_labels_dict = {} with open(classes_txt_file, "r") as f: for line in f: label = line.strip() - if label not in self._dm_labels: - self._dm_labels[label] = len(self._dm_labels) - self._num_of_labels = len(self._dm_labels) + if label not in dm_labels_dict: + dm_labels_dict[label] = len(dm_labels_dict) + return dm_labels_dict def run_ensemble(self) -> None: """ @@ -140,14 +141,12 @@ def run_ensemble(self) -> None: while self._model_queue: model_name = self._model_queue.popleft() - rank_zero_info(f"Processing model: {model_name}") - model, model_props = self._load_model_and_its_props(model_name) + print(f"Processing model: {model_name}") - rank_zero_info("\t Passing model to controller to generate predictions...") - pred_conf_dict = self._controller(model, model_props) - del model # Model can be huge to keep it in memory, delete as no longer needed + print("\t Passing model to controller to generate predictions...") + pred_conf_dict, model_props = self._controller(model_name) - rank_zero_info("\t Passing predictions to consolidator for aggregation...") + print("\t Passing predictions to consolidator for aggregation...") self._consolidator( pred_conf_dict, model_props, @@ -155,7 +154,7 @@ def run_ensemble(self) -> None: false_scores=false_scores, ) - rank_zero_info(f"Consolidating predictions for {self.__class__.__name__}") + print(f"Consolidating predictions for {self.__class__.__name__}") final_preds = self._consolidate_on_finish( true_scores=true_scores, false_scores=false_scores ) From bf3cf640fe921fb6a581c44adf9e5692e41adeb5 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 1 Jun 2025 13:21:07 +0200 Subject: [PATCH 51/78] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index db40440e..da3bbbcb 100644 --- a/.gitignore +++ b/.gitignore @@ -169,3 +169,4 @@ cython_debug/ electra_pretrained.ckpt /lightning_logs .isort.cfg +/.vscode From 76d8a79c8e62b779bbf9af088f432a5847d7f574 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 1 Jun 2025 15:30:56 +0200 Subject: [PATCH 52/78] predict method implementation for data file and list of smiles --- chebai/ensemble/_base.py | 87 ++++++++++++-------- chebai/ensemble/_controller.py | 13 +-- chebai/ensemble/_wrappers/_base.py | 44 ++++------ chebai/ensemble/_wrappers/_neural_network.py | 29 ++++--- 4 files changed, 93 insertions(+), 80 deletions(-) diff --git a/chebai/ensemble/_base.py b/chebai/ensemble/_base.py index 7f960dc0..5841c937 100644 --- a/chebai/ensemble/_base.py +++ b/chebai/ensemble/_base.py @@ -1,26 +1,15 @@ -import importlib -import json -import os from abc import ABC, abstractmethod from collections import deque -from typing import Any, Deque, Dict, Optional, Tuple +from pathlib import Path +from typing import Any, Deque, Dict, Optional +import pandas as pd import torch from lightning import LightningModule -from chebai.models import ChebaiBaseNet -from chebai.preprocessing.structures import XYData from chebai.result.classification import print_metrics -from ._constants import ( - MODEL_CKPT_PATH, - MODEL_CLS_PATH, - MODEL_LBL_PATH, - READER_CLS_PATH, - WRAPPER_CLS_PATH, -) -from ._utils import _load_class -from ._wrappers import BaseWrapper +from ._constants import MODEL_CLS_PATH, MODEL_LBL_PATH, WRAPPER_CLS_PATH class EnsembleBase(ABC): @@ -33,7 +22,8 @@ class EnsembleBase(ABC): def __init__( self, model_configs: Dict[str, Dict[str, Any]], - data_processed_dir_main: str, + data_file_path: str, + classes_file_path: str, **kwargs: Any, ) -> None: """ @@ -41,7 +31,7 @@ def __init__( Args: model_configs (Dict[str, Dict[str, Any]]): Dictionary of model configurations. - data_processed_dir_main (str): Path to the processed data directory. + data_file_path (str): Path to the processed data directory. reader_dir_name (str): Name of the directory used by the reader. Defaults to 'smiles_token'. **kwargs (Any): Additional arguments, such as 'input_dim' and '_validate_configs'. """ @@ -49,22 +39,28 @@ def __init__( self._validate_model_configs(model_configs) self._model_configs: Dict[str, Dict[str, Any]] = model_configs - self._data_processed_dir_main: str = data_processed_dir_main + self._data_file_path: str = data_file_path + self._classes_file_path: str = classes_file_path self._input_dim: Optional[int] = kwargs.get("input_dim", None) - self._total_data_size: int = len(self._collated_data) + self._total_data_size: int = None + self._ensemble_input: list[str] | Path = self._process_input_to_ensemble( + data_file_path + ) + print(f"Total data size (data.pkl) is {self._total_data_size}") self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._models: Dict[str, LightningModule] = {} - self._dm_labels: Dict[str, int] = self._load_data_module_labels() + self._dm_labels: Dict[str, int] = self._load_data_module_labels( + classes_file_path + ) self._num_of_labels: int = len(self._dm_labels) + print(f"Number of labes for this data is {self._num_of_labels} ") self._num_models_per_label: torch.Tensor = torch.zeros( 1, self._num_of_labels, device=self._device ) self._model_queue: Deque[str] = deque() - self._collated_data: Optional[XYData] = None - self._total_data_size: Optional[int] = None @classmethod def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> None: @@ -107,21 +103,43 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> No class_set.add(model_class_path) labels_set.add(model_labels_path) - def _load_data_module_labels(self) -> dict[str, int]: + def _process_input_to_ensemble(self, path: str): + p = Path(path) + if p.is_file(): + smiles_list = [] + with open(p, "r") as f: + for line in f: + # Skip empty or whitespace-only lines + if line.strip(): + # Split on whitespace and take the first item as the SMILES + smiles = line.strip().split()[0] + smiles_list.append(smiles) + self._total_data_size = len(smiles_list) + return smiles_list + elif p.is_dir(): + data_pkl_path = p / "data.pkl" + if not data_pkl_path.exists(): + raise FileNotFoundError() + self._total_data_size = len(pd.read_pickle(data_pkl_path)) + return p + else: + raise "Invalid path" + + @staticmethod + def _load_data_module_labels(classes_file_path: str) -> dict[str, int]: """ Loads class labels from the classes.txt file and sets internal label mapping. Raises: FileNotFoundError: If the expected classes.txt file is not found. """ - classes_txt_file = os.path.join(self._data_processed_dir_main, "classes.txt") - print(f"Loading {classes_txt_file} ....") - - if not os.path.exists(classes_txt_file): - raise FileNotFoundError(f"{classes_txt_file} does not exist") + classes_file_path = Path(classes_file_path) + if not classes_file_path.exists(): + raise FileNotFoundError(f"{classes_file_path} does not exist") + print(f"Loading {classes_file_path} ....") dm_labels_dict = {} - with open(classes_txt_file, "r") as f: + with open(classes_file_path, "r") as f: for line in f: label = line.strip() if label not in dm_labels_dict: @@ -132,6 +150,7 @@ def run_ensemble(self) -> None: """ Executes the full ensemble prediction pipeline, aggregating predictions and printing metrics. """ + assert self._total_data_size is not None and self._num_of_labels is not None true_scores = torch.zeros( self._total_data_size, self._num_of_labels, device=self._device ) @@ -144,12 +163,12 @@ def run_ensemble(self) -> None: print(f"Processing model: {model_name}") print("\t Passing model to controller to generate predictions...") - pred_conf_dict, model_props = self._controller(model_name) + controller_output = self._controller(model_name, self._ensemble_input) print("\t Passing predictions to consolidator for aggregation...") self._consolidator( - pred_conf_dict, - model_props, + pred_conf_dict=controller_output["pred_conf_dict"], + model_props=controller_output["model_props"], true_scores=true_scores, false_scores=false_scores, ) @@ -168,8 +187,8 @@ def run_ensemble(self) -> None: @abstractmethod def _controller( self, - model: LightningModule, - model_props: Dict[str, torch.Tensor], + model_name: str, + model_input: list[str] | Path, **kwargs: Any, ) -> Dict[str, torch.Tensor]: """ diff --git a/chebai/ensemble/_controller.py b/chebai/ensemble/_controller.py index 8ae77cee..adbd1b20 100644 --- a/chebai/ensemble/_controller.py +++ b/chebai/ensemble/_controller.py @@ -1,14 +1,11 @@ -import os.path from abc import ABC from collections import deque from typing import Any, Deque, Dict import torch -from lightning import LightningModule from torch import Tensor from chebai.models import ChebaiBaseNet -from chebai.preprocessing.collate import RaggedCollator from ._base import EnsembleBase from ._constants import WRAPPER_CLS_PATH @@ -72,7 +69,6 @@ def _wrap_model(self, model_name: str) -> BaseWrapper: **self._kwargs ) assert isinstance(wrapped_model, BaseWrapper), "" - # del wrapped_model # Model can be huge to keep it in memory, delete as no longer needed return wrapped_model @@ -93,7 +89,7 @@ def __init__(self, **kwargs: Any): super().__init__(**kwargs) self._model_queue: Deque[str] = deque(list(self._model_configs.keys())) - def _controller(self, model_name, **kwargs: Any) -> Dict[str, Tensor]: + def _controller(self, model_name, model_input, **kwargs: Any) -> Dict[str, Tensor]: """ Performs inference with the model and extracts predictions and confidence values. @@ -105,4 +101,9 @@ def _controller(self, model_name, **kwargs: Any) -> Dict[str, Tensor]: Dict[str, Tensor]: Dictionary containing predictions and confidence scores. """ wrapped_model = self._wrap_model(model_name) - return self._get_pred_conf_from_model_output(model_output, model_props["mask"]) + model_output, model_props = wrapped_model.predict(model_input) + del wrapped_model # Model can be huge to keep it in memory, delete asap as no longer needed + pred_conf_dict = self._get_pred_conf_from_model_output( + model_output, model_props["mask"] + ) + return {"pred_conf_dict": pred_conf_dict, "model_props": model_props} diff --git a/chebai/ensemble/_wrappers/_base.py b/chebai/ensemble/_wrappers/_base.py index 4410ed1c..c7ad1726 100644 --- a/chebai/ensemble/_wrappers/_base.py +++ b/chebai/ensemble/_wrappers/_base.py @@ -1,8 +1,7 @@ -import importlib import json import os from abc import ABC, abstractmethod -from typing import overload +from pathlib import Path import torch @@ -22,10 +21,9 @@ def __init__( self._model_name = model_name self._model_class_path = self._model_config[MODEL_CLS_PATH] self._model_labels_path = self._model_config[MODEL_LBL_PATH] - self._dm_labels: dict[str, int] = dm_labels - self._model_props = self._generate_model_label_props() + self._model_props = self._generate_model_label_props(dm_labels=dm_labels) - def _generate_model_label_props(self) -> dict[str, torch.Tensor]: + def _generate_model_label_props(self, dm_labels) -> dict[str, torch.Tensor]: """ Generates label mask and confidence tensors (TPV, FPV) for a model. @@ -38,13 +36,15 @@ def _generate_model_label_props(self) -> dict[str, torch.Tensor]: model_label_indices, tpv_label_values, fpv_label_values = [], [], [] for label, props in labels_dict.items(): - if label in self._dm_labels: + if label in dm_labels: try: self._validate_model_labels_json_element(labels_dict[label]) except Exception as e: - raise Exception(f"Label '{label}' has an unexpected error") from e + raise Exception( + f"Label '{label}' has an unexpected error \n Error: {e}" + ) - model_label_indices.append(self._dm_labels[label]) + model_label_indices.append(dm_labels[label]) tpv_label_values.append(props["TPV"]) fpv_label_values.append(props["FPV"]) @@ -54,7 +54,7 @@ def _generate_model_label_props(self) -> dict[str, torch.Tensor]: ) # Create masks to apply predictions only to known classes - mask = torch.zeros(len(self._dm_labels), dtype=torch.bool, device=self._device) + mask = torch.zeros(len(dm_labels), dtype=torch.bool, device=self._device) mask[torch.tensor(model_label_indices, device=self._device)] = True tpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self._device) @@ -113,26 +113,14 @@ def _validate_model_labels_json_element(label_dict: dict[str, float]) -> None: def name(self): return f"Wrapper({self.__class__.__name__}) for model: {self._model_name}" - @overload - def predict(self, smiles_list: list) -> tuple[dict, dict]: - pass - - @overload - def predict(self, data_file_path: str) -> tuple[dict, dict]: - pass - - def predict(self, x: list | str) -> tuple[dict, dict]: - if isinstance(x, list): - return self._predict_from_list_of_smiles(x), self._model_props - elif isinstance(x, str): - return self._predict_from_data_file(x), self._model_props - else: - raise TypeError(f"Type {type(x)} is not supported.") + def predict(self, x: list) -> tuple[dict, dict]: + return self._predict_from_list_of_smiles(x), self._model_props @abstractmethod - def _predict_from_list_of_smiles(self, smiles_list: list) -> dict: - pass + def _predict_from_list_of_smiles(self, smiles_list: list) -> dict: ... + + def evaluate(self, data_processed_dir_main: Path) -> tuple[dict, dict]: + return self._evaluate_from_data_file(data_processed_dir_main), self._model_props @abstractmethod - def _predict_from_data_file(self, data_file_path: str) -> dict: - pass + def _evaluate_from_data_file(self, data_file_path: str) -> dict: ... diff --git a/chebai/ensemble/_wrappers/_neural_network.py b/chebai/ensemble/_wrappers/_neural_network.py index 2549c6b7..29c0b64b 100644 --- a/chebai/ensemble/_wrappers/_neural_network.py +++ b/chebai/ensemble/_wrappers/_neural_network.py @@ -1,5 +1,5 @@ import os -from typing import Optional, Type +from typing import Type import torch from rdkit import Chem @@ -15,7 +15,9 @@ class NNWrapper(BaseWrapper): def __init__(self, **kwargs): - self._validate_model_configs(**kwargs) + self._validate_model_configs( + model_config=kwargs["model_config"], model_name=kwargs["model_name"] + ) super().__init__(**kwargs) self._model_ckpt_path = self._model_config[MODEL_CKPT_PATH] @@ -30,11 +32,15 @@ def __init__(self, **kwargs): assert issubclass(reader_cls, DataReader), "" self._reader = reader_cls(**self._reader_kwargs) self._collator = reader_cls.COLLATOR() - self._model: ChebaiBaseNet = self._load_model_() + self._model: ChebaiBaseNet = self._load_model_( + input_dim=kwargs.get("input_dim", None) + ) @classmethod def _validate_model_configs( - cls, model_config: dict[str, str], model_name: str + cls, + model_config: dict[str, str], + model_name: str, ) -> None: """ Validates model configuration dictionary for required keys and uniqueness. @@ -57,12 +63,12 @@ def _validate_model_configs( f"Missing keys {missing_keys} in model '{model_name}' configuration." ) - def _load_model_(self) -> ChebaiBaseNet: + def _load_model_(self, input_dim: int | None) -> ChebaiBaseNet: """ Loads a model checkpoint and its label-related properties. Args: - model_name (str): Name of the model to load. + input_dim (int): Name of the model to load. Returns: Tuple[LightningModule, Dict[str, torch.Tensor]]: The model and its label properties. @@ -73,22 +79,21 @@ def _load_model_(self) -> ChebaiBaseNet: f"Model path '{self._model_ckpt_path}' for '{self._model_name}' does not exist." ) - lightning_cls = self._load_class(self._model_class_path) + lightning_cls = _load_class(self._model_class_path) assert isinstance(lightning_cls, type), f"{lightning_cls} is not a class." assert issubclass( lightning_cls, ChebaiBaseNet ), f"{lightning_cls} must inherit from ChebaiBaseNet" - try: model = lightning_cls.load_from_checkpoint( - self._model_ckpt_path, input_dim=self.input_dim + self._model_ckpt_path, input_dim=5 ) - model.eval() - model.freeze() except Exception as e: - raise RuntimeError(f"Error loading model {self._model_name}") from e + raise RuntimeError(f"Error loading model {self._model_name} \n Error: {e}") + model.eval() + model.freeze() return model def _predict_from_list_of_smiles(self, smiles_list) -> list: From 95d49c1f65e735af1fc0a6b6d69a022d31f91f73 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 1 Jun 2025 16:29:40 +0200 Subject: [PATCH 53/78] seperate method for evaluate and prediction --- chebai/ensemble/_base.py | 73 ++++++++++++-------- chebai/ensemble/_constants.py | 4 ++ chebai/ensemble/_controller.py | 41 ++++++++--- chebai/ensemble/_wrappers/_base.py | 13 +++- chebai/ensemble/_wrappers/_neural_network.py | 18 +++-- 5 files changed, 103 insertions(+), 46 deletions(-) diff --git a/chebai/ensemble/_base.py b/chebai/ensemble/_base.py index 5841c937..b1b94a61 100644 --- a/chebai/ensemble/_base.py +++ b/chebai/ensemble/_base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from collections import deque from pathlib import Path -from typing import Any, Deque, Dict, Optional +from typing import Any, Deque, Dict, Literal, Optional import pandas as pd import torch @@ -9,7 +9,13 @@ from chebai.result.classification import print_metrics -from ._constants import MODEL_CLS_PATH, MODEL_LBL_PATH, WRAPPER_CLS_PATH +from ._constants import ( + EVAL_OP, + MODEL_CLS_PATH, + MODEL_LBL_PATH, + PRED_OP, + WRAPPER_CLS_PATH, +) class EnsembleBase(ABC): @@ -22,8 +28,8 @@ class EnsembleBase(ABC): def __init__( self, model_configs: Dict[str, Dict[str, Any]], - data_file_path: str, - classes_file_path: str, + data_processed_dir_main: str, + operation: str = EVAL_OP, **kwargs: Any, ) -> None: """ @@ -31,29 +37,31 @@ def __init__( Args: model_configs (Dict[str, Dict[str, Any]]): Dictionary of model configurations. - data_file_path (str): Path to the processed data directory. + data_processed_dir_main (str): Path to the processed data directory. reader_dir_name (str): Name of the directory used by the reader. Defaults to 'smiles_token'. **kwargs (Any): Additional arguments, such as 'input_dim' and '_validate_configs'. """ - if bool(kwargs.get("_validate_configs", True)): - self._validate_model_configs(model_configs) + if bool(kwargs.get("_perform_validation_checks", True)): + self._perform_validation_checks( + model_configs, operation=operation, **kwargs + ) self._model_configs: Dict[str, Dict[str, Any]] = model_configs - self._data_file_path: str = data_file_path - self._classes_file_path: str = classes_file_path + self._data_processed_dir_main: str = data_processed_dir_main + self._operation: str = operation + print(f"Ensemble operation: {self._operation}") + self._input_dim: Optional[int] = kwargs.get("input_dim", None) self._total_data_size: int = None self._ensemble_input: list[str] | Path = self._process_input_to_ensemble( - data_file_path + **kwargs ) print(f"Total data size (data.pkl) is {self._total_data_size}") self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._models: Dict[str, LightningModule] = {} - self._dm_labels: Dict[str, int] = self._load_data_module_labels( - classes_file_path - ) + self._dm_labels: Dict[str, int] = self._load_data_module_labels() self._num_of_labels: int = len(self._dm_labels) print(f"Number of labes for this data is {self._num_of_labels} ") @@ -63,7 +71,9 @@ def __init__( self._model_queue: Deque[str] = deque() @classmethod - def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> None: + def _perform_validation_checks( + cls, model_configs: Dict[str, Dict[str, Any]], operation, **kwargs + ) -> None: """ Validates model configuration dictionary for required keys and uniqueness. @@ -74,6 +84,19 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> No AttributeError: If any model config is missing required keys. ValueError: If duplicate paths are found for model checkpoint, class, or labels. """ + if operation not in ["evaluate", "predict"]: + raise ValueError( + f"Invalid operation '{operation}'. Must be 'evaluate' or 'predict'." + ) + + if operation == "predict" and not kwargs.get("smiles_list_file_path", None): + raise ValueError( + "For 'predict' operation, 'smiles_list_file_path' must be provided." + ) + + if not Path(kwargs.get("smiles_list_file_path")).exists(): + raise FileNotFoundError(f"{kwargs.get('smiles_list_file_path')}") + class_set, labels_set = set(), set() required_keys = { MODEL_CLS_PATH, @@ -103,9 +126,9 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> No class_set.add(model_class_path) labels_set.add(model_labels_path) - def _process_input_to_ensemble(self, path: str): - p = Path(path) - if p.is_file(): + def _process_input_to_ensemble(self, **kwargs: any) -> list[str] | Path: + if self._operation == PRED_OP: + p = Path(kwargs["smiles_list_file_path"]) smiles_list = [] with open(p, "r") as f: for line in f: @@ -116,24 +139,23 @@ def _process_input_to_ensemble(self, path: str): smiles_list.append(smiles) self._total_data_size = len(smiles_list) return smiles_list - elif p.is_dir(): - data_pkl_path = p / "data.pkl" + elif self._operation == EVAL_OP: + data_pkl_path = Path(self._data_processed_dir_main) / "data.pkl" if not data_pkl_path.exists(): raise FileNotFoundError() self._total_data_size = len(pd.read_pickle(data_pkl_path)) return p else: - raise "Invalid path" + raise ValueError("Invalid operation") - @staticmethod - def _load_data_module_labels(classes_file_path: str) -> dict[str, int]: + def _load_data_module_labels(self) -> dict[str, int]: """ Loads class labels from the classes.txt file and sets internal label mapping. Raises: FileNotFoundError: If the expected classes.txt file is not found. """ - classes_file_path = Path(classes_file_path) + classes_file_path = Path(self._data_processed_dir_main) / "classes.txt" if not classes_file_path.exists(): raise FileNotFoundError(f"{classes_file_path} does not exist") print(f"Loading {classes_file_path} ....") @@ -197,14 +219,13 @@ def _controller( Returns: Dict[str, torch.Tensor]: Predictions or confidence scores. """ - pass @abstractmethod def _consolidator( self, + *, pred_conf_dict: Dict[str, torch.Tensor], model_props: Dict[str, torch.Tensor], - *, true_scores: torch.Tensor, false_scores: torch.Tensor, **kwargs: Any, @@ -214,7 +235,6 @@ def _consolidator( Should update the provided `true_scores` and `false_scores`. """ - pass @abstractmethod def _consolidate_on_finish( @@ -226,4 +246,3 @@ def _consolidate_on_finish( Returns: torch.Tensor: Final aggregated predictions. """ - pass diff --git a/chebai/ensemble/_constants.py b/chebai/ensemble/_constants.py index effc017a..3d3b5db7 100644 --- a/chebai/ensemble/_constants.py +++ b/chebai/ensemble/_constants.py @@ -6,3 +6,7 @@ READER_CLS_PATH = "reader_class_path" READER_KWARGS = "reader_kwargs" + + +PRED_OP = "prediction" +EVAL_OP = "evaluation" diff --git a/chebai/ensemble/_controller.py b/chebai/ensemble/_controller.py index adbd1b20..beaf6e15 100644 --- a/chebai/ensemble/_controller.py +++ b/chebai/ensemble/_controller.py @@ -1,14 +1,12 @@ -from abc import ABC +from abc import ABC, abstractmethod from collections import deque from typing import Any, Deque, Dict import torch from torch import Tensor -from chebai.models import ChebaiBaseNet - from ._base import EnsembleBase -from ._constants import WRAPPER_CLS_PATH +from ._constants import EVAL_OP, PRED_OP, WRAPPER_CLS_PATH from ._utils import _load_class from ._wrappers import BaseWrapper @@ -33,6 +31,30 @@ def __init__(self, **kwargs: Any): super().__init__(**kwargs) self._kwargs = kwargs + @abstractmethod + def _controller(self, model_name, model_input, **kwargs: Any) -> Dict[str, Tensor]: + """ + Performs inference with the model and extracts predictions and confidence values. + + Args: + model (ChebaiBaseNet): The model to perform inference with. + model_props (Dict[str, Tensor]): Dictionary with label mask and trust scores. + + Returns: + Dict[str, Tensor]: Dictionary containing predictions and confidence scores. + """ + wrapped_model = self._wrap_model(model_name) + if self._operation == PRED_OP: + model_output, model_props = wrapped_model.predict(model_input) + else: + model_output, model_props = wrapped_model.evaluate(model_input) + del wrapped_model # Model can be huge to keep it in memory, delete asap as no longer needed + + pred_conf_dict = self._get_pred_conf_from_model_output( + model_output, model_props["mask"] + ) + return {"pred_conf_dict": pred_conf_dict, "model_props": model_props} + def _get_pred_conf_from_model_output( self, model_output: Dict[str, Tensor], model_label_mask: Tensor ) -> Dict[str, Tensor]: @@ -100,10 +122,7 @@ def _controller(self, model_name, model_input, **kwargs: Any) -> Dict[str, Tenso Returns: Dict[str, Tensor]: Dictionary containing predictions and confidence scores. """ - wrapped_model = self._wrap_model(model_name) - model_output, model_props = wrapped_model.predict(model_input) - del wrapped_model # Model can be huge to keep it in memory, delete asap as no longer needed - pred_conf_dict = self._get_pred_conf_from_model_output( - model_output, model_props["mask"] - ) - return {"pred_conf_dict": pred_conf_dict, "model_props": model_props} + + output_dict = super()._controller(model_name, model_input, **kwargs) + # Some activation condition can be applied, not in this controller, so we return the output directly + return output_dict diff --git a/chebai/ensemble/_wrappers/_base.py b/chebai/ensemble/_wrappers/_base.py index c7ad1726..b3254ba7 100644 --- a/chebai/ensemble/_wrappers/_base.py +++ b/chebai/ensemble/_wrappers/_base.py @@ -114,13 +114,22 @@ def name(self): return f"Wrapper({self.__class__.__name__}) for model: {self._model_name}" def predict(self, x: list) -> tuple[dict, dict]: + if not isinstance(x, list): + raise TypeError(f"Input must be a list of SMILES strings, got {type(x)}") return self._predict_from_list_of_smiles(x), self._model_props @abstractmethod def _predict_from_list_of_smiles(self, smiles_list: list) -> dict: ... - def evaluate(self, data_processed_dir_main: Path) -> tuple[dict, dict]: - return self._evaluate_from_data_file(data_processed_dir_main), self._model_props + def evaluate( + self, data_processed_dir_main: Path, **kwargs: any + ) -> tuple[dict, dict]: + if not data_processed_dir_main.is_dir(): + raise NotADirectoryError(f"{data_processed_dir_main} is not a directory.") + return ( + self._evaluate_from_data_file(data_processed_dir_main, **kwargs), + self._model_props, + ) @abstractmethod def _evaluate_from_data_file(self, data_file_path: str) -> dict: ... diff --git a/chebai/ensemble/_wrappers/_neural_network.py b/chebai/ensemble/_wrappers/_neural_network.py index 29c0b64b..ef183a92 100644 --- a/chebai/ensemble/_wrappers/_neural_network.py +++ b/chebai/ensemble/_wrappers/_neural_network.py @@ -1,4 +1,5 @@ import os +from pathlib import Path from typing import Type import torch @@ -90,13 +91,18 @@ def _load_model_(self, input_dim: int | None) -> ChebaiBaseNet: self._model_ckpt_path, input_dim=5 ) except Exception as e: - raise RuntimeError(f"Error loading model {self._model_name} \n Error: {e}") + raise RuntimeError( + f"Error loading model {self._model_name} \n Error: {e}" + ) from e + assert isinstance( + model, ChebaiBaseNet + ), f"{model} is not a ChebaiBaseNet instance." model.eval() model.freeze() return model - def _predict_from_list_of_smiles(self, smiles_list) -> list: + def _predict_from_list_of_smiles(self, smiles_list: list[str]) -> list: token_dicts = [] could_not_parse = [] index_map = dict() @@ -131,16 +137,16 @@ def _read_smiles(self, smiles): return self._reader.to_data(dict(features=smiles, labels=None)) def _forward_pass(self, batch): - processable_data = self._model._process_batch( + processable_data = self._model._process_batch( # noqa self._collator(batch).to(self._device), 0 ) return self._model(processable_data, **processable_data["model_kwargs"]) - def _predict_from_data_file( - self, processed_dir_main: str, data_file_name="data.pt" + def _evaluate_from_data_file( + self, data_processed_dir_main: Path, data_file_name="data.pt" ) -> list: data = torch.load( - os.path.join(processed_dir_main, self._reader.name(), data_file_name), + data_processed_dir_main / self._reader.name() / data_file_name, weights_only=False, map_location=self._device, ) From a20ce7626efd3a3df0eb9fd29de83574e9d6dc0f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 1 Jun 2025 20:45:14 +0200 Subject: [PATCH 54/78] store collated label or any model in instance var --- chebai/ensemble/_base.py | 40 +++++++++++--------- chebai/ensemble/_controller.py | 37 ++++++++---------- chebai/ensemble/_utils.py | 2 +- chebai/ensemble/_wrappers/_base.py | 1 + chebai/ensemble/_wrappers/_neural_network.py | 14 ++++--- 5 files changed, 49 insertions(+), 45 deletions(-) diff --git a/chebai/ensemble/_base.py b/chebai/ensemble/_base.py index b1b94a61..f6764556 100644 --- a/chebai/ensemble/_base.py +++ b/chebai/ensemble/_base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from collections import deque from pathlib import Path -from typing import Any, Deque, Dict, Literal, Optional +from typing import Any, Deque, Dict, Optional import pandas as pd import torch @@ -38,7 +38,6 @@ def __init__( Args: model_configs (Dict[str, Dict[str, Any]]): Dictionary of model configurations. data_processed_dir_main (str): Path to the processed data directory. - reader_dir_name (str): Name of the directory used by the reader. Defaults to 'smiles_token'. **kwargs (Any): Additional arguments, such as 'input_dim' and '_validate_configs'. """ if bool(kwargs.get("_perform_validation_checks", True)): @@ -51,8 +50,8 @@ def __init__( self._operation: str = operation print(f"Ensemble operation: {self._operation}") - self._input_dim: Optional[int] = kwargs.get("input_dim", None) - self._total_data_size: int = None + # These instance variable will be set in method `_process_input_to_ensemble` + self._total_data_size: int | None = None self._ensemble_input: list[str] | Path = self._process_input_to_ensemble( **kwargs ) @@ -60,7 +59,6 @@ def __init__( self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self._models: Dict[str, LightningModule] = {} self._dm_labels: Dict[str, int] = self._load_data_module_labels() self._num_of_labels: int = len(self._dm_labels) print(f"Number of labes for this data is {self._num_of_labels} ") @@ -69,6 +67,7 @@ def __init__( 1, self._num_of_labels, device=self._device ) self._model_queue: Deque[str] = deque() + self._collated_labels: torch.Tensor | None = None @classmethod def _perform_validation_checks( @@ -126,10 +125,10 @@ def _perform_validation_checks( class_set.add(model_class_path) labels_set.add(model_labels_path) - def _process_input_to_ensemble(self, **kwargs: any) -> list[str] | Path: + def _process_input_to_ensemble(self, **kwargs: Any) -> list[str] | Path: if self._operation == PRED_OP: p = Path(kwargs["smiles_list_file_path"]) - smiles_list = [] + smiles_list: list[str] = [] with open(p, "r") as f: for line in f: # Skip empty or whitespace-only lines @@ -140,11 +139,14 @@ def _process_input_to_ensemble(self, **kwargs: any) -> list[str] | Path: self._total_data_size = len(smiles_list) return smiles_list elif self._operation == EVAL_OP: - data_pkl_path = Path(self._data_processed_dir_main) / "data.pkl" + processed_dir_path = Path(self._data_processed_dir_main) + data_pkl_path = processed_dir_path / "data.pkl" if not data_pkl_path.exists(): - raise FileNotFoundError() + raise FileNotFoundError( + f"data.pkl does not exist in the {processed_dir_path} directory" + ) self._total_data_size = len(pd.read_pickle(data_pkl_path)) - return p + return processed_dir_path else: raise ValueError("Invalid operation") @@ -180,6 +182,9 @@ def run_ensemble(self) -> None: self._total_data_size, self._num_of_labels, device=self._device ) + print( + f"Running {self.__class__.__name__} ensemble for {self._operation} operation..." + ) while self._model_queue: model_name = self._model_queue.popleft() print(f"Processing model: {model_name}") @@ -195,16 +200,17 @@ def run_ensemble(self) -> None: false_scores=false_scores, ) - print(f"Consolidating predictions for {self.__class__.__name__}") final_preds = self._consolidate_on_finish( true_scores=true_scores, false_scores=false_scores ) - print_metrics( - final_preds, - self._collated_data.y, - self._device, - classes=list(self._dm_labels.keys()), - ) + + if self._operation == EVAL_OP: + print_metrics( + final_preds, + self._collated_labels, + self._device, + classes=list(self._dm_labels.keys()), + ) @abstractmethod def _controller( diff --git a/chebai/ensemble/_controller.py b/chebai/ensemble/_controller.py index beaf6e15..d6545ede 100644 --- a/chebai/ensemble/_controller.py +++ b/chebai/ensemble/_controller.py @@ -1,5 +1,6 @@ -from abc import ABC, abstractmethod +from abc import ABC from collections import deque +from pathlib import Path from typing import Any, Deque, Dict import torch @@ -7,7 +8,7 @@ from ._base import EnsembleBase from ._constants import EVAL_OP, PRED_OP, WRAPPER_CLS_PATH -from ._utils import _load_class +from ._utils import load_class from ._wrappers import BaseWrapper @@ -30,9 +31,16 @@ def __init__(self, **kwargs: Any): """ super().__init__(**kwargs) self._kwargs = kwargs + # If an activation condition correponding model is added to queue, removed from this set + # This is in order to avoid re-adding models that have already been processed + self._model_key_set: set[str] = set(self._model_configs.keys()) - @abstractmethod - def _controller(self, model_name, model_input, **kwargs: Any) -> Dict[str, Tensor]: + # Labels from any processed data.pt file for any reader + self._collated_labels: torch.Tensor | None = None + + def _controller( + self, model_name: str, model_input: list[str] | Path, **kwargs: Any + ) -> Dict[str, Tensor]: """ Performs inference with the model and extracts predictions and confidence values. @@ -82,7 +90,7 @@ def _get_pred_conf_from_model_output( def _wrap_model(self, model_name: str) -> BaseWrapper: model_config = self._model_configs[model_name] - wrp_cls = _load_class(model_config[WRAPPER_CLS_PATH]) + wrp_cls = load_class(model_config[WRAPPER_CLS_PATH]) assert issubclass(wrp_cls, BaseWrapper), "" wrapped_model = wrp_cls( model_name=model_name, @@ -90,6 +98,9 @@ def _wrap_model(self, model_name: str) -> BaseWrapper: dm_labels=self._dm_labels, **self._kwargs ) + if self._collated_labels is not None and self._operation == EVAL_OP: + self._collated_labels = wrapped_model.collated_labels + assert isinstance(wrapped_model, BaseWrapper), "" return wrapped_model @@ -110,19 +121,3 @@ def __init__(self, **kwargs: Any): """ super().__init__(**kwargs) self._model_queue: Deque[str] = deque(list(self._model_configs.keys())) - - def _controller(self, model_name, model_input, **kwargs: Any) -> Dict[str, Tensor]: - """ - Performs inference with the model and extracts predictions and confidence values. - - Args: - model (ChebaiBaseNet): The model to perform inference with. - model_props (Dict[str, Tensor]): Dictionary with label mask and trust scores. - - Returns: - Dict[str, Tensor]: Dictionary containing predictions and confidence scores. - """ - - output_dict = super()._controller(model_name, model_input, **kwargs) - # Some activation condition can be applied, not in this controller, so we return the output directly - return output_dict diff --git a/chebai/ensemble/_utils.py b/chebai/ensemble/_utils.py index c9e273bf..1da494f1 100644 --- a/chebai/ensemble/_utils.py +++ b/chebai/ensemble/_utils.py @@ -1,7 +1,7 @@ import importlib -def _load_class(class_path): +def load_class(class_path: str) -> type: module_path, class_name = class_path.rsplit(".", 1) module = importlib.import_module(module_path) return getattr(module, class_name) diff --git a/chebai/ensemble/_wrappers/_base.py b/chebai/ensemble/_wrappers/_base.py index b3254ba7..a1703917 100644 --- a/chebai/ensemble/_wrappers/_base.py +++ b/chebai/ensemble/_wrappers/_base.py @@ -22,6 +22,7 @@ def __init__( self._model_class_path = self._model_config[MODEL_CLS_PATH] self._model_labels_path = self._model_config[MODEL_LBL_PATH] self._model_props = self._generate_model_label_props(dm_labels=dm_labels) + self.collated_labels = None def _generate_model_label_props(self, dm_labels) -> dict[str, torch.Tensor]: """ diff --git a/chebai/ensemble/_wrappers/_neural_network.py b/chebai/ensemble/_wrappers/_neural_network.py index ef183a92..982a410c 100644 --- a/chebai/ensemble/_wrappers/_neural_network.py +++ b/chebai/ensemble/_wrappers/_neural_network.py @@ -7,9 +7,10 @@ from chebai.models import ChebaiBaseNet from chebai.preprocessing.reader import DataReader +from chebai.preprocessing.structures import XYData from .._constants import MODEL_CKPT_PATH, READER_CLS_PATH, READER_KWARGS -from .._utils import _load_class +from .._utils import load_class from ._base import BaseWrapper @@ -29,10 +30,11 @@ def __init__(self, **kwargs): else dict() ) - reader_cls: Type[DataReader] = _load_class(self._reader_class_path) + reader_cls: Type[DataReader] = load_class(self._reader_class_path) assert issubclass(reader_cls, DataReader), "" self._reader = reader_cls(**self._reader_kwargs) self._collator = reader_cls.COLLATOR() + self.collated_labels = None self._model: ChebaiBaseNet = self._load_model_( input_dim=kwargs.get("input_dim", None) ) @@ -80,7 +82,7 @@ def _load_model_(self, input_dim: int | None) -> ChebaiBaseNet: f"Model path '{self._model_ckpt_path}' for '{self._model_name}' does not exist." ) - lightning_cls = _load_class(self._model_class_path) + lightning_cls = load_class(self._model_class_path) assert isinstance(lightning_cls, type), f"{lightning_cls} is not a class." assert issubclass( @@ -137,9 +139,9 @@ def _read_smiles(self, smiles): return self._reader.to_data(dict(features=smiles, labels=None)) def _forward_pass(self, batch): - processable_data = self._model._process_batch( # noqa - self._collator(batch).to(self._device), 0 - ) + collated_batch: XYData = self._collator(batch).to(self._device) + self.collated_labels = collated_batch.y + processable_data = self._model._process_batch(collated_batch, 0) # noqa return self._model(processable_data, **processable_data["model_kwargs"]) def _evaluate_from_data_file( From c0cb6c90d0e30ccba7fa49a62f01f52c0511291f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 1 Jun 2025 23:21:56 +0200 Subject: [PATCH 55/78] fix collated labels none error --- chebai/ensemble/__init__.py | 4 +++- chebai/ensemble/_base.py | 3 +++ chebai/ensemble/_controller.py | 12 ++++++++---- chebai/ensemble/_scripts/_ensemble_run_script.py | 9 ++++----- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/chebai/ensemble/__init__.py b/chebai/ensemble/__init__.py index 67570b10..b68b0fd2 100644 --- a/chebai/ensemble/__init__.py +++ b/chebai/ensemble/__init__.py @@ -1,5 +1,7 @@ +from ._base import EnsembleBase from ._consolidator import WeightedMajorityVoting from ._controller import NoActivationCondition +from ._wrappers import NNWrapper class FullEnsembleWMV(NoActivationCondition, WeightedMajorityVoting): @@ -8,4 +10,4 @@ class FullEnsembleWMV(NoActivationCondition, WeightedMajorityVoting): pass -__all__ = ["FullEnsembleWMV"] +__all__ = ["FullEnsembleWMV", "NNWrapper"] diff --git a/chebai/ensemble/_base.py b/chebai/ensemble/_base.py index f6764556..e9abbcdf 100644 --- a/chebai/ensemble/_base.py +++ b/chebai/ensemble/_base.py @@ -205,6 +205,9 @@ def run_ensemble(self) -> None: ) if self._operation == EVAL_OP: + assert ( + self._collated_labels is not None + ), "Collated labels must be set for evaluation operation." print_metrics( final_preds, self._collated_labels, diff --git a/chebai/ensemble/_controller.py b/chebai/ensemble/_controller.py index d6545ede..b9a65a1f 100644 --- a/chebai/ensemble/_controller.py +++ b/chebai/ensemble/_controller.py @@ -7,7 +7,7 @@ from torch import Tensor from ._base import EnsembleBase -from ._constants import EVAL_OP, PRED_OP, WRAPPER_CLS_PATH +from ._constants import PRED_OP, WRAPPER_CLS_PATH from ._utils import load_class from ._wrappers import BaseWrapper @@ -35,7 +35,7 @@ def __init__(self, **kwargs: Any): # This is in order to avoid re-adding models that have already been processed self._model_key_set: set[str] = set(self._model_configs.keys()) - # Labels from any processed data.pt file for any reader + # Labels from any processed `data.pt` file of any reader self._collated_labels: torch.Tensor | None = None def _controller( @@ -56,6 +56,12 @@ def _controller( model_output, model_props = wrapped_model.predict(model_input) else: model_output, model_props = wrapped_model.evaluate(model_input) + if ( + self._collated_labels is None + and wrapped_model.collated_labels is not None + ): + self._collated_labels = wrapped_model.collated_labels + del wrapped_model # Model can be huge to keep it in memory, delete asap as no longer needed pred_conf_dict = self._get_pred_conf_from_model_output( @@ -98,8 +104,6 @@ def _wrap_model(self, model_name: str) -> BaseWrapper: dm_labels=self._dm_labels, **self._kwargs ) - if self._collated_labels is not None and self._operation == EVAL_OP: - self._collated_labels = wrapped_model.collated_labels assert isinstance(wrapped_model, BaseWrapper), "" return wrapped_model diff --git a/chebai/ensemble/_scripts/_ensemble_run_script.py b/chebai/ensemble/_scripts/_ensemble_run_script.py index 0ff65787..3dc6bd45 100644 --- a/chebai/ensemble/_scripts/_ensemble_run_script.py +++ b/chebai/ensemble/_scripts/_ensemble_run_script.py @@ -1,11 +1,10 @@ -from typing import Any, Dict, Type +from typing import Any, Dict import yaml from jsonargparse import ArgumentParser -from chebai.ensemble._utils import _load_class - -from .._base import EnsembleBase +from chebai.ensemble._base import EnsembleBase +from chebai.ensemble._utils import load_class def load_config_and_instantiate(config_path: str) -> EnsembleBase: @@ -27,7 +26,7 @@ def load_config_and_instantiate(config_path: str) -> EnsembleBase: class_path: str = config["class_path"] init_args: Dict[str, Any] = config.get("init_args", {}) - cls = _load_class(class_path) + cls = load_class(class_path) if not issubclass(cls, EnsembleBase): raise TypeError(f"{cls} must be subclass of EnsembleBase") From 9fc5d20865b4c537d1d1db0a73cb0b44c9f34511 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 2 Jun 2025 14:03:01 +0200 Subject: [PATCH 56/78] script to generate classes props --- .../_scripts/_generate_classes_props_json.py | 195 ++++++++++++++++++ 1 file changed, 195 insertions(+) create mode 100644 chebai/ensemble/_scripts/_generate_classes_props_json.py diff --git a/chebai/ensemble/_scripts/_generate_classes_props_json.py b/chebai/ensemble/_scripts/_generate_classes_props_json.py new file mode 100644 index 00000000..7ab65284 --- /dev/null +++ b/chebai/ensemble/_scripts/_generate_classes_props_json.py @@ -0,0 +1,195 @@ +"""Generate TPV/NPV JSON for multi-class classification models.""" + +import json +from pathlib import Path + +import pandas as pd +import torch +from jsonargparse import CLI +from sklearn.metrics import confusion_matrix +from torch.utils.data import DataLoader + +from chebai.ensemble._utils import load_class +from chebai.models.base import ChebaiBaseNet +from chebai.preprocessing.collate import Collator + + +class ClassesPropertiesGenerator: + """ + Computes TPV (Precision/ True Predictive Value) and NPV (Negative Predictive Value) + for each class in a multi-class classification problem using a PyTorch Lightning model. + """ + + @staticmethod + def load_class_labels(path: str) -> list[str]: + """ + Load a list of class names from a .json or .txt file. + + Args: + path (str): Path to class labels file. + + Returns: + list[str]: List of class names. + """ + with open(path) as f: + return [line.strip() for line in f if line.strip()] + + @staticmethod + def compute_tpv_npv( + y_true: list[int], y_pred: list[int], class_names: list[str] + ) -> dict[str, dict[str, float]]: + """ + Compute TPV and NPV for each class using the confusion matrix. + + Args: + y_true (list[int]): Ground truth labels. + y_pred (list[int]): Predicted labels. + class_names (list[str]): List of class names corresponding to class indices. + + Returns: + dict[str, dict[str, float]]: Dictionary with class names as keys and TPV/NPV as values. + """ + cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names)))) + metrics = {} + + for i, cls in enumerate(class_names): + TP = cm[i, i] + FP = cm[:, i].sum() - TP + FN = cm[i, :].sum() - TP + TN = cm.sum() - (TP + FP + FN) + + TPV = TP / (TP + FP) if (TP + FP) > 0 else 0.0 + NPV = TN / (TN + FN) if (TN + FN) > 0 else 0.0 + + metrics[cls] = {"TPV": round(TPV, 4), "NPV": round(NPV, 4)} + + return metrics + + def generate_props( + self, + model_path: str, + model_class_path: str, + splits_path: str, + data_path: str, + classes_file_path: str, + collator_class_path: str, + output_path: str, + batch_size: int = 32, + ) -> None: + """ + Main method to compute TPV/NPV from validation data and save as JSON. + + Args: + model_path (str): Path to the PyTorch Lightning model checkpoint. + model_class_path (str): Full path to the model class to load. + splits_path (str): CSV file with 'id' and 'split' columns. + data_path (str): processed `data.pt` file path. + classes_file_path (str): Path to file containing class names `classes.txt`. + collator_class_path (str): Full path to the collator class. + output_path (str): Output path for the saving JSON file. + batch_size (int): Batch size for inference. + """ + print("Extracting validation data for computation...") + splits_df = pd.read_csv(splits_path) + validation_ids = set(splits_df[splits_df["split"] == "validation"]["id"]) + data_df = pd.DataFrame(torch.load(data_path, weights_only=False)) + val_df = data_df[data_df["ident"].isin(validation_ids)] + + # Load model + print(f"Loading model from {model_path} ...") + model_cls = load_class(model_class_path) + if not issubclass(model_cls, ChebaiBaseNet): + raise TypeError("Loaded model is not a valid LightningModule.") + model = model_cls.load_from_checkpoint(model_path, input_dim=3) + model.freeze() + model.eval() + + # Load collator + collator_cls = load_class(collator_class_path) + if not issubclass(collator_cls, Collator): + raise TypeError(f"{collator_cls} must be subclass of Collator") + collator = collator_cls() + + val_loader = DataLoader( + val_df.to_dict(orient="records"), + collate_fn=collator, + batch_size=batch_size, + shuffle=False, + ) + + print("Running inference on validation data...") + y_true, y_pred = [], [] + for batch_idx, batch in enumerate(val_loader): + data = model._process_batch(batch, batch_idx=batch_idx) + labels = data["labels"] + model_output = model(data, **data.get("model_kwargs", dict())) + sigmoid_logits = torch.sigmoid(model_output["logits"]) + preds = sigmoid_logits > 0.5 + y_pred.extend(preds) + y_true.extend(labels) + + # Compute and save metrics + print("Computing TPV and NPV metrics...") + classes_file_path = Path(classes_file_path) + if output_path is None: + output_path = classes_file_path.parent / "classes.json" + class_names = self.load_class_labels(classes_file_path) + metrics = self.compute_tpv_npv(y_true, y_pred, class_names) + with open(output_path, "w") as f: + json.dump(metrics, f, indent=2) + print(f"Saved TPV/NPV metrics to {output_path}") + + +class Main: + """ + Command-line interface wrapper for the ClassesPropertiesGenerator. + """ + + def generate( + self, + model_path: str, + splits_path: str, + data_path: str, + classes_file_path: str, + model_class_path: str, + collator_class_path: str = "chebai.preprocessing.collate.RaggedCollator", + batch_size: int = 32, + output_path: str = None, # Default path will be the directory of classes_file_path + ) -> None: + """ + Entry point for CLI use. + + Args: + model_path (str): Path to the PyTorch Lightning model checkpoint. + model_class_path (str): Full path to the model class to load. + splits_path (str): CSV file with 'id' and 'split' columns. + data_path (str): processed `data.pt` file path. + classes_file_path (str): Path to file containing class names `classes.txt`. + collator_class_path (str): Full path to the collator class. + output_path (str): Output path for the saving JSON file. + batch_size (int): Batch size for inference. + """ + generator = ClassesPropertiesGenerator() + generator.generate_props( + model_path=model_path, + model_class_path=model_class_path, + splits_path=splits_path, + data_path=data_path, + classes_file_path=classes_file_path, + collator_class_path=collator_class_path, + output_path=output_path, + batch_size=batch_size, + ) + + +if __name__ == "__main__": + # _generate_classes_props_json.py generate \ + # --model_path "model/ckpt/path" \ + # --splits_path "splits/file/path" \ + # --data_path "data.pt/file/path" \ + # --classes_file_path "classes/file/path" \ + # --model_class_path "model.class.path" \ + # --collator_class_path "collator.class.path" \ + # --batch_size 32 \ # Optional, default is 32 + # --output_path "output/file/path" # Optional, default will be the directory of classes_file_path + CLI(Main, as_positional=False) From 93e9b7398060821052370ed37519d48a249783f3 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 2 Jun 2025 16:28:13 +0200 Subject: [PATCH 57/78] save prediction to csv for predict operation mode --- chebai/ensemble/__init__.py | 1 - chebai/ensemble/_base.py | 44 +++++++++++++++----- chebai/ensemble/_controller.py | 2 +- chebai/ensemble/_wrappers/_base.py | 2 +- chebai/ensemble/_wrappers/_neural_network.py | 1 - 5 files changed, 36 insertions(+), 14 deletions(-) diff --git a/chebai/ensemble/__init__.py b/chebai/ensemble/__init__.py index b68b0fd2..e6f227de 100644 --- a/chebai/ensemble/__init__.py +++ b/chebai/ensemble/__init__.py @@ -1,4 +1,3 @@ -from ._base import EnsembleBase from ._consolidator import WeightedMajorityVoting from ._controller import NoActivationCondition from ._wrappers import NNWrapper diff --git a/chebai/ensemble/_base.py b/chebai/ensemble/_base.py index e9abbcdf..cc1b7258 100644 --- a/chebai/ensemble/_base.py +++ b/chebai/ensemble/_base.py @@ -1,11 +1,10 @@ from abc import ABC, abstractmethod from collections import deque from pathlib import Path -from typing import Any, Deque, Dict, Optional +from typing import Any, Deque, Dict import pandas as pd import torch -from lightning import LightningModule from chebai.result.classification import print_metrics @@ -29,7 +28,7 @@ def __init__( self, model_configs: Dict[str, Dict[str, Any]], data_processed_dir_main: str, - operation: str = EVAL_OP, + operation_mode: str = EVAL_OP, **kwargs: Any, ) -> None: """ @@ -42,13 +41,13 @@ def __init__( """ if bool(kwargs.get("_perform_validation_checks", True)): self._perform_validation_checks( - model_configs, operation=operation, **kwargs + model_configs, operation=operation_mode, **kwargs ) self._model_configs: Dict[str, Dict[str, Any]] = model_configs self._data_processed_dir_main: str = data_processed_dir_main - self._operation: str = operation - print(f"Ensemble operation: {self._operation}") + self._operation_mode: str = operation_mode + print(f"Ensemble operation: {self._operation_mode}") # These instance variable will be set in method `_process_input_to_ensemble` self._total_data_size: int | None = None @@ -126,7 +125,7 @@ def _perform_validation_checks( labels_set.add(model_labels_path) def _process_input_to_ensemble(self, **kwargs: Any) -> list[str] | Path: - if self._operation == PRED_OP: + if self._operation_mode == PRED_OP: p = Path(kwargs["smiles_list_file_path"]) smiles_list: list[str] = [] with open(p, "r") as f: @@ -138,7 +137,7 @@ def _process_input_to_ensemble(self, **kwargs: Any) -> list[str] | Path: smiles_list.append(smiles) self._total_data_size = len(smiles_list) return smiles_list - elif self._operation == EVAL_OP: + elif self._operation_mode == EVAL_OP: processed_dir_path = Path(self._data_processed_dir_main) data_pkl_path = processed_dir_path / "data.pkl" if not data_pkl_path.exists(): @@ -183,7 +182,7 @@ def run_ensemble(self) -> None: ) print( - f"Running {self.__class__.__name__} ensemble for {self._operation} operation..." + f"Running {self.__class__.__name__} ensemble for {self._operation_mode} operation..." ) while self._model_queue: model_name = self._model_queue.popleft() @@ -204,7 +203,7 @@ def run_ensemble(self) -> None: true_scores=true_scores, false_scores=false_scores ) - if self._operation == EVAL_OP: + if self._operation_mode == EVAL_OP: assert ( self._collated_labels is not None ), "Collated labels must be set for evaluation operation." @@ -214,6 +213,31 @@ def run_ensemble(self) -> None: self._device, classes=list(self._dm_labels.keys()), ) + else: + # Get SMILES and label names + smiles_list = self._ensemble_input + label_names = list(self._dm_labels.keys()) + # Efficient conversion from tensor to NumPy + preds_np = final_preds.detach().cpu().numpy() + + assert ( + len(smiles_list) == preds_np.shape[0] + ), "Length of SMILES list does not match number of predictions." + assert ( + len(label_names) == preds_np.shape[1] + ), "Number of label names does not match number of predictions." + + # Build DataFrame + df = pd.DataFrame(preds_np, columns=label_names) + df.insert(0, "SMILES", smiles_list) + + # Save to CSV + output_path = ( + Path(self._data_processed_dir_main) / "ensemble_predictions.csv" + ) + df.to_csv(output_path, index=False) + + print(f"Predictions saved to {output_path}") @abstractmethod def _controller( diff --git a/chebai/ensemble/_controller.py b/chebai/ensemble/_controller.py index b9a65a1f..6a42e3a1 100644 --- a/chebai/ensemble/_controller.py +++ b/chebai/ensemble/_controller.py @@ -52,7 +52,7 @@ def _controller( Dict[str, Tensor]: Dictionary containing predictions and confidence scores. """ wrapped_model = self._wrap_model(model_name) - if self._operation == PRED_OP: + if self._operation_mode == PRED_OP: model_output, model_props = wrapped_model.predict(model_input) else: model_output, model_props = wrapped_model.evaluate(model_input) diff --git a/chebai/ensemble/_wrappers/_base.py b/chebai/ensemble/_wrappers/_base.py index a1703917..88e0d7fe 100644 --- a/chebai/ensemble/_wrappers/_base.py +++ b/chebai/ensemble/_wrappers/_base.py @@ -43,7 +43,7 @@ def _generate_model_label_props(self, dm_labels) -> dict[str, torch.Tensor]: except Exception as e: raise Exception( f"Label '{label}' has an unexpected error \n Error: {e}" - ) + ) from e model_label_indices.append(dm_labels[label]) tpv_label_values.append(props["TPV"]) diff --git a/chebai/ensemble/_wrappers/_neural_network.py b/chebai/ensemble/_wrappers/_neural_network.py index 982a410c..72f9e60c 100644 --- a/chebai/ensemble/_wrappers/_neural_network.py +++ b/chebai/ensemble/_wrappers/_neural_network.py @@ -126,7 +126,6 @@ def _predict_from_list_of_smiles(self, smiles_list: list[str]) -> list: else: index_map[i] = len(token_dicts) token_dicts.append(d) - print(f"Predicting {len(token_dicts), token_dicts} out of {len(smiles_list)}") if token_dicts: model_output = self._forward_pass(token_dicts) if not isinstance(model_output, dict) and not "logits" in model_output: From 954431ca605fa8a2eaebb32d9e5d6da3093f3d81 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 2 Jun 2025 17:46:17 +0200 Subject: [PATCH 58/78] use multilabel cm --- .../_scripts/_generate_classes_props_json.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/chebai/ensemble/_scripts/_generate_classes_props_json.py b/chebai/ensemble/_scripts/_generate_classes_props_json.py index 7ab65284..d00b4e8b 100644 --- a/chebai/ensemble/_scripts/_generate_classes_props_json.py +++ b/chebai/ensemble/_scripts/_generate_classes_props_json.py @@ -6,7 +6,7 @@ import pandas as pd import torch from jsonargparse import CLI -from sklearn.metrics import confusion_matrix +from sklearn.metrics import multilabel_confusion_matrix from torch.utils.data import DataLoader from chebai.ensemble._utils import load_class @@ -36,30 +36,31 @@ def load_class_labels(path: str) -> list[str]: @staticmethod def compute_tpv_npv( - y_true: list[int], y_pred: list[int], class_names: list[str] + y_true: list[torch.Tensor], y_pred: list[torch.Tensor], class_names: list[str] ) -> dict[str, dict[str, float]]: """ - Compute TPV and NPV for each class using the confusion matrix. + Compute TPV and NPV for each class in a multi-label classification problem. Args: - y_true (list[int]): Ground truth labels. - y_pred (list[int]): Predicted labels. + y_true (list[Tensor]): List of binary ground truth label tensors per sample. + y_pred (list[Tensor]): List of binary prediction tensors per sample. class_names (list[str]): List of class names corresponding to class indices. Returns: dict[str, dict[str, float]]: Dictionary with class names as keys and TPV/NPV as values. """ - cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names)))) - metrics = {} + # Convert list of tensors to a single binary matrix + y_true_tensor = torch.stack(y_true).cpu().numpy().astype(int) + y_pred_tensor = torch.stack(y_pred).cpu().numpy().astype(int) + + cm = multilabel_confusion_matrix(y_true_tensor, y_pred_tensor) + metrics = {} for i, cls in enumerate(class_names): - TP = cm[i, i] - FP = cm[:, i].sum() - TP - FN = cm[i, :].sum() - TP - TN = cm.sum() - (TP + FP + FN) + tn, fp, fn, tp = cm[i].ravel() - TPV = TP / (TP + FP) if (TP + FP) > 0 else 0.0 - NPV = TN / (TN + FN) if (TN + FN) > 0 else 0.0 + TPV = tp / (tp + fp) if (tp + fp) > 0 else 0.0 + NPV = tn / (tn + fn) if (tn + fn) > 0 else 0.0 metrics[cls] = {"TPV": round(TPV, 4), "NPV": round(NPV, 4)} From 6ce02a71129bde5842d3d8b178f2316014e35bf9 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 9 Jun 2025 18:31:08 +0200 Subject: [PATCH 59/78] raise error for duplicate subclass/wrapper --- chebai/ensemble/_wrappers/_base.py | 31 ++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/chebai/ensemble/_wrappers/_base.py b/chebai/ensemble/_wrappers/_base.py index 88e0d7fe..b09b80ef 100644 --- a/chebai/ensemble/_wrappers/_base.py +++ b/chebai/ensemble/_wrappers/_base.py @@ -7,6 +7,8 @@ from .._constants import MODEL_CLS_PATH, MODEL_LBL_PATH +_MODEL_REGISTRY = {} + class BaseWrapper(ABC): def __init__( @@ -20,10 +22,31 @@ def __init__( self._model_config = model_config self._model_name = model_name self._model_class_path = self._model_config[MODEL_CLS_PATH] + self._model_labels_path = self._model_config[MODEL_LBL_PATH] self._model_props = self._generate_model_label_props(dm_labels=dm_labels) self.collated_labels = None + @classmethod + def _cls_name(cls) -> str: + return cls.__name__ + + @property + def name(self): + return f"Wrapper({self._cls_name()}) for model: {self._model_name}" + + def __init_subclass__(cls): + """ + Automatically registers subclasses in the model registry to prevent duplicates. + + Args: + **kwargs: Additional keyword arguments. + """ + if cls._cls_name() in _MODEL_REGISTRY: + raise ValueError(f"Model {cls._cls_name()} does already exist") + else: + _MODEL_REGISTRY[cls._cls_name()] = cls + def _generate_model_label_props(self, dm_labels) -> dict[str, torch.Tensor]: """ Generates label mask and confidence tensors (TPV, FPV) for a model. @@ -41,7 +64,7 @@ def _generate_model_label_props(self, dm_labels) -> dict[str, torch.Tensor]: try: self._validate_model_labels_json_element(labels_dict[label]) except Exception as e: - raise Exception( + raise RuntimeError( f"Label '{label}' has an unexpected error \n Error: {e}" ) from e @@ -110,10 +133,6 @@ def _validate_model_labels_json_element(label_dict: dict[str, float]) -> None: except (TypeError, ValueError) as e: raise ValueError(f"Invalid value for '{key}': {label_dict[key]}") from e - @property - def name(self): - return f"Wrapper({self.__class__.__name__}) for model: {self._model_name}" - def predict(self, x: list) -> tuple[dict, dict]: if not isinstance(x, list): raise TypeError(f"Input must be a list of SMILES strings, got {type(x)}") @@ -133,4 +152,4 @@ def evaluate( ) @abstractmethod - def _evaluate_from_data_file(self, data_file_path: str) -> dict: ... + def _evaluate_from_data_file(self, data_processed_dir_main: str) -> dict: ... From 549a71f1a31cfdeb199d0bc47bb9e70c2e8625f8 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 9 Jun 2025 20:42:05 +0200 Subject: [PATCH 60/78] add model load kwargs and move cls path to nn wrapper --- chebai/ensemble/_constants.py | 1 + chebai/ensemble/_wrappers/_base.py | 8 +++-- chebai/ensemble/_wrappers/_neural_network.py | 37 +++++++++++++------- 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/chebai/ensemble/_constants.py b/chebai/ensemble/_constants.py index 3d3b5db7..88349979 100644 --- a/chebai/ensemble/_constants.py +++ b/chebai/ensemble/_constants.py @@ -1,4 +1,5 @@ MODEL_CLS_PATH = "model_class_path" +MODEL_LD_KWARGS = "model_load_kwargs" MODEL_LBL_PATH = "model_labels_path" MODEL_CKPT_PATH = "model_ckpt_path" diff --git a/chebai/ensemble/_wrappers/_base.py b/chebai/ensemble/_wrappers/_base.py index b09b80ef..fc57f7f7 100644 --- a/chebai/ensemble/_wrappers/_base.py +++ b/chebai/ensemble/_wrappers/_base.py @@ -5,7 +5,7 @@ import torch -from .._constants import MODEL_CLS_PATH, MODEL_LBL_PATH +from .._constants import MODEL_LBL_PATH _MODEL_REGISTRY = {} @@ -18,10 +18,14 @@ def __init__( dm_labels: dict[str, int], **kwargs, ): + if MODEL_LBL_PATH not in model_config: + raise AttributeError( + f"Missing key {MODEL_LBL_PATH} in model '{model_name}' configuration." + ) + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._model_config = model_config self._model_name = model_name - self._model_class_path = self._model_config[MODEL_CLS_PATH] self._model_labels_path = self._model_config[MODEL_LBL_PATH] self._model_props = self._generate_model_label_props(dm_labels=dm_labels) diff --git a/chebai/ensemble/_wrappers/_neural_network.py b/chebai/ensemble/_wrappers/_neural_network.py index 72f9e60c..e9bde049 100644 --- a/chebai/ensemble/_wrappers/_neural_network.py +++ b/chebai/ensemble/_wrappers/_neural_network.py @@ -9,7 +9,13 @@ from chebai.preprocessing.reader import DataReader from chebai.preprocessing.structures import XYData -from .._constants import MODEL_CKPT_PATH, READER_CLS_PATH, READER_KWARGS +from .._constants import ( + MODEL_CKPT_PATH, + MODEL_CLS_PATH, + MODEL_LD_KWARGS, + READER_CLS_PATH, + READER_KWARGS, +) from .._utils import load_class from ._base import BaseWrapper @@ -22,12 +28,20 @@ def __init__(self, **kwargs): ) super().__init__(**kwargs) + self._model_class_path = self._model_config[MODEL_CLS_PATH] self._model_ckpt_path = self._model_config[MODEL_CKPT_PATH] + self._model_ld_kwargs: dict = ( + self._model_config[MODEL_LD_KWARGS] + if MODEL_LD_KWARGS in self._model_config + and self._model_config[MODEL_LD_KWARGS] + else {} + ) + self._reader_class_path = self._model_config[READER_CLS_PATH] self._reader_kwargs: dict = ( self._model_config[READER_KWARGS] - if self._model_config[READER_KWARGS] - else dict() + if READER_KWARGS in self._model_config and self._model_config[READER_KWARGS] + else {} ) reader_cls: Type[DataReader] = load_class(self._reader_class_path) @@ -35,9 +49,7 @@ def __init__(self, **kwargs): self._reader = reader_cls(**self._reader_kwargs) self._collator = reader_cls.COLLATOR() self.collated_labels = None - self._model: ChebaiBaseNet = self._load_model_( - input_dim=kwargs.get("input_dim", None) - ) + self._model: ChebaiBaseNet = self._load_model_() @classmethod def _validate_model_configs( @@ -55,10 +67,7 @@ def _validate_model_configs( AttributeError: If any model config is missing required keys. ValueError: If duplicate paths are found for model checkpoint, class, or labels. """ - required_keys = { - MODEL_CKPT_PATH, - READER_CLS_PATH, - } + required_keys = {MODEL_CKPT_PATH, READER_CLS_PATH, MODEL_CLS_PATH} missing_keys = required_keys - model_config.keys() if missing_keys: @@ -66,7 +75,7 @@ def _validate_model_configs( f"Missing keys {missing_keys} in model '{model_name}' configuration." ) - def _load_model_(self, input_dim: int | None) -> ChebaiBaseNet: + def _load_model_(self) -> ChebaiBaseNet: """ Loads a model checkpoint and its label-related properties. @@ -90,7 +99,7 @@ def _load_model_(self, input_dim: int | None) -> ChebaiBaseNet: ), f"{lightning_cls} must inherit from ChebaiBaseNet" try: model = lightning_cls.load_from_checkpoint( - self._model_ckpt_path, input_dim=5 + self._model_ckpt_path, input_dim=5, **self._model_ld_kwargs ) except Exception as e: raise RuntimeError( @@ -140,7 +149,9 @@ def _read_smiles(self, smiles): def _forward_pass(self, batch): collated_batch: XYData = self._collator(batch).to(self._device) self.collated_labels = collated_batch.y - processable_data = self._model._process_batch(collated_batch, 0) # noqa + processable_data = self._model._process_batch( # pylint: disable=W0212 + collated_batch, 0 + ) return self._model(processable_data, **processable_data["model_kwargs"]) def _evaluate_from_data_file( From 366c72bb84d0d5d7130499d42d7cc56798de627d Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 9 Jun 2025 20:43:27 +0200 Subject: [PATCH 61/78] refine chemlog wrapper --- chebai/ensemble/_wrappers/_chemlog.py | 47 ++++++++++----------------- 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/chebai/ensemble/_wrappers/_chemlog.py b/chebai/ensemble/_wrappers/_chemlog.py index 0df47e78..c4e43da0 100644 --- a/chebai/ensemble/_wrappers/_chemlog.py +++ b/chebai/ensemble/_wrappers/_chemlog.py @@ -1,14 +1,7 @@ -import os -import sys -from typing import Optional +from pathlib import Path -sys.path.append(os.path.join("..", "..", "..", "PycharmProjects", "chemlog2")) -import chemlog -from chebi_utils import CHEBI_FRAGMENT, get_transitive_predictions -from chemlog.classification.charge_classifier import ( - ChargeCategories, - get_charge_category, -) +import pandas as pd +from chemlog.classification.charge_classifier import get_charge_category from chemlog.classification.peptide_size_classifier import get_n_amino_acid_residues from chemlog.classification.proteinogenics_classifier import ( get_proteinogenic_amino_acids, @@ -18,31 +11,33 @@ is_emericellamide, ) from chemlog.cli import resolve_chebi_classes -from prediction_models.base import PredictionModel from rdkit import Chem +from chebai.ensemble._wrappers._base import BaseWrapper + -class ChemLog(PredictionModel): +class ChemLog(BaseWrapper): - def __init__( - self, - name: Optional[str] = None, - description: Optional[ - str - ] = "A rule-based model for predicting peptides and peptide-like molecules.", - ): - super().__init__(name, description) + def _predict_from_list_of_smiles(self, smiles_list): + return self.get_chemlog_results(smiles_list) + + def _evaluate_from_data_file( + self, data_processed_dir_main: Path, data_file_name="data.pkl" + ) -> list: + data_df = pd.read_pickle(data_processed_dir_main / data_file_name) + smiles_list = data_df["SMILES"].to_list() + return self.get_chemlog_results(smiles_list) def get_chemlog_results(self, smiles_list) -> list: all_preds = [] - for i, smiles in enumerate(smiles_list): + for smiles in smiles_list: mol = Chem.MolFromSmiles(smiles, sanitize=False) if mol is None or not smiles: all_preds.append(None) continue mol.UpdatePropertyCache() charge_category = get_charge_category(mol) - n_amino_acid_residues, add_output = get_n_amino_acid_residues(mol) + n_amino_acid_residues, _ = get_n_amino_acid_residues(mol) r = { "charge_category": charge_category.name, "n_amino_acid_residues": n_amino_acid_residues, @@ -65,7 +60,7 @@ def get_chemlog_result_info(self, smiles): mol.UpdatePropertyCache() try: Chem.Kekulize(mol) - except Chem.KekulizeException as e: + except Chem.KekulizeException: pass charge_category = get_charge_category(mol) @@ -95,9 +90,3 @@ def get_chemlog_result_info(self, smiles): results["2,5-diketopiperazines_atoms"] = diketopiperazine[1] return {**results, **add_output} - - def predict(self, smiles_list): - return [ - get_transitive_predictions([positive_i]) - for positive_i in self.get_chemlog_results(smiles_list) - ] From 2739c64f9fc75e5827c64459906c8b0f348ad1c6 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 11 Jun 2025 00:30:36 +0200 Subject: [PATCH 62/78] use data class instead of explicit reader, collator - for clean code and code reusability --- chebai/ensemble/_constants.py | 5 +- chebai/ensemble/_wrappers/_neural_network.py | 51 ++++++++++---------- 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/chebai/ensemble/_constants.py b/chebai/ensemble/_constants.py index 88349979..04426baf 100644 --- a/chebai/ensemble/_constants.py +++ b/chebai/ensemble/_constants.py @@ -5,9 +5,8 @@ WRAPPER_CLS_PATH = "wrapper_class_path" -READER_CLS_PATH = "reader_class_path" -READER_KWARGS = "reader_kwargs" - +DATA_CLS_PATH = "data_class_path" +DATA_CLS_KWARGS = "data_class_kwargs" PRED_OP = "prediction" EVAL_OP = "evaluation" diff --git a/chebai/ensemble/_wrappers/_neural_network.py b/chebai/ensemble/_wrappers/_neural_network.py index e9bde049..ae188ba4 100644 --- a/chebai/ensemble/_wrappers/_neural_network.py +++ b/chebai/ensemble/_wrappers/_neural_network.py @@ -1,20 +1,19 @@ import os from pathlib import Path -from typing import Type import torch from rdkit import Chem from chebai.models import ChebaiBaseNet -from chebai.preprocessing.reader import DataReader +from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.preprocessing.structures import XYData from .._constants import ( + DATA_CLS_KWARGS, + DATA_CLS_PATH, MODEL_CKPT_PATH, MODEL_CLS_PATH, MODEL_LD_KWARGS, - READER_CLS_PATH, - READER_KWARGS, ) from .._utils import load_class from ._base import BaseWrapper @@ -30,24 +29,9 @@ def __init__(self, **kwargs): self._model_class_path = self._model_config[MODEL_CLS_PATH] self._model_ckpt_path = self._model_config[MODEL_CKPT_PATH] - self._model_ld_kwargs: dict = ( - self._model_config[MODEL_LD_KWARGS] - if MODEL_LD_KWARGS in self._model_config - and self._model_config[MODEL_LD_KWARGS] - else {} - ) - - self._reader_class_path = self._model_config[READER_CLS_PATH] - self._reader_kwargs: dict = ( - self._model_config[READER_KWARGS] - if READER_KWARGS in self._model_config and self._model_config[READER_KWARGS] - else {} - ) + self._model_ld_kwargs: dict = self._model_config.get(MODEL_LD_KWARGS, {}) - reader_cls: Type[DataReader] = load_class(self._reader_class_path) - assert issubclass(reader_cls, DataReader), "" - self._reader = reader_cls(**self._reader_kwargs) - self._collator = reader_cls.COLLATOR() + self._data_cls_instance: XYBaseDataModule = self._load_data_instance() self.collated_labels = None self._model: ChebaiBaseNet = self._load_model_() @@ -67,7 +51,7 @@ def _validate_model_configs( AttributeError: If any model config is missing required keys. ValueError: If duplicate paths are found for model checkpoint, class, or labels. """ - required_keys = {MODEL_CKPT_PATH, READER_CLS_PATH, MODEL_CLS_PATH} + required_keys = {MODEL_CKPT_PATH, DATA_CLS_PATH, MODEL_CLS_PATH} missing_keys = required_keys - model_config.keys() if missing_keys: @@ -75,6 +59,15 @@ def _validate_model_configs( f"Missing keys {missing_keys} in model '{model_name}' configuration." ) + def _load_data_instance(self): + data_cls = load_class(self._model_config[DATA_CLS_PATH]) + assert isinstance(data_cls, type), f"{data_cls} is not a class." + assert issubclass( + data_cls, XYBaseDataModule + ), f"{data_cls} must inherit from XYBaseDataModule" + data_kwargs = self._model_config.get(DATA_CLS_KWARGS, {}) + return data_cls(**data_kwargs) + def _load_model_(self) -> ChebaiBaseNet: """ Loads a model checkpoint and its label-related properties. @@ -143,11 +136,15 @@ def _predict_from_list_of_smiles(self, smiles_list: list[str]) -> list: else: raise ValueError() - def _read_smiles(self, smiles): - return self._reader.to_data(dict(features=smiles, labels=None)) + def _read_smiles(self, smiles: str): + return self._data_cls_instance.reader.to_data( + dict(features=smiles, labels=None) + ) def _forward_pass(self, batch): - collated_batch: XYData = self._collator(batch).to(self._device) + collated_batch: XYData = self._data_cls_instance.reader.collator(batch).to( + self._device + ) self.collated_labels = collated_batch.y processable_data = self._model._process_batch( # pylint: disable=W0212 collated_batch, 0 @@ -158,7 +155,9 @@ def _evaluate_from_data_file( self, data_processed_dir_main: Path, data_file_name="data.pt" ) -> list: data = torch.load( - data_processed_dir_main / self._reader.name() / data_file_name, + data_processed_dir_main + / self._data_cls_instance.reader.name() + / data_file_name, weights_only=False, map_location=self._device, ) From a96ae435a3725736b51632f851f6ee6a54068fd9 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 11 Jun 2025 00:41:40 +0200 Subject: [PATCH 63/78] refine gnn wrapper --- chebai/ensemble/_wrappers/_gnn.py | 79 +++++++++----------- chebai/ensemble/_wrappers/_neural_network.py | 4 + 2 files changed, 38 insertions(+), 45 deletions(-) diff --git a/chebai/ensemble/_wrappers/_gnn.py b/chebai/ensemble/_wrappers/_gnn.py index 3b777409..21efd27a 100644 --- a/chebai/ensemble/_wrappers/_gnn.py +++ b/chebai/ensemble/_wrappers/_gnn.py @@ -1,65 +1,47 @@ -from typing import Optional, Union +from pathlib import Path import chebai_graph.preprocessing.properties as p import torch -from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred -from chebai_graph.preprocessing.datasets.chebi import ( - ChEBI50GraphProperties, - ChEBI100GraphProperties, - GraphPropertiesMixIn, -) from chebai_graph.preprocessing.property_encoder import IndexEncoder, OneHotEncoder from torch_geometric.data.data import Data as GeomData +from .._constants import DATA_CLS_KWARGS from ._neural_network import NNWrapper -if torch.cuda.is_available(): - device = "cuda" -else: - device = "cpu" - class GNNResGated(NNWrapper): - def __init__( - self, - checkpoint_path: str, - data_class: Union[GraphPropertiesMixIn, str], - prediction_headers_path: str, - batch_size: Optional[int] = 32, - name: Optional[str] = None, - description: Optional[str] = "Residual-gated Graph Convolutional Network for " - "predicting arbitrary ChEBI classes.", - ): - super().__init__(prediction_headers_path, batch_size, name, description) - self.model = ResGatedGraphConvNetGraphPred.load_from_checkpoint( - checkpoint_path, - map_location=torch.device(device), - criterion=None, - strict=False, - metrics=dict(train=dict(), test=dict(), validation=dict()), - pretrained_checkpoint=None, - config={ - "in_length": 256, - "hidden_length": 512, - "dropout_rate": 0.1, - "n_conv_layers": 3, - "n_linear_layers": 3, - "n_atom_properties": 158, - "n_bond_properties": 7, - "n_molecule_properties": 200, - }, + def _pre_load_hook(self): + self._model_config[DATA_CLS_KWARGS] = self._model_config.get( + DATA_CLS_KWARGS, + dict( + properties=[ + p.AtomType(), + p.NumAtomBonds(), + p.AtomCharge(), + p.AtomAromaticity(), + p.AtomHybridization(), + p.AtomNumHs(), + p.BondType(), + p.BondInRing(), + p.BondAromaticity(), + p.RDKit2DNormalized(), + ] + ), ) + return super()._pre_load_hook() def _read_smiles(self, smiles): - d = self._reader.to_data(dict(features=smiles, labels=None)) + d = self._data_cls_instance.reader.to_data(dict(features=smiles, labels=None)) geom_data = d["features"] - assert isinstance(geom_data, GeomData), "" + assert isinstance(geom_data, GeomData), "Must be an instance of GeoData" edge_attr = geom_data.edge_attr x = geom_data.x molecule_attr = torch.empty((1, 0)) - for property in self.data_class.properties: - property_values = reader.read_property(smiles, property) + for property in self._data_cls_instance.properties: + property_values = self._data_cls_instance.reader.read_property( + smiles, property + ) encoded_values = [] for value in property_values: # cant use standard encode for index encoder because model has been trained on a certain range of values @@ -77,7 +59,7 @@ def _read_smiles(self, smiles): ) if isinstance(property.encoder, OneHotEncoder): encoded_values.append( - torch.nn.functional.one_hot( + torch.nn.functional.one_hot( # pylint: disable=E1102 torch.tensor(index), num_classes=property.encoder.get_encoding_length(), ) @@ -113,3 +95,10 @@ def _read_smiles(self, smiles): molecule_attr=molecule_attr, ) return d + + # def _evaluate_from_data_file( + # self, data_processed_dir_main: Path, data_file_name="data.pt" + # ) -> list: + # data_path = data_processed_dir_main / self._reader.name() / data_file_name + # data_dict = self._data_class.load_processed_data_from_file(data_path) + # return self._forward_pass(data) diff --git a/chebai/ensemble/_wrappers/_neural_network.py b/chebai/ensemble/_wrappers/_neural_network.py index ae188ba4..dbf5c4a1 100644 --- a/chebai/ensemble/_wrappers/_neural_network.py +++ b/chebai/ensemble/_wrappers/_neural_network.py @@ -31,6 +31,8 @@ def __init__(self, **kwargs): self._model_ckpt_path = self._model_config[MODEL_CKPT_PATH] self._model_ld_kwargs: dict = self._model_config.get(MODEL_LD_KWARGS, {}) + self._pre_load_hook() + self._data_cls_instance: XYBaseDataModule = self._load_data_instance() self.collated_labels = None self._model: ChebaiBaseNet = self._load_model_() @@ -59,6 +61,8 @@ def _validate_model_configs( f"Missing keys {missing_keys} in model '{model_name}' configuration." ) + def _pre_load_hook(self) -> None: ... + def _load_data_instance(self): data_cls = load_class(self._model_config[DATA_CLS_PATH]) assert isinstance(data_cls, type), f"{data_cls} is not a class." From a6800b3de61e4ea13f31942f248e907150e107de Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 12 Jun 2025 11:56:21 +0200 Subject: [PATCH 64/78] correct PPV and FPV key and rectify nn wrapper --- chebai/ensemble/_controller.py | 4 ++-- .../_scripts/_generate_classes_props_json.py | 4 ++-- chebai/ensemble/_wrappers/_base.py | 21 ++++++------------ chebai/ensemble/_wrappers/_neural_network.py | 22 ++++++------------- 4 files changed, 18 insertions(+), 33 deletions(-) diff --git a/chebai/ensemble/_controller.py b/chebai/ensemble/_controller.py index 6a42e3a1..9509f1b5 100644 --- a/chebai/ensemble/_controller.py +++ b/chebai/ensemble/_controller.py @@ -31,7 +31,7 @@ def __init__(self, **kwargs: Any): """ super().__init__(**kwargs) self._kwargs = kwargs - # If an activation condition correponding model is added to queue, removed from this set + # If an activation condition corresponding model is added to queue, removed from this set # This is in order to avoid re-adding models that have already been processed self._model_key_set: set[str] = set(self._model_configs.keys()) @@ -55,7 +55,7 @@ def _controller( if self._operation_mode == PRED_OP: model_output, model_props = wrapped_model.predict(model_input) else: - model_output, model_props = wrapped_model.evaluate(model_input) + model_output, model_props = wrapped_model.evaluate() if ( self._collated_labels is None and wrapped_model.collated_labels is not None diff --git a/chebai/ensemble/_scripts/_generate_classes_props_json.py b/chebai/ensemble/_scripts/_generate_classes_props_json.py index d00b4e8b..0ce5fc9f 100644 --- a/chebai/ensemble/_scripts/_generate_classes_props_json.py +++ b/chebai/ensemble/_scripts/_generate_classes_props_json.py @@ -59,10 +59,10 @@ def compute_tpv_npv( for i, cls in enumerate(class_names): tn, fp, fn, tp = cm[i].ravel() - TPV = tp / (tp + fp) if (tp + fp) > 0 else 0.0 + PPV = tp / (tp + fp) if (tp + fp) > 0 else 0.0 NPV = tn / (tn + fn) if (tn + fn) > 0 else 0.0 - metrics[cls] = {"TPV": round(TPV, 4), "NPV": round(NPV, 4)} + metrics[cls] = {"PPV": round(TPV, 4), "NPV": round(NPV, 4)} return metrics diff --git a/chebai/ensemble/_wrappers/_base.py b/chebai/ensemble/_wrappers/_base.py index fc57f7f7..56bc54ed 100644 --- a/chebai/ensemble/_wrappers/_base.py +++ b/chebai/ensemble/_wrappers/_base.py @@ -73,8 +73,8 @@ def _generate_model_label_props(self, dm_labels) -> dict[str, torch.Tensor]: ) from e model_label_indices.append(dm_labels[label]) - tpv_label_values.append(props["TPV"]) - fpv_label_values.append(props["FPV"]) + tpv_label_values.append(props["PPV"]) + fpv_label_values.append(props["NPV"]) if not all([model_label_indices, tpv_label_values, fpv_label_values]): raise ValueError( @@ -121,13 +121,13 @@ def _validate_model_labels_json_element(label_dict: dict[str, float]) -> None: Validates a label confidence dictionary to ensure required keys and values are valid. Args: - label_dict (Dict[str, Any]): Label data with TPV and FPV keys. + label_dict (Dict[str, Any]): Label data with PPV and NPV keys. Raises: AttributeError: If required keys are missing. ValueError: If values are not valid floats or are negative. """ - for key in ["TPV", "FPV"]: + for key in ["PPV", "NPV"]: if key not in label_dict: raise AttributeError(f"Missing key '{key}' in label dict.") try: @@ -145,15 +145,8 @@ def predict(self, x: list) -> tuple[dict, dict]: @abstractmethod def _predict_from_list_of_smiles(self, smiles_list: list) -> dict: ... - def evaluate( - self, data_processed_dir_main: Path, **kwargs: any - ) -> tuple[dict, dict]: - if not data_processed_dir_main.is_dir(): - raise NotADirectoryError(f"{data_processed_dir_main} is not a directory.") - return ( - self._evaluate_from_data_file(data_processed_dir_main, **kwargs), - self._model_props, - ) + def evaluate(self, **kwargs) -> tuple[dict, dict]: + return self._evaluate_from_data_file(**kwargs), self._model_props @abstractmethod - def _evaluate_from_data_file(self, data_processed_dir_main: str) -> dict: ... + def _evaluate_from_data_file(self, **kwargs) -> dict: ... diff --git a/chebai/ensemble/_wrappers/_neural_network.py b/chebai/ensemble/_wrappers/_neural_network.py index dbf5c4a1..16937e01 100644 --- a/chebai/ensemble/_wrappers/_neural_network.py +++ b/chebai/ensemble/_wrappers/_neural_network.py @@ -1,7 +1,6 @@ import os from pathlib import Path -import torch from rdkit import Chem from chebai.models import ChebaiBaseNet @@ -113,7 +112,6 @@ def _load_model_(self) -> ChebaiBaseNet: def _predict_from_list_of_smiles(self, smiles_list: list[str]) -> list: token_dicts = [] could_not_parse = [] - index_map = dict() for i, smiles in enumerate(smiles_list): try: # Try to parse the smiles string @@ -130,7 +128,6 @@ def _predict_from_list_of_smiles(self, smiles_list: list[str]) -> list: if rdmol is None: could_not_parse.append(i) else: - index_map[i] = len(token_dicts) token_dicts.append(d) if token_dicts: model_output = self._forward_pass(token_dicts) @@ -140,12 +137,12 @@ def _predict_from_list_of_smiles(self, smiles_list: list[str]) -> list: else: raise ValueError() - def _read_smiles(self, smiles: str): + def _read_smiles(self, smiles: str) -> dict: return self._data_cls_instance.reader.to_data( dict(features=smiles, labels=None) ) - def _forward_pass(self, batch): + def _forward_pass(self, batch: list[dict]) -> dict: collated_batch: XYData = self._data_cls_instance.reader.collator(batch).to( self._device ) @@ -155,14 +152,9 @@ def _forward_pass(self, batch): ) return self._model(processable_data, **processable_data["model_kwargs"]) - def _evaluate_from_data_file( - self, data_processed_dir_main: Path, data_file_name="data.pt" - ) -> list: - data = torch.load( - data_processed_dir_main - / self._data_cls_instance.reader.name() - / data_file_name, - weights_only=False, - map_location=self._device, + def _evaluate_from_data_file(self) -> list: + filename = self._data_cls_instance.processed_file_names_dict["data"] + data_list_of_dict = self._data_cls_instance.load_processed_data_from_file( + os.path.join(self._data_cls_instance.processed_dir, filename) ) - return self._forward_pass(data) + return self._forward_pass(data_list_of_dict) From 64b3e7ee221f094db0925b083377a7be2a909560 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 12 Jun 2025 13:35:36 +0200 Subject: [PATCH 65/78] load cls, load model as utilities --- chebai/ensemble/_utils.py | 56 +++++++++++++++++ chebai/ensemble/_wrappers/_neural_network.py | 64 ++++---------------- 2 files changed, 68 insertions(+), 52 deletions(-) diff --git a/chebai/ensemble/_utils.py b/chebai/ensemble/_utils.py index 1da494f1..0e5238a0 100644 --- a/chebai/ensemble/_utils.py +++ b/chebai/ensemble/_utils.py @@ -1,7 +1,63 @@ import importlib +from pathlib import Path + +from chebai.models.base import ChebaiBaseNet +from chebai.preprocessing.datasets.base import XYBaseDataModule def load_class(class_path: str) -> type: module_path, class_name = class_path.rsplit(".", 1) module = importlib.import_module(module_path) return getattr(module, class_name) + + +def load_data_instance(data_cls_path: str, data_cls_kwargs: dict): + assert isinstance(data_cls_kwargs, dict), "data_cls_kwargs must be a dict" + data_cls = load_class(data_cls_path) + assert isinstance(data_cls, type), f"{data_cls} is not a class." + assert issubclass( + data_cls, XYBaseDataModule + ), f"{data_cls} must inherit from XYBaseDataModule" + return data_cls(**data_cls_kwargs) + + +def load_model_for_inference( + model_ckpt_path: str, model_cls_path: str, model_load_kwargs: dict, **kwargs +) -> ChebaiBaseNet: + """ + Loads a model checkpoint and its label-related properties. + + Args: + input_dim (int): Name of the model to load. + + Returns: + Tuple[LightningModule, Dict[str, torch.Tensor]]: The model and its label properties. + """ + assert isinstance(model_load_kwargs, dict), "model_load_kwargs must be a dict" + + model_name = kwargs.get("model_name", model_ckpt_path) + + if not Path(model_ckpt_path).exists(): + raise FileNotFoundError( + f"Model path '{model_ckpt_path}' for '{model_name}' does not exist." + ) + + lightning_cls = load_class(model_cls_path) + + assert isinstance(lightning_cls, type), f"{lightning_cls} is not a class." + assert issubclass( + lightning_cls, ChebaiBaseNet + ), f"{lightning_cls} must inherit from ChebaiBaseNet" + try: + model = lightning_cls.load_from_checkpoint( + model_ckpt_path, input_dim=5, **model_load_kwargs + ) + except Exception as e: + raise RuntimeError(f"Error loading model {model_name} \n Error: {e}") from e + + assert isinstance( + model, ChebaiBaseNet + ), f"Model: {model}(Model Name: {model_name}) is not a ChebaiBaseNet instance." + model.eval() + model.freeze() + return model diff --git a/chebai/ensemble/_wrappers/_neural_network.py b/chebai/ensemble/_wrappers/_neural_network.py index 16937e01..7b3dcd6c 100644 --- a/chebai/ensemble/_wrappers/_neural_network.py +++ b/chebai/ensemble/_wrappers/_neural_network.py @@ -1,4 +1,3 @@ -import os from pathlib import Path from rdkit import Chem @@ -14,7 +13,7 @@ MODEL_CLS_PATH, MODEL_LD_KWARGS, ) -from .._utils import load_class +from .._utils import load_class, load_data_instance, load_model_for_inference from ._base import BaseWrapper @@ -32,9 +31,17 @@ def __init__(self, **kwargs): self._pre_load_hook() - self._data_cls_instance: XYBaseDataModule = self._load_data_instance() + self._data_cls_instance: XYBaseDataModule = load_data_instance( + self._model_config[DATA_CLS_PATH], + self._model_config.get(DATA_CLS_KWARGS, {}), + ) self.collated_labels = None - self._model: ChebaiBaseNet = self._load_model_() + self._model: ChebaiBaseNet = load_model_for_inference( + self._model_config[MODEL_CKPT_PATH], + self._model_config[MODEL_CLS_PATH], + self._model_config.get(MODEL_LD_KWARGS, {}), + **kwargs, + ) @classmethod def _validate_model_configs( @@ -62,53 +69,6 @@ def _validate_model_configs( def _pre_load_hook(self) -> None: ... - def _load_data_instance(self): - data_cls = load_class(self._model_config[DATA_CLS_PATH]) - assert isinstance(data_cls, type), f"{data_cls} is not a class." - assert issubclass( - data_cls, XYBaseDataModule - ), f"{data_cls} must inherit from XYBaseDataModule" - data_kwargs = self._model_config.get(DATA_CLS_KWARGS, {}) - return data_cls(**data_kwargs) - - def _load_model_(self) -> ChebaiBaseNet: - """ - Loads a model checkpoint and its label-related properties. - - Args: - input_dim (int): Name of the model to load. - - Returns: - Tuple[LightningModule, Dict[str, torch.Tensor]]: The model and its label properties. - """ - - if not os.path.exists(self._model_ckpt_path): - raise FileNotFoundError( - f"Model path '{self._model_ckpt_path}' for '{self._model_name}' does not exist." - ) - - lightning_cls = load_class(self._model_class_path) - - assert isinstance(lightning_cls, type), f"{lightning_cls} is not a class." - assert issubclass( - lightning_cls, ChebaiBaseNet - ), f"{lightning_cls} must inherit from ChebaiBaseNet" - try: - model = lightning_cls.load_from_checkpoint( - self._model_ckpt_path, input_dim=5, **self._model_ld_kwargs - ) - except Exception as e: - raise RuntimeError( - f"Error loading model {self._model_name} \n Error: {e}" - ) from e - - assert isinstance( - model, ChebaiBaseNet - ), f"{model} is not a ChebaiBaseNet instance." - model.eval() - model.freeze() - return model - def _predict_from_list_of_smiles(self, smiles_list: list[str]) -> list: token_dicts = [] could_not_parse = [] @@ -155,6 +115,6 @@ def _forward_pass(self, batch: list[dict]) -> dict: def _evaluate_from_data_file(self) -> list: filename = self._data_cls_instance.processed_file_names_dict["data"] data_list_of_dict = self._data_cls_instance.load_processed_data_from_file( - os.path.join(self._data_cls_instance.processed_dir, filename) + Path(self._data_cls_instance.processed_dir) / filename ) return self._forward_pass(data_list_of_dict) From 7e673f28c9de64ef1c60ef0b54966aa88a10564a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 12 Jun 2025 13:36:37 +0200 Subject: [PATCH 66/78] evaluate_from_data_file not needed for gnn wrapper --- chebai/ensemble/_wrappers/_gnn.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/chebai/ensemble/_wrappers/_gnn.py b/chebai/ensemble/_wrappers/_gnn.py index 21efd27a..c01a2ffd 100644 --- a/chebai/ensemble/_wrappers/_gnn.py +++ b/chebai/ensemble/_wrappers/_gnn.py @@ -14,8 +14,8 @@ class GNNResGated(NNWrapper): def _pre_load_hook(self): self._model_config[DATA_CLS_KWARGS] = self._model_config.get( DATA_CLS_KWARGS, - dict( - properties=[ + { + "properties": [ p.AtomType(), p.NumAtomBonds(), p.AtomCharge(), @@ -27,9 +27,8 @@ def _pre_load_hook(self): p.BondAromaticity(), p.RDKit2DNormalized(), ] - ), + }, ) - return super()._pre_load_hook() def _read_smiles(self, smiles): d = self._data_cls_instance.reader.to_data(dict(features=smiles, labels=None)) @@ -95,10 +94,3 @@ def _read_smiles(self, smiles): molecule_attr=molecule_attr, ) return d - - # def _evaluate_from_data_file( - # self, data_processed_dir_main: Path, data_file_name="data.pt" - # ) -> list: - # data_path = data_processed_dir_main / self._reader.name() / data_file_name - # data_dict = self._data_class.load_processed_data_from_file(data_path) - # return self._forward_pass(data) From 0c1be27e2ff3c87d525d15e525bafc8f8e2874d9 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 12 Jun 2025 13:47:26 +0200 Subject: [PATCH 67/78] use dataclass and utilities --- .../_scripts/_generate_classes_props_json.py | 209 ++++++++---------- 1 file changed, 93 insertions(+), 116 deletions(-) diff --git a/chebai/ensemble/_scripts/_generate_classes_props_json.py b/chebai/ensemble/_scripts/_generate_classes_props_json.py index 0ce5fc9f..519628e9 100644 --- a/chebai/ensemble/_scripts/_generate_classes_props_json.py +++ b/chebai/ensemble/_scripts/_generate_classes_props_json.py @@ -1,196 +1,173 @@ -"""Generate TPV/NPV JSON for multi-class classification models.""" - import json from pathlib import Path -import pandas as pd import torch from jsonargparse import CLI from sklearn.metrics import multilabel_confusion_matrix -from torch.utils.data import DataLoader -from chebai.ensemble._utils import load_class -from chebai.models.base import ChebaiBaseNet -from chebai.preprocessing.collate import Collator +from chebai.ensemble._utils import load_data_instance, load_model_for_inference +from chebai.preprocessing.datasets.base import XYBaseDataModule class ClassesPropertiesGenerator: """ - Computes TPV (Precision/ True Predictive Value) and NPV (Negative Predictive Value) - for each class in a multi-class classification problem using a PyTorch Lightning model. + Computes PPV (Positive Predictive Value) and NPV (Negative Predictive Value) + for each class in a multi-label classification problem using a PyTorch Lightning model. """ @staticmethod - def load_class_labels(path: str) -> list[str]: + def load_class_labels(path: Path) -> list[str]: """ Load a list of class names from a .json or .txt file. Args: - path (str): Path to class labels file. + path: Path to the class labels file (txt or json). Returns: - list[str]: List of class names. + A list of class names, one per line. """ - with open(path) as f: + path = Path(path) + with path.open() as f: return [line.strip() for line in f if line.strip()] @staticmethod def compute_tpv_npv( - y_true: list[torch.Tensor], y_pred: list[torch.Tensor], class_names: list[str] + y_true: list[torch.Tensor], + y_pred: list[torch.Tensor], + class_names: list[str], ) -> dict[str, dict[str, float]]: """ - Compute TPV and NPV for each class in a multi-label classification problem. + Compute TPV (precision) and NPV for each class in a multi-label setting. Args: - y_true (list[Tensor]): List of binary ground truth label tensors per sample. - y_pred (list[Tensor]): List of binary prediction tensors per sample. - class_names (list[str]): List of class names corresponding to class indices. + y_true: List of binary ground-truth label tensors, one tensor per sample. + y_pred: List of binary prediction tensors, one tensor per sample. + class_names: Ordered list of class names corresponding to class indices. Returns: - dict[str, dict[str, float]]: Dictionary with class names as keys and TPV/NPV as values. + Dictionary mapping each class name to its TPV and NPV metrics: + { + "class_name": {"PPV": float, "NPV": float}, + ... + } """ - # Convert list of tensors to a single binary matrix - y_true_tensor = torch.stack(y_true).cpu().numpy().astype(int) - y_pred_tensor = torch.stack(y_pred).cpu().numpy().astype(int) - - cm = multilabel_confusion_matrix(y_true_tensor, y_pred_tensor) - - metrics = {} - for i, cls in enumerate(class_names): - tn, fp, fn, tp = cm[i].ravel() + # Stack per-sample tensors into (n_samples, n_classes) numpy arrays + true_np = torch.stack(y_true).cpu().numpy().astype(int) + pred_np = torch.stack(y_pred).cpu().numpy().astype(int) - PPV = tp / (tp + fp) if (tp + fp) > 0 else 0.0 - NPV = tn / (tn + fn) if (tn + fn) > 0 else 0.0 + # Compute confusion matrix for each class + cm = multilabel_confusion_matrix(true_np, pred_np) - metrics[cls] = {"PPV": round(TPV, 4), "NPV": round(NPV, 4)} - - return metrics + results: dict[str, dict[str, float]] = {} + for idx, cls_name in enumerate(class_names): + tn, fp, fn, tp = cm[idx].ravel() + tpv = tp / (tp + fp) if (tp + fp) > 0 else 0.0 + npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0 + results[cls_name] = {"PPV": round(tpv, 4), "NPV": round(npv, 4)} + return results def generate_props( self, - model_path: str, + model_ckpt_path: str, + data_cls_path: str, model_class_path: str, - splits_path: str, - data_path: str, - classes_file_path: str, - collator_class_path: str, - output_path: str, - batch_size: int = 32, + output_path: str | None = None, + data_cls_kwargs: dict | None = None, + model_load_kwargs: dict | None = None, ) -> None: """ - Main method to compute TPV/NPV from validation data and save as JSON. + Run inference on validation set, compute TPV/NPV per class, and save to JSON. Args: - model_path (str): Path to the PyTorch Lightning model checkpoint. - model_class_path (str): Full path to the model class to load. - splits_path (str): CSV file with 'id' and 'split' columns. - data_path (str): processed `data.pt` file path. - classes_file_path (str): Path to file containing class names `classes.txt`. - collator_class_path (str): Full path to the collator class. - output_path (str): Output path for the saving JSON file. - batch_size (int): Batch size for inference. + model_ckpt_path: Path to the PyTorch Lightning checkpoint file. + data_cls_path: Import path or module path to the data module class. + model_class_path: Import path or module path to the model class. + output_path: Optional path where to write the JSON metrics file. + Defaults to '/classes.json'. + data_cls_kwargs: Optional dict of kwargs to initialize the data module. + model_load_kwargs: Optional dict of kwargs when loading the model. """ print("Extracting validation data for computation...") - splits_df = pd.read_csv(splits_path) - validation_ids = set(splits_df[splits_df["split"] == "validation"]["id"]) - data_df = pd.DataFrame(torch.load(data_path, weights_only=False)) - val_df = data_df[data_df["ident"].isin(validation_ids)] - - # Load model - print(f"Loading model from {model_path} ...") - model_cls = load_class(model_class_path) - if not issubclass(model_cls, ChebaiBaseNet): - raise TypeError("Loaded model is not a valid LightningModule.") - model = model_cls.load_from_checkpoint(model_path, input_dim=3) - model.freeze() - model.eval() - - # Load collator - collator_cls = load_class(collator_class_path) - if not issubclass(collator_cls, Collator): - raise TypeError(f"{collator_cls} must be subclass of Collator") - collator = collator_cls() - - val_loader = DataLoader( - val_df.to_dict(orient="records"), - collate_fn=collator, - batch_size=batch_size, - shuffle=False, + + data_module: XYBaseDataModule = load_data_instance( + data_cls_path, data_cls_kwargs or {} + ) + model = load_model_for_inference( + model_ckpt_path, model_class_path, model_load_kwargs or {} ) + val_loader = data_module.val_dataloader() print("Running inference on validation data...") + y_true, y_pred = [], [] for batch_idx, batch in enumerate(val_loader): - data = model._process_batch(batch, batch_idx=batch_idx) + data = model._process_batch( # pylint: disable=W0212 + batch, batch_idx=batch_idx + ) labels = data["labels"] - model_output = model(data, **data.get("model_kwargs", dict())) - sigmoid_logits = torch.sigmoid(model_output["logits"]) - preds = sigmoid_logits > 0.5 + outputs = model(data, **data.get("model_kwargs", {})) + logits = outputs["logits"] + preds = torch.sigmoid(logits) > 0.5 y_pred.extend(preds) y_true.extend(labels) - # Compute and save metrics print("Computing TPV and NPV metrics...") - classes_file_path = Path(classes_file_path) + classes_file = Path(data_module.processed_dir_main) / "classes.txt" if output_path is None: - output_path = classes_file_path.parent / "classes.json" - class_names = self.load_class_labels(classes_file_path) + output_file = Path(data_module.processed_dir_main) / "classes.json" + else: + output_file = Path(output_path) + + class_names = self.load_class_labels(classes_file) metrics = self.compute_tpv_npv(y_true, y_pred, class_names) - with open(output_path, "w") as f: + + with output_file.open("w") as f: json.dump(metrics, f, indent=2) - print(f"Saved TPV/NPV metrics to {output_path}") + print(f"Saved TPV/NPV metrics to {output_file}") class Main: """ - Command-line interface wrapper for the ClassesPropertiesGenerator. + CLI wrapper for ClassesPropertiesGenerator. """ def generate( self, - model_path: str, - splits_path: str, - data_path: str, - classes_file_path: str, + model_ckpt_path: str, + data_cls_path: str, model_class_path: str, - collator_class_path: str = "chebai.preprocessing.collate.RaggedCollator", - batch_size: int = 32, - output_path: str = None, # Default path will be the directory of classes_file_path + output_path: str | None = None, + data_cls_kwargs: dict | None = None, + model_load_kwargs: dict | None = None, ) -> None: """ - Entry point for CLI use. + CLI command to generate TPV/NPV JSON. Args: - model_path (str): Path to the PyTorch Lightning model checkpoint. - model_class_path (str): Full path to the model class to load. - splits_path (str): CSV file with 'id' and 'split' columns. - data_path (str): processed `data.pt` file path. - classes_file_path (str): Path to file containing class names `classes.txt`. - collator_class_path (str): Full path to the collator class. - output_path (str): Output path for the saving JSON file. - batch_size (int): Batch size for inference. + model_ckpt_path: Path to the Lightning model checkpoint. + data_cls_path: Module path to data module class. + model_class_path: Module path to model class. + output_path: Output JSON file path (optional). + data_cls_kwargs: Kwargs for data module instantiation (optional). + model_load_kwargs: Kwargs for model loading (optional). """ generator = ClassesPropertiesGenerator() generator.generate_props( - model_path=model_path, - model_class_path=model_class_path, - splits_path=splits_path, - data_path=data_path, - classes_file_path=classes_file_path, - collator_class_path=collator_class_path, - output_path=output_path, - batch_size=batch_size, + model_ckpt_path, + data_cls_path, + model_class_path, + output_path, + data_cls_kwargs, + model_load_kwargs, ) if __name__ == "__main__": # _generate_classes_props_json.py generate \ - # --model_path "model/ckpt/path" \ - # --splits_path "splits/file/path" \ - # --data_path "data.pt/file/path" \ - # --classes_file_path "classes/file/path" \ + # --model_ckpt_path "model/ckpt/path" \ + # --data_cls_path "data.class.path" \ # --model_class_path "model.class.path" \ - # --collator_class_path "collator.class.path" \ - # --batch_size 32 \ # Optional, default is 32 - # --output_path "output/file/path" # Optional, default will be the directory of classes_file_path + # --output_path "output/file/path" # Optional + # --data_cls_kwargs "{kwargs1: 1, kwargs2: 2}" # Optional + # --model_load_kwargs "{kwargs1: 1, kwargs2: 2}" # Optional CLI(Main, as_positional=False) From e5ec383daffaf52694e5e116a6903847b079117c Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 12 Jun 2025 17:07:18 +0200 Subject: [PATCH 68/78] pass config file for model, data instead of explicit params --- chebai/ensemble/_base.py | 45 ++++------------ chebai/ensemble/_constants.py | 6 +-- .../_scripts/_generate_classes_props_json.py | 51 ++++++++----------- chebai/ensemble/_utils.py | 25 +++++++++ chebai/ensemble/_wrappers/_base.py | 1 - chebai/ensemble/_wrappers/_neural_network.py | 37 ++++++-------- 6 files changed, 75 insertions(+), 90 deletions(-) diff --git a/chebai/ensemble/_base.py b/chebai/ensemble/_base.py index cc1b7258..0eaae312 100644 --- a/chebai/ensemble/_base.py +++ b/chebai/ensemble/_base.py @@ -8,13 +8,7 @@ from chebai.result.classification import print_metrics -from ._constants import ( - EVAL_OP, - MODEL_CLS_PATH, - MODEL_LBL_PATH, - PRED_OP, - WRAPPER_CLS_PATH, -) +from ._constants import EVAL_OP, PRED_OP, WRAPPER_CLS_PATH class EnsembleBase(ABC): @@ -60,7 +54,7 @@ def __init__( self._dm_labels: Dict[str, int] = self._load_data_module_labels() self._num_of_labels: int = len(self._dm_labels) - print(f"Number of labes for this data is {self._num_of_labels} ") + print(f"Number of labels for this data is {self._num_of_labels} ") self._num_models_per_label: torch.Tensor = torch.zeros( 1, self._num_of_labels, device=self._device @@ -87,20 +81,16 @@ def _perform_validation_checks( f"Invalid operation '{operation}'. Must be 'evaluate' or 'predict'." ) - if operation == "predict" and not kwargs.get("smiles_list_file_path", None): - raise ValueError( - "For 'predict' operation, 'smiles_list_file_path' must be provided." - ) + if operation == "predict": + if kwargs.get("smiles_list_file_path", None): + raise ValueError( + "For 'predict' operation, 'smiles_list_file_path' must be provided." + ) - if not Path(kwargs.get("smiles_list_file_path")).exists(): - raise FileNotFoundError(f"{kwargs.get('smiles_list_file_path')}") + if not Path(kwargs.get("smiles_list_file_path")).exists(): + raise FileNotFoundError(f"{kwargs.get('smiles_list_file_path')}") - class_set, labels_set = set(), set() - required_keys = { - MODEL_CLS_PATH, - MODEL_LBL_PATH, - WRAPPER_CLS_PATH, - } + required_keys = {WRAPPER_CLS_PATH} for model_name, config in model_configs.items(): missing_keys = required_keys - config.keys() @@ -109,21 +99,6 @@ def _perform_validation_checks( f"Missing keys {missing_keys} in model '{model_name}' configuration." ) - model_class_path, model_labels_path = ( - config[MODEL_CLS_PATH], - config[MODEL_LBL_PATH], - ) - - if model_class_path in class_set: - raise ValueError( - f"Duplicate class path detected: '{model_class_path}'." - ) - if model_labels_path in labels_set: - raise ValueError(f"Duplicate labels path: {model_labels_path}.") - - class_set.add(model_class_path) - labels_set.add(model_labels_path) - def _process_input_to_ensemble(self, **kwargs: Any) -> list[str] | Path: if self._operation_mode == PRED_OP: p = Path(kwargs["smiles_list_file_path"]) diff --git a/chebai/ensemble/_constants.py b/chebai/ensemble/_constants.py index 04426baf..3c0459de 100644 --- a/chebai/ensemble/_constants.py +++ b/chebai/ensemble/_constants.py @@ -1,12 +1,10 @@ -MODEL_CLS_PATH = "model_class_path" -MODEL_LD_KWARGS = "model_load_kwargs" MODEL_LBL_PATH = "model_labels_path" MODEL_CKPT_PATH = "model_ckpt_path" WRAPPER_CLS_PATH = "wrapper_class_path" -DATA_CLS_PATH = "data_class_path" -DATA_CLS_KWARGS = "data_class_kwargs" +DATA_CONFIG_PATH = "data_config_file_path" +MODEL_CONFIG_PATH = "model_config_file_path" PRED_OP = "prediction" EVAL_OP = "evaluation" diff --git a/chebai/ensemble/_scripts/_generate_classes_props_json.py b/chebai/ensemble/_scripts/_generate_classes_props_json.py index 519628e9..91d50725 100644 --- a/chebai/ensemble/_scripts/_generate_classes_props_json.py +++ b/chebai/ensemble/_scripts/_generate_classes_props_json.py @@ -5,9 +5,10 @@ from jsonargparse import CLI from sklearn.metrics import multilabel_confusion_matrix -from chebai.ensemble._utils import load_data_instance, load_model_for_inference from chebai.preprocessing.datasets.base import XYBaseDataModule +from .._utils import load_data_instance, load_model_for_inference, parse_config_file + class ClassesPropertiesGenerator: """ @@ -69,31 +70,30 @@ def compute_tpv_npv( def generate_props( self, model_ckpt_path: str, - data_cls_path: str, - model_class_path: str, + model_config_file_path: str, + data_config_file_path: str, output_path: str | None = None, - data_cls_kwargs: dict | None = None, - model_load_kwargs: dict | None = None, ) -> None: """ Run inference on validation set, compute TPV/NPV per class, and save to JSON. Args: model_ckpt_path: Path to the PyTorch Lightning checkpoint file. - data_cls_path: Import path or module path to the data module class. - model_class_path: Import path or module path to the model class. + model_config_file_path: Path to yaml config file of the model. + data_config_file_path: Path to yaml config file of the data. output_path: Optional path where to write the JSON metrics file. Defaults to '/classes.json'. - data_cls_kwargs: Optional dict of kwargs to initialize the data module. - model_load_kwargs: Optional dict of kwargs when loading the model. """ print("Extracting validation data for computation...") + data_cls_path, data_cls_kwargs = parse_config_file(data_config_file_path) data_module: XYBaseDataModule = load_data_instance( - data_cls_path, data_cls_kwargs or {} + data_cls_path, data_cls_kwargs ) + + model_class_path, model_kwargs = parse_config_file(model_config_file_path) model = load_model_for_inference( - model_ckpt_path, model_class_path, model_load_kwargs or {} + model_ckpt_path, model_class_path, model_kwargs ) val_loader = data_module.val_dataloader() @@ -134,40 +134,33 @@ class Main: def generate( self, model_ckpt_path: str, - data_cls_path: str, - model_class_path: str, + model_config_file_path: str, + data_config_file_path: str, output_path: str | None = None, - data_cls_kwargs: dict | None = None, - model_load_kwargs: dict | None = None, ) -> None: """ CLI command to generate TPV/NPV JSON. Args: - model_ckpt_path: Path to the Lightning model checkpoint. - data_cls_path: Module path to data module class. - model_class_path: Module path to model class. - output_path: Output JSON file path (optional). - data_cls_kwargs: Kwargs for data module instantiation (optional). - model_load_kwargs: Kwargs for model loading (optional). + model_ckpt_path: Path to the PyTorch Lightning checkpoint file. + model_config_file_path: Path to yaml config file of the model. + data_config_file_path: Path to yaml config file of the data. + output_path: Optional path where to write the JSON metrics file. + Defaults to '/classes.json'. """ generator = ClassesPropertiesGenerator() generator.generate_props( model_ckpt_path, - data_cls_path, - model_class_path, + model_config_file_path, + data_config_file_path, output_path, - data_cls_kwargs, - model_load_kwargs, ) if __name__ == "__main__": # _generate_classes_props_json.py generate \ # --model_ckpt_path "model/ckpt/path" \ - # --data_cls_path "data.class.path" \ - # --model_class_path "model.class.path" \ + # --model_config_file_path "model/config/file/path" \ + # --data_config_file_path "data/config/file/path" \ # --output_path "output/file/path" # Optional - # --data_cls_kwargs "{kwargs1: 1, kwargs2: 2}" # Optional - # --model_load_kwargs "{kwargs1: 1, kwargs2: 2}" # Optional CLI(Main, as_positional=False) diff --git a/chebai/ensemble/_utils.py b/chebai/ensemble/_utils.py index 0e5238a0..a54d2fd8 100644 --- a/chebai/ensemble/_utils.py +++ b/chebai/ensemble/_utils.py @@ -1,6 +1,8 @@ import importlib from pathlib import Path +import yaml + from chebai.models.base import ChebaiBaseNet from chebai.preprocessing.datasets.base import XYBaseDataModule @@ -61,3 +63,26 @@ def load_model_for_inference( model.eval() model.freeze() return model + + +def parse_config_file(config_path: str) -> tuple[str, dict]: + path = Path(config_path) + + # Check file existence + if not path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + + # Check file extension + if path.suffix.lower() not in [".yml", ".yaml"]: + raise ValueError( + f"Unsupported config file type: {path.suffix}. Expected .yaml or .yml" + ) + + # Load YAML content + with open(path, "r") as f: + config: dict = yaml.safe_load(f) + + class_path: str = config["class_path"] + init_args: dict = config.get("init_args", {}) + assert isinstance(init_args, dict), "init_args must be a dictionary" + return class_path, init_args diff --git a/chebai/ensemble/_wrappers/_base.py b/chebai/ensemble/_wrappers/_base.py index 56bc54ed..31b7883b 100644 --- a/chebai/ensemble/_wrappers/_base.py +++ b/chebai/ensemble/_wrappers/_base.py @@ -1,7 +1,6 @@ import json import os from abc import ABC, abstractmethod -from pathlib import Path import torch diff --git a/chebai/ensemble/_wrappers/_neural_network.py b/chebai/ensemble/_wrappers/_neural_network.py index 7b3dcd6c..cf231ed4 100644 --- a/chebai/ensemble/_wrappers/_neural_network.py +++ b/chebai/ensemble/_wrappers/_neural_network.py @@ -6,14 +6,8 @@ from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.preprocessing.structures import XYData -from .._constants import ( - DATA_CLS_KWARGS, - DATA_CLS_PATH, - MODEL_CKPT_PATH, - MODEL_CLS_PATH, - MODEL_LD_KWARGS, -) -from .._utils import load_class, load_data_instance, load_model_for_inference +from .._constants import DATA_CONFIG_PATH, MODEL_CKPT_PATH, MODEL_CONFIG_PATH +from .._utils import load_data_instance, load_model_for_inference, parse_config_file from ._base import BaseWrapper @@ -25,24 +19,25 @@ def __init__(self, **kwargs): ) super().__init__(**kwargs) - self._model_class_path = self._model_config[MODEL_CLS_PATH] self._model_ckpt_path = self._model_config[MODEL_CKPT_PATH] - self._model_ld_kwargs: dict = self._model_config.get(MODEL_LD_KWARGS, {}) - - self._pre_load_hook() + self._model_config_path = self._model_config[MODEL_CONFIG_PATH] + self._data_config_path = self._model_config[DATA_CONFIG_PATH] + data_cls_path, data_kwargs = parse_config_file(self._data_config_path) self._data_cls_instance: XYBaseDataModule = load_data_instance( - self._model_config[DATA_CLS_PATH], - self._model_config.get(DATA_CLS_KWARGS, {}), + data_cls_path, data_kwargs ) - self.collated_labels = None + + model_cls_path, model_kwargs = parse_config_file(self._model_config_path) self._model: ChebaiBaseNet = load_model_for_inference( - self._model_config[MODEL_CKPT_PATH], - self._model_config[MODEL_CLS_PATH], - self._model_config.get(MODEL_LD_KWARGS, {}), - **kwargs, + self._model_ckpt_path, + model_cls_path, + model_kwargs, + model_name=kwargs["model_name"], ) + self.collated_labels = None + @classmethod def _validate_model_configs( cls, @@ -59,7 +54,7 @@ def _validate_model_configs( AttributeError: If any model config is missing required keys. ValueError: If duplicate paths are found for model checkpoint, class, or labels. """ - required_keys = {MODEL_CKPT_PATH, DATA_CLS_PATH, MODEL_CLS_PATH} + required_keys = {MODEL_CKPT_PATH, MODEL_CONFIG_PATH, DATA_CONFIG_PATH} missing_keys = required_keys - model_config.keys() if missing_keys: @@ -112,7 +107,7 @@ def _forward_pass(self, batch: list[dict]) -> dict: ) return self._model(processable_data, **processable_data["model_kwargs"]) - def _evaluate_from_data_file(self) -> list: + def _evaluate_from_data_file(self, **kwargs) -> list: filename = self._data_cls_instance.processed_file_names_dict["data"] data_list_of_dict = self._data_cls_instance.load_processed_data_from_file( Path(self._data_cls_instance.processed_dir) / filename From 9a3328f17f093586eefb6ca75f41b46b15313f37 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 12 Jun 2025 17:20:21 +0200 Subject: [PATCH 69/78] use utility for scripts --- chebai/ensemble/_scripts/_ensemble_run_script.py | 10 ++-------- .../ensemble/_scripts/_generate_classes_props_json.py | 7 +++++-- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/chebai/ensemble/_scripts/_ensemble_run_script.py b/chebai/ensemble/_scripts/_ensemble_run_script.py index 3dc6bd45..26d003c4 100644 --- a/chebai/ensemble/_scripts/_ensemble_run_script.py +++ b/chebai/ensemble/_scripts/_ensemble_run_script.py @@ -1,10 +1,7 @@ -from typing import Any, Dict - -import yaml from jsonargparse import ArgumentParser from chebai.ensemble._base import EnsembleBase -from chebai.ensemble._utils import load_class +from chebai.ensemble._utils import load_class, parse_config_file def load_config_and_instantiate(config_path: str) -> EnsembleBase: @@ -20,11 +17,8 @@ def load_config_and_instantiate(config_path: str) -> EnsembleBase: Raises: TypeError: If the loaded class is not a subclass of EnsembleBase. """ - with open(config_path, "r") as f: - config: Dict[str, Any] = yaml.safe_load(f) - class_path: str = config["class_path"] - init_args: Dict[str, Any] = config.get("init_args", {}) + class_path, init_args = parse_config_file(config_path) cls = load_class(class_path) diff --git a/chebai/ensemble/_scripts/_generate_classes_props_json.py b/chebai/ensemble/_scripts/_generate_classes_props_json.py index 91d50725..a1fa30d2 100644 --- a/chebai/ensemble/_scripts/_generate_classes_props_json.py +++ b/chebai/ensemble/_scripts/_generate_classes_props_json.py @@ -5,10 +5,13 @@ from jsonargparse import CLI from sklearn.metrics import multilabel_confusion_matrix +from chebai.ensemble._utils import ( + load_data_instance, + load_model_for_inference, + parse_config_file, +) from chebai.preprocessing.datasets.base import XYBaseDataModule -from .._utils import load_data_instance, load_model_for_inference, parse_config_file - class ClassesPropertiesGenerator: """ From f40eff901ceb0da1bc6b71cbf8af942534dac872 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 12 Jun 2025 19:29:57 +0200 Subject: [PATCH 70/78] dm should have splits_file_path or splits.csv in its dir --- .../_scripts/_generate_classes_props_json.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/chebai/ensemble/_scripts/_generate_classes_props_json.py b/chebai/ensemble/_scripts/_generate_classes_props_json.py index a1fa30d2..b16fc708 100644 --- a/chebai/ensemble/_scripts/_generate_classes_props_json.py +++ b/chebai/ensemble/_scripts/_generate_classes_props_json.py @@ -94,6 +94,21 @@ def generate_props( data_cls_path, data_cls_kwargs ) + splits_file_path = Path(data_module.processed_dir_main, "splits.csv") + if data_module.splits_file_path is None: + if not splits_file_path.exists(): + raise RuntimeError( + "Either the data module should be initialized with a `splits_file_path`, " + f"or the file `{splits_file_path}` must exists.\n" + "This is to prevent the data module from dynamically generating the splits." + ) + + print( + f"`splits_file_path` is not provided as an initialization parameter to the data module\n" + f"Using splits from the file {splits_file_path}" + ) + data_module.splits_file_path = splits_file_path + model_class_path, model_kwargs = parse_config_file(model_config_file_path) model = load_model_for_inference( model_ckpt_path, model_class_path, model_kwargs From e89ec4f1e9f8a83093b32450887c2b368178abc4 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 15 Jun 2025 18:06:19 +0200 Subject: [PATCH 71/78] fix gnn logits error --- chebai/ensemble/__init__.py | 4 ++-- chebai/ensemble/_controller.py | 3 +++ chebai/ensemble/_utils.py | 7 +------ chebai/ensemble/_wrappers/__init__.py | 3 ++- chebai/ensemble/_wrappers/_gnn.py | 30 ++++++--------------------- 5 files changed, 14 insertions(+), 33 deletions(-) diff --git a/chebai/ensemble/__init__.py b/chebai/ensemble/__init__.py index e6f227de..2d75f066 100644 --- a/chebai/ensemble/__init__.py +++ b/chebai/ensemble/__init__.py @@ -1,6 +1,6 @@ from ._consolidator import WeightedMajorityVoting from ._controller import NoActivationCondition -from ._wrappers import NNWrapper +from ._wrappers import GNNWrapper, NNWrapper class FullEnsembleWMV(NoActivationCondition, WeightedMajorityVoting): @@ -9,4 +9,4 @@ class FullEnsembleWMV(NoActivationCondition, WeightedMajorityVoting): pass -__all__ = ["FullEnsembleWMV", "NNWrapper"] +__all__ = ["FullEnsembleWMV", "NNWrapper", "GNNWrapper"] diff --git a/chebai/ensemble/_controller.py b/chebai/ensemble/_controller.py index 9509f1b5..35629c99 100644 --- a/chebai/ensemble/_controller.py +++ b/chebai/ensemble/_controller.py @@ -62,6 +62,9 @@ def _controller( ): self._collated_labels = wrapped_model.collated_labels + assert ( + isinstance(model_output, dict) and "logits" in model_output + ), "Forward pass should return dict containing logits for consistency across all types of models" del wrapped_model # Model can be huge to keep it in memory, delete asap as no longer needed pred_conf_dict = self._get_pred_conf_from_model_output( diff --git a/chebai/ensemble/_utils.py b/chebai/ensemble/_utils.py index a54d2fd8..6f38891f 100644 --- a/chebai/ensemble/_utils.py +++ b/chebai/ensemble/_utils.py @@ -29,9 +29,6 @@ def load_model_for_inference( """ Loads a model checkpoint and its label-related properties. - Args: - input_dim (int): Name of the model to load. - Returns: Tuple[LightningModule, Dict[str, torch.Tensor]]: The model and its label properties. """ @@ -51,9 +48,7 @@ def load_model_for_inference( lightning_cls, ChebaiBaseNet ), f"{lightning_cls} must inherit from ChebaiBaseNet" try: - model = lightning_cls.load_from_checkpoint( - model_ckpt_path, input_dim=5, **model_load_kwargs - ) + model = lightning_cls.load_from_checkpoint(model_ckpt_path, **model_load_kwargs) except Exception as e: raise RuntimeError(f"Error loading model {model_name} \n Error: {e}") from e diff --git a/chebai/ensemble/_wrappers/__init__.py b/chebai/ensemble/_wrappers/__init__.py index 4c4bac6d..5478d32b 100644 --- a/chebai/ensemble/_wrappers/__init__.py +++ b/chebai/ensemble/_wrappers/__init__.py @@ -1,4 +1,5 @@ from ._base import BaseWrapper +from ._gnn import GNNWrapper from ._neural_network import NNWrapper -__all__ = ["NNWrapper", "BaseWrapper"] +__all__ = ["NNWrapper", "BaseWrapper", "GNNWrapper"] diff --git a/chebai/ensemble/_wrappers/_gnn.py b/chebai/ensemble/_wrappers/_gnn.py index c01a2ffd..9b81f77a 100644 --- a/chebai/ensemble/_wrappers/_gnn.py +++ b/chebai/ensemble/_wrappers/_gnn.py @@ -1,35 +1,12 @@ -from pathlib import Path - import chebai_graph.preprocessing.properties as p import torch from chebai_graph.preprocessing.property_encoder import IndexEncoder, OneHotEncoder from torch_geometric.data.data import Data as GeomData -from .._constants import DATA_CLS_KWARGS from ._neural_network import NNWrapper -class GNNResGated(NNWrapper): - - def _pre_load_hook(self): - self._model_config[DATA_CLS_KWARGS] = self._model_config.get( - DATA_CLS_KWARGS, - { - "properties": [ - p.AtomType(), - p.NumAtomBonds(), - p.AtomCharge(), - p.AtomAromaticity(), - p.AtomHybridization(), - p.AtomNumHs(), - p.BondType(), - p.BondInRing(), - p.BondAromaticity(), - p.RDKit2DNormalized(), - ] - }, - ) - +class GNNWrapper(NNWrapper): def _read_smiles(self, smiles): d = self._data_cls_instance.reader.to_data(dict(features=smiles, labels=None)) geom_data = d["features"] @@ -94,3 +71,8 @@ def _read_smiles(self, smiles): molecule_attr=molecule_attr, ) return d + + def _evaluate_from_data_file(self, **kwargs) -> list: + model_logits = super()._evaluate_from_data_file(**kwargs) + # Currently gnn in forward method, logits are returned instead of dict containing logits + return {"logits": model_logits} From 8d406370b702d48345a13ee9285bfdfa0843e2b5 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 15 Jun 2025 22:46:12 +0200 Subject: [PATCH 72/78] fix gnn predict_from smiles list logits error --- chebai/ensemble/_wrappers/_gnn.py | 13 +++++-- chebai/ensemble/_wrappers/_neural_network.py | 36 +++++++++----------- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/chebai/ensemble/_wrappers/_gnn.py b/chebai/ensemble/_wrappers/_gnn.py index 9b81f77a..b0b35f65 100644 --- a/chebai/ensemble/_wrappers/_gnn.py +++ b/chebai/ensemble/_wrappers/_gnn.py @@ -25,8 +25,7 @@ def _read_smiles(self, smiles): if isinstance(property.encoder, IndexEncoder): if str(value) in property.encoder.cache: index = ( - property.encoder.cache.index(str(value)) - + property.encoder.offset + property.encoder.cache[str(value)] + property.encoder.offset ) else: index = 0 @@ -60,7 +59,11 @@ def _read_smiles(self, smiles): if isinstance(property, p.AtomProperty): x = torch.cat([x, encoded_values], dim=1) elif isinstance(property, p.BondProperty): - edge_attr = torch.cat([edge_attr, encoded_values], dim=1) + # Concat/Duplicate properties values for undirected graph as `edge_index` has first src to tgt edges, then tgt to src edges + edge_attr = torch.cat( + [edge_attr, torch.cat([encoded_values, encoded_values], dim=0)], + dim=1, + ) else: molecule_attr = torch.cat([molecule_attr, encoded_values[0]], dim=1) @@ -76,3 +79,7 @@ def _evaluate_from_data_file(self, **kwargs) -> list: model_logits = super()._evaluate_from_data_file(**kwargs) # Currently gnn in forward method, logits are returned instead of dict containing logits return {"logits": model_logits} + + def _predict_from_list_of_smiles(self, smiles_list: list[str]) -> dict: + model_logits = super()._predict_from_list_of_smiles(smiles_list) + return {"logits": model_logits} diff --git a/chebai/ensemble/_wrappers/_neural_network.py b/chebai/ensemble/_wrappers/_neural_network.py index cf231ed4..3ab5702d 100644 --- a/chebai/ensemble/_wrappers/_neural_network.py +++ b/chebai/ensemble/_wrappers/_neural_network.py @@ -62,35 +62,33 @@ def _validate_model_configs( f"Missing keys {missing_keys} in model '{model_name}' configuration." ) - def _pre_load_hook(self) -> None: ... - - def _predict_from_list_of_smiles(self, smiles_list: list[str]) -> list: + def _predict_from_list_of_smiles(self, smiles_list: list[str]) -> dict: token_dicts = [] could_not_parse = [] + for i, smiles in enumerate(smiles_list): try: - # Try to parse the smiles string if not smiles: - raise ValueError() + raise ValueError("Empty SMILES string") d = self._read_smiles(smiles) - # This is just for sanity checks rdmol = Chem.MolFromSmiles(smiles, sanitize=False) + if rdmol is None: + raise ValueError("RDKit failed to parse") except Exception as e: - # Note if it fails could_not_parse.append(i) - print(f"Failing to parse {smiles} due to {e}") + print(f"Failing to parse '{smiles}' (index {i}): {e}") else: - if rdmol is None: - could_not_parse.append(i) - else: - token_dicts.append(d) - if token_dicts: - model_output = self._forward_pass(token_dicts) - if not isinstance(model_output, dict) and not "logits" in model_output: - raise ValueError() - return model_output - else: - raise ValueError() + token_dicts.append(d) + + if not token_dicts: + raise ValueError("No valid SMILES could be parsed.") + + model_output = self._forward_pass(token_dicts) + # ----- This check is handled in controller + # if not isinstance(model_output, dict) or "logits" not in model_output: + # raise ValueError("Model output is malformed; expected dict with 'logits' key.") + + return model_output def _read_smiles(self, smiles: str) -> dict: return self._data_cls_instance.reader.to_data( From 3976531a058a9d76aea72dcf3170bbb667df8217 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 15 Jun 2025 23:49:20 +0200 Subject: [PATCH 73/78] chemlog wrapper return logits --- chebai/ensemble/__init__.py | 4 ++-- chebai/ensemble/_base.py | 2 +- chebai/ensemble/_controller.py | 4 +++- chebai/ensemble/_wrappers/__init__.py | 3 ++- chebai/ensemble/_wrappers/_chemlog.py | 13 ++++++------- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/chebai/ensemble/__init__.py b/chebai/ensemble/__init__.py index 2d75f066..b64a80a8 100644 --- a/chebai/ensemble/__init__.py +++ b/chebai/ensemble/__init__.py @@ -1,6 +1,6 @@ from ._consolidator import WeightedMajorityVoting from ._controller import NoActivationCondition -from ._wrappers import GNNWrapper, NNWrapper +from ._wrappers import ChemLogWrapper, GNNWrapper, NNWrapper class FullEnsembleWMV(NoActivationCondition, WeightedMajorityVoting): @@ -9,4 +9,4 @@ class FullEnsembleWMV(NoActivationCondition, WeightedMajorityVoting): pass -__all__ = ["FullEnsembleWMV", "NNWrapper", "GNNWrapper"] +__all__ = ["FullEnsembleWMV", "NNWrapper", "GNNWrapper", "ChemLogWrapper"] diff --git a/chebai/ensemble/_base.py b/chebai/ensemble/_base.py index 0eaae312..3c316a79 100644 --- a/chebai/ensemble/_base.py +++ b/chebai/ensemble/_base.py @@ -120,7 +120,7 @@ def _process_input_to_ensemble(self, **kwargs: Any) -> list[str] | Path: f"data.pkl does not exist in the {processed_dir_path} directory" ) self._total_data_size = len(pd.read_pickle(data_pkl_path)) - return processed_dir_path + return data_pkl_path else: raise ValueError("Invalid operation") diff --git a/chebai/ensemble/_controller.py b/chebai/ensemble/_controller.py index 35629c99..2d4dad64 100644 --- a/chebai/ensemble/_controller.py +++ b/chebai/ensemble/_controller.py @@ -55,7 +55,9 @@ def _controller( if self._operation_mode == PRED_OP: model_output, model_props = wrapped_model.predict(model_input) else: - model_output, model_props = wrapped_model.evaluate() + model_output, model_props = wrapped_model.evaluate( + data_pkl_file_path=model_input + ) if ( self._collated_labels is None and wrapped_model.collated_labels is not None diff --git a/chebai/ensemble/_wrappers/__init__.py b/chebai/ensemble/_wrappers/__init__.py index 5478d32b..e68fec43 100644 --- a/chebai/ensemble/_wrappers/__init__.py +++ b/chebai/ensemble/_wrappers/__init__.py @@ -1,5 +1,6 @@ from ._base import BaseWrapper +from ._chemlog import ChemLogWrapper from ._gnn import GNNWrapper from ._neural_network import NNWrapper -__all__ = ["NNWrapper", "BaseWrapper", "GNNWrapper"] +__all__ = ["NNWrapper", "BaseWrapper", "GNNWrapper", "ChemLogWrapper"] diff --git a/chebai/ensemble/_wrappers/_chemlog.py b/chebai/ensemble/_wrappers/_chemlog.py index c4e43da0..ed25e61d 100644 --- a/chebai/ensemble/_wrappers/_chemlog.py +++ b/chebai/ensemble/_wrappers/_chemlog.py @@ -16,17 +16,16 @@ from chebai.ensemble._wrappers._base import BaseWrapper -class ChemLog(BaseWrapper): +class ChemLogWrapper(BaseWrapper): def _predict_from_list_of_smiles(self, smiles_list): - return self.get_chemlog_results(smiles_list) + return {"logits": self.get_chemlog_results(smiles_list)} - def _evaluate_from_data_file( - self, data_processed_dir_main: Path, data_file_name="data.pkl" - ) -> list: - data_df = pd.read_pickle(data_processed_dir_main / data_file_name) + def _evaluate_from_data_file(self, **kwargs) -> list: + data_pkl_file_path = kwargs["data_pkl_file_path"] + data_df = pd.read_pickle(data_pkl_file_path) smiles_list = data_df["SMILES"].to_list() - return self.get_chemlog_results(smiles_list) + return {"logits": self.get_chemlog_results(smiles_list)} def get_chemlog_results(self, smiles_list) -> list: all_preds = [] From e4e9a28ef06e25ffe6d9c60fc8826a32bed927d0 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 24 Jun 2025 10:36:32 +0200 Subject: [PATCH 74/78] save tp, fp, fn and tn as model properties --- .../_scripts/_generate_classes_props_json.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/chebai/ensemble/_scripts/_generate_classes_props_json.py b/chebai/ensemble/_scripts/_generate_classes_props_json.py index b16fc708..9aa7c6c3 100644 --- a/chebai/ensemble/_scripts/_generate_classes_props_json.py +++ b/chebai/ensemble/_scripts/_generate_classes_props_json.py @@ -67,7 +67,14 @@ def compute_tpv_npv( tn, fp, fn, tp = cm[idx].ravel() tpv = tp / (tp + fp) if (tp + fp) > 0 else 0.0 npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0 - results[cls_name] = {"PPV": round(tpv, 4), "NPV": round(npv, 4)} + results[cls_name] = { + "PPV": round(tpv, 4), + "NPV": round(npv, 4), + "TN": int(tn), + "FP": int(fp), + "FN": int(fn), + "TP": int(tp), + } return results def generate_props( @@ -78,7 +85,7 @@ def generate_props( output_path: str | None = None, ) -> None: """ - Run inference on validation set, compute TPV/NPV per class, and save to JSON. + Run inference on validation set, compute TPV/NPV per class, and save to JSON. Args: model_ckpt_path: Path to the PyTorch Lightning checkpoint file. @@ -124,7 +131,7 @@ def generate_props( ) labels = data["labels"] outputs = model(data, **data.get("model_kwargs", {})) - logits = outputs["logits"] + logits = outputs["logits"] if isinstance(outputs, dict) else outputs preds = torch.sigmoid(logits) > 0.5 y_pred.extend(preds) y_true.extend(labels) From 41dd1c6c832e83d12d2792b8c667a4ecd81d821a Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 24 Jun 2025 14:46:24 +0200 Subject: [PATCH 75/78] move ensemble to chebifier repo, move property calculation and utils to results --- chebai/ensemble/__init__.py | 12 - chebai/ensemble/_base.py | 256 ------------------ chebai/ensemble/_consolidator.py | 111 -------- chebai/ensemble/_constants.py | 10 - chebai/ensemble/_controller.py | 132 --------- .../ensemble/_scripts/_ensemble_run_script.py | 50 ---- chebai/ensemble/_utils.py | 83 ------ chebai/ensemble/_wrappers/__init__.py | 6 - chebai/ensemble/_wrappers/_base.py | 151 ----------- chebai/ensemble/_wrappers/_chemlog.py | 91 ------- chebai/ensemble/_wrappers/_gnn.py | 85 ------ chebai/ensemble/_wrappers/_neural_network.py | 113 -------- .../_generate_classes_props_json.py | 4 +- chebai/result/utils.py | 76 ++++++ configs/ensemble/fullEnsembleWithWMV.yaml | 18 -- 15 files changed, 78 insertions(+), 1120 deletions(-) delete mode 100644 chebai/ensemble/__init__.py delete mode 100644 chebai/ensemble/_base.py delete mode 100644 chebai/ensemble/_consolidator.py delete mode 100644 chebai/ensemble/_constants.py delete mode 100644 chebai/ensemble/_controller.py delete mode 100644 chebai/ensemble/_scripts/_ensemble_run_script.py delete mode 100644 chebai/ensemble/_utils.py delete mode 100644 chebai/ensemble/_wrappers/__init__.py delete mode 100644 chebai/ensemble/_wrappers/_base.py delete mode 100644 chebai/ensemble/_wrappers/_chemlog.py delete mode 100644 chebai/ensemble/_wrappers/_gnn.py delete mode 100644 chebai/ensemble/_wrappers/_neural_network.py rename chebai/{ensemble/_scripts => result}/_generate_classes_props_json.py (99%) delete mode 100644 configs/ensemble/fullEnsembleWithWMV.yaml diff --git a/chebai/ensemble/__init__.py b/chebai/ensemble/__init__.py deleted file mode 100644 index b64a80a8..00000000 --- a/chebai/ensemble/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from ._consolidator import WeightedMajorityVoting -from ._controller import NoActivationCondition -from ._wrappers import ChemLogWrapper, GNNWrapper, NNWrapper - - -class FullEnsembleWMV(NoActivationCondition, WeightedMajorityVoting): - """Full Ensemble (no activation condition) with Weighted Majority Voting""" - - pass - - -__all__ = ["FullEnsembleWMV", "NNWrapper", "GNNWrapper", "ChemLogWrapper"] diff --git a/chebai/ensemble/_base.py b/chebai/ensemble/_base.py deleted file mode 100644 index 3c316a79..00000000 --- a/chebai/ensemble/_base.py +++ /dev/null @@ -1,256 +0,0 @@ -from abc import ABC, abstractmethod -from collections import deque -from pathlib import Path -from typing import Any, Deque, Dict - -import pandas as pd -import torch - -from chebai.result.classification import print_metrics - -from ._constants import EVAL_OP, PRED_OP, WRAPPER_CLS_PATH - - -class EnsembleBase(ABC): - """ - Base class for ensemble models in the Chebai framework. - - Handles loading, validating, and coordinating multiple models for ensemble prediction. - """ - - def __init__( - self, - model_configs: Dict[str, Dict[str, Any]], - data_processed_dir_main: str, - operation_mode: str = EVAL_OP, - **kwargs: Any, - ) -> None: - """ - Initializes the ensemble model and loads configurations, labels, and sets up the environment. - - Args: - model_configs (Dict[str, Dict[str, Any]]): Dictionary of model configurations. - data_processed_dir_main (str): Path to the processed data directory. - **kwargs (Any): Additional arguments, such as 'input_dim' and '_validate_configs'. - """ - if bool(kwargs.get("_perform_validation_checks", True)): - self._perform_validation_checks( - model_configs, operation=operation_mode, **kwargs - ) - - self._model_configs: Dict[str, Dict[str, Any]] = model_configs - self._data_processed_dir_main: str = data_processed_dir_main - self._operation_mode: str = operation_mode - print(f"Ensemble operation: {self._operation_mode}") - - # These instance variable will be set in method `_process_input_to_ensemble` - self._total_data_size: int | None = None - self._ensemble_input: list[str] | Path = self._process_input_to_ensemble( - **kwargs - ) - print(f"Total data size (data.pkl) is {self._total_data_size}") - - self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - self._dm_labels: Dict[str, int] = self._load_data_module_labels() - self._num_of_labels: int = len(self._dm_labels) - print(f"Number of labels for this data is {self._num_of_labels} ") - - self._num_models_per_label: torch.Tensor = torch.zeros( - 1, self._num_of_labels, device=self._device - ) - self._model_queue: Deque[str] = deque() - self._collated_labels: torch.Tensor | None = None - - @classmethod - def _perform_validation_checks( - cls, model_configs: Dict[str, Dict[str, Any]], operation, **kwargs - ) -> None: - """ - Validates model configuration dictionary for required keys and uniqueness. - - Args: - model_configs (Dict[str, Dict[str, Any]]): Model configuration dictionary. - - Raises: - AttributeError: If any model config is missing required keys. - ValueError: If duplicate paths are found for model checkpoint, class, or labels. - """ - if operation not in ["evaluate", "predict"]: - raise ValueError( - f"Invalid operation '{operation}'. Must be 'evaluate' or 'predict'." - ) - - if operation == "predict": - if kwargs.get("smiles_list_file_path", None): - raise ValueError( - "For 'predict' operation, 'smiles_list_file_path' must be provided." - ) - - if not Path(kwargs.get("smiles_list_file_path")).exists(): - raise FileNotFoundError(f"{kwargs.get('smiles_list_file_path')}") - - required_keys = {WRAPPER_CLS_PATH} - - for model_name, config in model_configs.items(): - missing_keys = required_keys - config.keys() - if missing_keys: - raise AttributeError( - f"Missing keys {missing_keys} in model '{model_name}' configuration." - ) - - def _process_input_to_ensemble(self, **kwargs: Any) -> list[str] | Path: - if self._operation_mode == PRED_OP: - p = Path(kwargs["smiles_list_file_path"]) - smiles_list: list[str] = [] - with open(p, "r") as f: - for line in f: - # Skip empty or whitespace-only lines - if line.strip(): - # Split on whitespace and take the first item as the SMILES - smiles = line.strip().split()[0] - smiles_list.append(smiles) - self._total_data_size = len(smiles_list) - return smiles_list - elif self._operation_mode == EVAL_OP: - processed_dir_path = Path(self._data_processed_dir_main) - data_pkl_path = processed_dir_path / "data.pkl" - if not data_pkl_path.exists(): - raise FileNotFoundError( - f"data.pkl does not exist in the {processed_dir_path} directory" - ) - self._total_data_size = len(pd.read_pickle(data_pkl_path)) - return data_pkl_path - else: - raise ValueError("Invalid operation") - - def _load_data_module_labels(self) -> dict[str, int]: - """ - Loads class labels from the classes.txt file and sets internal label mapping. - - Raises: - FileNotFoundError: If the expected classes.txt file is not found. - """ - classes_file_path = Path(self._data_processed_dir_main) / "classes.txt" - if not classes_file_path.exists(): - raise FileNotFoundError(f"{classes_file_path} does not exist") - print(f"Loading {classes_file_path} ....") - - dm_labels_dict = {} - with open(classes_file_path, "r") as f: - for line in f: - label = line.strip() - if label not in dm_labels_dict: - dm_labels_dict[label] = len(dm_labels_dict) - return dm_labels_dict - - def run_ensemble(self) -> None: - """ - Executes the full ensemble prediction pipeline, aggregating predictions and printing metrics. - """ - assert self._total_data_size is not None and self._num_of_labels is not None - true_scores = torch.zeros( - self._total_data_size, self._num_of_labels, device=self._device - ) - false_scores = torch.zeros( - self._total_data_size, self._num_of_labels, device=self._device - ) - - print( - f"Running {self.__class__.__name__} ensemble for {self._operation_mode} operation..." - ) - while self._model_queue: - model_name = self._model_queue.popleft() - print(f"Processing model: {model_name}") - - print("\t Passing model to controller to generate predictions...") - controller_output = self._controller(model_name, self._ensemble_input) - - print("\t Passing predictions to consolidator for aggregation...") - self._consolidator( - pred_conf_dict=controller_output["pred_conf_dict"], - model_props=controller_output["model_props"], - true_scores=true_scores, - false_scores=false_scores, - ) - - final_preds = self._consolidate_on_finish( - true_scores=true_scores, false_scores=false_scores - ) - - if self._operation_mode == EVAL_OP: - assert ( - self._collated_labels is not None - ), "Collated labels must be set for evaluation operation." - print_metrics( - final_preds, - self._collated_labels, - self._device, - classes=list(self._dm_labels.keys()), - ) - else: - # Get SMILES and label names - smiles_list = self._ensemble_input - label_names = list(self._dm_labels.keys()) - # Efficient conversion from tensor to NumPy - preds_np = final_preds.detach().cpu().numpy() - - assert ( - len(smiles_list) == preds_np.shape[0] - ), "Length of SMILES list does not match number of predictions." - assert ( - len(label_names) == preds_np.shape[1] - ), "Number of label names does not match number of predictions." - - # Build DataFrame - df = pd.DataFrame(preds_np, columns=label_names) - df.insert(0, "SMILES", smiles_list) - - # Save to CSV - output_path = ( - Path(self._data_processed_dir_main) / "ensemble_predictions.csv" - ) - df.to_csv(output_path, index=False) - - print(f"Predictions saved to {output_path}") - - @abstractmethod - def _controller( - self, - model_name: str, - model_input: list[str] | Path, - **kwargs: Any, - ) -> Dict[str, torch.Tensor]: - """ - Abstract method to define model-specific prediction logic. - - Returns: - Dict[str, torch.Tensor]: Predictions or confidence scores. - """ - - @abstractmethod - def _consolidator( - self, - *, - pred_conf_dict: Dict[str, torch.Tensor], - model_props: Dict[str, torch.Tensor], - true_scores: torch.Tensor, - false_scores: torch.Tensor, - **kwargs: Any, - ) -> None: - """ - Abstract method to define aggregation logic. - - Should update the provided `true_scores` and `false_scores`. - """ - - @abstractmethod - def _consolidate_on_finish( - self, *, true_scores: torch.Tensor, false_scores: torch.Tensor - ) -> torch.Tensor: - """ - Abstract method to produce final predictions after all models have been evaluated. - - Returns: - torch.Tensor: Final aggregated predictions. - """ diff --git a/chebai/ensemble/_consolidator.py b/chebai/ensemble/_consolidator.py deleted file mode 100644 index f629ef84..00000000 --- a/chebai/ensemble/_consolidator.py +++ /dev/null @@ -1,111 +0,0 @@ -from abc import ABC -from typing import Any, Dict - -from torch import Tensor - -from ._base import EnsembleBase - - -class WeightedMajorityVoting(EnsembleBase, ABC): - """ - Ensemble consolidator using weighted majority voting. - Each model's contribution is weighted by a function of confidence, - true positive value (TPV), and negative predictive value (NPV). - """ - - def _consolidator( - self, - pred_conf_dict: Dict[str, Tensor], - model_props: Dict[str, Tensor], - *, - true_scores: Tensor, - false_scores: Tensor, - **kwargs: Any - ) -> None: - """ - Updates true/false scores based on model predictions using a weighted voting scheme. - - Args: - pred_conf_dict (Dict[str, Tensor]): Contains model predictions and confidence scores. - model_props (Dict[str, Tensor]): Contains mask, TPV and NPV tensors for model. - true_scores (Tensor): Tensor accumulating weighted "true" contributions. - false_scores (Tensor): Tensor accumulating weighted "false" contributions. - **kwargs (Any): Additional unused keyword arguments. - """ - tpv = model_props["tpv_tensor"] - npv = model_props["fpv_tensor"] - conf = pred_conf_dict["confidence"] - mask = model_props["mask"] - - weight = conf * (tpv * conf + npv * (1 - conf)) - - # Apply mask: Only update scores for valid classes - true_scores += weight * conf * mask - false_scores += weight * (1 - conf) * mask - - def _consolidate_on_finish( - self, *, true_scores: Tensor, false_scores: Tensor - ) -> Tensor: - """ - Finalizes predictions after all models have contributed their scores. - - Args: - true_scores (Tensor): Accumulated weighted true scores per label. - false_scores (Tensor): Accumulated weighted false scores per label. - - Returns: - Tensor: Final binary predictions (True if true_score > false_score). - """ - # Avoid division by zero: Set valid_counts to 1 where it's zero - valid_counts = self._num_models_per_label.clamp(min=1) - # Normalize by valid contributions to prevent bias - final_preds = (true_scores / valid_counts) > (false_scores / valid_counts) - return final_preds - - -class MajorityVoting(EnsembleBase, ABC): - """ - Ensemble consolidator using simple majority voting. - Each model contributes equally; confidence is used directly as "vote weight". - """ - - def _consolidator( - self, - pred_conf_dict: Dict[str, Tensor], - model_props: Dict[str, Tensor], - *, - true_scores: Tensor, - false_scores: Tensor, - **kwargs: Any - ) -> None: - """ - Updates true/false scores based on model predictions using unweighted voting. - - Args: - pred_conf_dict (Dict[str, Tensor]): Contains model predictions and confidence scores. - model_props (Dict[str, Tensor]): Contains mask tensor for model. - true_scores (Tensor): Tensor accumulating true contributions. - false_scores (Tensor): Tensor accumulating false contributions. - **kwargs (Any): Additional unused keyword arguments. - """ - conf = pred_conf_dict["confidence"] - # Apply mask: Only update scores for valid classes - mask = model_props["mask"] - - true_scores += conf * mask - false_scores += (1 - conf) * mask - - def _consolidate_on_finish( - self, *, true_scores: Tensor, false_scores: Tensor - ) -> Tensor: - """ - Finalizes predictions after all models have voted. - - Args: - true_scores (Tensor): Accumulated true votes per label. - false_scores (Tensor): Accumulated false votes per label. - - Returns: - Tensor: Final binary predictions (True if true_score > false_score). - """ - return true_scores > false_scores diff --git a/chebai/ensemble/_constants.py b/chebai/ensemble/_constants.py deleted file mode 100644 index 3c0459de..00000000 --- a/chebai/ensemble/_constants.py +++ /dev/null @@ -1,10 +0,0 @@ -MODEL_LBL_PATH = "model_labels_path" -MODEL_CKPT_PATH = "model_ckpt_path" - -WRAPPER_CLS_PATH = "wrapper_class_path" - -DATA_CONFIG_PATH = "data_config_file_path" -MODEL_CONFIG_PATH = "model_config_file_path" - -PRED_OP = "prediction" -EVAL_OP = "evaluation" diff --git a/chebai/ensemble/_controller.py b/chebai/ensemble/_controller.py deleted file mode 100644 index 2d4dad64..00000000 --- a/chebai/ensemble/_controller.py +++ /dev/null @@ -1,132 +0,0 @@ -from abc import ABC -from collections import deque -from pathlib import Path -from typing import Any, Deque, Dict - -import torch -from torch import Tensor - -from ._base import EnsembleBase -from ._constants import PRED_OP, WRAPPER_CLS_PATH -from ._utils import load_class -from ._wrappers import BaseWrapper - - -class _Controller(EnsembleBase, ABC): - """ - Abstract base controller for ensemble models that handles data loading, collating, - and inference logic over a collection of models. - - Inherits from: - EnsembleBase: The base ensemble class with shared ensemble logic. - ABC: For defining abstract methods. - """ - - def __init__(self, **kwargs: Any): - """ - Initializes the controller with data loader and collator. - - Args: - **kwargs (Any): Keyword arguments passed to the EnsembleBase initializer. - """ - super().__init__(**kwargs) - self._kwargs = kwargs - # If an activation condition corresponding model is added to queue, removed from this set - # This is in order to avoid re-adding models that have already been processed - self._model_key_set: set[str] = set(self._model_configs.keys()) - - # Labels from any processed `data.pt` file of any reader - self._collated_labels: torch.Tensor | None = None - - def _controller( - self, model_name: str, model_input: list[str] | Path, **kwargs: Any - ) -> Dict[str, Tensor]: - """ - Performs inference with the model and extracts predictions and confidence values. - - Args: - model (ChebaiBaseNet): The model to perform inference with. - model_props (Dict[str, Tensor]): Dictionary with label mask and trust scores. - - Returns: - Dict[str, Tensor]: Dictionary containing predictions and confidence scores. - """ - wrapped_model = self._wrap_model(model_name) - if self._operation_mode == PRED_OP: - model_output, model_props = wrapped_model.predict(model_input) - else: - model_output, model_props = wrapped_model.evaluate( - data_pkl_file_path=model_input - ) - if ( - self._collated_labels is None - and wrapped_model.collated_labels is not None - ): - self._collated_labels = wrapped_model.collated_labels - - assert ( - isinstance(model_output, dict) and "logits" in model_output - ), "Forward pass should return dict containing logits for consistency across all types of models" - del wrapped_model # Model can be huge to keep it in memory, delete asap as no longer needed - - pred_conf_dict = self._get_pred_conf_from_model_output( - model_output, model_props["mask"] - ) - return {"pred_conf_dict": pred_conf_dict, "model_props": model_props} - - def _get_pred_conf_from_model_output( - self, model_output: Dict[str, Tensor], model_label_mask: Tensor - ) -> Dict[str, Tensor]: - """ - Processes model output to extract binary predictions and confidence scores. - - Args: - model_output (Dict[str, Tensor]): Dictionary containing logits from the model. - model_label_mask (Tensor): A boolean mask indicating active labels for the model. - - Returns: - Dict[str, Tensor]: Dictionary with keys "prediction" and "confidence" containing - tensors of the same shape as logits, filled only for active labels. - """ - sigmoid_logits = torch.sigmoid(model_output["logits"]) - prediction = torch.full( - (self._total_data_size, self._num_of_labels), -1, dtype=torch.bool - ) - confidence = torch.full( - (self._total_data_size, self._num_of_labels), -1, dtype=torch.float - ) - prediction[:, model_label_mask] = sigmoid_logits > 0.5 - confidence[:, model_label_mask] = 2 * torch.abs(sigmoid_logits - 0.5) - return {"prediction": prediction, "confidence": confidence} - - def _wrap_model(self, model_name: str) -> BaseWrapper: - model_config = self._model_configs[model_name] - wrp_cls = load_class(model_config[WRAPPER_CLS_PATH]) - assert issubclass(wrp_cls, BaseWrapper), "" - wrapped_model = wrp_cls( - model_name=model_name, - model_config=model_config, - dm_labels=self._dm_labels, - **self._kwargs - ) - - assert isinstance(wrapped_model, BaseWrapper), "" - return wrapped_model - - -class NoActivationCondition(_Controller): - """ - A controller that queues and activates all models unconditionally. - - This implementation does not filter or select models dynamically. - """ - - def __init__(self, **kwargs: Any): - """ - Initializes the controller and loads all model names into the processing queue. - - Args: - **kwargs (Any): Keyword arguments passed to the _Controller initializer. - """ - super().__init__(**kwargs) - self._model_queue: Deque[str] = deque(list(self._model_configs.keys())) diff --git a/chebai/ensemble/_scripts/_ensemble_run_script.py b/chebai/ensemble/_scripts/_ensemble_run_script.py deleted file mode 100644 index 26d003c4..00000000 --- a/chebai/ensemble/_scripts/_ensemble_run_script.py +++ /dev/null @@ -1,50 +0,0 @@ -from jsonargparse import ArgumentParser - -from chebai.ensemble._base import EnsembleBase -from chebai.ensemble._utils import load_class, parse_config_file - - -def load_config_and_instantiate(config_path: str) -> EnsembleBase: - """ - Loads a YAML config file, imports the specified class, and instantiates it with the provided arguments. - - Args: - config_path (str): Path to the YAML configuration file. - - Returns: - EnsembleBase: An instantiated object of the specified class. - - Raises: - TypeError: If the loaded class is not a subclass of EnsembleBase. - """ - - class_path, init_args = parse_config_file(config_path) - - cls = load_class(class_path) - - if not issubclass(cls, EnsembleBase): - raise TypeError(f"{cls} must be subclass of EnsembleBase") - - return cls(**init_args) - - -if __name__ == "__main__": - # Example usage: - # python ensemble_run_script.py --config=configs/ensemble/fullEnsembleWithWMV.yaml - - # Set up argument parser to receive config file path from CLI - parser = ArgumentParser() - parser.add_argument("--config", type=str, help="Path to the YAML config file") - - # Parse arguments from the command line - args = parser.parse_args() - - # Load and instantiate the ensemble object - ensemble = load_config_and_instantiate(args.config) - - # Ensure the loaded object is a valid EnsembleBase instance - if not isinstance(ensemble, EnsembleBase): - raise TypeError("Object must be an instance of `EnsembleBase`") - - # Run the ensemble pipeline - ensemble.run_ensemble() diff --git a/chebai/ensemble/_utils.py b/chebai/ensemble/_utils.py deleted file mode 100644 index 6f38891f..00000000 --- a/chebai/ensemble/_utils.py +++ /dev/null @@ -1,83 +0,0 @@ -import importlib -from pathlib import Path - -import yaml - -from chebai.models.base import ChebaiBaseNet -from chebai.preprocessing.datasets.base import XYBaseDataModule - - -def load_class(class_path: str) -> type: - module_path, class_name = class_path.rsplit(".", 1) - module = importlib.import_module(module_path) - return getattr(module, class_name) - - -def load_data_instance(data_cls_path: str, data_cls_kwargs: dict): - assert isinstance(data_cls_kwargs, dict), "data_cls_kwargs must be a dict" - data_cls = load_class(data_cls_path) - assert isinstance(data_cls, type), f"{data_cls} is not a class." - assert issubclass( - data_cls, XYBaseDataModule - ), f"{data_cls} must inherit from XYBaseDataModule" - return data_cls(**data_cls_kwargs) - - -def load_model_for_inference( - model_ckpt_path: str, model_cls_path: str, model_load_kwargs: dict, **kwargs -) -> ChebaiBaseNet: - """ - Loads a model checkpoint and its label-related properties. - - Returns: - Tuple[LightningModule, Dict[str, torch.Tensor]]: The model and its label properties. - """ - assert isinstance(model_load_kwargs, dict), "model_load_kwargs must be a dict" - - model_name = kwargs.get("model_name", model_ckpt_path) - - if not Path(model_ckpt_path).exists(): - raise FileNotFoundError( - f"Model path '{model_ckpt_path}' for '{model_name}' does not exist." - ) - - lightning_cls = load_class(model_cls_path) - - assert isinstance(lightning_cls, type), f"{lightning_cls} is not a class." - assert issubclass( - lightning_cls, ChebaiBaseNet - ), f"{lightning_cls} must inherit from ChebaiBaseNet" - try: - model = lightning_cls.load_from_checkpoint(model_ckpt_path, **model_load_kwargs) - except Exception as e: - raise RuntimeError(f"Error loading model {model_name} \n Error: {e}") from e - - assert isinstance( - model, ChebaiBaseNet - ), f"Model: {model}(Model Name: {model_name}) is not a ChebaiBaseNet instance." - model.eval() - model.freeze() - return model - - -def parse_config_file(config_path: str) -> tuple[str, dict]: - path = Path(config_path) - - # Check file existence - if not path.exists(): - raise FileNotFoundError(f"Config file not found: {config_path}") - - # Check file extension - if path.suffix.lower() not in [".yml", ".yaml"]: - raise ValueError( - f"Unsupported config file type: {path.suffix}. Expected .yaml or .yml" - ) - - # Load YAML content - with open(path, "r") as f: - config: dict = yaml.safe_load(f) - - class_path: str = config["class_path"] - init_args: dict = config.get("init_args", {}) - assert isinstance(init_args, dict), "init_args must be a dictionary" - return class_path, init_args diff --git a/chebai/ensemble/_wrappers/__init__.py b/chebai/ensemble/_wrappers/__init__.py deleted file mode 100644 index e68fec43..00000000 --- a/chebai/ensemble/_wrappers/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from ._base import BaseWrapper -from ._chemlog import ChemLogWrapper -from ._gnn import GNNWrapper -from ._neural_network import NNWrapper - -__all__ = ["NNWrapper", "BaseWrapper", "GNNWrapper", "ChemLogWrapper"] diff --git a/chebai/ensemble/_wrappers/_base.py b/chebai/ensemble/_wrappers/_base.py deleted file mode 100644 index 31b7883b..00000000 --- a/chebai/ensemble/_wrappers/_base.py +++ /dev/null @@ -1,151 +0,0 @@ -import json -import os -from abc import ABC, abstractmethod - -import torch - -from .._constants import MODEL_LBL_PATH - -_MODEL_REGISTRY = {} - - -class BaseWrapper(ABC): - def __init__( - self, - model_name: str, - model_config: dict[str, str], - dm_labels: dict[str, int], - **kwargs, - ): - if MODEL_LBL_PATH not in model_config: - raise AttributeError( - f"Missing key {MODEL_LBL_PATH} in model '{model_name}' configuration." - ) - - self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self._model_config = model_config - self._model_name = model_name - - self._model_labels_path = self._model_config[MODEL_LBL_PATH] - self._model_props = self._generate_model_label_props(dm_labels=dm_labels) - self.collated_labels = None - - @classmethod - def _cls_name(cls) -> str: - return cls.__name__ - - @property - def name(self): - return f"Wrapper({self._cls_name()}) for model: {self._model_name}" - - def __init_subclass__(cls): - """ - Automatically registers subclasses in the model registry to prevent duplicates. - - Args: - **kwargs: Additional keyword arguments. - """ - if cls._cls_name() in _MODEL_REGISTRY: - raise ValueError(f"Model {cls._cls_name()} does already exist") - else: - _MODEL_REGISTRY[cls._cls_name()] = cls - - def _generate_model_label_props(self, dm_labels) -> dict[str, torch.Tensor]: - """ - Generates label mask and confidence tensors (TPV, FPV) for a model. - - Returns: - Dict[str, torch.Tensor]: Dictionary containing mask, TPV and FPV tensors. - """ - print("\t Generating model label masks and properties") - labels_dict = self._load_model_labels() - - model_label_indices, tpv_label_values, fpv_label_values = [], [], [] - - for label, props in labels_dict.items(): - if label in dm_labels: - try: - self._validate_model_labels_json_element(labels_dict[label]) - except Exception as e: - raise RuntimeError( - f"Label '{label}' has an unexpected error \n Error: {e}" - ) from e - - model_label_indices.append(dm_labels[label]) - tpv_label_values.append(props["PPV"]) - fpv_label_values.append(props["NPV"]) - - if not all([model_label_indices, tpv_label_values, fpv_label_values]): - raise ValueError( - f"No valid label values found in {self._model_labels_path}." - ) - - # Create masks to apply predictions only to known classes - mask = torch.zeros(len(dm_labels), dtype=torch.bool, device=self._device) - mask[torch.tensor(model_label_indices, device=self._device)] = True - - tpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self._device) - fpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self._device) - - tpv_tensor[mask] = torch.tensor( - tpv_label_values, dtype=torch.float, device=self._device - ) - fpv_tensor[mask] = torch.tensor( - fpv_label_values, dtype=torch.float, device=self._device - ) - - return {"mask": mask, "tpv_tensor": tpv_tensor, "fpv_tensor": fpv_tensor} - - def _load_model_labels(self) -> dict[str, dict[str, float]]: - """ - Loads a JSON label file for a model. - - Returns: - Dict[str, Dict[str, float]]: Parsed label confidence data. - - Raises: - FileNotFoundError: If the file is missing. - TypeError: If the file is not a JSON. - """ - if not os.path.exists(self._model_labels_path): - raise FileNotFoundError(f"{self._model_labels_path} does not exist.") - if not self._model_labels_path.endswith(".json"): - raise TypeError(f"{self._model_labels_path} is not a JSON file.") - with open(self._model_labels_path, "r") as f: - return json.load(f) - - @staticmethod - def _validate_model_labels_json_element(label_dict: dict[str, float]) -> None: - """ - Validates a label confidence dictionary to ensure required keys and values are valid. - - Args: - label_dict (Dict[str, Any]): Label data with PPV and NPV keys. - - Raises: - AttributeError: If required keys are missing. - ValueError: If values are not valid floats or are negative. - """ - for key in ["PPV", "NPV"]: - if key not in label_dict: - raise AttributeError(f"Missing key '{key}' in label dict.") - try: - value = float(label_dict[key]) - if value < 0: - raise ValueError(f"'{key}' must be non-negative but got {value}") - except (TypeError, ValueError) as e: - raise ValueError(f"Invalid value for '{key}': {label_dict[key]}") from e - - def predict(self, x: list) -> tuple[dict, dict]: - if not isinstance(x, list): - raise TypeError(f"Input must be a list of SMILES strings, got {type(x)}") - return self._predict_from_list_of_smiles(x), self._model_props - - @abstractmethod - def _predict_from_list_of_smiles(self, smiles_list: list) -> dict: ... - - def evaluate(self, **kwargs) -> tuple[dict, dict]: - return self._evaluate_from_data_file(**kwargs), self._model_props - - @abstractmethod - def _evaluate_from_data_file(self, **kwargs) -> dict: ... diff --git a/chebai/ensemble/_wrappers/_chemlog.py b/chebai/ensemble/_wrappers/_chemlog.py deleted file mode 100644 index ed25e61d..00000000 --- a/chebai/ensemble/_wrappers/_chemlog.py +++ /dev/null @@ -1,91 +0,0 @@ -from pathlib import Path - -import pandas as pd -from chemlog.classification.charge_classifier import get_charge_category -from chemlog.classification.peptide_size_classifier import get_n_amino_acid_residues -from chemlog.classification.proteinogenics_classifier import ( - get_proteinogenic_amino_acids, -) -from chemlog.classification.substructure_classifier import ( - is_diketopiperazine, - is_emericellamide, -) -from chemlog.cli import resolve_chebi_classes -from rdkit import Chem - -from chebai.ensemble._wrappers._base import BaseWrapper - - -class ChemLogWrapper(BaseWrapper): - - def _predict_from_list_of_smiles(self, smiles_list): - return {"logits": self.get_chemlog_results(smiles_list)} - - def _evaluate_from_data_file(self, **kwargs) -> list: - data_pkl_file_path = kwargs["data_pkl_file_path"] - data_df = pd.read_pickle(data_pkl_file_path) - smiles_list = data_df["SMILES"].to_list() - return {"logits": self.get_chemlog_results(smiles_list)} - - def get_chemlog_results(self, smiles_list) -> list: - all_preds = [] - for smiles in smiles_list: - mol = Chem.MolFromSmiles(smiles, sanitize=False) - if mol is None or not smiles: - all_preds.append(None) - continue - mol.UpdatePropertyCache() - charge_category = get_charge_category(mol) - n_amino_acid_residues, _ = get_n_amino_acid_residues(mol) - r = { - "charge_category": charge_category.name, - "n_amino_acid_residues": n_amino_acid_residues, - } - if n_amino_acid_residues == 5: - r["emericellamide"] = is_emericellamide(mol)[0] - if n_amino_acid_residues == 2: - r["2,5-diketopiperazines"] = is_diketopiperazine(mol)[0] - - chebi_classes = [f"CHEBI:{c}" for c in resolve_chebi_classes(r)] - - all_preds.append(chebi_classes) - return all_preds - - def get_chemlog_result_info(self, smiles): - """Get classification for single molecule with additional information.""" - mol = Chem.MolFromSmiles(smiles, sanitize=False) - if mol is None or not smiles: - return {"error": "Failed to parse SMILES"} - mol.UpdatePropertyCache() - try: - Chem.Kekulize(mol) - except Chem.KekulizeException: - pass - - charge_category = get_charge_category(mol) - n_amino_acid_residues, add_output = get_n_amino_acid_residues(mol) - if n_amino_acid_residues > 1: - proteinogenics, proteinogenics_locations, _ = get_proteinogenic_amino_acids( - mol, add_output["amino_residue"], add_output["carboxy_residue"] - ) - else: - proteinogenics, proteinogenics_locations, _ = [], [], [] - results = { - "charge_category": charge_category.name, - "n_amino_acid_residues": n_amino_acid_residues, - "proteinogenics": proteinogenics, - "proteinogenics_locations": proteinogenics_locations, - } - - if n_amino_acid_residues == 5: - emericellamide = is_emericellamide(mol) - results["emericellamide"] = emericellamide[0] - if emericellamide[0]: - results["emericellamide_atoms"] = emericellamide[1] - if n_amino_acid_residues == 2: - diketopiperazine = is_diketopiperazine(mol) - results["2,5-diketopiperazines"] = diketopiperazine[0] - if diketopiperazine[0]: - results["2,5-diketopiperazines_atoms"] = diketopiperazine[1] - - return {**results, **add_output} diff --git a/chebai/ensemble/_wrappers/_gnn.py b/chebai/ensemble/_wrappers/_gnn.py deleted file mode 100644 index b0b35f65..00000000 --- a/chebai/ensemble/_wrappers/_gnn.py +++ /dev/null @@ -1,85 +0,0 @@ -import chebai_graph.preprocessing.properties as p -import torch -from chebai_graph.preprocessing.property_encoder import IndexEncoder, OneHotEncoder -from torch_geometric.data.data import Data as GeomData - -from ._neural_network import NNWrapper - - -class GNNWrapper(NNWrapper): - def _read_smiles(self, smiles): - d = self._data_cls_instance.reader.to_data(dict(features=smiles, labels=None)) - geom_data = d["features"] - assert isinstance(geom_data, GeomData), "Must be an instance of GeoData" - edge_attr = geom_data.edge_attr - x = geom_data.x - molecule_attr = torch.empty((1, 0)) - for property in self._data_cls_instance.properties: - property_values = self._data_cls_instance.reader.read_property( - smiles, property - ) - encoded_values = [] - for value in property_values: - # cant use standard encode for index encoder because model has been trained on a certain range of values - # use default value if we meet an unseen value - if isinstance(property.encoder, IndexEncoder): - if str(value) in property.encoder.cache: - index = ( - property.encoder.cache[str(value)] + property.encoder.offset - ) - else: - index = 0 - print( - f"Unknown property value {value} for property {property} at smiles {smiles}" - ) - if isinstance(property.encoder, OneHotEncoder): - encoded_values.append( - torch.nn.functional.one_hot( # pylint: disable=E1102 - torch.tensor(index), - num_classes=property.encoder.get_encoding_length(), - ) - ) - else: - encoded_values.append(torch.tensor([index])) - - else: - encoded_values.append(property.encoder.encode(value)) - if len(encoded_values) > 0: - encoded_values = torch.stack(encoded_values) - - if isinstance(encoded_values, torch.Tensor): - if len(encoded_values.size()) == 0: - encoded_values = encoded_values.unsqueeze(0) - if len(encoded_values.size()) == 1: - encoded_values = encoded_values.unsqueeze(1) - else: - encoded_values = torch.zeros( - (0, property.encoder.get_encoding_length()) - ) - if isinstance(property, p.AtomProperty): - x = torch.cat([x, encoded_values], dim=1) - elif isinstance(property, p.BondProperty): - # Concat/Duplicate properties values for undirected graph as `edge_index` has first src to tgt edges, then tgt to src edges - edge_attr = torch.cat( - [edge_attr, torch.cat([encoded_values, encoded_values], dim=0)], - dim=1, - ) - else: - molecule_attr = torch.cat([molecule_attr, encoded_values[0]], dim=1) - - d["features"] = GeomData( - x=x, - edge_index=geom_data.edge_index, - edge_attr=edge_attr, - molecule_attr=molecule_attr, - ) - return d - - def _evaluate_from_data_file(self, **kwargs) -> list: - model_logits = super()._evaluate_from_data_file(**kwargs) - # Currently gnn in forward method, logits are returned instead of dict containing logits - return {"logits": model_logits} - - def _predict_from_list_of_smiles(self, smiles_list: list[str]) -> dict: - model_logits = super()._predict_from_list_of_smiles(smiles_list) - return {"logits": model_logits} diff --git a/chebai/ensemble/_wrappers/_neural_network.py b/chebai/ensemble/_wrappers/_neural_network.py deleted file mode 100644 index 3ab5702d..00000000 --- a/chebai/ensemble/_wrappers/_neural_network.py +++ /dev/null @@ -1,113 +0,0 @@ -from pathlib import Path - -from rdkit import Chem - -from chebai.models import ChebaiBaseNet -from chebai.preprocessing.datasets.base import XYBaseDataModule -from chebai.preprocessing.structures import XYData - -from .._constants import DATA_CONFIG_PATH, MODEL_CKPT_PATH, MODEL_CONFIG_PATH -from .._utils import load_data_instance, load_model_for_inference, parse_config_file -from ._base import BaseWrapper - - -class NNWrapper(BaseWrapper): - - def __init__(self, **kwargs): - self._validate_model_configs( - model_config=kwargs["model_config"], model_name=kwargs["model_name"] - ) - super().__init__(**kwargs) - - self._model_ckpt_path = self._model_config[MODEL_CKPT_PATH] - self._model_config_path = self._model_config[MODEL_CONFIG_PATH] - self._data_config_path = self._model_config[DATA_CONFIG_PATH] - - data_cls_path, data_kwargs = parse_config_file(self._data_config_path) - self._data_cls_instance: XYBaseDataModule = load_data_instance( - data_cls_path, data_kwargs - ) - - model_cls_path, model_kwargs = parse_config_file(self._model_config_path) - self._model: ChebaiBaseNet = load_model_for_inference( - self._model_ckpt_path, - model_cls_path, - model_kwargs, - model_name=kwargs["model_name"], - ) - - self.collated_labels = None - - @classmethod - def _validate_model_configs( - cls, - model_config: dict[str, str], - model_name: str, - ) -> None: - """ - Validates model configuration dictionary for required keys and uniqueness. - - Args: - model_configs (Dict[str, Dict[str, Any]]): Model configuration dictionary. - - Raises: - AttributeError: If any model config is missing required keys. - ValueError: If duplicate paths are found for model checkpoint, class, or labels. - """ - required_keys = {MODEL_CKPT_PATH, MODEL_CONFIG_PATH, DATA_CONFIG_PATH} - - missing_keys = required_keys - model_config.keys() - if missing_keys: - raise AttributeError( - f"Missing keys {missing_keys} in model '{model_name}' configuration." - ) - - def _predict_from_list_of_smiles(self, smiles_list: list[str]) -> dict: - token_dicts = [] - could_not_parse = [] - - for i, smiles in enumerate(smiles_list): - try: - if not smiles: - raise ValueError("Empty SMILES string") - d = self._read_smiles(smiles) - rdmol = Chem.MolFromSmiles(smiles, sanitize=False) - if rdmol is None: - raise ValueError("RDKit failed to parse") - except Exception as e: - could_not_parse.append(i) - print(f"Failing to parse '{smiles}' (index {i}): {e}") - else: - token_dicts.append(d) - - if not token_dicts: - raise ValueError("No valid SMILES could be parsed.") - - model_output = self._forward_pass(token_dicts) - # ----- This check is handled in controller - # if not isinstance(model_output, dict) or "logits" not in model_output: - # raise ValueError("Model output is malformed; expected dict with 'logits' key.") - - return model_output - - def _read_smiles(self, smiles: str) -> dict: - return self._data_cls_instance.reader.to_data( - dict(features=smiles, labels=None) - ) - - def _forward_pass(self, batch: list[dict]) -> dict: - collated_batch: XYData = self._data_cls_instance.reader.collator(batch).to( - self._device - ) - self.collated_labels = collated_batch.y - processable_data = self._model._process_batch( # pylint: disable=W0212 - collated_batch, 0 - ) - return self._model(processable_data, **processable_data["model_kwargs"]) - - def _evaluate_from_data_file(self, **kwargs) -> list: - filename = self._data_cls_instance.processed_file_names_dict["data"] - data_list_of_dict = self._data_cls_instance.load_processed_data_from_file( - Path(self._data_cls_instance.processed_dir) / filename - ) - return self._forward_pass(data_list_of_dict) diff --git a/chebai/ensemble/_scripts/_generate_classes_props_json.py b/chebai/result/_generate_classes_props_json.py similarity index 99% rename from chebai/ensemble/_scripts/_generate_classes_props_json.py rename to chebai/result/_generate_classes_props_json.py index 9aa7c6c3..b8704591 100644 --- a/chebai/ensemble/_scripts/_generate_classes_props_json.py +++ b/chebai/result/_generate_classes_props_json.py @@ -5,12 +5,12 @@ from jsonargparse import CLI from sklearn.metrics import multilabel_confusion_matrix -from chebai.ensemble._utils import ( +from chebai.preprocessing.datasets.base import XYBaseDataModule +from chebai.result.utils import ( load_data_instance, load_model_for_inference, parse_config_file, ) -from chebai.preprocessing.datasets.base import XYBaseDataModule class ClassesPropertiesGenerator: diff --git a/chebai/result/utils.py b/chebai/result/utils.py index 991960d6..61679c6f 100644 --- a/chebai/result/utils.py +++ b/chebai/result/utils.py @@ -222,6 +222,82 @@ def load_results_from_buffer( return test_preds, test_labels +def load_class(class_path: str) -> type: + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +def load_data_instance(data_cls_path: str, data_cls_kwargs: dict): + assert isinstance(data_cls_kwargs, dict), "data_cls_kwargs must be a dict" + data_cls = load_class(data_cls_path) + assert isinstance(data_cls, type), f"{data_cls} is not a class." + assert issubclass( + data_cls, XYBaseDataModule + ), f"{data_cls} must inherit from XYBaseDataModule" + return data_cls(**data_cls_kwargs) + + +def load_model_for_inference( + model_ckpt_path: str, model_cls_path: str, model_load_kwargs: dict, **kwargs +) -> ChebaiBaseNet: + """ + Loads a model checkpoint and its label-related properties. + + Returns: + Tuple[LightningModule, Dict[str, torch.Tensor]]: The model and its label properties. + """ + assert isinstance(model_load_kwargs, dict), "model_load_kwargs must be a dict" + + model_name = kwargs.get("model_name", model_ckpt_path) + + if not Path(model_ckpt_path).exists(): + raise FileNotFoundError( + f"Model path '{model_ckpt_path}' for '{model_name}' does not exist." + ) + + lightning_cls = load_class(model_cls_path) + + assert isinstance(lightning_cls, type), f"{lightning_cls} is not a class." + assert issubclass( + lightning_cls, ChebaiBaseNet + ), f"{lightning_cls} must inherit from ChebaiBaseNet" + try: + model = lightning_cls.load_from_checkpoint(model_ckpt_path, **model_load_kwargs) + except Exception as e: + raise RuntimeError(f"Error loading model {model_name} \n Error: {e}") from e + + assert isinstance( + model, ChebaiBaseNet + ), f"Model: {model}(Model Name: {model_name}) is not a ChebaiBaseNet instance." + model.eval() + model.freeze() + return model + + +def parse_config_file(config_path: str) -> tuple[str, dict]: + path = Path(config_path) + + # Check file existence + if not path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + + # Check file extension + if path.suffix.lower() not in [".yml", ".yaml"]: + raise ValueError( + f"Unsupported config file type: {path.suffix}. Expected .yaml or .yml" + ) + + # Load YAML content + with open(path, "r") as f: + config: dict = yaml.safe_load(f) + + class_path: str = config["class_path"] + init_args: dict = config.get("init_args", {}) + assert isinstance(init_args, dict), "init_args must be a dictionary" + return class_path, init_args + + if __name__ == "__main__": import sys diff --git a/configs/ensemble/fullEnsembleWithWMV.yaml b/configs/ensemble/fullEnsembleWithWMV.yaml deleted file mode 100644 index 626be694..00000000 --- a/configs/ensemble/fullEnsembleWithWMV.yaml +++ /dev/null @@ -1,18 +0,0 @@ -class_path: chebai.ensemble.FullEnsembleWMV -init_args: - data_processed_dir_main: "path/to/data/processed/main/directory" -# reader_dir_name: "name_of_reader_dir" # default is `smiles_token` -# _validate_configs: False # To avoid check for using same model with same model configs - model_configs: { - "model_1_name": { - "ckpt_path": "path/to/your/model_1/checkpoint.ckpt", - "class_path": "path/to/your/model_1/class", - "labels_path": "path/to/your/model_1/classes.json", - }, - - "model_2_name": { - "ckpt_path": "path/to/your/model_2/checkpoint.ckpt", - "class_path": "path/to/your/model_2/class", - "labels_path": "path/to/your/model_2/classes.json", - }, - } From 44b60ddb0a50ba076b1d24a225adce7ac3b97fc9 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 24 Jun 2025 14:53:37 +0200 Subject: [PATCH 76/78] remove data processed dir param linking --- chebai/cli.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/chebai/cli.py b/chebai/cli.py index 97078b5d..61a1da8a 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -70,12 +70,6 @@ def call_data_methods(data: Type[XYBaseDataModule]): "data.num_of_labels", "trainer.callbacks.init_args.num_labels" ) - parser.link_arguments( - "data.processed_dir_main", - "model.init_args.data_processed_dir_main", - apply_on="instantiate", - ) - @staticmethod def subcommands() -> Dict[str, Set[str]]: """ From f16550cedcc95dc83beaebb964d478b3656e186a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 24 Jun 2025 15:00:47 +0200 Subject: [PATCH 77/78] fix utils imports --- chebai/result/utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/chebai/result/utils.py b/chebai/result/utils.py index 61679c6f..78d20013 100644 --- a/chebai/result/utils.py +++ b/chebai/result/utils.py @@ -1,14 +1,16 @@ +import importlib import os import shutil -from typing import Optional, Tuple, Union +from pathlib import Path +from typing import Optional, Tuple import torch import tqdm import wandb import wandb.util as wandb_util +import yaml from chebai.models.base import ChebaiBaseNet -from chebai.models.electra import Electra from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor @@ -121,7 +123,7 @@ def evaluate_model( save_batch_size = 128 n_saved = 1 - print(f"") + print("") for i in tqdm.tqdm(range(0, len(data_list), batch_size)): if not ( skip_existing_preds @@ -307,5 +309,5 @@ def parse_config_file(config_path: str) -> tuple[str, dict]: ) os.makedirs(buffer_dir_concat, exist_ok=True) preds, labels = load_results_from_buffer(buffer_dir, "cpu") - torch.save(preds, os.path.join(buffer_dir_concat, f"preds000.pt")) - torch.save(labels, os.path.join(buffer_dir_concat, f"labels000.pt")) + torch.save(preds, os.path.join(buffer_dir_concat, "preds000.pt")) + torch.save(labels, os.path.join(buffer_dir_concat, "labels000.pt")) From c33dec37b4cbda112f2bf47957ec0b1c85d4c8b8 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 24 Jun 2025 15:15:45 +0200 Subject: [PATCH 78/78] update gitignore --- .gitignore | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index da3bbbcb..bafec1d9 100644 --- a/.gitignore +++ b/.gitignore @@ -167,6 +167,12 @@ cython_debug/ /logs /results_buffer electra_pretrained.ckpt -/lightning_logs + +build +.virtual_documents +.jupyter +chebai.egg-info +lightning_logs +logs .isort.cfg /.vscode