diff --git a/README.md b/README.md index eeecd714..ec1fa1ed 100644 --- a/README.md +++ b/README.md @@ -3,14 +3,30 @@ ChEBai is a deep learning library designed for the integration of deep learning methods with chemical ontologies, particularly ChEBI. The library emphasizes the incorporation of the semantic qualities of the ontology into the learning process. -## Installation +## News + +We now support regression tasks! + +## Note for developers -You can install ChEBai via pip: +If you have used ChEBai before PR #39, the file structure in which your ChEBI-data is saved has changed. This means that +datasets will be freshly generated. The data however is the same. If you want to keep the old data (including the old +splits), you can use a migration script. It copies the old data to the new location for a specific ChEBI class +(including chebi version and other parameters). The script can be called by specifying the data module from a config ``` -pip install chebai +python chebai/preprocessing/migration/chebi_data_migration.py migrate --datamodule=[path-to-data-config] +``` +or by specifying the class name (e.g. `ChEBIOver50`) and arguments separately ``` +python chebai/preprocessing/migration/chebi_data_migration.py migrate --class_name=[data-class] [--chebi_version=[version]] +``` +The new dataset will by default generate random data splits (with a given seed). +To reuse a fixed data split, you have to provide the path of the csv file generated during the migration: +`--data.init_args.splits_file_path=[path-to-processed_data]/splits.csv` -Alternatively, you can get the latest development version directly from GitHub: +## Installation + +To install ChEBai, follow these steps: 1. Clone the repository: ``` @@ -63,11 +79,16 @@ A command with additional options may look like this: python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi50.yml --model.criterion=configs/loss/bce.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_unweighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000 ``` -### Fine-tuning for Toxicity prediction +### Fine-tuning for classification tasks, e.g. Toxicity prediction ``` python -m chebai fit --config=[path-to-your-tox21-config] --trainer.callbacks=configs/training/default_callbacks.yml --model.pretrained_checkpoint=[path-to-pretrained-model] ``` +### Fine-tuning for regression tasks, e.g. solubility prediction +``` +python -m chebai fit --config=[path-to-your-esol-config] --trainer.callbacks=configs/training/solCur_callbacks.yml --model.pretrained_checkpoint=[path-to-pretrained-model] +``` + ### Predicting classes given SMILES strings ``` python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] diff --git a/chebai/cli.py b/chebai/cli.py index 96262447..502a5834 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -60,15 +60,40 @@ def call_data_methods(data: Type[XYBaseDataModule]): ) for kind in ("train", "val", "test"): - for average in ("micro-f1", "macro-f1", "balanced-accuracy"): + for average in ( + "micro-f1", + "macro-f1", + "balanced-accuracy", + "roc-auc", + "f1", + "mse", + "rmse", + "r2", + ): + # When using lightning > 2.5.1 then need to uncomment all metrics that are not used + # for average in ("mse", "rmse","r2"): # for regression + # for average in ("f1", "roc-auc"): # for binary classification + # for average in ("micro-f1", "macro-f1", "roc-auc"): # for multilabel classification + # for average in ("micro-f1", "macro-f1", "balanced-accuracy", "roc-auc"): # for multilabel classification using balanced-accuracy parser.link_arguments( "data.num_of_labels", f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels", apply_on="instantiate", ) + parser.link_arguments( "data.num_of_labels", "trainer.callbacks.init_args.num_labels" ) + # parser.link_arguments( + # "model.init_args.out_dim", "trainer.callbacks.init_args.num_labels" + # ) + # parser.link_arguments( + # "data", "model.init_args.criterion.init_args.data_extractor" + # ) + # parser.link_arguments( + # "data.init_args.chebi_version", + # "model.init_args.criterion.init_args.data_extractor.init_args.chebi_version", + # ) @staticmethod def subcommands() -> Dict[str, Set[str]]: diff --git a/chebai/loss/focal_loss.py b/chebai/loss/focal_loss.py new file mode 100644 index 00000000..0fcc3c61 --- /dev/null +++ b/chebai/loss/focal_loss.py @@ -0,0 +1,152 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# from https://github.com/itakurah/Focal-loss-PyTorch + + +class FocalLoss(nn.Module): + def __init__( + self, + gamma=2, + alpha=None, + reduction="mean", + task_type="binary", + num_classes=None, + ): + """ + Unified Focal Loss class for binary, multi-class, and multi-label classification tasks. + :param gamma: Focusing parameter, controls the strength of the modulating factor (1 - p_t)^gamma + :param alpha: Balancing factor, can be a scalar or a tensor for class-wise weights. If None, no class balancing is used. + :param reduction: Specifies the reduction method: 'none' | 'mean' | 'sum' + :param task_type: Specifies the type of task: 'binary', 'multi-class', or 'multi-label' + :param num_classes: Number of classes (only required for multi-class classification) + """ + super(FocalLoss, self).__init__() + self.gamma = gamma + self.alpha = alpha + self.reduction = reduction + self.task_type = task_type + self.num_classes = num_classes + + # Handle alpha for class balancing in multi-class tasks + if ( + task_type == "multi-class" + and alpha is not None + and isinstance(alpha, (list, torch.Tensor)) + ): + assert ( + num_classes is not None + ), "num_classes must be specified for multi-class classification" + if isinstance(alpha, list): + self.alpha = torch.Tensor(alpha) + else: + self.alpha = alpha + + def forward(self, inputs, targets): + """ + Forward pass to compute the Focal Loss based on the specified task type. + :param inputs: Predictions (logits) from the model. + Shape: + - binary/multi-label: (batch_size, num_classes) + - multi-class: (batch_size, num_classes) + :param targets: Ground truth labels. + Shape: + - binary: (batch_size,) + - multi-label: (batch_size, num_classes) + - multi-class: (batch_size,) + """ + if self.task_type == "binary": + return self.binary_focal_loss(inputs, targets) + elif self.task_type == "multi-class": + return self.multi_class_focal_loss(inputs, targets) + elif self.task_type == "multi-label": + return self.multi_label_focal_loss(inputs, targets) + else: + raise ValueError( + f"Unsupported task_type '{self.task_type}'. Use 'binary', 'multi-class', or 'multi-label'." + ) + + def binary_focal_loss(self, inputs, targets): + """Focal loss for binary classification.""" + probs = torch.sigmoid(inputs) + targets = targets.float() + + # Compute binary cross entropy + bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + + # Compute focal weight + p_t = probs * targets + (1 - probs) * (1 - targets) + focal_weight = (1 - p_t) ** self.gamma + + # Apply alpha if provided + if self.alpha is not None: + alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) + bce_loss = alpha_t * bce_loss + + # Apply focal loss weighting + loss = focal_weight * bce_loss + + if self.reduction == "mean": + return loss.mean() + elif self.reduction == "sum": + return loss.sum() + return loss + + def multi_class_focal_loss(self, inputs, targets): + """Focal loss for multi-class classification.""" + if self.alpha is not None: + alpha = self.alpha.to(inputs.device) + + # Convert logits to probabilities with softmax + probs = F.softmax(inputs, dim=1) + + # One-hot encode the targets + targets_one_hot = F.one_hot(targets, num_classes=self.num_classes).float() + + # Compute cross-entropy for each class + ce_loss = -targets_one_hot * torch.log(probs) + + # Compute focal weight + p_t = torch.sum(probs * targets_one_hot, dim=1) # p_t for each sample + focal_weight = (1 - p_t) ** self.gamma + + # Apply alpha if provided (per-class weighting) + if self.alpha is not None: + alpha_t = alpha.gather(0, targets) + ce_loss = alpha_t.unsqueeze(1) * ce_loss + + # Apply focal loss weight + loss = focal_weight.unsqueeze(1) * ce_loss + + if self.reduction == "mean": + return loss.mean() + elif self.reduction == "sum": + return loss.sum() + return loss + + def multi_label_focal_loss(self, inputs, targets): + """Focal loss for multi-label classification.""" + probs = torch.sigmoid(inputs) + + # Compute binary cross entropy + bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + + # Compute focal weight + p_t = probs * targets + (1 - probs) * (1 - targets) + focal_weight = (1 - p_t) ** self.gamma + + # Apply alpha if provided + if self.alpha is not None: + alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) + bce_loss = alpha_t * bce_loss + + # Apply focal loss weight + loss = focal_weight * bce_loss + + if self.reduction == "mean": + return loss.mean() + elif self.reduction == "sum": + return loss.sum() + return loss diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 89abb175..3fef3085 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -2,7 +2,7 @@ import math import os import pickle -from typing import TYPE_CHECKING, List, Literal, Union +from typing import TYPE_CHECKING, List, Literal, Union, Tuple import torch @@ -62,7 +62,7 @@ def __init__( pos_epsilon: float = 0.01, multiply_by_softmax: bool = False, use_sigmoidal_implication: bool = False, - weight_epoch_dependent: Union[bool | tuple[int, int]] = False, + weight_epoch_dependent: Union[bool, Tuple[int, int]] = False, start_at_epoch: int = 0, violations_per_cls_aggregator: Literal[ "sum", "max", "mean", "log-sum", "log-max", "log-mean" diff --git a/chebai/models/base.py b/chebai/models/base.py index 7653f13c..c4447907 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -42,7 +42,8 @@ def __init__( exclude_hyperparameter_logging: Optional[Iterable[str]] = None, **kwargs, ): - super().__init__() + super().__init__(**kwargs) + # super().__init__() if exclude_hyperparameter_logging is None: exclude_hyperparameter_logging = tuple() self.criterion = criterion @@ -273,7 +274,6 @@ def _execute( loss_kwargs = dict() if self.pass_loss_kwargs: loss_kwargs = loss_kwargs_candidates - loss_kwargs["current_epoch"] = self.trainer.current_epoch loss = self.criterion(loss_data, loss_labels, **loss_kwargs) if isinstance(loss, tuple): unnamed_loss_index = 1 diff --git a/chebai/models/electra.py b/chebai/models/electra.py index c053db1c..88bc73e7 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -19,6 +19,7 @@ logging.getLogger("pysmiles").setLevel(logging.CRITICAL) + from chebai.loss.semantic import DisjointLoss as ElectraChEBIDisjointLoss # noqa @@ -40,6 +41,7 @@ class ElectraPre(ChebaiBaseNet): def __init__(self, config: Dict[str, Any] = None, **kwargs: Any): super().__init__(config=config, **kwargs) + self.generator_config = ElectraConfig(**config["generator"]) self.generator = ElectraForMaskedLM(self.generator_config) self.discriminator_config = ElectraConfig(**config["discriminator"]) @@ -224,6 +226,7 @@ def __init__( config: Optional[Dict[str, Any]] = None, pretrained_checkpoint: Optional[str] = None, load_prefix: Optional[str] = None, + model_type="classification", freeze_electra: bool = False, **kwargs: Any, ): @@ -237,6 +240,8 @@ def __init__( config["num_labels"] = self.out_dim self.config = ElectraConfig(**config, output_attentions=True) self.word_dropout = nn.Dropout(config.get("word_dropout", 0)) + self.model_type = model_type + self.pass_loss_kwargs = True in_d = self.config.hidden_size self.output = nn.Sequential( @@ -285,9 +290,16 @@ def _process_for_loss( tuple: A tuple containing the processed model output, labels, and loss arguments. """ kwargs_copy = dict(loss_kwargs) + output = model_output["logits"] if labels is not None: labels = labels.float() - return model_output["logits"], labels, kwargs_copy + if "missing_labels" in kwargs_copy: + missing_labels = kwargs_copy.pop("missing_labels") + output = output * (~missing_labels).int() - 10000 * missing_labels.int() + labels = labels * (~missing_labels).int() + if self.model_type == "classification": + assert ((labels <= torch.tensor(1.0)) & (labels >= torch.tensor(0.0))).all() + return output, labels, kwargs_copy def _get_prediction_and_labels( self, data: Dict[str, Any], labels: Tensor, model_output: Dict[str, Tensor] @@ -308,7 +320,25 @@ def _get_prediction_and_labels( if "non_null_labels" in loss_kwargs: n = loss_kwargs["non_null_labels"] d = d[n] - return torch.sigmoid(d), labels.int() if labels is not None else None + if self.model_type == "classification": + # print(self.model_type, ' in electra 324') + # for mulitclass here softmax instead of sigmoid + d = torch.sigmoid( + d + ) # changing this made a difference for the roc-auc but not the f1, why? + if "missing_labels" in loss_kwargs: + missing_labels = loss_kwargs["missing_labels"] + d = d * (~missing_labels).int().to( + device=d.device + ) # we set the prob of missing labels to 0 + labels = labels * (~missing_labels).int().to( + device=d.device + ) # we set the labels of missing labels to 0 + return d, labels.int() if labels is not None else None + elif self.model_type == "regression": + return d, labels + else: + raise ValueError("Please specify a valid model type in your model config.") def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: """ diff --git a/chebai/preprocessing/bin/smiles_token/tokens.txt b/chebai/preprocessing/bin/smiles_token/tokens.txt index 960173cd..c5553958 100644 --- a/chebai/preprocessing/bin/smiles_token/tokens.txt +++ b/chebai/preprocessing/bin/smiles_token/tokens.txt @@ -4371,3 +4371,5 @@ b [90Sr] [32PH2] [CaH2] +[NH3] +[OH2] diff --git a/chebai/preprocessing/collate.py b/chebai/preprocessing/collate.py index b420ef47..308ed6c7 100644 --- a/chebai/preprocessing/collate.py +++ b/chebai/preprocessing/collate.py @@ -64,7 +64,7 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData: Handles both fully and partially labeled data, where some samples may have `None` as their label. The indices of non-null labels are stored in the `non_null_labels` field, which is used to filter out predictions for unlabeled data during evaluation (e.g., F1, MSE). For models supporting partially labeled data, this method - ensures alignment between features and labels. + ensures alignment between features and labels. Missing labels are passed as a loss keyword. Args: data (List[Union[Dict, Tuple]]): List of ragged data samples. Each sample can be a dictionary or tuple @@ -81,10 +81,15 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData: if isinstance(data[0], tuple): # For legacy data x, y, idents = zip(*data) + missing_labels = None else: x, y, idents = zip( *((d["features"], d["labels"], d.get("ident")) for d in data) ) + missing_labels = [ + d.get("missing_labels", [False for _ in y[0]]) for d in data + ] + if any(x is not None for x in y): # If any label is not None: (None, None, `1`, None) if any(x is None for x in y): @@ -97,11 +102,13 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData: else: # If all labels are not None: (`0`, `2`, `1`, `3`) y = self.process_label_rows(y) + else: # If all labels are None : (`None`, `None`, `None`, `None`) y = None loss_kwargs["non_null_labels"] = [] + loss_kwargs["missing_labels"] = torch.tensor(missing_labels) # Calculate the lengths of each sequence, create a binary mask for valid (non-padded) positions lens = torch.tensor(list(map(len, x))) model_kwargs["mask"] = torch.arange(max(lens))[None, :] < lens[:, None] diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 68254007..02b6ec72 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -332,6 +332,7 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: Returns: List: A list of dictionaries containing the features and labels. """ + lines = self._get_data_size(path) print(f"Processing {lines} lines...") data = [ diff --git a/chebai/preprocessing/datasets/molecule_classification.py b/chebai/preprocessing/datasets/molecule_classification.py new file mode 100644 index 00000000..c2916675 --- /dev/null +++ b/chebai/preprocessing/datasets/molecule_classification.py @@ -0,0 +1,1052 @@ +import csv +import gzip +import os +import shutil +from tempfile import NamedTemporaryFile +from typing import Dict, List +from urllib import request + +import numpy as np +import torch +from sklearn.model_selection import GroupShuffleSplit, train_test_split + +from chebai.preprocessing import reader as dr +from chebai.preprocessing.datasets.base import XYBaseDataModule + + +class ClinTox(XYBaseDataModule): + """Data module for ClinTox MoleculeNet dataset.""" + + HEADERS = [ + "FDA_APPROVED", + "CT_TOX", + ] + + @property + def _name(self) -> str: + """Returns the name of the dataset.""" + return "ClinTox" + + @property + def label_number(self) -> int: + """Returns the number of labels.""" + return 2 + + @property + def raw_file_names(self) -> List[str]: + """Returns a list of raw file names.""" + return ["clintox.csv"] + + # @property + # def processed_file_names(self) -> List[str]: + # """Returns a list of processed file names.""" + # return ["test.pt", "train.pt", "validation.pt"] + + @property + def processed_file_names_dict(self) -> dict: + return { + "test": "test.pt", + "train": "train.pt", + "validation": "validation.pt", + } + + def download(self) -> None: + """Downloads and extracts the dataset.""" + with NamedTemporaryFile("rb") as gout: + request.urlretrieve( + "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/clintox.csv.gz", + gout.name, + ) + with gzip.open(gout.name) as gfile: + with open(os.path.join(self.raw_dir, "clintox.csv"), "wt") as fout: + fout.write(gfile.read().decode()) + + def setup_processed(self) -> None: + """Processes and splits the dataset.""" + print("Create splits") + data = list( + self._load_data_from_file(os.path.join(self.raw_dir, "clintox.csv")) + ) + groups = np.array([d["group"] for d in data]) + if not all(g is None for g in groups): + split_size = int( + len(set(groups)) * (1 - self.test_split - self.validation_split) + ) + os.makedirs(self.processed_dir, exist_ok=True) + splitter = GroupShuffleSplit(train_size=split_size, n_splits=1) + + train_split_index, temp_split_index = next( + splitter.split(data, groups=groups) + ) + + split_groups = groups[temp_split_index] + + splitter = GroupShuffleSplit( + train_size=int( + len(set(split_groups)) + * (1 - self.test_split - self.validation_split) + ), + n_splits=1, + ) + test_split_index, validation_split_index = next( + splitter.split(temp_split_index, groups=split_groups) + ) + train_split = [data[i] for i in train_split_index] + test_split = [ + d for d in (data[temp_split_index[i]] for i in test_split_index) + ] + validation_split = [ + d for d in (data[temp_split_index[i]] for i in validation_split_index) + ] + else: + train_split, test_split = train_test_split( + data, test_size=self.test_split, shuffle=True + ) + train_split, validation_split = train_test_split( + train_split, test_size=self.validation_split, shuffle=True + ) + for k, split in [ + ("test", test_split), + ("train", train_split), + ("validation", validation_split), + ]: + print("transform", k) + torch.save( + split, + os.path.join(self.processed_dir, f"{k}.pt"), + ) + + def setup(self, **kwargs) -> None: + """Sets up the dataset by downloading and processing if necessary.""" + if any( + not os.path.isfile(os.path.join(self.raw_dir, f)) + for f in self.raw_file_names + ): + self.download() + if any( + not os.path.isfile(os.path.join(self.processed_dir, f)) + for f in self.processed_file_names + ): + self.setup_processed() + + self._after_setup() + + def _set_processed_data_props(self): + """ + Load processed data and extract metadata. + + Sets: + - self._num_of_labels: Number of target labels in the dataset. + - self._feature_vector_size: Maximum feature vector length across all data points. + """ + pt_file_path = os.path.join( + self.processed_dir, self.processed_file_names_dict["train"] + ) + data_pt = torch.load(pt_file_path, weights_only=False) + + self._num_of_labels = len(data_pt[0]["labels"]) + self._feature_vector_size = max(len(d["features"]) for d in data_pt) + + def _load_dict(self, input_file_path: str) -> List[Dict]: + """Loads data from a CSV file. + + Args: + input_file_path (str): Path to the CSV file. + + Returns: + List[Dict]: List of data dictionaries. + """ + i = 0 + with open(input_file_path, "r") as input_file: + reader = csv.DictReader(input_file) + for row in reader: + i += 1 + smiles = row["smiles"] + labels = [ + bool(int(label)) if label else None + for label in (row[k] for k in self.HEADERS) + ] + # group = int(row["group"]) + yield dict( + features=smiles, + labels=labels, + ident=i, + # group=group + ) + # yield dict(features=smiles, labels=labels, ident=i) + # yield self.reader.to_data(dict(features=smiles, labels=labels, ident=i)) + + def _perform_data_preparation(self, *args, **kwargs) -> None: + pass + + +class BBBP(XYBaseDataModule): + """Data module for ClinTox MoleculeNet dataset.""" + + HEADERS = [ + "p_np", + ] + + @property + def _name(self) -> str: + """Returns the name of the dataset.""" + return "BBBP" + + @property + def label_number(self) -> int: + """Returns the number of labels.""" + return 1 + + @property + def raw_file_names(self) -> List[str]: + """Returns a list of raw file names.""" + return ["bbbp.csv"] + + # @property + # def processed_file_names(self) -> List[str]: + # """Returns a list of processed file names.""" + # return ["test.pt", "train.pt", "validation.pt"] + + @property + def processed_file_names_dict(self) -> dict: + return { + "test": "test.pt", + "train": "train.pt", + "validation": "validation.pt", + } + + def download(self) -> None: + """Downloads and extracts the dataset.""" + with open(os.path.join(self.raw_dir, "bbbp.csv"), "ab") as dst: + with request.urlopen( + "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv", + ) as src: + shutil.copyfileobj(src, dst) + + def setup_processed(self) -> None: + """Processes and splits the dataset.""" + print("Create splits") + data = list(self._load_data_from_file(os.path.join(self.raw_dir, "bbbp.csv"))) + groups = np.array([d["group"] for d in data]) + if not all(g is None for g in groups): + print("Group shuffled") + split_size = int( + len(set(groups)) * (1 - self.test_split - self.validation_split) + ) + os.makedirs(self.processed_dir, exist_ok=True) + splitter = GroupShuffleSplit(train_size=split_size, n_splits=1) + + train_split_index, temp_split_index = next( + splitter.split(data, groups=groups) + ) + + split_groups = groups[temp_split_index] + + splitter = GroupShuffleSplit( + train_size=int( + len(set(split_groups)) + * (1 - self.test_split - self.validation_split) + ), + n_splits=1, + ) + test_split_index, validation_split_index = next( + splitter.split(temp_split_index, groups=split_groups) + ) + train_split = [data[i] for i in train_split_index] + test_split = [ + d + for d in (data[temp_split_index[i]] for i in test_split_index) + # if d["original"] + ] + validation_split = [ + d + for d in (data[temp_split_index[i]] for i in validation_split_index) + # if d["original"] + ] + else: + train_split, test_split = train_test_split( + data, test_size=self.test_split, shuffle=True + ) + train_split, validation_split = train_test_split( + train_split, test_size=self.validation_split, shuffle=True + ) + for k, split in [ + ("test", test_split), + ("train", train_split), + ("validation", validation_split), + ]: + print("transform", k) + torch.save( + split, + os.path.join(self.processed_dir, f"{k}.pt"), + ) + + def setup(self, **kwargs) -> None: + """Sets up the dataset by downloading and processing if necessary.""" + if any( + not os.path.isfile(os.path.join(self.raw_dir, f)) + for f in self.raw_file_names + ): + self.download() + if any( + not os.path.isfile(os.path.join(self.processed_dir, f)) + for f in self.processed_file_names + ): + self.setup_processed() + + self._after_setup() + + def _set_processed_data_props(self): + """ + Load processed data and extract metadata. + + Sets: + - self._num_of_labels: Number of target labels in the dataset. + - self._feature_vector_size: Maximum feature vector length across all data points. + """ + pt_file_path = os.path.join( + self.processed_dir, self.processed_file_names_dict["train"] + ) + data_pt = torch.load(pt_file_path, weights_only=False) + + self._num_of_labels = len(data_pt[0]["labels"]) + self._feature_vector_size = max(len(d["features"]) for d in data_pt) + + def _load_dict(self, input_file_path: str) -> List[Dict]: + """Loads data from a CSV file. + + Args: + input_file_path (str): Path to the CSV file. + + Returns: + List[Dict]: List of data dictionaries. + """ + i = 0 + with open(input_file_path, "r") as input_file: + reader = csv.DictReader(input_file) + for row in reader: + i += 1 + smiles = row["smiles"] + labels = [int(row["p_np"])] + # group = int(row["group"]) + yield dict( + features=smiles, + labels=labels, + ident=i, + # , group=group + ) + # yield self.reader.to_data(dict(features=smiles, labels=labels, ident=i)) + + def _perform_data_preparation(self, *args, **kwargs) -> None: + pass + + +class Sider(XYBaseDataModule): + """Data module for ClinTox MoleculeNet dataset.""" + + HEADERS = [ + "Hepatobiliary disorders", + "Metabolism and nutrition disorders", + "Product issues", + "Eye disorders", + "Investigations", + "Musculoskeletal and connective tissue disorders", + "Gastrointestinal disorders", + "Social circumstances", + "Immune system disorders", + "Reproductive system and breast disorders", + "Neoplasms benign, malignant and unspecified (incl cysts and polyps)", + "General disorders and administration site conditions", + "Endocrine disorders", + "Surgical and medical procedures", + "Vascular disorders", + "Blood and lymphatic system disorders", + "Skin and subcutaneous tissue disorders", + "Congenital, familial and genetic disorders", + "Infections and infestations", + "Respiratory, thoracic and mediastinal disorders", + "Psychiatric disorders", + "Renal and urinary disorders", + "Pregnancy, puerperium and perinatal conditions", + "Ear and labyrinth disorders", + "Cardiac disorders", + "Nervous system disorders", + "Injury, poisoning and procedural complications", + ] + + @property + def _name(self) -> str: + """Returns the name of the dataset.""" + return "Sider" + + @property + def label_number(self) -> int: + """Returns the number of labels.""" + return 27 + + @property + def raw_file_names(self) -> List[str]: + """Returns a list of raw file names.""" + return ["sider.csv"] + + # @property + # def processed_file_names(self) -> List[str]: + # """Returns a list of processed file names.""" + # return ["test.pt", "train.pt", "validation.pt"] + + @property + def processed_file_names_dict(self) -> dict: + return { + "test": "test.pt", + "train": "train.pt", + "validation": "validation.pt", + } + + def download(self) -> None: + """Downloads and extracts the dataset.""" + with NamedTemporaryFile("rb") as gout: + request.urlretrieve( + "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/sider.csv.gz", + gout.name, + ) + with gzip.open(gout.name) as gfile: + with open(os.path.join(self.raw_dir, "sider.csv"), "wt") as fout: + fout.write(gfile.read().decode()) + + def setup_processed(self) -> None: + """Processes and splits the dataset.""" + print("Create splits") + data = list(self._load_data_from_file(os.path.join(self.raw_dir, "sider.csv"))) + groups = np.array([d["group"] for d in data]) + if not all(g is None for g in groups): + split_size = int( + len(set(groups)) * (1 - self.test_split - self.validation_split) + ) + os.makedirs(self.processed_dir, exist_ok=True) + splitter = GroupShuffleSplit(train_size=split_size, n_splits=1) + + train_split_index, temp_split_index = next( + splitter.split(data, groups=groups) + ) + + split_groups = groups[temp_split_index] + + splitter = GroupShuffleSplit( + train_size=int( + len(set(split_groups)) + * (1 - self.test_split - self.validation_split) + ), + n_splits=1, + ) + test_split_index, validation_split_index = next( + splitter.split(temp_split_index, groups=split_groups) + ) + train_split = [data[i] for i in train_split_index] + test_split = [ + d + for d in (data[temp_split_index[i]] for i in test_split_index) + # if d["original"] + ] + validation_split = [ + d + for d in (data[temp_split_index[i]] for i in validation_split_index) + # if d["original"] + ] + else: + train_split, test_split = train_test_split( + data, test_size=self.test_split, shuffle=True + ) + train_split, validation_split = train_test_split( + train_split, test_size=self.validation_split, shuffle=True + ) + for k, split in [ + ("test", test_split), + ("train", train_split), + ("validation", validation_split), + ]: + print("transform", k) + torch.save( + split, + os.path.join(self.processed_dir, f"{k}.pt"), + ) + + def setup(self, **kwargs) -> None: + """Sets up the dataset by downloading and processing if necessary.""" + if any( + not os.path.isfile(os.path.join(self.raw_dir, f)) + for f in self.raw_file_names + ): + self.download() + if any( + not os.path.isfile(os.path.join(self.processed_dir, f)) + for f in self.processed_file_names + ): + self.setup_processed() + + self._after_setup() + + def _set_processed_data_props(self): + """ + Load processed data and extract metadata. + + Sets: + - self._num_of_labels: Number of target labels in the dataset. + - self._feature_vector_size: Maximum feature vector length across all data points. + """ + pt_file_path = os.path.join( + self.processed_dir, self.processed_file_names_dict["train"] + ) + data_pt = torch.load(pt_file_path, weights_only=False) + + self._num_of_labels = len(data_pt[0]["labels"]) + self._feature_vector_size = max(len(d["features"]) for d in data_pt) + + def _load_dict(self, input_file_path: str) -> List[Dict]: + """Loads data from a CSV file. + + Args: + input_file_path (str): Path to the CSV file. + + Returns: + List[Dict]: List of data dictionaries. + """ + i = 0 + with open(input_file_path, "r") as input_file: + reader = csv.DictReader(input_file) + for row in reader: + i += 1 + smiles = row["smiles"] + labels = [ + bool(int(label)) if label else None + for label in (row[k] for k in self.HEADERS) + ] + # group = row["group"] + yield dict( + features=smiles, + labels=labels, + ident=i, + # , group=group + ) + # yield self.reader.to_data(dict(features=smiles, labels=labels, ident=i)) + + def _perform_data_preparation(self, *args, **kwargs) -> None: + pass + + +class Bace(XYBaseDataModule): + """Data module for ClinTox MoleculeNet dataset.""" + + HEADERS = [ + "class", + ] + + @property + def _name(self) -> str: + """Returns the name of the dataset.""" + return "Bace" + + @property + def label_number(self) -> int: + """Returns the number of labels.""" + return 1 + + @property + def raw_file_names(self) -> List[str]: + """Returns a list of raw file names.""" + return ["bace.csv"] + + # @property + # def processed_file_names(self) -> List[str]: + # """Returns a list of processed file names.""" + # return ["test.pt", "train.pt", "validation.pt"] + + @property + def processed_file_names_dict(self) -> dict: + return { + "test": "test.pt", + "train": "train.pt", + "validation": "validation.pt", + } + + def download(self) -> None: + """Downloads and extracts the dataset.""" + with open(os.path.join(self.raw_dir, "bace.csv"), "ab") as dst: + with request.urlopen( + "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv", + ) as src: + shutil.copyfileobj(src, dst) + + def setup_processed(self) -> None: + """Processes and splits the dataset.""" + print("Create splits") + data = list(self._load_data_from_file(os.path.join(self.raw_dir, "bace.csv"))) + # groups = np.array([d.get("group") for d in data]) + + # if not all(g is None for g in groups): + # split_size = int(len(set(groups)) * (1 - self.test_split - self.validation_split)) + # os.makedirs(self.processed_dir, exist_ok=True) + # splitter = GroupShuffleSplit(train_size=split_size, n_splits=1) + + # train_split_index, temp_split_index = next( + # splitter.split(data, groups=groups) + # ) + + # split_groups = groups[temp_split_index] + + # splitter = GroupShuffleSplit( + # train_size=int(len(set(split_groups)) * (1 - self.test_split - self.validation_split)), n_splits=1 + # ) + # test_split_index, validation_split_index = next( + # splitter.split(temp_split_index, groups=split_groups) + # ) + # train_split = [data[i] for i in train_split_index] + # test_split = [ + # d + # for d in (data[temp_split_index[i]] for i in test_split_index) + # ] + # validation_split = [ + # d + # for d in (data[temp_split_index[i]] for i in validation_split_index) + # ] + # else: + train_split, test_split = train_test_split( + data, test_size=self.test_split, shuffle=True + ) + train_split, validation_split = train_test_split( + train_split, test_size=self.validation_split, shuffle=True + ) + for k, split in [ + ("test", test_split), + ("train", train_split), + ("validation", validation_split), + ]: + print("transform", k) + torch.save( + split, + os.path.join(self.processed_dir, f"{k}.pt"), + ) + + def setup(self, **kwargs) -> None: + """Sets up the dataset by downloading and processing if necessary.""" + if any( + not os.path.isfile(os.path.join(self.raw_dir, f)) + for f in self.raw_file_names + ): + self.download() + if any( + not os.path.isfile(os.path.join(self.processed_dir, f)) + for f in self.processed_file_names + ): + self.setup_processed() + + self._after_setup() + + def _set_processed_data_props(self): + """ + Load processed data and extract metadata. + + Sets: + - self._num_of_labels: Number of target labels in the dataset. + - self._feature_vector_size: Maximum feature vector length across all data points. + """ + pt_file_path = os.path.join( + self.processed_dir, self.processed_file_names_dict["train"] + ) + data_pt = torch.load(pt_file_path, weights_only=False) + + self._num_of_labels = len(data_pt[0]["labels"]) + self._feature_vector_size = max(len(d["features"]) for d in data_pt) + + def _load_dict(self, input_file_path: str) -> List[Dict]: + """Loads data from a CSV file. + + Args: + input_file_path (str): Path to the CSV file. + + Returns: + List[Dict]: List of data dictionaries. + """ + i = 0 + with open(input_file_path, "r") as input_file: + reader = csv.DictReader(input_file) + for row in reader: + i += 1 + smiles = row["mol"] + labels = [int(row["Class"])] + # group = row["group"] + yield dict(features=smiles, labels=labels, ident=i) # , group=group + # yield self.reader.to_data(dict(features=smiles, labels=labels, ident=i)) + + def _perform_data_preparation(self, *args, **kwargs) -> None: + pass + + +class HIV(XYBaseDataModule): + """Data module for ClinTox MoleculeNet dataset.""" + + HEADERS = [ + "HIV_active", + ] + + @property + def _name(self) -> str: + """Returns the name of the dataset.""" + return "HIV" + + @property + def label_number(self) -> int: + """Returns the number of labels.""" + return 1 + + @property + def raw_file_names(self) -> List[str]: + """Returns a list of raw file names.""" + return ["hiv.csv"] + + # @property + # def processed_file_names(self) -> List[str]: + # """Returns a list of processed file names.""" + # return ["test.pt", "train.pt", "validation.pt"] + + @property + def processed_file_names_dict(self) -> dict: + return { + "test": "test.pt", + "train": "train.pt", + "validation": "validation.pt", + } + + def download(self) -> None: + """Downloads and extracts the dataset.""" + with open(os.path.join(self.raw_dir, "hiv.csv"), "ab") as dst: + with request.urlopen( + "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/HIV.csv", + ) as src: + shutil.copyfileobj(src, dst) + + def setup_processed(self) -> None: + """Processes and splits the dataset.""" + print("Create splits") + data = list(self._load_data_from_file(os.path.join(self.raw_dir, "hiv.csv"))) + groups = np.array([d["group"] for d in data]) + if not all(g is None for g in groups): + print("Group shuffled") + split_size = int( + len(set(groups)) * (1 - self.test_split - self.validation_split) + ) + os.makedirs(self.processed_dir, exist_ok=True) + splitter = GroupShuffleSplit(train_size=split_size, n_splits=1) + + train_split_index, temp_split_index = next( + splitter.split(data, groups=groups) + ) + + split_groups = groups[temp_split_index] + + splitter = GroupShuffleSplit( + train_size=int( + len(set(split_groups)) + * (1 - self.test_split - self.validation_split) + ), + n_splits=1, + ) + test_split_index, validation_split_index = next( + splitter.split(temp_split_index, groups=split_groups) + ) + train_split = [data[i] for i in train_split_index] + test_split = [ + d for d in (data[temp_split_index[i]] for i in test_split_index) + ] + validation_split = [ + d for d in (data[temp_split_index[i]] for i in validation_split_index) + ] + else: + train_split, test_split = train_test_split( + data, test_size=self.test_split, shuffle=True + ) + train_split, validation_split = train_test_split( + train_split, test_size=self.validation_split, shuffle=True + ) + for k, split in [ + ("test", test_split), + ("train", train_split), + ("validation", validation_split), + ]: + print("transform", k) + torch.save( + split, + os.path.join(self.processed_dir, f"{k}.pt"), + ) + + def setup(self, **kwargs) -> None: + """Sets up the dataset by downloading and processing if necessary.""" + if any( + not os.path.isfile(os.path.join(self.raw_dir, f)) + for f in self.raw_file_names + ): + self.download() + if any( + not os.path.isfile(os.path.join(self.processed_dir, f)) + for f in self.processed_file_names + ): + self.setup_processed() + + self._after_setup() + + def _set_processed_data_props(self): + """ + Load processed data and extract metadata. + + Sets: + - self._num_of_labels: Number of target labels in the dataset. + - self._feature_vector_size: Maximum feature vector length across all data points. + """ + pt_file_path = os.path.join( + self.processed_dir, self.processed_file_names_dict["train"] + ) + data_pt = torch.load(pt_file_path, weights_only=False) + + self._num_of_labels = len(data_pt[0]["labels"]) + self._feature_vector_size = max(len(d["features"]) for d in data_pt) + + def _load_dict(self, input_file_path: str) -> List[Dict]: + """Loads data from a CSV file. + + Args: + input_file_path (str): Path to the CSV file. + + Returns: + List[Dict]: List of data dictionaries. + """ + i = 0 + with open(input_file_path, "r") as input_file: + reader = csv.DictReader(input_file) + for row in reader: + if len(row) > 1: + i += 1 + smiles = row["smiles"] + labels = [int(row["HIV_active"])] + # group = int(row["group"]) + yield dict( + features=smiles, + labels=labels, + ident=i, + # , group=group + ) + # yield self.reader.to_data(dict(features=smiles, labels=labels, ident=i)) + + def _perform_data_preparation(self, *args, **kwargs) -> None: + pass + + +class MUV(XYBaseDataModule): + """Data module for ClinTox MoleculeNet dataset.""" + + HEADERS = [ + "MUV-466", + "MUV-548", + "MUV-600", + "MUV-644", + "MUV-652", + "MUV-689", + "MUV-692", + "MUV-712", + "MUV-713", + "MUV-733", + "MUV-737", + "MUV-810", + "MUV-832", + "MUV-846", + "MUV-852", + "MUV-858", + "MUV-859", + ] + + @property + def _name(self) -> str: + """Returns the name of the dataset.""" + return "MUV" + + @property + def label_number(self) -> int: + """Returns the number of labels.""" + return 17 + + @property + def raw_file_names(self) -> List[str]: + """Returns a list of raw file names.""" + return ["muv.csv"] + + # @property + # def processed_file_names(self) -> List[str]: + # """Returns a list of processed file names.""" + # return ["test.pt", "train.pt", "validation.pt"] + + @property + def processed_file_names_dict(self) -> dict: + return { + "test": "test.pt", + "train": "train.pt", + "validation": "validation.pt", + } + + def download(self) -> None: + """Downloads and extracts the dataset.""" + with NamedTemporaryFile("rb") as gout: + request.urlretrieve( + "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/muv.csv.gz", + gout.name, + ) + with gzip.open(gout.name) as gfile: + with open(os.path.join(self.raw_dir, "muv.csv"), "wt") as fout: + fout.write(gfile.read().decode()) + + def setup_processed(self) -> None: + """Processes and splits the dataset.""" + print("Create splits") + data = list(self._load_data_from_file(os.path.join(self.raw_dir, "muv.csv"))) + groups = np.array([d["group"] for d in data]) + if not all(g is None for g in groups): + split_size = int( + len(set(groups)) * (1 - self.test_split - self.validation_split) + ) + os.makedirs(self.processed_dir, exist_ok=True) + splitter = GroupShuffleSplit(train_size=split_size, n_splits=1) + + train_split_index, temp_split_index = next( + splitter.split(data, groups=groups) + ) + + split_groups = groups[temp_split_index] + + splitter = GroupShuffleSplit( + train_size=int( + len(set(split_groups)) + * (1 - self.test_split - self.validation_split) + ), + n_splits=1, + ) + test_split_index, validation_split_index = next( + splitter.split(temp_split_index, groups=split_groups) + ) + train_split = [data[i] for i in train_split_index] + test_split = [ + d + for d in (data[temp_split_index[i]] for i in test_split_index) + # if d["original"] + ] + validation_split = [ + d + for d in (data[temp_split_index[i]] for i in validation_split_index) + # if d["original"] + ] + else: + train_split, test_split = train_test_split( + data, test_size=self.test_split, shuffle=True + ) + train_split, validation_split = train_test_split( + train_split, test_size=self.validation_split, shuffle=True + ) + for k, split in [ + ("test", test_split), + ("train", train_split), + ("validation", validation_split), + ]: + print("transform", k) + torch.save( + split, + os.path.join(self.processed_dir, f"{k}.pt"), + ) + + def setup(self, **kwargs) -> None: + """Sets up the dataset by downloading and processing if necessary.""" + if any( + not os.path.isfile(os.path.join(self.raw_dir, f)) + for f in self.raw_file_names + ): + self.download() + if any( + not os.path.isfile(os.path.join(self.processed_dir, f)) + for f in self.processed_file_names + ): + self.setup_processed() + + self._after_setup() + + def _set_processed_data_props(self): + """ + Load processed data and extract metadata. + + Sets: + - self._num_of_labels: Number of target labels in the dataset. + - self._feature_vector_size: Maximum feature vector length across all data points. + """ + pt_file_path = os.path.join( + self.processed_dir, self.processed_file_names_dict["train"] + ) + data_pt = torch.load(pt_file_path, weights_only=False) + + self._num_of_labels = len(data_pt[0]["labels"]) + self._feature_vector_size = max(len(d["features"]) for d in data_pt) + + def _load_dict(self, input_file_path: str) -> List[Dict]: + """Loads data from a CSV file. + + Args: + input_file_path (str): Path to the CSV file. + + Returns: + List[Dict]: List of data dictionaries. + """ + i = 0 + with open(input_file_path, "r") as input_file: + reader = csv.DictReader(input_file) + for row in reader: + i += 1 + smiles = row["smiles"] + labels = [ + bool(int(label)) if label else None + for label in (row[k] for k in self.HEADERS) + ] + # group = row["group"] + yield dict(features=smiles, labels=labels, ident=i) # , group=group) + # yield self.reader.to_data(dict(features=smiles, labels=labels, ident=i)) + + def _perform_data_preparation(self, *args, **kwargs) -> None: + pass + + +class BaceChem(Bace): + """Chemical data reader for Tox21MolNet dataset.""" + + READER = dr.ChemDataReader + + +class SiderChem(Sider): + """Chemical data reader for Tox21MolNet dataset.""" + + READER = dr.ChemDataReader + + +class BBBPChem(BBBP): + """Chemical data reader for Tox21MolNet dataset.""" + + READER = dr.ChemDataReader + + +class ClinToxChem(ClinTox): + """Chemical data reader for Tox21MolNet dataset.""" + + READER = dr.ChemDataReader + + +class HIVChem(HIV): + """Chemical data reader for Tox21MolNet dataset.""" + + READER = dr.ChemDataReader + + +class MUVChem(MUV): + """Chemical data reader for Tox21MolNet dataset.""" + + READER = dr.ChemDataReader diff --git a/chebai/preprocessing/datasets/molecule_regression.py b/chebai/preprocessing/datasets/molecule_regression.py new file mode 100644 index 00000000..bc74df34 --- /dev/null +++ b/chebai/preprocessing/datasets/molecule_regression.py @@ -0,0 +1,282 @@ +from urllib import request +import csv +import os +import shutil +from typing import Dict, List + +from sklearn.model_selection import train_test_split +import torch + +from chebai.preprocessing import reader as dr +from chebai.preprocessing.datasets.base import XYBaseDataModule + + +class Lipo(XYBaseDataModule): + HEADERS = [ + "exp", + ] + + @property + def _name(self): + return "Lipo" + + @property + def label_number(self): + return 1 + + @property + def raw_file_names(self): + return ["Lipo.csv"] + + # @property + # def processed_file_names(self): + # return ["test.pt", "train.pt", "validation.pt"] + + @property + def processed_file_names_dict(self) -> dict: + return { + "test": "test.pt", + "train": "train.pt", + "validation": "validation.pt", + } + + def download(self): + # download + with open(os.path.join(self.raw_dir, "Lipo.csv"), "ab") as dst: + with request.urlopen( + "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/Lipophilicity.csv", + ) as src: + shutil.copyfileobj(src, dst) + + def setup_processed(self): + print("Create splits") + data = list(self._load_data_from_file(os.path.join(self.raw_dir, "Lipo.csv"))) + print(len(data)) + + train_split, test_split = train_test_split( + data, test_size=self.test_split, shuffle=True + ) + train_split, validation_split = train_test_split( + train_split, test_size=self.validation_split, shuffle=True + ) + + if False: + train_split, test_split = train_test_split( + data, train_size=self.train_split, shuffle=True + ) + test_split, validation_split = train_test_split( + test_split, train_size=0.5, shuffle=True + ) + for k, split in [ + ("test", test_split), + ("train", train_split), + ("validation", validation_split), + ]: + print("transform", k) + torch.save( + split, + os.path.join(self.processed_dir, f"{k}.pt"), + ) + + def setup(self, **kwargs): + if any( + not os.path.isfile(os.path.join(self.raw_dir, f)) + for f in self.raw_file_names + ): + self.download() + print( + [ + not os.path.isfile(os.path.join(self.processed_dir, f)) + for f in self.processed_file_names + ] + ) + if any( + not os.path.isfile(os.path.join(self.processed_dir, f)) + for f in self.processed_file_names + ): + self.setup_processed() + + self._after_setup() + + def _set_processed_data_props(self): + """ + Load processed data and extract metadata. + + Sets: + - self._num_of_labels: Number of target labels in the dataset. + - self._feature_vector_size: Maximum feature vector length across all data points. + """ + pt_file_path = os.path.join( + self.processed_dir, self.processed_file_names_dict["train"] + ) + data_pt = torch.load(pt_file_path, weights_only=False) + + self._num_of_labels = len(data_pt[0]["labels"]) + self._feature_vector_size = max(len(d["features"]) for d in data_pt) + + def _load_dict(self, input_file_path: str) -> List[Dict]: + """Loads data from a CSV file. + + Args: + input_file_path (str): Path to the CSV file. + + Returns: + List[Dict]: List of data dictionaries. + """ + smiles_l = [] + labels_l = [] + with open(input_file_path, "r") as input_file: + reader = csv.DictReader(input_file) + print(reader.fieldnames) + for row in reader: + smiles_l.append(row["smiles"]) + labels_l.append(float(row["exp"])) + + for i in range(0, len(smiles_l)): + yield dict(features=smiles_l[i], labels=[labels_l[i]], ident=i) + # yield self.reader.to_data(dict(features=smiles_l[i], labels=[labels_l[i]], ident=i)) + + def _perform_data_preparation(self, *args, **kwargs) -> None: + pass + + +class LipoChem(Lipo): + """Chemical data reader for the solubility dataset.""" + + READER = dr.ChemDataReader + + +class FreeSolv(XYBaseDataModule): + HEADERS = [ + "expt", + ] + + @property + def _name(self): + return "FreeSolv" + + @property + def label_number(self): + return 1 + + @property + def raw_file_names(self): + return ["FreeSolv.csv"] + + # @property + # def processed_file_names(self): + # return ["test.pt", "train.pt", "validation.pt"] + + @property + def processed_file_names_dict(self) -> dict: + return { + "test": "test.pt", + "train": "train.pt", + "validation": "validation.pt", + } + + def download(self): + # download + with open(os.path.join(self.raw_dir, "FreeSolv.csv"), "ab") as dst: + with request.urlopen( + "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/SAMPL.csv", + ) as src: + shutil.copyfileobj(src, dst) + + def setup_processed(self): + print("Create splits") + data = list( + self._load_data_from_file(os.path.join(self.raw_dir, "FreeSolv.csv")) + ) + print(len(data)) + train_split, test_split = train_test_split( + data, test_size=self.test_split, shuffle=True + ) + train_split, validation_split = train_test_split( + train_split, test_size=self.validation_split, shuffle=True + ) + + if False: + train_split, test_split = train_test_split( + data, train_size=self.train_split, shuffle=True + ) + test_split, validation_split = train_test_split( + test_split, train_size=0.5, shuffle=True + ) + for k, split in [ + ("test", test_split), + ("train", train_split), + ("validation", validation_split), + ]: + print("transform", k) + torch.save( + split, + os.path.join(self.processed_dir, f"{k}.pt"), + ) + + def setup(self, **kwargs): + if any( + not os.path.isfile(os.path.join(self.raw_dir, f)) + for f in self.raw_file_names + ): + self.download() + print( + [ + not os.path.isfile(os.path.join(self.processed_dir, f)) + for f in self.processed_file_names + ] + ) + if any( + not os.path.isfile(os.path.join(self.processed_dir, f)) + for f in self.processed_file_names + ): + self.setup_processed() + + self._after_setup() + + def _set_processed_data_props(self): + """ + Load processed data and extract metadata. + + Sets: + - self._num_of_labels: Number of target labels in the dataset. + - self._feature_vector_size: Maximum feature vector length across all data points. + """ + pt_file_path = os.path.join( + self.processed_dir, self.processed_file_names_dict["train"] + ) + data_pt = torch.load(pt_file_path, weights_only=False) + + self._num_of_labels = len(data_pt[0]["labels"]) + self._feature_vector_size = max(len(d["features"]) for d in data_pt) + + def _load_dict(self, input_file_path: str) -> List[Dict]: + """Loads data from a CSV file. + + Args: + input_file_path (str): Path to the CSV file. + + Returns: + List[Dict]: List of data dictionaries. + """ + smiles_l = [] + labels_l = [] + with open(input_file_path, "r") as input_file: + reader = csv.DictReader(input_file) + print(reader.fieldnames) + for row in reader: + smiles_l.append(row["smiles"]) + labels_l.append(float(row["expt"])) + + for i in range(0, len(smiles_l)): + yield dict(features=smiles_l[i], labels=[labels_l[i]], ident=i) + # yield self.reader.to_data(dict(features=smiles_l[i], labels=[labels_l[i]], ident=i)) + + def _perform_data_preparation(self, *args, **kwargs) -> None: + pass + + +class FreeSolvChem(FreeSolv): + """Chemical data reader for the solubility dataset.""" + + READER = dr.ChemDataReader diff --git a/chebai/preprocessing/datasets/solCuration.py b/chebai/preprocessing/datasets/solCuration.py new file mode 100644 index 00000000..61b88ce1 --- /dev/null +++ b/chebai/preprocessing/datasets/solCuration.py @@ -0,0 +1,298 @@ +from urllib import request +import csv +import os +import shutil +from typing import Dict, List + +from sklearn.model_selection import train_test_split +import torch + +from chebai.preprocessing import reader as dr +from chebai.preprocessing.datasets.base import XYBaseDataModule + + +class SolCuration(XYBaseDataModule): + HEADERS = [ + "logS", + ] + + @property + def _name(self): + return "SolCuration" + + @property + def label_number(self): + return 1 + + @property + def raw_file_names(self): + return ["solCuration.csv"] + + # @property + # def processed_file_names(self): + # return ["test.pt", "train.pt", "validation.pt"] + + @property + def processed_file_names_dict(self) -> dict: + return { + "test": "test.pt", + "train": "train.pt", + "validation": "validation.pt", + } + + def download(self): + # download and combine all the available curated datasets from xxx + db_sol = ["aqsol", "aqua", "esol", "ochem", "phys"] + with open(os.path.join(self.raw_dir, "solCuration.csv"), "ab") as dst: + for i, db in enumerate(db_sol): + with request.urlopen( + f"https://raw.githubusercontent.com/Mengjintao/SolCuration/master/cure/{db}_cure.csv", + ) as src: + if i > 0: + src.readline() + shutil.copyfileobj(src, dst) + + def setup_processed(self): + print("Create splits") + data = list( + self._load_data_from_file(os.path.join(self.raw_dir, "solCuration.csv")) + ) + print(len(data)) + + train_split, test_split = train_test_split( + data, test_size=self.test_split, shuffle=True + ) + train_split, validation_split = train_test_split( + train_split, test_size=self.validation_split, shuffle=True + ) + + if False: + train_split, test_split = train_test_split( + data, train_size=self.train_split, shuffle=True + ) + test_split, validation_split = train_test_split( + test_split, train_size=0.5, shuffle=True + ) + for k, split in [ + ("test", test_split), + ("train", train_split), + ("validation", validation_split), + ]: + print("transform", k) + torch.save( + split, + os.path.join(self.processed_dir, f"{k}.pt"), + ) + + def setup(self, **kwargs): + if any( + not os.path.isfile(os.path.join(self.raw_dir, f)) + for f in self.raw_file_names + ): + self.download() + print( + [ + not os.path.isfile(os.path.join(self.processed_dir, f)) + for f in self.processed_file_names + ] + ) + if any( + not os.path.isfile(os.path.join(self.processed_dir, f)) + for f in self.processed_file_names + ): + self.setup_processed() + + self._after_setup() + + def _set_processed_data_props(self): + """ + Load processed data and extract metadata. + + Sets: + - self._num_of_labels: Number of target labels in the dataset. + - self._feature_vector_size: Maximum feature vector length across all data points. + """ + pt_file_path = os.path.join( + self.processed_dir, self.processed_file_names_dict["train"] + ) + data_pt = torch.load(pt_file_path, weights_only=False) + + self._num_of_labels = len(data_pt[0]["labels"]) + self._feature_vector_size = max(len(d["features"]) for d in data_pt) + + def _load_data_from_file(self, input_file_path: str) -> List[Dict]: + """Loads data from a CSV file. + + Args: + input_file_path (str): Path to the CSV file. + + Returns: + List[Dict]: List of data dictionaries. + """ + smiles_l = [] + labels_l = [] + with open(input_file_path, "r") as input_file: + reader = csv.DictReader(input_file) + for row in reader: + if row["smiles"] not in smiles_l: + smiles_l.append(row["smiles"]) + labels_l.append(float(row["logS"])) + # print(len(smiles_l), len(labels_l)) + # labels_l.append(np.floor(float(row["logS"]))) + # onehotencoding + # label_binarizer = LabelBinarizer() + # label_binarizer.fit(labels_l) + # onehot_label_l = label_binarizer.transform(labels_l) + + # normalise data to be between 0 and 1 + # labels_norm = [(float(label)-min(labels_l))/(max(labels_l)-min(labels_l)) for label in labels_l] + for i in range(0, len(smiles_l)): + yield self.reader.to_data( + dict(features=smiles_l[i], labels=[labels_l[i]], ident=i) + ) + + def _perform_data_preparation(self, *args, **kwargs) -> None: + pass + + +class SolCurationChem(SolCuration): + """Chemical data reader for the solubility dataset.""" + + READER = dr.ChemDataReader + + +class SolESOL(XYBaseDataModule): + HEADERS = [ + "logS", + ] + + @property + def _name(self): + return "SolESOL" + + @property + def label_number(self): + return 1 + + @property + def raw_file_names(self): + return ["solESOL.csv"] + + # @property + # def processed_file_names(self): + # return ["test.pt", "train.pt", "validation.pt"] + + @property + def processed_file_names_dict(self) -> dict: + return { + "test": "test.pt", + "train": "train.pt", + "validation": "validation.pt", + } + + def download(self): + # download + with open(os.path.join(self.raw_dir, "solESOL.csv"), "ab") as dst: + with request.urlopen( + "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv", + ) as src: + shutil.copyfileobj(src, dst) + + def setup_processed(self): + print("Create splits") + data = list( + self._load_data_from_file(os.path.join(self.raw_dir, "solESOL.csv")) + ) + print(len(data)) + + train_split, test_split = train_test_split( + data, test_size=self.test_split, shuffle=True + ) + train_split, validation_split = train_test_split( + train_split, test_size=self.validation_split, shuffle=True + ) + + if False: + train_split, test_split = train_test_split( + data, train_size=self.train_split, shuffle=True + ) + test_split, validation_split = train_test_split( + test_split, train_size=0.5, shuffle=True + ) + for k, split in [ + ("test", test_split), + ("train", train_split), + ("validation", validation_split), + ]: + print("transform", k) + torch.save( + split, + os.path.join(self.processed_dir, f"{k}.pt"), + ) + + def setup(self, **kwargs): + if any( + not os.path.isfile(os.path.join(self.raw_dir, f)) + for f in self.raw_file_names + ): + self.download() + print( + [ + not os.path.isfile(os.path.join(self.processed_dir, f)) + for f in self.processed_file_names + ] + ) + if any( + not os.path.isfile(os.path.join(self.processed_dir, f)) + for f in self.processed_file_names + ): + self.setup_processed() + + self._after_setup() + + def _set_processed_data_props(self): + """ + Load processed data and extract metadata. + + Sets: + - self._num_of_labels: Number of target labels in the dataset. + - self._feature_vector_size: Maximum feature vector length across all data points. + """ + pt_file_path = os.path.join( + self.processed_dir, self.processed_file_names_dict["train"] + ) + data_pt = torch.load(pt_file_path, weights_only=False) + + self._num_of_labels = len(data_pt[0]["labels"]) + self._feature_vector_size = max(len(d["features"]) for d in data_pt) + + def _load_dict(self, input_file_path: str) -> List[Dict]: + """Loads data from a CSV file. + + Args: + input_file_path (str): Path to the CSV file. + + Returns: + List[Dict]: List of data dictionaries. + """ + smiles_l = [] + labels_l = [] + with open(input_file_path, "r") as input_file: + reader = csv.DictReader(input_file) + print(reader.fieldnames) + for row in reader: + smiles_l.append(row["smiles"]) + labels_l.append(float(row["measured log solubility in mols per litre"])) + + for i in range(0, len(smiles_l)): + yield dict(features=smiles_l[i], labels=[labels_l[i]], ident=i) + # yield self.reader.to_data(dict(features=smiles_l[i], labels=[labels_l[i]], ident=i)) + + def _perform_data_preparation(self, *args, **kwargs) -> None: + pass + + +class SolESOLChem(SolESOL): + """Chemical data reader for the solubility dataset.""" + + READER = dr.ChemDataReader diff --git a/chebai/preprocessing/datasets/tox21.py b/chebai/preprocessing/datasets/tox21.py index 976c910e..709c620d 100644 --- a/chebai/preprocessing/datasets/tox21.py +++ b/chebai/preprocessing/datasets/tox21.py @@ -44,10 +44,18 @@ def raw_file_names(self) -> List[str]: """Returns a list of raw file names.""" return ["tox21.csv"] + # @property + # def processed_file_names(self) -> List[str]: + # """Returns a list of processed file names.""" + # return ["test.pt", "train.pt", "validation.pt"] + @property - def processed_file_names(self) -> List[str]: - """Returns a list of processed file names.""" - return ["test.pt", "train.pt", "validation.pt"] + def processed_file_names_dict(self) -> dict: + return { + "test": "test.pt", + "train": "train.pt", + "validation": "validation.pt", + } def download(self) -> None: """Downloads and extracts the dataset.""" @@ -63,10 +71,13 @@ def download(self) -> None: def setup_processed(self) -> None: """Processes and splits the dataset.""" print("Create splits") - data = self._load_data_from_file(os.path.join(self.raw_dir, "tox21.csv")) - groups = np.array([d["group"] for d in data]) + data = list(self._load_data_from_file(os.path.join(self.raw_dir, "tox21.csv"))) + groups = np.array([d.get("group") for d in data]) + if not all(g is None for g in groups): - split_size = int(len(set(groups)) * self.train_split) + split_size = int( + len(set(groups)) * (1 - self.test_split - self.validation_split) + ) os.makedirs(self.processed_dir, exist_ok=True) splitter = GroupShuffleSplit(train_size=split_size, n_splits=1) @@ -77,7 +88,11 @@ def setup_processed(self) -> None: split_groups = groups[temp_split_index] splitter = GroupShuffleSplit( - train_size=int(len(set(split_groups)) * self.train_split), n_splits=1 + train_size=int( + len(set(split_groups)) + * (1 - self.test_split - self.validation_split) + ), + n_splits=1, ) test_split_index, validation_split_index = next( splitter.split(temp_split_index, groups=split_groups) @@ -86,20 +101,21 @@ def setup_processed(self) -> None: test_split = [ d for d in (data[temp_split_index[i]] for i in test_split_index) - if d["original"] + # if d["original"] ] validation_split = [ d for d in (data[temp_split_index[i]] for i in validation_split_index) - if d["original"] + # if d["original"] ] else: train_split, test_split = train_test_split( - data, train_size=self.train_split, shuffle=True + data, test_size=self.test_split, shuffle=True ) - test_split, validation_split = train_test_split( - test_split, train_size=0.5, shuffle=True + train_split, validation_split = train_test_split( + train_split, test_size=self.validation_split, shuffle=True ) + for k, split in [ ("test", test_split), ("train", train_split), @@ -128,9 +144,10 @@ def setup(self, **kwargs) -> None: ): self.setup_processed() - self._set_processed_data_props() + # self._set_processed_data_props() + self._after_setup() - def _load_data_from_file(self, input_file_path: str) -> List[Dict]: + def _load_dict(self, input_file_path: str) -> List[Dict]: """Loads data from a CSV file. Args: @@ -144,10 +161,36 @@ def _load_data_from_file(self, input_file_path: str) -> List[Dict]: for row in reader: smiles = row["smiles"] labels = [ - bool(int(line)) if line else None - for line in (row[k] for k in self.HEADERS) + bool(int(float(label))) if len(label) > 1 else None + for label in (row[k] for k in self.HEADERS) ] - yield dict(features=smiles, labels=labels, ident=row["mol_id"]) + # group = int(row["group"]) + yield dict( + features=smiles, + labels=labels, + ident=row["mol_id"], + # group=group + ) + # yield self.reader.to_data(dict(features=smiles, labels=labels, ident=row["mol_id"])) + + def _set_processed_data_props(self): + """ + Load processed data and extract metadata. + + Sets: + - self._num_of_labels: Number of target labels in the dataset. + - self._feature_vector_size: Maximum feature vector length across all data points. + """ + pt_file_path = os.path.join( + self.processed_dir, self.processed_file_names_dict["train"] + ) + data_pt = torch.load(pt_file_path, weights_only=False) + + self._num_of_labels = len(data_pt[0]["labels"]) + self._feature_vector_size = max(len(d["features"]) for d in data_pt) + + def _perform_data_preparation(self, *args, **kwargs) -> None: + pass class Tox21Challenge(XYBaseDataModule): diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 22b91a0e..2b3b1b0e 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -94,13 +94,18 @@ def _read_group(self, raw: Any) -> Any: return raw def _read_components(self, row: Dict[str, Any]) -> Dict[str, Any]: - """Read and return components from the row.""" + """Read and return components from the row. If the data contains any missing labels (`None`), they are tracked + under the additional `missing_labels` keyword.""" + labels = self._get_raw_label(row) + additional_kwargs = self._get_additional_kwargs(row) + if any(label is None for label in labels): + additional_kwargs["missing_labels"] = [label is None for label in labels] return dict( features=self._get_raw_data(row), - labels=self._get_raw_label(row), + labels=labels, ident=self._get_raw_id(row), group=self._get_raw_group(row), - additional_kwargs=self._get_additional_kwargs(row), + additional_kwargs=additional_kwargs, ) def to_data(self, row: Dict[str, Any]) -> Dict[str, Any]: @@ -209,6 +214,24 @@ def _read_data(self, raw_data: str) -> List[int]: print(f"\t{e}") return None + def _back_to_smiles(self, smiles_encoded): + + token_file = self.reader.token_path + token_coding = {} + counter = 0 + smiles_decoded = "" + + # todo: for now just copied over from a notebook but ideally do this using the cache + with open(token_file, "r") as file: + for line in file: + token_coding[counter] = line.strip() + counter += 1 + + for token in smiles_encoded: + smiles_decoded += token_coding[token - EMBEDDING_OFFSET] + + return smiles_decoded + class DeepChemDataReader(ChemDataReader): """ diff --git a/chebai/result/classification.py b/chebai/result/classification.py index eff8662c..ab8b1e2d 100644 --- a/chebai/result/classification.py +++ b/chebai/result/classification.py @@ -10,6 +10,11 @@ MultilabelF1Score, MultilabelPrecision, MultilabelRecall, + MultilabelAUROC, + BinaryF1Score, + BinaryAUROC, + BinaryAveragePrecision, + MultilabelAveragePrecision, ) from chebai.callbacks.epoch_metrics import BalancedAccuracy, MacroF1 @@ -59,6 +64,8 @@ def print_metrics( top_k: The number of top classes to display based on F1 score. markdown_output: If True, print metrics in markdown format. """ + if device != labels.device: + device = labels.device f1_micro = MultilabelF1Score(preds.shape[1], average="micro").to(device=device) my_f1_macro = MacroF1(preds.shape[1]).to(device=device) my_bal_acc = BalancedAccuracy(preds.shape[1]).to(device=device) @@ -106,3 +113,52 @@ def print_metrics( print( f"Found {len(zeros)} classes with F1-score == 0 (and non-zero labels): {', '.join(zeros)}" ) + + +def metrics_classification_multilabel( + preds: Tensor, + labels: Tensor, + device: torch.device, +): + + if device != labels.device: + device = labels.device + + my_bal_acc = BalancedAccuracy(preds.shape[1]).to(device=device) + + bal_acc = my_bal_acc(preds, labels).cpu().numpy() + my_f1_macro = MultilabelF1Score(preds.shape[1], average="micro").to(device=device) + f1_micro = MacroF1(preds.shape[1]).to(device=device) + my_auc_roc = MultilabelAUROC(preds.shape[1]).to(device=device) + my_av_prec = MultilabelAveragePrecision(preds.shape[1]).to(device=device) + + macro_f1 = my_f1_macro(preds, labels).cpu().numpy() + micro_f1 = f1_micro(preds, labels).cpu().numpy() + auc_roc = my_auc_roc(preds, labels).cpu().numpy() + prc_auc = my_av_prec(preds, labels).cpu().numpy() + + return auc_roc, macro_f1, micro_f1, bal_acc, prc_auc + + +def metrics_classification_binary( + preds: Tensor, + labels: Tensor, + device: torch.device, +): + + if device != labels.device: + device = labels.device + + my_auc_roc = BinaryAUROC() + my_f1 = BinaryF1Score().to(device=device) + my_av_prec = BinaryAveragePrecision().to(device=device) + my_bal_acc = BalancedAccuracy(preds.shape[1]).to(device=device) + + bal_acc = my_bal_acc(preds, labels).cpu().numpy() + auc_roc = my_auc_roc(preds, labels).cpu().numpy() + # my_auc_roc.update(preds.cpu()[:, 0], labels.cpu()[:, 0]) + # auc_roc = my_auc_roc.compute().numpy() + f1_score = my_f1(preds, labels).cpu().numpy() + prc_auc = my_av_prec(preds, labels).cpu().numpy() + + return auc_roc, f1_score, bal_acc, prc_auc diff --git a/chebai/result/molplot.py b/chebai/result/molplot.py index 6f8d1e79..2c548a26 100644 --- a/chebai/result/molplot.py +++ b/chebai/result/molplot.py @@ -11,7 +11,8 @@ from networkx.algorithms.isomorphism import GraphMatcher from pysmiles.read_smiles import LOGGER, TokenType, _tokenize from rdkit import Chem -from rdkit.Chem.Draw import MolToMPL, rdMolDraw2D +from rdkit.Chem.Draw import MolToMPL # , rdMolDraw2D +from rdkit.Chem.Draw import rdMolDraw2D from chebai.preprocessing.datasets import JCI_500_COLUMNS_INT from chebai.result.base import ResultProcessor diff --git a/chebai/result/regression.py b/chebai/result/regression.py new file mode 100644 index 00000000..ed660f12 --- /dev/null +++ b/chebai/result/regression.py @@ -0,0 +1,68 @@ +import torch +from torch import Tensor +from torchmetrics.regression import MeanSquaredError + +# from chebai.callbacks.epoch_metrics import BalancedAccuracy, MacroF1 +# from chebai.result.utils import * + +# def visualise_f1(logs_path: str) -> None: +# """ +# Visualize F1 scores from metrics.csv and save the plot as f1_plot.png. + +# Args: +# logs_path: The path to the directory containing metrics.csv. +# """ +# df = pd.read_csv(os.path.join(logs_path, "metrics.csv")) +# df_loss = df.melt( +# id_vars="epoch", +# value_vars=[ +# "val_ep_macro-f1", +# "val_micro-f1", +# "train_micro-f1", +# "train_ep_macro-f1", +# ], +# ) +# lineplt = sns.lineplot(df_loss, x="epoch", y="value", hue="variable") +# plt.savefig(os.path.join(logs_path, "f1_plot.png")) +# plt.show() + + +def metrics_regression( + preds: Tensor, + labels: Tensor, + device: torch.device, + markdown_output: bool = False, +) -> None: + """ + Prints relevant metrics, including micro and macro F1, recall and precision, + best k classes, and worst classes. + + Args: + preds: Predicted labels as a tensor. + labels: True labels as a tensor. + device: The device to perform computations on. + classes: Optional list of class names. + top_k: The number of top classes to display based on F1 score. + markdown_output: If True, print metrics in markdown format. + """ + mse = MeanSquaredError() + mse = mse.to(labels.device) + + rmse = MeanSquaredError(squared=False) + rmse = rmse.to(labels.device) + + return (mse(preds, labels), rmse(preds, labels)) + + # print(f"Micro-F1: {f1_micro(preds, labels):3f}") + # print(f"Balanced Accuracy: {my_bal_acc(preds, labels):3f}") + + # if markdown_output: + # print( + # f"| Model | MSE | RMSE | Macro-Precision | Micro-Precision | Macro-Recall | Micro-Recall | Balanced Accuracy" + # ) + # print(f"| --- | --- | --- | --- | --- | --- | --- | --- |") + # print( + # f"| Elektra | {my_f1_macro(preds, labels):3f} | {f1_micro(preds, labels):3f} | {precision_macro(preds, labels):3f} | " + # f"{precision_micro(preds, labels):3f} | {recall_macro(preds, labels):3f} | " + # f"{recall_micro(preds, labels):3f} | {my_bal_acc(preds, labels):3f} |" + # ) diff --git a/chebai/result/utils.py b/chebai/result/utils.py index 27ff1783..55549297 100644 --- a/chebai/result/utils.py +++ b/chebai/result/utils.py @@ -62,7 +62,7 @@ def _run_batch(batch, model, collate): if collated.y is not None: collated.y = collated.to_y(model.device) processable_data = model._process_batch(collated, 0) - del processable_data["loss_kwargs"] + # del processable_data["loss_kwargs"] model_output = model(processable_data, **processable_data["model_kwargs"]) preds, labels = model._get_prediction_and_labels( processable_data, processable_data["labels"], model_output @@ -70,6 +70,20 @@ def _run_batch(batch, model, collate): return preds, labels +def _run_batch_give_attention(batch, model, collate): + collated = collate(batch) + collated.x = collated.to_x(model.device) + if collated.y is not None: + collated.y = collated.to_y(model.device) + processable_data = model._process_batch(collated, 0) + # del processable_data["loss_kwargs"] + model_output = model(processable_data, **processable_data["model_kwargs"]) + preds, labels = model._get_prediction_and_labels( + processable_data, processable_data["labels"], model_output + ) + return preds, labels, model_output + + def _concat_tuple(l_): if isinstance(l_[0], tuple): print(l_[0]) @@ -87,7 +101,7 @@ def evaluate_model( kind: str = "test", ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ - Runs the model on the test set of the data module or on the dataset found in the specified file. + Runs a classification model on the test set of the data module or on the dataset found in the specified file. If buffer_dir is set, results will be saved in buffer_dir. Note: @@ -105,6 +119,7 @@ def evaluate_model( Returns: Tensors with predictions and labels. """ + assert model.model_type == "classification" model.eval() collate = data_module.reader.COLLATOR() @@ -158,16 +173,422 @@ def evaluate_model( return test_preds, test_labels return test_preds, None elif len(preds_list) > 0: - if len(preds_list) > 0 and preds_list[0] is not None: + if preds_list[0] is not None: torch.save( _concat_tuple(preds_list), os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), ) - if len(labels_list) > 0 and labels_list[0] is not None: + if labels_list[0] is not None: + torch.save( + _concat_tuple(labels_list), + os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), + ) + return torch.cat(preds_list), torch.cat(labels_list) + + +def evaluate_model_regression( + model: ChebaiBaseNet, + data_module: XYBaseDataModule, + filename: Optional[str] = None, + buffer_dir: Optional[str] = None, + batch_size: int = 32, + skip_existing_preds: bool = False, + kind: str = "test", +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Runs a regression model on the test set of the data module or on the dataset found in the specified file. + If buffer_dir is set, results will be saved in buffer_dir. + + Note: + No need to provide "filename" parameter for Chebi dataset, "kind" parameter should be provided. + + Args: + model: The model to evaluate. + data_module: The data module containing the dataset. + filename: Optional file name for the dataset. + buffer_dir: Optional directory to save the results. + batch_size: The batch size for evaluation. + skip_existing_preds: Whether to skip evaluation if predictions already exist. + kind: Kind of split of the data to be used for testing the model. Default is `test`. + + Returns: + Tensors with predictions and labels. + """ + model.eval() + collate = data_module.reader.COLLATOR() + + if isinstance(data_module, _ChEBIDataExtractor): + # As the dynamic split change is implemented only for chebi-dataset as of now + data_df = data_module.dynamic_split_dfs[kind] + data_list = data_df.to_dict(orient="records") + else: + data_list = data_module.load_processed_data("test", filename) + data_list = data_list[: data_module.data_limit] + preds_list = [] + labels_list = [] + preds_list_all = [] + labels_list_all = [] + if buffer_dir is not None: + os.makedirs(buffer_dir, exist_ok=True) + save_ind = 0 + save_batch_size = 128 + n_saved = 1 + + print("") + for i in tqdm.tqdm(range(0, len(data_list), batch_size)): + if not ( + skip_existing_preds + and os.path.isfile(os.path.join(buffer_dir, f"preds{save_ind:03d}.pt")) + ): + preds, labels = _run_batch(data_list[i : i + batch_size], model, collate) + preds_list.append(preds) + labels_list.append(labels) + preds_list_all.append(preds) + labels_list_all.append(labels) + if buffer_dir is not None: + if n_saved * batch_size >= save_batch_size: + torch.save( + _concat_tuple(preds_list), + os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), + ) + if labels_list[0] is not None: + torch.save( + _concat_tuple(labels_list), + os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), + ) + preds_list = [] + labels_list = [] + if n_saved * batch_size >= save_batch_size: + save_ind += 1 + n_saved = 0 + n_saved += 1 + + if buffer_dir is None: + test_preds = _concat_tuple(preds_list) + if labels_list is not None: + test_labels = _concat_tuple(labels_list) + + return test_preds, test_labels + return test_preds, None + else: + torch.save( + _concat_tuple(preds_list), + os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), + ) + if labels_list[0] is not None: + torch.save( + _concat_tuple(labels_list), + os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), + ) + return torch.cat(preds_list_all), torch.cat(labels_list_all) + + +def evaluate_model_regression_attention( + model: ChebaiBaseNet, + data_module: XYBaseDataModule, + filename: Optional[str] = None, + buffer_dir: Optional[str] = None, + batch_size: int = 32, + skip_existing_preds: bool = False, + kind: str = "test", +) -> Tuple[torch.Tensor, Optional[torch.Tensor], list, list]: + """ + Runs the model on the test set of the data module or on the dataset found in the specified file. + If buffer_dir is set, results will be saved in buffer_dir. + + Note: + No need to provide "filename" parameter for Chebi dataset, "kind" parameter should be provided. + + Args: + model: The model to evaluate. + data_module: The data module containing the dataset. + filename: Optional file name for the dataset. + buffer_dir: Optional directory to save the results. + batch_size: The batch size for evaluation. + skip_existing_preds: Whether to skip evaluation if predictions already exist. + kind: Kind of split of the data to be used for testing the model. Default is `test`. + + Returns: + Tensors with predictions and labels. + """ + model.eval() + collate = data_module.reader.COLLATOR() + + if isinstance(data_module, _ChEBIDataExtractor): + # As the dynamic split change is implemented only for chebi-dataset as of now + data_df = data_module.dynamic_split_dfs[kind] + data_list = data_df.to_dict(orient="records") + else: + data_list = data_module.load_processed_data("test", filename) + data_list = data_list[: data_module.data_limit] + preds_list = [] + labels_list = [] + preds_list_all = [] + labels_list_all = [] + features_list_all = [] + attention_list_all = [] + if buffer_dir is not None: + os.makedirs(buffer_dir, exist_ok=True) + save_ind = 0 + save_batch_size = 128 + n_saved = 1 + + print("") + for i in tqdm.tqdm(range(0, len(data_list), batch_size)): + if not ( + skip_existing_preds + and os.path.isfile(os.path.join(buffer_dir, f"preds{save_ind:03d}.pt")) + ): + preds, labels, model_output = _run_batch_give_attention( + data_list[i : i + batch_size], model, collate + ) + preds_list.append(preds) + labels_list.append(labels) + preds_list_all.append(preds) + labels_list_all.append(labels) + attention_list_all.append(model_output) + features_list_all.append(data_list[i : i + batch_size]) + if buffer_dir is not None: + if n_saved * batch_size >= save_batch_size: + torch.save( + _concat_tuple(preds_list), + os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), + ) + if labels_list[0] is not None: + torch.save( + _concat_tuple(labels_list), + os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), + ) + preds_list = [] + labels_list = [] + if n_saved * batch_size >= save_batch_size: + save_ind += 1 + n_saved = 0 + n_saved += 1 + + if buffer_dir is None: + test_preds = _concat_tuple(preds_list) + if labels_list is not None: + test_labels = _concat_tuple(labels_list) + + return test_preds, test_labels, features_list_all, attention_list_all + return test_preds, None + else: + torch.save( + _concat_tuple(preds_list), + os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), + ) + if labels_list[0] is not None: torch.save( _concat_tuple(labels_list), os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), ) + return torch.cat(preds_list_all), torch.cat(labels_list_all) + + +# def evaluate_model_regression( +# model: ChebaiBaseNet, +# data_module: XYBaseDataModule, +# filename: Optional[str] = None, +# buffer_dir: Optional[str] = None, +# batch_size: int = 32, +# skip_existing_preds: bool = False, +# kind: str = "test", +# ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: +# """ +# Runs the model on the test set of the data module or on the dataset found in the specified file. +# If buffer_dir is set, results will be saved in buffer_dir. + +# Note: +# No need to provide "filename" parameter for Chebi dataset, "kind" parameter should be provided. + +# Args: +# model: The model to evaluate. +# data_module: The data module containing the dataset. +# filename: Optional file name for the dataset. +# buffer_dir: Optional directory to save the results. +# batch_size: The batch size for evaluation. +# skip_existing_preds: Whether to skip evaluation if predictions already exist. +# kind: Kind of split of the data to be used for testing the model. Default is `test`. + +# Returns: +# Tensors with predictions and labels. +# """ +# model.eval() +# collate = data_module.reader.COLLATOR() + +# if isinstance(data_module, _ChEBIDataExtractor): +# # As the dynamic split change is implemented only for chebi-dataset as of now +# data_df = data_module.dynamic_split_dfs[kind] +# data_list = data_df.to_dict(orient="records") +# else: +# data_list = data_module.load_processed_data("test", filename) +# data_list = data_list[: data_module.data_limit] +# preds_list = [] +# labels_list = [] +# preds_list_all = [] +# labels_list_all = [] +# if buffer_dir is not None: +# os.makedirs(buffer_dir, exist_ok=True) +# save_ind = 0 +# save_batch_size = 128 +# n_saved = 1 + +# print("") +# for i in tqdm.tqdm(range(0, len(data_list), batch_size)): +# if not ( +# skip_existing_preds +# and os.path.isfile(os.path.join(buffer_dir, f"preds{save_ind:03d}.pt")) +# ): +# preds, labels = _run_batch(data_list[i : i + batch_size], model, collate) +# preds_list.append(preds) +# labels_list.append(labels) +# preds_list_all.append(preds) +# labels_list_all.append(labels) +# if buffer_dir is not None: +# if n_saved * batch_size >= save_batch_size: +# torch.save( +# _concat_tuple(preds_list), +# os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), +# ) +# if labels_list[0] is not None: +# torch.save( +# _concat_tuple(labels_list), +# os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), +# ) +# preds_list = [] +# labels_list = [] +# if n_saved * batch_size >= save_batch_size: +# save_ind += 1 +# n_saved = 0 +# n_saved += 1 + +# if buffer_dir is None: +# test_preds = _concat_tuple(preds_list) +# if labels_list is not None: +# test_labels = _concat_tuple(labels_list) + +# return test_preds, test_labels +# return test_preds, None +# else: +# torch.save( +# _concat_tuple(preds_list), +# os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), +# ) +# if labels_list[0] is not None: +# torch.save( +# _concat_tuple(labels_list), +# os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), +# ) +# return torch.cat(preds_list_all), torch.cat(labels_list_all) + + +# def evaluate_model_regression_attention( +# model: ChebaiBaseNet, +# data_module: XYBaseDataModule, +# filename: Optional[str] = None, +# buffer_dir: Optional[str] = None, +# batch_size: int = 32, +# skip_existing_preds: bool = False, +# kind: str = "test", +# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], list, list]: +# """ +# Runs the model on the test set of the data module or on the dataset found in the specified file. +# If buffer_dir is set, results will be saved in buffer_dir. + +# Note: +# No need to provide "filename" parameter for Chebi dataset, "kind" parameter should be provided. + +# Args: +# model: The model to evaluate. +# data_module: The data module containing the dataset. +# filename: Optional file name for the dataset. +# buffer_dir: Optional directory to save the results. +# batch_size: The batch size for evaluation. +# skip_existing_preds: Whether to skip evaluation if predictions already exist. +# kind: Kind of split of the data to be used for testing the model. Default is `test`. + +# Returns: +# Tensors with predictions and labels. +# """ +# model.eval() +# collate = data_module.reader.COLLATOR() + +# if isinstance(data_module, _ChEBIDataExtractor): +# # As the dynamic split change is implemented only for chebi-dataset as of now +# data_df = data_module.dynamic_split_dfs[kind] +# data_list = data_df.to_dict(orient="records") +# else: +# data_list = data_module.load_processed_data("test", filename) +# data_list = data_list[: data_module.data_limit] +# preds_list = [] +# labels_list = [] +# preds_list_all = [] +# labels_list_all = [] +# features_list_all = [] +# attention_list_all = [] +# if buffer_dir is not None: +# os.makedirs(buffer_dir, exist_ok=True) +# save_ind = 0 +# save_batch_size = 128 +# n_saved = 1 + +# print("") +# for i in tqdm.tqdm(range(0, len(data_list), batch_size)): +# if not ( +# skip_existing_preds +# and os.path.isfile(os.path.join(buffer_dir, f"preds{save_ind:03d}.pt")) +# ): +# preds, labels, model_output = _run_batch_give_attention( +# data_list[i : i + batch_size], model, collate +# ) +# preds_list.append(preds) +# labels_list.append(labels) +# preds_list_all.append(preds) +# labels_list_all.append(labels) +# attention_list_all.append(model_output) +# features_list_all.append(data_list[i : i + batch_size]) +# if buffer_dir is not None: +# if n_saved * batch_size >= save_batch_size: +# torch.save( +# _concat_tuple(preds_list), +# os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), +# ) +# if labels_list[0] is not None: +# torch.save( +# _concat_tuple(labels_list), +# os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), +# ) +# preds_list = [] +# labels_list = [] +# if n_saved * batch_size >= save_batch_size: +# save_ind += 1 +# n_saved = 0 +# n_saved += 1 + +# if buffer_dir is None: +# test_preds = _concat_tuple(preds_list) +# if labels_list is not None: +# test_labels = _concat_tuple(labels_list) + +# return test_preds, test_labels, features_list_all, attention_list_all +# return test_preds, None +# else: +# torch.save( +# _concat_tuple(preds_list), +# os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), +# ) +# if labels_list[0] is not None: +# torch.save( +# _concat_tuple(labels_list), +# os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), +# ) +# return ( +# torch.cat(preds_list_all), +# torch.cat(labels_list_all), +# features_list_all, +# attention_list_all, +# ) def load_results_from_buffer( diff --git a/chebai/train.py b/chebai/train.py index bab5f089..883f2263 100644 --- a/chebai/train.py +++ b/chebai/train.py @@ -127,8 +127,9 @@ def _execute( Returns: - train_running_loss (float): Average loss over the data. - - f1 (float): Average F1 score over the data. + - f1 (float): Average F1 score over the data. -> so this is for classification tasks only? """ + train_running_loss = 0.0 data_size = 0 f1 = 0 diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index f7fbce26..e93cff85 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -130,9 +130,11 @@ def _predict_smiles( ) features = torch.cat((cls_tokens, x), dim=1) model_output = model({"features": features}) - preds = torch.sigmoid(model_output["logits"]) + if model.model_type == "regression": + preds = model_output["logits"] + else: + preds = torch.sigmoid(model_output["logits"]) - print(preds.shape) return preds @property diff --git a/configs/data/moleculenet/bace_moleculenet.yml b/configs/data/moleculenet/bace_moleculenet.yml new file mode 100644 index 00000000..e5d4bdb7 --- /dev/null +++ b/configs/data/moleculenet/bace_moleculenet.yml @@ -0,0 +1,5 @@ +class_path: chebai.preprocessing.datasets.molecule_classification.BaceChem +init_args: + batch_size: 32 + validation_split: 0.05 + test_split: 0.15 \ No newline at end of file diff --git a/configs/data/moleculenet/bbbp_moleculenet.yml b/configs/data/moleculenet/bbbp_moleculenet.yml new file mode 100644 index 00000000..01479443 --- /dev/null +++ b/configs/data/moleculenet/bbbp_moleculenet.yml @@ -0,0 +1,5 @@ +class_path: chebai.preprocessing.datasets.molecule_classification.BBBPChem +init_args: + batch_size: 32 + validation_split: 0.05 + test_split: 0.15 diff --git a/configs/data/moleculenet/clintox_moleculenet.yml b/configs/data/moleculenet/clintox_moleculenet.yml new file mode 100644 index 00000000..d7b7c3be --- /dev/null +++ b/configs/data/moleculenet/clintox_moleculenet.yml @@ -0,0 +1,5 @@ +class_path: chebai.preprocessing.datasets.molecule_classification.ClinToxChem +init_args: + batch_size: 32 + validation_split: 0.05 + test_split: 0.15 diff --git a/configs/data/moleculenet/freesolv_moleculenet.yml b/configs/data/moleculenet/freesolv_moleculenet.yml new file mode 100644 index 00000000..0378a0c0 --- /dev/null +++ b/configs/data/moleculenet/freesolv_moleculenet.yml @@ -0,0 +1,5 @@ +class_path: chebai.preprocessing.datasets.molecule_regression.FreeSolvChem +init_args: + batch_size: 32 + validation_split: 0.05 + test_split: 0.15 diff --git a/configs/data/moleculenet/hiv_moleculenet.yml b/configs/data/moleculenet/hiv_moleculenet.yml new file mode 100644 index 00000000..70c74434 --- /dev/null +++ b/configs/data/moleculenet/hiv_moleculenet.yml @@ -0,0 +1,5 @@ +class_path: chebai.preprocessing.datasets.molecule_classification.HIVChem +init_args: + batch_size: 32 + validation_split: 0.05 + test_split: 0.15 \ No newline at end of file diff --git a/configs/data/moleculenet/lipo_moleculenet.yml b/configs/data/moleculenet/lipo_moleculenet.yml new file mode 100644 index 00000000..c246db5b --- /dev/null +++ b/configs/data/moleculenet/lipo_moleculenet.yml @@ -0,0 +1,5 @@ +class_path: chebai.preprocessing.datasets.molecule_regression.LipoChem +init_args: + batch_size: 32 + validation_split: 0.05 + test_split: 0.15 \ No newline at end of file diff --git a/configs/data/moleculenet/muv_moleculenet.yml b/configs/data/moleculenet/muv_moleculenet.yml new file mode 100644 index 00000000..f4eba3e1 --- /dev/null +++ b/configs/data/moleculenet/muv_moleculenet.yml @@ -0,0 +1,5 @@ +class_path: chebai.preprocessing.datasets.molecule_classification.MUVChem +init_args: + batch_size: 32 + validation_split: 0.05 + test_split: 0.15 \ No newline at end of file diff --git a/configs/data/pubchem_kmeans.yml b/configs/data/moleculenet/pubchem_kmeans.yml similarity index 100% rename from configs/data/pubchem_kmeans.yml rename to configs/data/moleculenet/pubchem_kmeans.yml diff --git a/configs/data/moleculenet/sider_moleculenet.yml b/configs/data/moleculenet/sider_moleculenet.yml new file mode 100644 index 00000000..a1d635c5 --- /dev/null +++ b/configs/data/moleculenet/sider_moleculenet.yml @@ -0,0 +1,5 @@ +class_path: chebai.preprocessing.datasets.molecule_classification.SiderChem +init_args: + batch_size: 10 + validation_split: 0.05 + test_split: 0.15 \ No newline at end of file diff --git a/configs/data/moleculenet/solubilityCuration.yml b/configs/data/moleculenet/solubilityCuration.yml new file mode 100644 index 00000000..89145905 --- /dev/null +++ b/configs/data/moleculenet/solubilityCuration.yml @@ -0,0 +1,5 @@ +class_path: chebai.preprocessing.datasets.solCuration.SolCurationChem +init_args: + batch_size: 32 + validation_split: 0.05 + test_split: 0.15 diff --git a/configs/data/moleculenet/solubilityESOL.yml b/configs/data/moleculenet/solubilityESOL.yml new file mode 100644 index 00000000..a58c4ba0 --- /dev/null +++ b/configs/data/moleculenet/solubilityESOL.yml @@ -0,0 +1,5 @@ +class_path: chebai.preprocessing.datasets.solCuration.SolESOLChem +init_args: + batch_size: 32 + validation_split: 0.05 + test_split: 0.15 diff --git a/configs/data/tox21/tox21_moleculenet.yml b/configs/data/tox21/tox21_moleculenet.yml index 5579a829..1e8af70f 100644 --- a/configs/data/tox21/tox21_moleculenet.yml +++ b/configs/data/tox21/tox21_moleculenet.yml @@ -1,3 +1,5 @@ class_path: chebai.preprocessing.datasets.tox21.Tox21MolNetChem init_args: - batch_size: 10 + batch_size: 32 + validation_split: 0.05 + test_split: 0.15 diff --git a/configs/loss/bce.yml b/configs/loss/bce.yml index e2fc30b8..10135513 100644 --- a/configs/loss/bce.yml +++ b/configs/loss/bce.yml @@ -1 +1,3 @@ class_path: chebai.loss.bce_weighted.BCEWeighted +init_args: + beta: 1000 \ No newline at end of file diff --git a/configs/loss/bce_new.yml b/configs/loss/bce_new.yml new file mode 100644 index 00000000..f8fbe98d --- /dev/null +++ b/configs/loss/bce_new.yml @@ -0,0 +1 @@ +class_path: torch.nn.BCEWithLogitsLoss \ No newline at end of file diff --git a/configs/loss/bce_try.yml b/configs/loss/bce_try.yml new file mode 100644 index 00000000..ff8f9d4e --- /dev/null +++ b/configs/loss/bce_try.yml @@ -0,0 +1 @@ +class_path: torch.nn.BCELoss \ No newline at end of file diff --git a/configs/loss/focal_loss_12.yml b/configs/loss/focal_loss_12.yml new file mode 100644 index 00000000..0351a942 --- /dev/null +++ b/configs/loss/focal_loss_12.yml @@ -0,0 +1,4 @@ +class_path: chebai.loss.focal_loss.FocalLoss +init_args: + task_type: multi-label + num_classes: 12 \ No newline at end of file diff --git a/configs/loss/mae.yml b/configs/loss/mae.yml new file mode 100644 index 00000000..75e011be --- /dev/null +++ b/configs/loss/mae.yml @@ -0,0 +1 @@ +class_path: torch.nn.L1Loss \ No newline at end of file diff --git a/configs/loss/mse.yml b/configs/loss/mse.yml new file mode 100644 index 00000000..16fab1c8 --- /dev/null +++ b/configs/loss/mse.yml @@ -0,0 +1 @@ +class_path: torch.nn.MSELoss \ No newline at end of file diff --git a/configs/metrics/mae.yml b/configs/metrics/mae.yml new file mode 100644 index 00000000..323e5fb4 --- /dev/null +++ b/configs/metrics/mae.yml @@ -0,0 +1,5 @@ +class_path: torchmetrics.MetricCollection +init_args: + metrics: + mae: + class_path: torchmetrics.regression.MeanAbsoluteError \ No newline at end of file diff --git a/configs/metrics/micro-macro-f1-roc-auc-17.yml b/configs/metrics/micro-macro-f1-roc-auc-17.yml new file mode 100644 index 00000000..a730c129 --- /dev/null +++ b/configs/metrics/micro-macro-f1-roc-auc-17.yml @@ -0,0 +1,13 @@ +class_path: torchmetrics.MetricCollection +init_args: + metrics: + micro-f1: + class_path: torchmetrics.classification.MultilabelF1Score + init_args: + average: micro + macro-f1: + class_path: chebai.callbacks.epoch_metrics.MacroF1 + roc-auc: + class_path: torchmetrics.classification.MultilabelAUROC + init_args: + num_labels: 17 diff --git a/configs/metrics/micro-macro-f1-roc-auc-17_test.yml b/configs/metrics/micro-macro-f1-roc-auc-17_test.yml new file mode 100644 index 00000000..0a42fb0e --- /dev/null +++ b/configs/metrics/micro-macro-f1-roc-auc-17_test.yml @@ -0,0 +1,22 @@ +class_path: torchmetrics.MetricCollection +init_args: + metrics: + micro-f1: + class_path: torchmetrics.classification.MultilabelF1Score + init_args: + average: micro + num_labels: 17 + macro-f1: + class_path: chebai.callbacks.epoch_metrics.MacroF1 + roc-auc: + class_path: torchmetrics.classification.MultilabelAUROC + init_args: + num_labels: 17 + precision: + class_path: torchmetrics.classification.MultilabelPrecision + init_args: + num_labels: 17 + recall: + class_path: torchmetrics.classification.MultilabelRecall + init_args: + num_labels: 17 \ No newline at end of file diff --git a/configs/metrics/micro-macro-f1-roc-auc-2.yml b/configs/metrics/micro-macro-f1-roc-auc-2.yml new file mode 100644 index 00000000..d69bf123 --- /dev/null +++ b/configs/metrics/micro-macro-f1-roc-auc-2.yml @@ -0,0 +1,13 @@ +class_path: torchmetrics.MetricCollection +init_args: + metrics: + micro-f1: + class_path: torchmetrics.classification.MultilabelF1Score + init_args: + average: micro + macro-f1: + class_path: chebai.callbacks.epoch_metrics.MacroF1 + roc-auc: + class_path: torchmetrics.classification.MultilabelAUROC + init_args: + num_labels: 2 diff --git a/configs/metrics/micro-macro-f1-roc-auc-27.yml b/configs/metrics/micro-macro-f1-roc-auc-27.yml new file mode 100644 index 00000000..81b2b091 --- /dev/null +++ b/configs/metrics/micro-macro-f1-roc-auc-27.yml @@ -0,0 +1,13 @@ +class_path: torchmetrics.MetricCollection +init_args: + metrics: + micro-f1: + class_path: torchmetrics.classification.MultilabelF1Score + init_args: + average: micro + macro-f1: + class_path: chebai.callbacks.epoch_metrics.MacroF1 + roc-auc: + class_path: torchmetrics.classification.MultilabelAUROC + init_args: + num_labels: 27 diff --git a/configs/metrics/micro-macro-f1-roc-auc-binary.yml b/configs/metrics/micro-macro-f1-roc-auc-binary.yml new file mode 100644 index 00000000..05834343 --- /dev/null +++ b/configs/metrics/micro-macro-f1-roc-auc-binary.yml @@ -0,0 +1,7 @@ +class_path: torchmetrics.MetricCollection +init_args: + metrics: + f1: + class_path: torchmetrics.classification.BinaryF1Score + roc-auc: + class_path: torchmetrics.classification.BinaryAUROC diff --git a/configs/metrics/micro-macro-f1-roc-auc.yml b/configs/metrics/micro-macro-f1-roc-auc.yml new file mode 100644 index 00000000..18ddfff1 --- /dev/null +++ b/configs/metrics/micro-macro-f1-roc-auc.yml @@ -0,0 +1,13 @@ +class_path: torchmetrics.MetricCollection +init_args: + metrics: + micro-f1: + class_path: torchmetrics.classification.MultilabelF1Score + init_args: + average: micro + macro-f1: + class_path: chebai.callbacks.epoch_metrics.MacroF1 + roc-auc: + class_path: torchmetrics.classification.MultilabelAUROC + init_args: + num_labels: 12 diff --git a/configs/metrics/mse-rmse-r2.yml b/configs/metrics/mse-rmse-r2.yml new file mode 100644 index 00000000..ad7bb53f --- /dev/null +++ b/configs/metrics/mse-rmse-r2.yml @@ -0,0 +1,11 @@ +class_path: torchmetrics.MetricCollection +init_args: + metrics: + mse: + class_path: torchmetrics.regression.MeanSquaredError + rmse: + class_path: torchmetrics.regression.MeanSquaredError + init_args: + squared: True + r2: + class_path: torchmetrics.regression.R2Score \ No newline at end of file diff --git a/configs/metrics/mse.yml b/configs/metrics/mse.yml new file mode 100644 index 00000000..1914442e --- /dev/null +++ b/configs/metrics/mse.yml @@ -0,0 +1,5 @@ +class_path: torchmetrics.MetricCollection +init_args: + metrics: + mse: + class_path: torchmetrics.regression.MeanSquaredError \ No newline at end of file diff --git a/configs/model/OPT_experiments/electra_LR.yml b/configs/model/OPT_experiments/electra_LR.yml new file mode 100644 index 00000000..5e12a0ae --- /dev/null +++ b/configs/model/OPT_experiments/electra_LR.yml @@ -0,0 +1,12 @@ +class_path: chebai.models.Electra +init_args: + model_type: classification + optimizer_kwargs: + lr: 1e-5 + config: + vocab_size: 1400 + max_position_embeddings: 1800 + num_attention_heads: 8 + num_hidden_layers: 6 + type_vocab_size: 1 + hidden_size: 256 diff --git a/configs/model/OPT_experiments/electra_tox_expl.yml b/configs/model/OPT_experiments/electra_tox_expl.yml new file mode 100644 index 00000000..e17ad570 --- /dev/null +++ b/configs/model/OPT_experiments/electra_tox_expl.yml @@ -0,0 +1,15 @@ +class_path: chebai.models.Electra +init_args: + model_type: classification + optimizer_kwargs: + lr: 1e-4 + weight_decay: 0.0001 + config: + vocab_size: 1400 + max_position_embeddings: 1800 + num_attention_heads: 8 + num_hidden_layers: 6 + type_vocab_size: 1 + hidden_size: 256 + hidden_dropout_prob: 0.4 + word_dropout: 0.2 diff --git a/configs/model/OPT_experiments/electra_tox_paper.yml b/configs/model/OPT_experiments/electra_tox_paper.yml new file mode 100644 index 00000000..9f6797c8 --- /dev/null +++ b/configs/model/OPT_experiments/electra_tox_paper.yml @@ -0,0 +1,15 @@ +class_path: chebai.models.Electra +init_args: + model_type: classification + optimizer_kwargs: + lr: 1e-4 + # weight_decay: 0.0001 + config: + vocab_size: 1400 + max_position_embeddings: 1800 + num_attention_heads: 8 + num_hidden_layers: 6 + type_vocab_size: 1 + hidden_size: 256 + hidden_dropout_prob: 0.4 + word_dropout: 0.2 diff --git a/configs/model/OPT_experiments/electra_tox_paper_regression.yml b/configs/model/OPT_experiments/electra_tox_paper_regression.yml new file mode 100644 index 00000000..640c7ba0 --- /dev/null +++ b/configs/model/OPT_experiments/electra_tox_paper_regression.yml @@ -0,0 +1,15 @@ +class_path: chebai.models.Electra +init_args: + model_type: regression + optimizer_kwargs: + lr: 1e-4 + # weight_decay: 0.0001 + config: + vocab_size: 1400 + max_position_embeddings: 1800 + num_attention_heads: 8 + num_hidden_layers: 6 + type_vocab_size: 1 + hidden_size: 256 + hidden_dropout_prob: 0.4 + word_dropout: 0.2 diff --git a/configs/model/electra.yml b/configs/model/electra.yml index 663a8fa1..053f5d65 100644 --- a/configs/model/electra.yml +++ b/configs/model/electra.yml @@ -1,11 +1,12 @@ class_path: chebai.models.Electra init_args: + model_type: regression optimizer_kwargs: - lr: 1e-3 + lr: 1e-4 config: vocab_size: 4400 max_position_embeddings: 1800 num_attention_heads: 8 num_hidden_layers: 6 type_vocab_size: 1 - hidden_size: 256 + hidden_size: 256 \ No newline at end of file diff --git a/configs/model/electra_tox.yml b/configs/model/electra_tox.yml new file mode 100644 index 00000000..fbba5993 --- /dev/null +++ b/configs/model/electra_tox.yml @@ -0,0 +1,13 @@ +class_path: chebai.models.Electra +init_args: + model_type: classification + optimizer_kwargs: + lr: 1e-4 + config: + vocab_size: 1400 + max_position_embeddings: 1800 + num_attention_heads: 8 + num_hidden_layers: 6 + type_vocab_size: 1 + hidden_size: 256 + out_dim: 12 diff --git a/configs/training/binary_callbacks.yml b/configs/training/binary_callbacks.yml new file mode 100644 index 00000000..013b8c77 --- /dev/null +++ b/configs/training/binary_callbacks.yml @@ -0,0 +1,43 @@ +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint + init_args: + monitor: val_f1 + mode: 'max' + filename: 'best_f1_{epoch:02d}_{val_loss:.4f}_{val_f1:.4f}_{val_roc-auc:.4f}' + every_n_epochs: 1 + save_top_k: 1 +# - class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint +# init_args: +# monitor: val_loss +# mode: 'min' +# filename: 'best_loss_{epoch:02d}_{val_loss:.4f}_{val_f1:.4f}_{val_roc-auc:.4f}' +# every_n_epochs: 1 +# save_top_k: 1 +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint + init_args: + monitor: val_roc-auc + mode: 'max' + filename: 'best_roc-auc_{epoch:02d}_{val_loss:.4f}_{val_f1:.4f}_{val_roc-auc:.4f}' + every_n_epochs: 1 + save_top_k: 1 +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint + init_args: + filename: 'per_{epoch:02d}_{val_loss:.4f}_{val_f1:.4f}_{val_roc-auc:.4f}' + every_n_epochs: 25 + save_top_k: 1 + +# - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping +# init_args: +# monitor: "val_roc-auc" +# min_delta: 0.0 +# patience: 5 +# verbose: False +# mode: "max" + + +# - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping +# init_args: +# monitor: "val_loss_epoch" +# min_delta: 0.0 +# patience: 10 +# verbose: False +# mode: "min" diff --git a/configs/training/binary_trainer.yml b/configs/training/binary_trainer.yml new file mode 100644 index 00000000..5787a67c --- /dev/null +++ b/configs/training/binary_trainer.yml @@ -0,0 +1,5 @@ +min_epochs: 20 +max_epochs: 100 +default_root_dir: &default_root_dir logs +logger: csv_logger.yml +callbacks: binary_callbacks.yml diff --git a/configs/training/default_callbacks.yml b/configs/training/default_callbacks.yml index ade7d149..ee76e0d5 100644 --- a/configs/training/default_callbacks.yml +++ b/configs/training/default_callbacks.yml @@ -2,11 +2,32 @@ init_args: monitor: val_micro-f1 mode: 'max' - filename: 'best_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}' + filename: 'best_micro_f1_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}_{val_roc-auc:.4f}' every_n_epochs: 1 - save_top_k: 3 + save_top_k: 1 +# - class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint +# init_args: +# monitor: val_loss +# mode: 'min' +# filename: 'best_loss_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}_{val_roc-auc:.4f}' +# every_n_epochs: 1 +# save_top_k: 1 - class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint init_args: - filename: 'per_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}' + monitor: val_roc-auc + mode: 'max' + filename: 'best_roc-auc_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}_{val_roc-auc:.4f}' + every_n_epochs: 1 + save_top_k: 1 +# - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping +# init_args: +# monitor: "val_roc-auc" +# min_delta: 0.0 +# patience: 5 +# verbose: False +# mode: "max" +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint + init_args: + filename: 'per_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}_{val_roc-auc:.4f}' every_n_epochs: 25 - save_top_k: -1 + save_top_k: 1 diff --git a/configs/training/default_trainer.yml b/configs/training/default_trainer.yml index 91aa4244..0ce68a49 100644 --- a/configs/training/default_trainer.yml +++ b/configs/training/default_trainer.yml @@ -1,4 +1,4 @@ -min_epochs: 100 +min_epochs: 20 max_epochs: 100 default_root_dir: &default_root_dir logs logger: csv_logger.yml diff --git a/configs/training/early_stop_callbacks_regression.yml b/configs/training/early_stop_callbacks_regression.yml new file mode 100644 index 00000000..ebf314aa --- /dev/null +++ b/configs/training/early_stop_callbacks_regression.yml @@ -0,0 +1,19 @@ +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint + init_args: + monitor: val_loss + mode: 'min' + filename: 'best_{epoch:02d}_{val_loss:.4f}_{val_mse:.4f}' + every_n_epochs: 1 + save_top_k: 3 +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint + init_args: + filename: 'per_{epoch:02d}_{val_loss:.4f}_{val_mse:.4f}' + every_n_epochs: 25 + save_top_k: -1 +# - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping +# init_args: +# monitor: "val_loss_epoch" +# min_delta: 0.0 +# patience: 5 +# verbose: False +# mode: "min" diff --git a/configs/training/early_stop_callbacks_tox21.yml b/configs/training/early_stop_callbacks_tox21.yml new file mode 100644 index 00000000..468ca2a2 --- /dev/null +++ b/configs/training/early_stop_callbacks_tox21.yml @@ -0,0 +1,28 @@ +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint + init_args: + monitor: val_micro-f1 + mode: 'max' + filename: 'best_micro_f1_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}_{val_roc-auc:.4f}' + every_n_epochs: 1 + save_top_k: 1 +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint + init_args: + monitor: val_loss + mode: 'min' + filename: 'best_loss_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}_{val_roc-auc:.4f}' + every_n_epochs: 1 + save_top_k: 1 +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint + init_args: + monitor: val_roc-auc + mode: 'max' + filename: 'best_roc-auc_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}_{val_roc-auc:.4f}' + every_n_epochs: 1 + save_top_k: 1 +- class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping + init_args: + monitor: "val_roc-auc" + min_delta: 0.0 + patience: 5 + verbose: False + mode: "max" diff --git a/configs/training/solCur_callbacks.yml b/configs/training/solCur_callbacks.yml new file mode 100644 index 00000000..155ab7c6 --- /dev/null +++ b/configs/training/solCur_callbacks.yml @@ -0,0 +1,34 @@ +# - class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint +# init_args: +# monitor: val_loss +# mode: 'min' +# filename: 'best_loss_{epoch:02d}_{val_loss:.4f}_{val_mse:.4f}_{val_rmse:.4f}_{val_r2:.4f}' +# every_n_epochs: 1 +# save_top_k: 1 +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint + init_args: + monitor: val_r2 + mode: 'max' + filename: 'best_r2_{epoch:02d}_{val_loss:.4f}_{val_mse:.4f}_{val_rmse:.4f}_{val_r2:.4f}' + every_n_epochs: 1 + save_top_k: 1 +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint + init_args: + monitor: val_rmse + mode: 'min' + filename: 'best_rmse_{epoch:02d}_{val_loss:.4f}_{val_mse:.4f}_{val_rmse:.4f}_{val_r2:.4f}' + every_n_epochs: 1 + save_top_k: 1 +# - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping +# init_args: +# monitor: "val_rmse" +# min_delta: 0.0 +# patience: 5 +# verbose: False +# mode: "min" + +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint + init_args: + filename: 'per_{epoch:02d}_{val_loss:.4f}_{val_mse:.4f}_{val_rmse:.4f}_{val_r2:.4f}' + every_n_epochs: 25 + save_top_k: 1 diff --git a/configs/training/wandb_logger.yml b/configs/training/wandb_logger.yml index b0dd8870..b7c51418 100644 --- a/configs/training/wandb_logger.yml +++ b/configs/training/wandb_logger.yml @@ -3,4 +3,4 @@ init_args: save_dir: logs project: 'chebai' entity: 'chebai' - log_model: 'all' + log_model: 'all' \ No newline at end of file diff --git a/tests/unit/mock_data/tox_mock_data.py b/tests/unit/mock_data/tox_mock_data.py index b5f85bda..bee21a2d 100644 --- a/tests/unit/mock_data/tox_mock_data.py +++ b/tests/unit/mock_data/tox_mock_data.py @@ -395,6 +395,10 @@ def data_in_dict_format() -> List[Dict]: for dict_ in data_list: dict_["features"] = Tox21ChallengeMockData.FEATURE_OF_SMILES dict_["group"] = None + if any(label is None for label in dict_["labels"]): + dict_["missing_labels"] = [ + True if label is None else False for label in dict_["labels"] + ] return data_list @@ -506,5 +510,9 @@ def get_setup_processed_output_data() -> List[Dict]: "group": None, } ) + if any(label is None for label in dict_["labels"]): + complete_list[-1]["missing_labels"] = [ + True if label is None else False for label in dict_["labels"] + ] return complete_list diff --git a/tutorials/data_exploration_chebi.ipynb b/tutorials/data_exploration_chebi.ipynb index 81256f4a..03285b56 100644 --- a/tutorials/data_exploration_chebi.ipynb +++ b/tutorials/data_exploration_chebi.ipynb @@ -1077,9 +1077,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python (env_chebai)", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "env_chebai" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -1091,7 +1091,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.11" } }, "nbformat": 4, diff --git a/tutorials/demo_process_results.ipynb b/tutorials/demo_process_results.ipynb index b62af78e..76a181b6 100644 --- a/tutorials/demo_process_results.ipynb +++ b/tutorials/demo_process_results.ipynb @@ -8,7 +8,10 @@ "end_time": "2023-11-29T08:17:25.832642900Z", "start_time": "2023-11-29T08:17:25.816890700Z" }, - "collapsed": true + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } }, "outputs": [], "source": [ @@ -37,7 +40,10 @@ "end_time": "2023-11-24T09:13:26.387885900Z", "start_time": "2023-11-24T09:06:23.191727Z" }, - "collapsed": false + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } }, "outputs": [ { @@ -109,7 +115,10 @@ "end_time": "2023-11-29T08:33:48.374202Z", "start_time": "2023-11-29T08:33:48.261436600Z" }, - "collapsed": false + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } }, "outputs": [ { @@ -239,7 +248,10 @@ "end_time": "2023-11-24T09:55:24.187152800Z", "start_time": "2023-11-24T09:55:21.580572700Z" }, - "collapsed": false + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } }, "outputs": [ { @@ -275,6 +287,9 @@ "execution_count": 2, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -299,6 +314,9 @@ "execution_count": 4, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -320,9 +338,9 @@ " \"per_epoch=99_val_loss=0.0167_val_micro-f1=0.91.ckpt\",\n", ")\n", "model_path_v200 = \"electra_c100_bce_unweighted.ckpt\"\n", - "model_v148 = Electra.load_from_checkpoint(model_path_v148).to(\"cpu\")\n", - "model_v200 = Electra.load_from_checkpoint(model_path_v200).to(\"cpu\")\n", - "model_v227 = Electra.load_from_checkpoint(model_path_v227).to(\"cpu\")\n", + "model_v148 = Electra.load_from_checkpoint(model_path_v148, pretrained_checkpoint=None).to(\"cpu\")\n", + "model_v200 = Electra.load_from_checkpoint(model_path_v200, pretrained_checkpoint=None).to(\"cpu\")\n", + "model_v227 = Electra.load_from_checkpoint(model_path_v227, pretrained_checkpoint=None).to(\"cpu\")\n", "\n", "data_module_v200 = ChEBIOver100()\n", "data_module_v148 = ChEBIOver100(chebi_version_train=148)\n", @@ -338,6 +356,9 @@ "execution_count": 7, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -360,6 +381,9 @@ "execution_count": 3, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -382,6 +406,9 @@ "execution_count": 4, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -403,6 +430,9 @@ "execution_count": 9, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -428,6 +458,9 @@ "execution_count": 11, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -451,6 +484,9 @@ "execution_count": 5, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -483,6 +519,9 @@ "execution_count": 58, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -643,6 +682,9 @@ "execution_count": 12, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -700,6 +742,9 @@ "execution_count": 11, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -730,7 +775,10 @@ { "cell_type": "markdown", "metadata": { - "collapsed": false + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } }, "source": [ "Results:\n", @@ -762,6 +810,9 @@ "execution_count": 40, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -794,6 +845,9 @@ "execution_count": 41, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -826,6 +880,9 @@ "execution_count": 42, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -858,6 +915,9 @@ "execution_count": 13, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -912,6 +972,9 @@ "start_time": "2023-11-24T07:36:43.594504200Z" }, "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -958,6 +1021,9 @@ "start_time": "2023-11-24T07:36:51.800819200Z" }, "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -984,6 +1050,9 @@ "execution_count": null, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -1010,6 +1079,9 @@ "execution_count": null, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -1035,23 +1107,23 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" + "pygments_lexer": "ipython3", + "version": "3.12.11" } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 4 } diff --git a/tutorials/eval_model_basic.ipynb b/tutorials/eval_model_basic.ipynb index c8f851c5..a2c570e1 100644 --- a/tutorials/eval_model_basic.ipynb +++ b/tutorials/eval_model_basic.ipynb @@ -126,7 +126,7 @@ ], "source": [ "# evaluates model, stores results in buffer_dir\n", - "model = model_class.load_from_checkpoint(checkpoint_path)\n", + "model = model_class.load_from_checkpoint(checkpoint_path, pretrained_checkpoint=None)\n", "if buffer_dir is None:\n", " preds, labels = evaluate_model(\n", " model,\n", @@ -234,7 +234,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.11" } }, "nbformat": 4, diff --git a/tutorials/process_results_old_chebi.ipynb b/tutorials/process_results_old_chebi.ipynb index 9b05883a..cb3ec3be 100644 --- a/tutorials/process_results_old_chebi.ipynb +++ b/tutorials/process_results_old_chebi.ipynb @@ -3,7 +3,10 @@ { "cell_type": "markdown", "metadata": { - "collapsed": false + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } }, "source": [ "This script evaluates two models trained on the datasets $ChEBI_{v200}^{854}$ and $ChEBI_{v148}^{709}$." @@ -17,7 +20,10 @@ "end_time": "2023-12-01T09:09:32.987478800Z", "start_time": "2023-12-01T09:09:32.979311Z" }, - "collapsed": true + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } }, "outputs": [], "source": [ @@ -43,6 +49,9 @@ "start_time": "2023-12-01T09:09:34.063840600Z" }, "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -52,8 +61,8 @@ "model_path_v200 = os.path.join(\"models\", \"electra_c100_bce_unweighted.ckpt\")\n", "model_path_v148 = os.path.join(\"models\", \"electra_c100_bce_unweighted_v148.ckpt\")\n", "\n", - "model_v200 = Electra.load_from_checkpoint(model_path_v200).to(DEVICE)\n", - "model_v148 = Electra.load_from_checkpoint(model_path_v148).to(DEVICE)\n", + "model_v200 = Electra.load_from_checkpoint(model_path_v200, pretrained_checkpoint=None).to(DEVICE)\n", + "model_v148 = Electra.load_from_checkpoint(model_path_v148, pretrained_checkpoint=None).to(DEVICE)\n", "\n", "data_module_v200 = ChEBIOver100(chebi_version=200)\n", "data_module_v148 = ChEBIOver100(chebi_version=200, chebi_version_train=148)" @@ -68,6 +77,9 @@ "start_time": "2023-12-01T09:09:35.195490300Z" }, "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -91,6 +103,9 @@ "start_time": "2023-12-01T09:09:37.598008300Z" }, "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -116,6 +131,9 @@ "start_time": "2023-12-01T09:11:07.914456300Z" }, "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -142,6 +160,9 @@ "execution_count": 12, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -196,6 +217,9 @@ "execution_count": 11, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -228,6 +252,9 @@ "execution_count": 40, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -260,6 +287,9 @@ "execution_count": 41, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -292,6 +322,9 @@ "execution_count": 42, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } @@ -336,9 +369,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.12.11" } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 4 }