From bdba4428ce26ca2fb8285a87880714329610fec5 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 4 Nov 2024 13:02:22 +0100 Subject: [PATCH 01/71] script to evaluate go predictions --- chebai/result/evaluate_predictions.py | 70 +++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 chebai/result/evaluate_predictions.py diff --git a/chebai/result/evaluate_predictions.py b/chebai/result/evaluate_predictions.py new file mode 100644 index 00000000..e25c2130 --- /dev/null +++ b/chebai/result/evaluate_predictions.py @@ -0,0 +1,70 @@ +import torch +from jsonargparse import CLI +from torchmetrics.functional.classification import multilabel_auroc + +from chebai.result.utils import load_results_from_buffer + + +class EvaluatePredictions: + def __init__(self, eval_dir: str): + """ + Initializes the EvaluatePredictions class. + + Args: + eval_dir (str): Path to the directory containing evaluation files. + """ + self.eval_dir = eval_dir + self.metrics = [] + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.num_labels = None + + @staticmethod + def validate_eval_dir(label_files: torch.Tensor, pred_files: torch.Tensor) -> None: + """ + Validates that the number of labels matches the number of predictions, + ensuring that they have the same shape. + + Args: + label_files (torch.Tensor): Tensor containing label data. + pred_files (torch.Tensor): Tensor containing prediction data. + + Raises: + ValueError: If label and prediction tensors are mismatched in shape. + """ + if label_files is None or pred_files is None: + raise ValueError("Both label and prediction tensors must be provided.") + + # Check if the number of labels matches the number of predictions + if label_files.shape[0] != pred_files.shape[0]: + raise ValueError( + "Number of label tensors does not match the number of prediction tensors." + ) + + # Validate that the last dimension matches the expected number of classes + if label_files.shape[1] != pred_files.shape[1]: + raise ValueError( + "Label and prediction tensors must have the same shape in terms of class outputs." + ) + + def evaluate(self) -> None: + """ + Loads predictions and labels, validates file correspondence, and calculates Multilabel AUROC. + """ + test_preds, test_labels = load_results_from_buffer(self.eval_dir, self.device) + self.validate_eval_dir(test_labels, test_preds) + self.num_labels = test_preds.shape[1] + + ml_auroc = multilabel_auroc( + test_preds, test_labels, num_labels=self.num_labels + ).item() + + print("Multilabel AUC-ROC:", ml_auroc) + + +class Main: + def evaluate(self, eval_dir: str): + EvaluatePredictions(eval_dir).evaluate() + + +if __name__ == "__main__": + CLI(Main) From 6c0fce185ef8fc754a05c2923dd1bbb9382d2f06 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 4 Nov 2024 15:41:16 +0100 Subject: [PATCH 02/71] add fmax to evaluation script --- chebai/result/evaluate_predictions.py | 39 ++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/chebai/result/evaluate_predictions.py b/chebai/result/evaluate_predictions.py index e25c2130..48ddef83 100644 --- a/chebai/result/evaluate_predictions.py +++ b/chebai/result/evaluate_predictions.py @@ -1,7 +1,11 @@ +from typing import Tuple + +import numpy as np import torch from jsonargparse import CLI from torchmetrics.functional.classification import multilabel_auroc +from chebai.callbacks.epoch_metrics import MacroF1 from chebai.result.utils import load_results_from_buffer @@ -48,7 +52,7 @@ def validate_eval_dir(label_files: torch.Tensor, pred_files: torch.Tensor) -> No def evaluate(self) -> None: """ - Loads predictions and labels, validates file correspondence, and calculates Multilabel AUROC. + Loads predictions and labels, validates file correspondence, and calculates Multilabel AUROC and Fmax. """ test_preds, test_labels = load_results_from_buffer(self.eval_dir, self.device) self.validate_eval_dir(test_labels, test_preds) @@ -60,6 +64,38 @@ def evaluate(self) -> None: print("Multilabel AUC-ROC:", ml_auroc) + fmax, threshold = self.calculate_fmax(test_preds, test_labels) + print(f"F-max : {fmax}, threshold: {threshold}") + + def calculate_fmax( + self, test_preds: torch.Tensor, test_labels: torch.Tensor + ) -> Tuple[float, float]: + """ + Calculates the Fmax metric using the F1 score at various thresholds. + + Args: + test_preds (torch.Tensor): Predicted scores for the labels. + test_labels (torch.Tensor): True labels for the evaluation. + + Returns: + Tuple[float, float]: The maximum F1 score and the corresponding threshold. + """ + thresholds = np.linspace(0, 1, 100) + fmax = 0.0 + best_threshold = 0.0 + + for t in thresholds: + custom_f1_metric = MacroF1(num_labels=self.num_labels, threshold=t) + custom_f1_metric.update(test_preds, test_labels) + custom_f1_metric_score = custom_f1_metric.compute().item() + + # Check if the current score is the best we've seen + if custom_f1_metric_score > fmax: + fmax = custom_f1_metric_score + best_threshold = t + + return fmax, best_threshold + class Main: def evaluate(self, eval_dir: str): @@ -67,4 +103,5 @@ def evaluate(self, eval_dir: str): if __name__ == "__main__": + # evaluate_predictions.py evaluate CLI(Main) From 58ae92d9a73b889256e02fe6a88d97ebbce5437f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 5 Nov 2024 23:37:29 +0100 Subject: [PATCH 03/71] add base code for deep_go data migration - migration from deep go format to chebai->go_uniprot format --- .../migration/deep_go_data_mirgration.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 chebai/preprocessing/migration/deep_go_data_mirgration.py diff --git a/chebai/preprocessing/migration/deep_go_data_mirgration.py b/chebai/preprocessing/migration/deep_go_data_mirgration.py new file mode 100644 index 00000000..ce35ff0b --- /dev/null +++ b/chebai/preprocessing/migration/deep_go_data_mirgration.py @@ -0,0 +1,54 @@ +from typing import List + +import pandas as pd + +# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/utils.py#L18-L22 +NAMESPACES = { + "cc": "cellular_component", + "mf": "molecular_function", + "bp": "biological_process", +} + +# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 +MAXLEN = 1000 + + +def load_data(data_dir): + test_df = pd.DataFrame(pd.read_pickle("test_data.pkl")) + train_df = pd.DataFrame(pd.read_pickle("train_data.pkl")) + validation_df = pd.DataFrame(pd.read_pickle("valid_data.pkl")) + + required_columns = [ + "proteins", + "accessions", + "sequences", + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L45-L58 + "exp_annotations", # Directly associated GO ids + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L60-L69 + "prop_annotations", # Transitively associated GO ids + ] + + new_df = pd.concat( + [ + train_df[required_columns], + validation_df[required_columns], + test_df[required_columns], + ], + ignore_index=True, + ) + # Generate splits.csv file to store ids of each corresponding split + split_assignment_list: List[pd.DataFrame] = [ + pd.DataFrame({"id": train_df["proteins"], "split": "train"}), + pd.DataFrame({"id": validation_df["proteins"], "split": "validation"}), + pd.DataFrame({"id": test_df["proteins"], "split": "test"}), + ] + + combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) + + +def save_data(data_dir, data_df): + pass + + +if __name__ == "__main__": + pass From 78a38de062c603b9e6d193a1a3a2278a56f9da82 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 5 Nov 2024 23:38:01 +0100 Subject: [PATCH 04/71] varry fmax threshold as per paper --- chebai/result/evaluate_predictions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chebai/result/evaluate_predictions.py b/chebai/result/evaluate_predictions.py index 48ddef83..355c07c0 100644 --- a/chebai/result/evaluate_predictions.py +++ b/chebai/result/evaluate_predictions.py @@ -80,7 +80,8 @@ def calculate_fmax( Returns: Tuple[float, float]: The maximum F1 score and the corresponding threshold. """ - thresholds = np.linspace(0, 1, 100) + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/metrics.py#L51-L52 + thresholds = np.linspace(0, 1, 101) fmax = 0.0 best_threshold = 0.0 From 3a4e007fc0267ad72d4c2f43f7bfb99fdb1245ee Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 5 Nov 2024 23:38:40 +0100 Subject: [PATCH 05/71] go_uniprot: add sequence len to docstring --- chebai/preprocessing/datasets/go_uniprot.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index a2c4ae54..12bb0adc 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -56,10 +56,16 @@ class _GOUniProtDataExtractor(_DynamicDataset, ABC): Args: dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. - **kwargs: Additional keyword arguments passed to XYBaseDataModule. + max_sequence_length (int, optional): Specifies the maximum allowed sequence length for a protein, with a + default of 1002. During data preprocessing, any proteins exceeding this length will be excluded from further + processing. + **kwargs: Additional keyword arguments passed to DynamicDataset and XYBaseDataModule. Attributes: dynamic_data_split_seed (int): The seed for random data splitting, default is 42. + max_sequence_length (int, optional): Specifies the maximum allowed sequence length for a protein, with a + default of 1002. During data preprocessing, any proteins exceeding this length will be excluded from further + processing. splits_file_path (Optional[str]): Path to the CSV file containing split assignments. """ From 227a014af32479932208c2dcc565babc8b3fbbf8 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 6 Nov 2024 15:47:46 +0100 Subject: [PATCH 06/71] update experiment evidence codes as per DeepGo SE - https://github.com/ChEB-AI/python-chebai/issues/36#issuecomment-2458153758 --- chebai/preprocessing/datasets/go_uniprot.py | 7 ++++++- chebai/preprocessing/datasets/protein_pretraining.py | 4 ++-- tests/unit/mock_data/ontology_mock_data.py | 4 ++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index 12bb0adc..73edc976 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -43,6 +43,11 @@ "IEP", "TAS", "IC", + "HTP", + "HDA", + "HMP", + "HGI", + "HEP", } # https://github.com/bio-ontology-research-group/deepgo/blob/d97447a05c108127fee97982fd2c57929b2cf7eb/aaindex.py#L8 @@ -414,7 +419,7 @@ def _get_swiss_to_go_mapping(self) -> pd.DataFrame: Quote from the DeepGo Paper: `We select proteins with annotations having experimental evidence codes - (EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC) and filter the proteins by a + `EXPERIMENTAL_EVIDENCE_CODES` and filter the proteins by a maximum length of 1002, ignoring proteins with ambiguous amino acid codes (B, O, J, U, X, Z) in their sequence.` diff --git a/chebai/preprocessing/datasets/protein_pretraining.py b/chebai/preprocessing/datasets/protein_pretraining.py index 8550db2b..f6e9d66d 100644 --- a/chebai/preprocessing/datasets/protein_pretraining.py +++ b/chebai/preprocessing/datasets/protein_pretraining.py @@ -96,8 +96,8 @@ def _download_required_data(self) -> str: def _parse_protein_data_for_pretraining(self) -> pd.DataFrame: """ Parses the Swiss-Prot data and returns a DataFrame containing Swiss-Prot proteins which does not have any valid - Gene Ontology(GO) label. A valid GO label is the one which has one of the following evidence code - (EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC). + Gene Ontology(GO) label. A valid GO label is the one which has one of the following evidence code defined in + `EXPERIMENTAL_EVIDENCE_CODES`. The DataFrame includes the following columns: - "swiss_id": The unique identifier for each Swiss-Prot record. diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py index a05b89f1..ca6148e7 100644 --- a/tests/unit/mock_data/ontology_mock_data.py +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -668,8 +668,8 @@ def get_UniProt_raw_data() -> str: - **Swiss_Prot_11**: Has only Invalid GO class but lacks a sequence. Note: - A valid GO label is the one which has one of the following evidence code - (EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC). + A valid GO label is the one which has one of the following evidence code defined in + `EXPERIMENTAL_EVIDENCE_CODES`. Returns: str: The raw UniProt data in string format. From c6d60cddd23e1e8d137ec4d02285061baa987d31 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 6 Nov 2024 16:40:53 +0100 Subject: [PATCH 07/71] consIder `X` as a valid amino acid as per DeepGO-SE - https://github.com/ChEB-AI/python-chebai/issues/36#issuecomment-2458153758 --- chebai/preprocessing/datasets/go_uniprot.py | 27 +++++++++++++------ .../datasets/protein_pretraining.py | 4 +-- chebai/preprocessing/reader.py | 4 ++- tests/unit/mock_data/ontology_mock_data.py | 13 ++++----- 4 files changed, 31 insertions(+), 17 deletions(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index 73edc976..7b1c16e3 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -1,13 +1,22 @@ -# Reference for this file : +# References for this file : +# Reference 1: # Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf; # DeepGO: Predicting protein functions from sequence and interactions # using a deep ontology-aware classifier, Bioinformatics, 2017. # https://doi.org/10.1093/bioinformatics/btx624 # Github: https://github.com/bio-ontology-research-group/deepgo + +# Reference 2: # https://www.ebi.ac.uk/GOA/downloads # https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt # https://www.uniprot.org/uniprotkb +# Reference 3: +# Kulmanov, M., Guzmán-Vega, F.J., Duek Roggli, +# P. et al. Protein function prediction as approximate semantic entailment. Nat Mach Intell 6, 220–228 (2024). +# https://doi.org/10.1038/s42256-024-00795-w +# https://github.com/bio-ontology-research-group/deepgo2 + __all__ = [ "GOUniProtOver250", "GOUniProtOver50", @@ -34,6 +43,7 @@ from chebai.preprocessing import reader as dr from chebai.preprocessing.datasets.base import _DynamicDataset +# https://github.com/bio-ontology-research-group/deepgo/blob/master/utils.py#L15 EXPERIMENTAL_EVIDENCE_CODES = { "EXP", "IDA", @@ -43,6 +53,8 @@ "IEP", "TAS", "IC", + # New evidence codes added in latest paper year 2024 Reference number 3 + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/utils.py#L24-L26 "HTP", "HDA", "HMP", @@ -51,7 +63,9 @@ } # https://github.com/bio-ontology-research-group/deepgo/blob/d97447a05c108127fee97982fd2c57929b2cf7eb/aaindex.py#L8 -AMBIGUOUS_AMINO_ACIDS = {"B", "O", "J", "U", "X", "Z", "*"} +# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L10 +# `X` is now considered as valid amino acid, as per latest paper year 2024 Refernce number 3 +AMBIGUOUS_AMINO_ACIDS = {"B", "O", "J", "U", "Z", "*"} class _GOUniProtDataExtractor(_DynamicDataset, ABC): @@ -416,12 +430,9 @@ def _get_swiss_to_go_mapping(self) -> pd.DataFrame: Note: This mapping is necessary because the GO data does not include the protein sequence representation. - - Quote from the DeepGo Paper: - `We select proteins with annotations having experimental evidence codes - `EXPERIMENTAL_EVIDENCE_CODES` and filter the proteins by a - maximum length of 1002, ignoring proteins with ambiguous amino acid codes - (B, O, J, U, X, Z) in their sequence.` + We select proteins with annotations having experimental evidence codes, as specified in + `EXPERIMENTAL_EVIDENCE_CODES` and filter the proteins by a maximum length of 1002, ignoring proteins with + ambiguous amino acid codes specified in `AMBIGUOUS_AMINO_ACIDS` in their sequence. Check the link below for keyword details: https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt diff --git a/chebai/preprocessing/datasets/protein_pretraining.py b/chebai/preprocessing/datasets/protein_pretraining.py index f6e9d66d..63d53144 100644 --- a/chebai/preprocessing/datasets/protein_pretraining.py +++ b/chebai/preprocessing/datasets/protein_pretraining.py @@ -96,7 +96,7 @@ def _download_required_data(self) -> str: def _parse_protein_data_for_pretraining(self) -> pd.DataFrame: """ Parses the Swiss-Prot data and returns a DataFrame containing Swiss-Prot proteins which does not have any valid - Gene Ontology(GO) label. A valid GO label is the one which has one of the following evidence code defined in + Gene Ontology(GO) label. A valid GO label is the one which has one of the following evidence codes, as specified in `EXPERIMENTAL_EVIDENCE_CODES`. The DataFrame includes the following columns: @@ -104,7 +104,7 @@ def _parse_protein_data_for_pretraining(self) -> pd.DataFrame: - "sequence": The protein sequence. Note: - We ignore proteins with ambiguous amino acid codes (B, O, J, U, X, Z) in their sequence.` + We ignore proteins with ambiguous amino acid specified in `AMBIGUOUS_AMINO_ACIDS` in their sequence.` Returns: pd.DataFrame: A DataFrame where each row corresponds to a Swiss-Prot record with not associated valid GO. diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index e220e1e4..a08a3f91 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -348,7 +348,7 @@ class ProteinDataReader(DataReader): COLLATOR = RaggedCollator - # 20 natural amino acid notation + # 21 natural amino acid notation AA_LETTER = [ "A", "R", @@ -370,6 +370,8 @@ class ProteinDataReader(DataReader): "W", "Y", "V", + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L3-L5 + "X", # Consider valid in latest paper year 2024 Reference number 3 in go_uniprot.py ] def name(self) -> str: diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py index ca6148e7..552d2918 100644 --- a/tests/unit/mock_data/ontology_mock_data.py +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -658,18 +658,19 @@ def get_UniProt_raw_data() -> str: - **Swiss_Prot_1**: A valid protein with three valid GO classes and one invalid GO class. - **Swiss_Prot_2**: Another valid protein with two valid GO classes and one invalid. - **Swiss_Prot_3**: Contains valid GO classes but has a sequence length > 1002. - - **Swiss_Prot_4**: Has valid GO classes but contains an invalid amino acid, 'X'. + - **Swiss_Prot_4**: Has valid GO classes but contains an invalid amino acid, 'B'. - **Swiss_Prot_5**: Has a sequence but no GO classes associated. - **Swiss_Prot_6**: Has GO classes without any associated evidence codes. - **Swiss_Prot_7**: Has a GO class with an invalid evidence code. - **Swiss_Prot_8**: Has a sequence length > 1002 and has only invalid GO class. - - **Swiss_Prot_9**: Has no GO classes but contains an invalid amino acid, 'X', in its sequence. + - **Swiss_Prot_9**: Has no GO classes but contains an invalid amino acid, 'B', in its sequence. - **Swiss_Prot_10**: Has a valid GO class but lacks a sequence. - **Swiss_Prot_11**: Has only Invalid GO class but lacks a sequence. Note: - A valid GO label is the one which has one of the following evidence code defined in - `EXPERIMENTAL_EVIDENCE_CODES`. + A valid GO label is the one which has one of the following evidence code specified in + go_uniprot.py->`EXPERIMENTAL_EVIDENCE_CODES`. + Invalid amino acids are specified in go_uniprot.py->`AMBIGUOUS_AMINO_ACIDS`. Returns: str: The raw UniProt data in string format. @@ -715,7 +716,7 @@ def get_UniProt_raw_data() -> str: "DR GO; GO:0000005; P:regulation of viral transcription; IEA:InterPro.\n" "DR GO; GO:0000006; P:regulation of viral transcription; EXP:PomBase.\n" "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" - " XAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + " BAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" "//\n" # Below protein with sequence string but has no GO class "ID Swiss_Prot_5 Reviewed; 60 AA.\n" @@ -749,7 +750,7 @@ def get_UniProt_raw_data() -> str: "ID Swiss_Prot_9 Reviewed; 60 AA.\n" "AC Q6GZX4;\n" "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" - " XAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + " BAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" "//\n" # Below protein with a `valid` associated GO class but without sequence string "ID Swiss_Prot_10 Reviewed; 60 AA.\n" From ca5461fce0bf4a431f620af0e7ad3df81c61b1b5 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 6 Nov 2024 20:40:09 +0100 Subject: [PATCH 08/71] deepgo se mirgration : add class to migrate --- .../migration/deep_go_data_mirgration.py | 342 +++++++++++++++--- 1 file changed, 297 insertions(+), 45 deletions(-) diff --git a/chebai/preprocessing/migration/deep_go_data_mirgration.py b/chebai/preprocessing/migration/deep_go_data_mirgration.py index ce35ff0b..a33e407b 100644 --- a/chebai/preprocessing/migration/deep_go_data_mirgration.py +++ b/chebai/preprocessing/migration/deep_go_data_mirgration.py @@ -1,54 +1,306 @@ -from typing import List +import os +from collections import OrderedDict +from random import randint +from typing import List, Literal import pandas as pd +from jsonargparse import CLI -# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/utils.py#L18-L22 -NAMESPACES = { - "cc": "cellular_component", - "mf": "molecular_function", - "bp": "biological_process", -} - -# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 -MAXLEN = 1000 - - -def load_data(data_dir): - test_df = pd.DataFrame(pd.read_pickle("test_data.pkl")) - train_df = pd.DataFrame(pd.read_pickle("train_data.pkl")) - validation_df = pd.DataFrame(pd.read_pickle("valid_data.pkl")) - - required_columns = [ - "proteins", - "accessions", - "sequences", - # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L45-L58 - "exp_annotations", # Directly associated GO ids - # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L60-L69 - "prop_annotations", # Transitively associated GO ids - ] - - new_df = pd.concat( - [ - train_df[required_columns], - validation_df[required_columns], - test_df[required_columns], - ], - ignore_index=True, - ) - # Generate splits.csv file to store ids of each corresponding split - split_assignment_list: List[pd.DataFrame] = [ - pd.DataFrame({"id": train_df["proteins"], "split": "train"}), - pd.DataFrame({"id": validation_df["proteins"], "split": "validation"}), - pd.DataFrame({"id": test_df["proteins"], "split": "test"}), - ] +from chebai.preprocessing.datasets.go_uniprot import ( + GOUniProtOver50, + GOUniProtOver250, + _GOUniProtDataExtractor, +) + + +class DeepGoDataMigration: + """ + A class to handle data migration and processing for the DeepGO project. It migrates the data from the DeepGO-SE + data structure to our data structure followed for GO-UniProt data. + + Attributes: + _CORRESPONDING_GO_CLASSES (dict): Mapping of GO branches to specific data extractor classes. + _MAXLEN (int): Maximum sequence length for sequences. + _LABELS_START_IDX (int): Starting index for labels in the dataset. + + Methods: + __init__(data_dir, go_branch): Initializes the data directory and GO branch. + _load_data(): Loads train, validation, test, and terms data from the specified directory. + _record_splits(): Creates a DataFrame with IDs and their corresponding split. + migrate(): Executes the migration process including data loading, processing, and saving. + _extract_required_data_from_splits(): Extracts required columns from the splits data. + _generate_labels(data_df): Generates label columns for the data based on GO terms. + extract_go_id(go_list): Extracts GO IDs from a list. + save_migrated_data(data_df, splits_df): Saves the processed data and splits. + """ + + # Link for the namespaces convention used for GO branch + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/utils.py#L18-L22 + _CORRESPONDING_GO_CLASSES = { + "cc": GOUniProtOver50, + "mf": GOUniProtOver50, + "bp": GOUniProtOver250, + } + + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 + _MAXLEN = 1000 + _LABELS_START_IDX = _GOUniProtDataExtractor._LABELS_START_IDX + + def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): + """ + Initializes the data migration object with a data directory and GO branch. + + Args: + data_dir (str): Directory containing the data files. + go_branch (Literal["cc", "mf", "bp"]): GO branch to use (cellular_component, molecular_function, or biological_process). + """ + valid_go_branches = list(self._CORRESPONDING_GO_CLASSES.keys()) + if go_branch not in valid_go_branches: + raise ValueError(f"go_branch must be one of {valid_go_branches}") + self._go_branch = go_branch + + self._data_dir = os.path.join(data_dir, go_branch) + self._train_df: pd.DataFrame = None + self._test_df: pd.DataFrame = None + self._validation_df: pd.DataFrame = None + self._terms_df: pd.DataFrame = None + self._classes: List[str] = None + + def _load_data(self) -> None: + """ + Loads the test, train, validation, and terms data from the pickled files + in the data directory. + """ + try: + print(f"Loading data from {self._data_dir}......") + self._test_df = pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "test_data.pkl")) + ) + self._train_df = pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "train_data.pkl")) + ) + self._validation_df = pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "valid_data.pkl")) + ) + self._terms_df = pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "terms.pkl")) + ) + except FileNotFoundError as e: + print(f"Error loading data: {e}") + + def _record_splits(self) -> pd.DataFrame: + """ + Creates a DataFrame that stores the IDs and their corresponding data splits. + + Returns: + pd.DataFrame: A combined DataFrame containing split assignments. + """ + print("Recording splits...") + split_assignment_list: List[pd.DataFrame] = [ + pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}), + pd.DataFrame( + {"id": self._validation_df["proteins"], "split": "validation"} + ), + pd.DataFrame({"id": self._test_df["proteins"], "split": "test"}), + ] + + combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) + return combined_split_assignment + + def migrate(self) -> None: + """ + Executes the data migration by loading, processing, and saving the data. + """ + print("Migration started......") + self._load_data() + if not all( + [self._train_df, self._validation_df, self._test_df, self._terms_df] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + splits_df = self._record_splits() + + data_df = self._extract_required_data_from_splits() + data_with_labels_df = self._generate_labels(data_df) + + if not all([data_with_labels_df, splits_df, self._classes]): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + + self.save_migrated_data(data_df, splits_df) + + def _extract_required_data_from_splits(self) -> pd.DataFrame: + """ + Extracts required columns from the combined data splits. + + Returns: + pd.DataFrame: A DataFrame containing the essential columns for processing. + """ + print("Combining the data splits with required data..... ") + required_columns = [ + "proteins", + "accessions", + "sequences", + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L45-L58 + "exp_annotations", # Directly associated GO ids + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L60-L69 + "prop_annotations", # Transitively associated GO ids + ] + + new_df = pd.concat( + [ + self._train_df[required_columns], + self._validation_df[required_columns], + self._test_df[required_columns], + ], + ignore_index=True, + ) + new_df["go_ids"] = new_df.apply( + lambda row: self.extract_go_id(row["exp_annotations"]) + + self.extract_go_id(row["prop_annotations"]), + axis=1, + ) - combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) + data_df = pd.DataFrame( + OrderedDict( + swiss_id=new_df["proteins"], + accession=new_df["accessions"], + go_ids=new_df["go_ids"], + sequence=new_df["sequences"], + ) + ) + return data_df + def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame: + """ + Generates label columns for each GO term in the dataset. -def save_data(data_dir, data_df): - pass + Args: + data_df (pd.DataFrame): DataFrame containing data with GO IDs. + + Returns: + pd.DataFrame: DataFrame with new label columns. + """ + print("Generating labels based on terms.pkl file.......") + parsed_go_ids: pd.Series = self._terms_df.apply( + lambda row: self.extract_go_id(row["gos"]) + ) + all_go_ids_list = parsed_go_ids.values.tolist() + self._classes = all_go_ids_list + new_label_columns = pd.DataFrame( + False, index=data_df.index, columns=all_go_ids_list + ) + data_df = pd.concat([data_df, new_label_columns], axis=1) + + for index, row in data_df.iterrows(): + for go_id in row["go_ids"]: + if go_id in data_df.columns: + data_df.at[index, go_id] = True + + data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)] + return data_df + + @staticmethod + def extract_go_id(go_list: List[str]) -> List[str]: + """ + Extracts and parses GO IDs from a list of GO annotations. + + Args: + go_list (List[str]): List of GO annotation strings. + + Returns: + List[str]: List of parsed GO IDs. + """ + return [ + _GOUniProtDataExtractor._parse_go_id(go_id_str) for go_id_str in go_list + ] + + def save_migrated_data( + self, data_df: pd.DataFrame, splits_df: pd.DataFrame + ) -> None: + """ + Saves the processed data and split information. + + Args: + data_df (pd.DataFrame): Data with GO labels. + splits_df (pd.DataFrame): Split assignment DataFrame. + """ + print("Saving transformed data......") + go_class_instance: _GOUniProtDataExtractor = self._CORRESPONDING_GO_CLASSES[ + self._go_branch + ](go_branch=self._go_branch, max_sequence_length=self._MAXLEN) + + go_class_instance.save_processed( + data_df, go_class_instance.processed_file_names_dict["data"] + ) + print( + f"{go_class_instance.processed_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}" + ) + + splits_df.to_csv( + os.path.join(go_class_instance.processed_dir_main, "splits.csv"), + index=False, + ) + print(f"splits.csv saved to {go_class_instance.processed_dir_main}") + + classes = sorted(self._classes) + with open( + os.path.join(go_class_instance.processed_dir_main, "classes.txt"), "wt" + ) as fout: + fout.writelines(str(node) + "\n" for node in classes) + print(f"classes.txt saved to {go_class_instance.processed_dir_main}") + print("Migration completed!") + + +class Main: + """ + Main class to handle the migration process for DeepGoDataMigration. + + Methods: + migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]): + Initiates the migration process for the specified data directory and GO branch. + """ + + def migrate(self, data_dir: str, go_branch: str) -> None: + """ + Initiates the migration process by creating a DeepGoDataMigration instance + and invoking its migrate method. + + Args: + data_dir (str): Directory containing the data files. + go_branch (Literal["cc", "mf", "bp"]): GO branch to use + ("cc" for cellular_component, + "mf" for molecular_function, + or "bp" for biological_process). + """ + DeepGoDataMigration(data_dir, go_branch).migrate() + + +class Main1: + def __init__(self, max_prize: int = 100): + """ + Args: + max_prize: Maximum prize that can be awarded. + """ + self.max_prize = max_prize + + def person(self, name: str, additional_prize: int = 0): + """ + Args: + name: Name of the winner. + additional_prize: Additional prize that can be added to the prize amount. + """ + prize = randint(0, self.max_prize) + additional_prize + return f"{name} won {prize}€!" if __name__ == "__main__": - pass + # Example: python script_name.py migrate data_dir="data/deep_go_se_training_data" go_branch="bp" + # --data_dir specifies the directory containing the data files. + # --go_branch specifies the GO branch (cc, mf, or bp) you want to use for the migration. + CLI( + Main1, + description="DeepGoDataMigration CLI tool to handle migration of GO data for specified branches (cc, mf, bp).", + ) From dfb9430795a7a45826eb350d3068074e2b567a83 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 7 Nov 2024 11:12:15 +0100 Subject: [PATCH 09/71] migration: rectify errors --- .../migration/deep_go_data_mirgration.py | 65 ++++++++----------- 1 file changed, 28 insertions(+), 37 deletions(-) diff --git a/chebai/preprocessing/migration/deep_go_data_mirgration.py b/chebai/preprocessing/migration/deep_go_data_mirgration.py index a33e407b..5c22b389 100644 --- a/chebai/preprocessing/migration/deep_go_data_mirgration.py +++ b/chebai/preprocessing/migration/deep_go_data_mirgration.py @@ -1,7 +1,6 @@ import os from collections import OrderedDict -from random import randint -from typing import List, Literal +from typing import List, Literal, Optional import pandas as pd from jsonargparse import CLI @@ -59,12 +58,12 @@ def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): raise ValueError(f"go_branch must be one of {valid_go_branches}") self._go_branch = go_branch - self._data_dir = os.path.join(data_dir, go_branch) - self._train_df: pd.DataFrame = None - self._test_df: pd.DataFrame = None - self._validation_df: pd.DataFrame = None - self._terms_df: pd.DataFrame = None - self._classes: List[str] = None + self._data_dir: str = os.path.join(rf"{data_dir}", go_branch) + self._train_df: Optional[pd.DataFrame] = None + self._test_df: Optional[pd.DataFrame] = None + self._validation_df: Optional[pd.DataFrame] = None + self._terms_df: Optional[pd.DataFrame] = None + self._classes: Optional[List[str]] = None def _load_data(self) -> None: """ @@ -114,7 +113,13 @@ def migrate(self) -> None: print("Migration started......") self._load_data() if not all( - [self._train_df, self._validation_df, self._test_df, self._terms_df] + df is not None + for df in [ + self._train_df, + self._validation_df, + self._test_df, + self._terms_df, + ] ): raise Exception( "Data splits or terms data is not available in instance variables." @@ -124,7 +129,9 @@ def migrate(self) -> None: data_df = self._extract_required_data_from_splits() data_with_labels_df = self._generate_labels(data_df) - if not all([data_with_labels_df, splits_df, self._classes]): + if not all( + var is not None for var in [data_with_labels_df, splits_df, self._classes] + ): raise Exception( "Data splits or terms data is not available in instance variables." ) @@ -184,8 +191,8 @@ def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame: pd.DataFrame: DataFrame with new label columns. """ print("Generating labels based on terms.pkl file.......") - parsed_go_ids: pd.Series = self._terms_df.apply( - lambda row: self.extract_go_id(row["gos"]) + parsed_go_ids: pd.Series = self._terms_df["gos"].apply( + lambda gos: _GOUniProtDataExtractor._parse_go_id(gos) ) all_go_ids_list = parsed_go_ids.values.tolist() self._classes = all_go_ids_list @@ -203,7 +210,7 @@ def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame: return data_df @staticmethod - def extract_go_id(go_list: List[str]) -> List[str]: + def extract_go_id(go_list: List[str]) -> List[int]: """ Extracts and parses GO IDs from a list of GO annotations. @@ -230,13 +237,13 @@ def save_migrated_data( print("Saving transformed data......") go_class_instance: _GOUniProtDataExtractor = self._CORRESPONDING_GO_CLASSES[ self._go_branch - ](go_branch=self._go_branch, max_sequence_length=self._MAXLEN) + ](go_branch=self._go_branch.upper(), max_sequence_length=self._MAXLEN) go_class_instance.save_processed( - data_df, go_class_instance.processed_file_names_dict["data"] + data_df, go_class_instance.processed_main_file_names_dict["data"] ) print( - f"{go_class_instance.processed_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}" + f"{go_class_instance.processed_main_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}" ) splits_df.to_csv( @@ -263,7 +270,8 @@ class Main: Initiates the migration process for the specified data directory and GO branch. """ - def migrate(self, data_dir: str, go_branch: str) -> None: + @staticmethod + def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None: """ Initiates the migration process by creating a DeepGoDataMigration instance and invoking its migrate method. @@ -278,29 +286,12 @@ def migrate(self, data_dir: str, go_branch: str) -> None: DeepGoDataMigration(data_dir, go_branch).migrate() -class Main1: - def __init__(self, max_prize: int = 100): - """ - Args: - max_prize: Maximum prize that can be awarded. - """ - self.max_prize = max_prize - - def person(self, name: str, additional_prize: int = 0): - """ - Args: - name: Name of the winner. - additional_prize: Additional prize that can be added to the prize amount. - """ - prize = randint(0, self.max_prize) + additional_prize - return f"{name} won {prize}€!" - - if __name__ == "__main__": - # Example: python script_name.py migrate data_dir="data/deep_go_se_training_data" go_branch="bp" + # Example: python script_name.py migrate --data_dir="data/deep_go_se_training_data" --go_branch="bp" # --data_dir specifies the directory containing the data files. # --go_branch specifies the GO branch (cc, mf, or bp) you want to use for the migration. CLI( - Main1, + Main, description="DeepGoDataMigration CLI tool to handle migration of GO data for specified branches (cc, mf, bp).", + as_positional=False, ) From 085b13b5798398d4dca9477ed8ad80ecf50d2e0b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 7 Nov 2024 13:25:21 +0100 Subject: [PATCH 10/71] protein trigram containing tokenS with `X` - https://github.com/ChEB-AI/python-chebai/issues/36#issuecomment-2458153758 --- .../bin/protein_token_3_gram/tokens.txt | 359 ++++++++++++++++++ 1 file changed, 359 insertions(+) diff --git a/chebai/preprocessing/bin/protein_token_3_gram/tokens.txt b/chebai/preprocessing/bin/protein_token_3_gram/tokens.txt index 69dca126..534e5db1 100644 --- a/chebai/preprocessing/bin/protein_token_3_gram/tokens.txt +++ b/chebai/preprocessing/bin/protein_token_3_gram/tokens.txt @@ -7998,3 +7998,362 @@ WWC WCC WCH WWM +TAX +AXD +XDR +IEX +EXV +QAX +AXX +XXE +XES +MXN +XNF +NRX +RXX +XXX +XXR +XRI +SAX +AXG +XGG +PRX +RXR +XRX +RXE +XEF +QEX +EXQ +XQR +REX +EXR +RXQ +XQQ +DRX +RXP +XPG +QMX +MXT +XTX +TXR +XRM +APX +PXX +XXG +XGI +NLX +LXX +XXM +XMA +LNX +NXE +XEA +GTX +TXN +XND +LIX +IXI +XIM +MVX +VXX +XXK +XKT +GLX +LXP +XPP +QGX +GXD +XDL +XAP +QNX +NXM +XMN +VAX +XGV +IKX +KXY +KEX +EXL +XLY +GQX +QXE +XEP +PLX +XKC +PVX +XKE +RXI +XIR +AXL +XLN +LLX +LXD +XDA +AXE +XEL +GGX +GXG +KAX +XXA +XAG +XWS +SPX +PXC +XCD +GWX +WXH +XHF +MPX +ESX +SXN +XNK +DLX +LXN +XNS +QXG +XGD +ITX +XRG +NEX +EXA +XAL +LDX +DXI +XII +TPX +PXM +XMR +NXG +XGY +ASX +SXV +XVE +TKX +KXA +KRX +XXT +XTL +IDX +DXX +XXL +XLV +AKX +KXX +QHX +HXV +XVN +NSX +SXX +XKX +XDP +DAX +AXK +XKQ +PIX +IXX +XXF +VLX +XDI +DIX +IXL +XLK +LKX +KXV +XVA +DNX +NXD +ILX +LXK +XKV +VYX +YXE +XEI +RXS +XSH +KGX +XGF +AVX +VXY +XYG +HVX +XXI +XID +TVX +XXS +XSA +ENX +NXX +XMD +IIX +XMQ +AEX +EXX +XME +PGX +GXP +XPR +SKX +KXF +XFT +HRX +XSW +PQX +XGR +QQX +VTX +XRP +PSX +SXP +XPL +VGX +GXY +RSX +SXS +XSL +VSX +XST +AXV +XVL +AGX +GXX +XTK +KLX +LXR +XRV +AHX +HXC +XCS +LVX +VXN +XNR +NGX +GXL +TSX +SXQ +XQN +KXL +XLL +VIX +IXG +XGA +GFX +FXG +XGL +PTX +TXT +XTS +EMX +MXQ +SXY +XYA +IQX +QXY +XYR +TXK +IGX +XPS +PXT +XTG +NXQ +VKX +KXS +XSN +GVX +VXE +GRX +XRE +YKX +KXE +XEE +EEX +EXT +XTI +EHX +HXN +XNL +NDX +DXD +IAX +KSX +SXL +RRX +XRK +DDX +DXE +RXG +VXL +XLS +DTX +TXG +VXF +XFA +XIG +VXT +XTA +ISX +SXR +XRY +VQX +QXP +XPC +LGX +GXS +HGX +XGH +XXD +XDD +KKX +XXV +PKX +XLT +XSP +XLD +RAX +AXS +XSI +IYX +YXX +XXP +XPI +MSX +SXT +GEX +XHP +LFX +FXX +VXI +XIW +QTX +TXX +XXQ +XQA +FLX +DXN +XNC +MXS +XSR +YLX +EQX +QXS +TMX +MXC +XCY +NXA +XAV +EXE +XEQ +HPX +PXP +LMX +MXX +KTX +XKK +XXH +XHS +MKX +XIH +WRX +XKS +EXY +XYQ +QKX From 3e0bae0d75c0d3330a75c3c72e6ffa023ae2b37b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 7 Nov 2024 13:28:59 +0100 Subject: [PATCH 11/71] protein token unigram contain `X` - https://github.com/ChEB-AI/python-chebai/issues/36#issuecomment-2458153758 --- chebai/preprocessing/bin/protein_token/tokens.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/chebai/preprocessing/bin/protein_token/tokens.txt b/chebai/preprocessing/bin/protein_token/tokens.txt index 72ad1b6d..c31c5b72 100644 --- a/chebai/preprocessing/bin/protein_token/tokens.txt +++ b/chebai/preprocessing/bin/protein_token/tokens.txt @@ -18,3 +18,4 @@ W E V H +X From 99b5af1e263aa86ccaf1f350fb8703da202e13ec Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 12 Nov 2024 00:21:29 +0100 Subject: [PATCH 12/71] add migration for deepgo1 - 2018 paper --- .../migration/deep_go/__init__.py | 0 .../deep_go/migrate_deep_go_1_data.py | 310 ++++++++++++++++++ .../migrate_deep_go_2_data.py} | 10 +- 3 files changed, 318 insertions(+), 2 deletions(-) create mode 100644 chebai/preprocessing/migration/deep_go/__init__.py create mode 100644 chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py rename chebai/preprocessing/migration/{deep_go_data_mirgration.py => deep_go/migrate_deep_go_2_data.py} (96%) diff --git a/chebai/preprocessing/migration/deep_go/__init__.py b/chebai/preprocessing/migration/deep_go/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py new file mode 100644 index 00000000..be709364 --- /dev/null +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py @@ -0,0 +1,310 @@ +import os +from collections import OrderedDict +from typing import List, Literal, Optional + +import pandas as pd +from jsonargparse import CLI + +from chebai.preprocessing.datasets.go_uniprot import ( + GOUniProtOver50, + GOUniProtOver250, + _GOUniProtDataExtractor, +) + + +class DeepGo1DataMigration: + """ + A class to handle data migration and processing for the DeepGO project. + It migrates the deepGO data to our data structure followed for GO-UniProt data. + + It migrates the data of DeepGO model of the below research paper: + Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf, + DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier, + Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668 + (https://doi.org/10.1093/bioinformatics/btx624), + + Attributes: + _CORRESPONDING_GO_CLASSES (dict): Mapping of GO branches to specific data extractor classes. + _MAXLEN (int): Maximum sequence length for sequences. + _LABELS_START_IDX (int): Starting index for labels in the dataset. + + Methods: + __init__(data_dir, go_branch): Initializes the data directory and GO branch. + _load_data(): Loads train, validation, test, and terms data from the specified directory. + _record_splits(): Creates a DataFrame with IDs and their corresponding split. + migrate(): Executes the migration process including data loading, processing, and saving. + _extract_required_data_from_splits(): Extracts required columns from the splits data. + _get_labels_columns(data_df): Generates label columns for the data based on GO terms. + extract_go_id(go_list): Extracts GO IDs from a list. + save_migrated_data(data_df, splits_df): Saves the processed data and splits. + """ + + # Number of annotations for each go_branch as per the research paper + _CORRESPONDING_GO_CLASSES = { + "cc": GOUniProtOver50, + "mf": GOUniProtOver50, + "bp": GOUniProtOver250, + } + + _MAXLEN = 1002 + _LABELS_START_IDX = _GOUniProtDataExtractor._LABELS_START_IDX + + def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): + """ + Initializes the data migration object with a data directory and GO branch. + + Args: + data_dir (str): Directory containing the data files. + go_branch (Literal["cc", "mf", "bp"]): GO branch to use (cellular_component, molecular_function, or biological_process). + """ + valid_go_branches = list(self._CORRESPONDING_GO_CLASSES.keys()) + if go_branch not in valid_go_branches: + raise ValueError(f"go_branch must be one of {valid_go_branches}") + self._go_branch = go_branch + + self._data_dir: str = rf"{data_dir}" + self._train_df: Optional[pd.DataFrame] = None + self._test_df: Optional[pd.DataFrame] = None + self._validation_df: Optional[pd.DataFrame] = None + self._terms_df: Optional[pd.DataFrame] = None + self._classes: Optional[List[str]] = None + + def _load_data(self) -> None: + """ + Loads the test, train, validation, and terms data from the pickled files + in the data directory. + """ + try: + print(f"Loading data from {self._data_dir}......") + self._test_df = pd.DataFrame( + pd.read_pickle( + os.path.join(self._data_dir, f"test-{self._go_branch}.pkl") + ) + ) + self._train_df = pd.DataFrame( + pd.read_pickle( + os.path.join(self._data_dir, f"train-{self._go_branch}.pkl") + ) + ) + # self._validation_df = pd.DataFrame( + # pd.read_pickle(os.path.join(self._data_dir, f"valid-{self._go_branch}.pkl")) + # ) + self._terms_df = pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, f"{self._go_branch}.pkl")) + ) + + except FileNotFoundError as e: + print(f"Error loading data: {e}") + + def _record_splits(self) -> pd.DataFrame: + """ + Creates a DataFrame that stores the IDs and their corresponding data splits. + + Returns: + pd.DataFrame: A combined DataFrame containing split assignments. + """ + print("Recording splits...") + split_assignment_list: List[pd.DataFrame] = [ + pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}), + # pd.DataFrame( + # {"id": self._validation_df["proteins"], "split": "validation"} + # ), + pd.DataFrame({"id": self._test_df["proteins"], "split": "test"}), + ] + + combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) + return combined_split_assignment + + def migrate(self) -> None: + """ + Executes the data migration by loading, processing, and saving the data. + """ + print("Migration started......") + self._load_data() + if not all( + df is not None + for df in [ + self._train_df, + # self._validation_df, + self._test_df, + self._terms_df, + ] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + splits_df = self._record_splits() + + data_with_labels_df = self._extract_required_data_from_splits() + + if not all( + var is not None for var in [data_with_labels_df, splits_df, self._classes] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + + self.save_migrated_data(data_with_labels_df, splits_df) + + def _extract_required_data_from_splits(self) -> pd.DataFrame: + """ + Extracts required columns from the combined data splits. + + Returns: + pd.DataFrame: A DataFrame containing the essential columns for processing. + """ + print("Combining the data splits with required data..... ") + required_columns = [ + "proteins", + "accessions", + "sequences", + # Note: The GO classes here only directly related one, and not transitive GO classes + "gos", + "labels", + ] + + new_df = pd.concat( + [ + self._train_df[required_columns], + # self._validation_df[required_columns], + self._test_df[required_columns], + ], + ignore_index=True, + ) + new_df["go_ids"] = new_df.apply( + lambda row: self.extract_go_id(row["gos"]), axis=1 + ) + + labels_df = self._get_labels_colums(new_df) + + data_df = pd.DataFrame( + OrderedDict( + swiss_id=new_df["proteins"], + accession=new_df["accessions"], + go_ids=new_df["go_ids"], + sequence=new_df["sequences"], + ) + ) + + df = pd.concat([data_df, labels_df], axis=1) + + return df + + def _get_labels_colums(self, data_df: pd.DataFrame) -> pd.DataFrame: + """ + Generates a DataFrame with one-hot encoded columns for each GO term label, + based on the terms provided in `self._terms_df` and the existing labels in `data_df`. + + This method extracts GO IDs from the `functions` column of `self._terms_df`, + creating a list of all unique GO IDs. It then uses this list to create new + columns in the returned DataFrame, where each row has binary values + (0 or 1) indicating the presence of each GO ID in the corresponding entry of + `data_df['labels']`. + + Args: + data_df (pd.DataFrame): DataFrame containing data with a 'labels' column, + which holds lists of GO ID labels for each row. + + Returns: + pd.DataFrame: A DataFrame with the same index as `data_df` and one column + per GO ID, containing binary values indicating label presence. + """ + print("Generating labels based on terms.pkl file.......") + parsed_go_ids: pd.Series = self._terms_df["functions"].apply( + lambda gos: _GOUniProtDataExtractor._parse_go_id(gos) + ) + all_go_ids_list = parsed_go_ids.values.tolist() + self._classes = all_go_ids_list + + new_label_columns = pd.DataFrame( + data_df["labels"].tolist(), index=data_df.index, columns=all_go_ids_list + ) + + return new_label_columns + + @staticmethod + def extract_go_id(go_list: List[str]) -> List[int]: + """ + Extracts and parses GO IDs from a list of GO annotations. + + Args: + go_list (List[str]): List of GO annotation strings. + + Returns: + List[str]: List of parsed GO IDs. + """ + return [ + _GOUniProtDataExtractor._parse_go_id(go_id_str) for go_id_str in go_list + ] + + def save_migrated_data( + self, data_df: pd.DataFrame, splits_df: pd.DataFrame + ) -> None: + """ + Saves the processed data and split information. + + Args: + data_df (pd.DataFrame): Data with GO labels. + splits_df (pd.DataFrame): Split assignment DataFrame. + """ + print("Saving transformed data......") + go_class_instance: _GOUniProtDataExtractor = self._CORRESPONDING_GO_CLASSES[ + self._go_branch + ](go_branch=self._go_branch.upper(), max_sequence_length=self._MAXLEN) + + go_class_instance.save_processed( + data_df, go_class_instance.processed_main_file_names_dict["data"] + ) + print( + f"{go_class_instance.processed_main_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}" + ) + + splits_df.to_csv( + os.path.join(go_class_instance.processed_dir_main, "splits.csv"), + index=False, + ) + print(f"splits.csv saved to {go_class_instance.processed_dir_main}") + + classes = sorted(self._classes) + with open( + os.path.join(go_class_instance.processed_dir_main, "classes.txt"), "wt" + ) as fout: + fout.writelines(str(node) + "\n" for node in classes) + print(f"classes.txt saved to {go_class_instance.processed_dir_main}") + print("Migration completed!") + + +class Main: + """ + Main class to handle the migration process for DeepGo1DataMigration. + + Methods: + migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]): + Initiates the migration process for the specified data directory and GO branch. + """ + + @staticmethod + def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None: + """ + Initiates the migration process by creating a DeepGoDataMigration instance + and invoking its migrate method. + + Args: + data_dir (str): Directory containing the data files. + go_branch (Literal["cc", "mf", "bp"]): GO branch to use + ("cc" for cellular_component, + "mf" for molecular_function, + or "bp" for biological_process). + """ + DeepGo1DataMigration(data_dir, go_branch).migrate() + + +if __name__ == "__main__": + # Example: python script_name.py migrate --data_dir="data/deep_go1" --go_branch="mf" + # --data_dir specifies the directory containing the data files. + # --go_branch specifies the GO branch (cc, mf, or bp) you want to use for the migration. + CLI( + Main, + description="DeepGo1DataMigration CLI tool to handle migration of GO data for specified branches (cc, mf, bp).", + as_positional=False, + ) diff --git a/chebai/preprocessing/migration/deep_go_data_mirgration.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py similarity index 96% rename from chebai/preprocessing/migration/deep_go_data_mirgration.py rename to chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py index 5c22b389..0d5266ef 100644 --- a/chebai/preprocessing/migration/deep_go_data_mirgration.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py @@ -12,11 +12,17 @@ ) -class DeepGoDataMigration: +class DeepGo2DataMigration: """ A class to handle data migration and processing for the DeepGO project. It migrates the data from the DeepGO-SE data structure to our data structure followed for GO-UniProt data. + It migrates the data of DeepGO model of the below research paper: + Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf, + DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier, + Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668 + (https://doi.org/10.1093/bioinformatics/btx624), + Attributes: _CORRESPONDING_GO_CLASSES (dict): Mapping of GO branches to specific data extractor classes. _MAXLEN (int): Maximum sequence length for sequences. @@ -283,7 +289,7 @@ def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None: "mf" for molecular_function, or "bp" for biological_process). """ - DeepGoDataMigration(data_dir, go_branch).migrate() + DeepGo2DataMigration(data_dir, go_branch).migrate() if __name__ == "__main__": From a15d49254c1d5a378dc8ac64508392c55fcb3841 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 12 Nov 2024 17:40:39 +0100 Subject: [PATCH 13/71] deepgo1: create non-exclusive val set as a placeholder --- .../deep_go/migrate_deep_go_1_data.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py index be709364..f42b08c3 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py @@ -89,6 +89,14 @@ def _load_data(self) -> None: # self._validation_df = pd.DataFrame( # pd.read_pickle(os.path.join(self._data_dir, f"valid-{self._go_branch}.pkl")) # ) + + # DeepGO1 data does not include a separate validation split, but our data structure requires one. + # To accommodate this, we will create a placeholder validation split by duplicating a small subset of the + # training data. However, to ensure a fair comparison with DeepGO1, we will retain the full training set + # without creating an exclusive validation split from it. + # Therefore, any metrics calculated on this placeholder validation set should be disregarded, as they do not + # reflect true validation performance. + self._validation_df = self._train_df[len(self._train_df) - 5 :] self._terms_df = pd.DataFrame( pd.read_pickle(os.path.join(self._data_dir, f"{self._go_branch}.pkl")) ) @@ -106,9 +114,9 @@ def _record_splits(self) -> pd.DataFrame: print("Recording splits...") split_assignment_list: List[pd.DataFrame] = [ pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}), - # pd.DataFrame( - # {"id": self._validation_df["proteins"], "split": "validation"} - # ), + pd.DataFrame( + {"id": self._validation_df["proteins"], "split": "validation"} + ), pd.DataFrame({"id": self._test_df["proteins"], "split": "test"}), ] @@ -125,7 +133,7 @@ def migrate(self) -> None: df is not None for df in [ self._train_df, - # self._validation_df, + self._validation_df, self._test_df, self._terms_df, ] @@ -166,7 +174,7 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame: new_df = pd.concat( [ self._train_df[required_columns], - # self._validation_df[required_columns], + self._validation_df[required_columns], self._test_df[required_columns], ], ignore_index=True, From e0a85247f2f7b561593d6cb9536e66aefb9ecebf Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 13 Nov 2024 22:45:32 +0100 Subject: [PATCH 14/71] deepgo1: further split train set into train and val for - +migration structure changes --- .../deep_go/migrate_deep_go_1_data.py | 241 +++++++++--------- 1 file changed, 118 insertions(+), 123 deletions(-) diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py index f42b08c3..48188cd7 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py @@ -1,53 +1,29 @@ import os from collections import OrderedDict -from typing import List, Literal, Optional +from typing import List, Literal, Optional, Tuple import pandas as pd +from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit from jsonargparse import CLI -from chebai.preprocessing.datasets.go_uniprot import ( - GOUniProtOver50, - GOUniProtOver250, - _GOUniProtDataExtractor, -) +from chebai.preprocessing.datasets.go_uniprot import DeepGO1MigratedData class DeepGo1DataMigration: """ A class to handle data migration and processing for the DeepGO project. - It migrates the deepGO data to our data structure followed for GO-UniProt data. + It migrates the DeepGO data to our data structure followed for GO-UniProt data. - It migrates the data of DeepGO model of the below research paper: + This class handles data from the DeepGO model as described in: Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf, DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier, Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668 - (https://doi.org/10.1093/bioinformatics/btx624), - - Attributes: - _CORRESPONDING_GO_CLASSES (dict): Mapping of GO branches to specific data extractor classes. - _MAXLEN (int): Maximum sequence length for sequences. - _LABELS_START_IDX (int): Starting index for labels in the dataset. - - Methods: - __init__(data_dir, go_branch): Initializes the data directory and GO branch. - _load_data(): Loads train, validation, test, and terms data from the specified directory. - _record_splits(): Creates a DataFrame with IDs and their corresponding split. - migrate(): Executes the migration process including data loading, processing, and saving. - _extract_required_data_from_splits(): Extracts required columns from the splits data. - _get_labels_columns(data_df): Generates label columns for the data based on GO terms. - extract_go_id(go_list): Extracts GO IDs from a list. - save_migrated_data(data_df, splits_df): Saves the processed data and splits. + (https://doi.org/10.1093/bioinformatics/btx624). """ - # Number of annotations for each go_branch as per the research paper - _CORRESPONDING_GO_CLASSES = { - "cc": GOUniProtOver50, - "mf": GOUniProtOver50, - "bp": GOUniProtOver250, - } - + # Max sequence length as per DeepGO1 _MAXLEN = 1002 - _LABELS_START_IDX = _GOUniProtDataExtractor._LABELS_START_IDX + _LABELS_START_IDX = DeepGO1MigratedData._LABELS_START_IDX def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): """ @@ -55,9 +31,9 @@ def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): Args: data_dir (str): Directory containing the data files. - go_branch (Literal["cc", "mf", "bp"]): GO branch to use (cellular_component, molecular_function, or biological_process). + go_branch (Literal["cc", "mf", "bp"]): GO branch to use. """ - valid_go_branches = list(self._CORRESPONDING_GO_CLASSES.keys()) + valid_go_branches = list(DeepGO1MigratedData.GO_BRANCH_MAPPING.keys()) if go_branch not in valid_go_branches: raise ValueError(f"go_branch must be one of {valid_go_branches}") self._go_branch = go_branch @@ -69,34 +45,60 @@ def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): self._terms_df: Optional[pd.DataFrame] = None self._classes: Optional[List[str]] = None + def migrate(self) -> None: + """ + Executes the data migration by loading, processing, and saving the data. + """ + print("Starting the migration process...") + self._load_data() + if not all( + df is not None + for df in [ + self._train_df, + self._validation_df, + self._test_df, + self._terms_df, + ] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + splits_df = self._record_splits() + data_with_labels_df = self._extract_required_data_from_splits() + + if not all( + var is not None for var in [data_with_labels_df, splits_df, self._classes] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + + self.save_migrated_data(data_with_labels_df, splits_df) + def _load_data(self) -> None: """ Loads the test, train, validation, and terms data from the pickled files in the data directory. """ try: - print(f"Loading data from {self._data_dir}......") + print(f"Loading data files from directory: {self._data_dir}") self._test_df = pd.DataFrame( pd.read_pickle( os.path.join(self._data_dir, f"test-{self._go_branch}.pkl") ) ) - self._train_df = pd.DataFrame( + + # DeepGO 1 lacks a validation split, so we will create one by further splitting the training set. + # Although this reduces the training data slightly compared to the original DeepGO setup, + # given the data size, the impact should be minimal. + train_df = pd.DataFrame( pd.read_pickle( os.path.join(self._data_dir, f"train-{self._go_branch}.pkl") ) ) - # self._validation_df = pd.DataFrame( - # pd.read_pickle(os.path.join(self._data_dir, f"valid-{self._go_branch}.pkl")) - # ) - - # DeepGO1 data does not include a separate validation split, but our data structure requires one. - # To accommodate this, we will create a placeholder validation split by duplicating a small subset of the - # training data. However, to ensure a fair comparison with DeepGO1, we will retain the full training set - # without creating an exclusive validation split from it. - # Therefore, any metrics calculated on this placeholder validation set should be disregarded, as they do not - # reflect true validation performance. - self._validation_df = self._train_df[len(self._train_df) - 5 :] + + self._train_df, self._validation_df = self._get_train_val_split(train_df) + self._terms_df = pd.DataFrame( pd.read_pickle(os.path.join(self._data_dir, f"{self._go_branch}.pkl")) ) @@ -104,6 +106,35 @@ def _load_data(self) -> None: except FileNotFoundError as e: print(f"Error loading data: {e}") + @staticmethod + def _get_train_val_split( + train_df: pd.DataFrame, + ) -> Tuple[pd.DataFrame, pd.DataFrame]: + """ + Splits the training data into a smaller training set and a validation set. + + Args: + train_df (pd.DataFrame): Original training DataFrame. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame]: Training and validation DataFrames. + """ + labels_list_train = train_df["labels"].tolist() + train_split = 0.85 + test_size = ((1 - train_split) ** 2) / train_split + + splitter = MultilabelStratifiedShuffleSplit( + n_splits=1, test_size=test_size, random_state=42 + ) + + train_indices, validation_indices = next( + splitter.split(labels_list_train, labels_list_train) + ) + + df_validation = train_df.iloc[validation_indices] + df_train = train_df.iloc[train_indices] + return df_train, df_validation + def _record_splits(self) -> pd.DataFrame: """ Creates a DataFrame that stores the IDs and their corresponding data splits. @@ -111,7 +142,7 @@ def _record_splits(self) -> pd.DataFrame: Returns: pd.DataFrame: A combined DataFrame containing split assignments. """ - print("Recording splits...") + print("Recording data splits for train, validation, and test sets.") split_assignment_list: List[pd.DataFrame] = [ pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}), pd.DataFrame( @@ -123,37 +154,6 @@ def _record_splits(self) -> pd.DataFrame: combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) return combined_split_assignment - def migrate(self) -> None: - """ - Executes the data migration by loading, processing, and saving the data. - """ - print("Migration started......") - self._load_data() - if not all( - df is not None - for df in [ - self._train_df, - self._validation_df, - self._test_df, - self._terms_df, - ] - ): - raise Exception( - "Data splits or terms data is not available in instance variables." - ) - splits_df = self._record_splits() - - data_with_labels_df = self._extract_required_data_from_splits() - - if not all( - var is not None for var in [data_with_labels_df, splits_df, self._classes] - ): - raise Exception( - "Data splits or terms data is not available in instance variables." - ) - - self.save_migrated_data(data_with_labels_df, splits_df) - def _extract_required_data_from_splits(self) -> pd.DataFrame: """ Extracts required columns from the combined data splits. @@ -161,12 +161,11 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame: Returns: pd.DataFrame: A DataFrame containing the essential columns for processing. """ - print("Combining the data splits with required data..... ") + print("Combining data splits into a single DataFrame with required columns.") required_columns = [ "proteins", "accessions", "sequences", - # Note: The GO classes here only directly related one, and not transitive GO classes "gos", "labels", ] @@ -183,7 +182,7 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame: lambda row: self.extract_go_id(row["gos"]), axis=1 ) - labels_df = self._get_labels_colums(new_df) + labels_df = self._get_labels_columns(new_df) data_df = pd.DataFrame( OrderedDict( @@ -198,28 +197,32 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame: return df - def _get_labels_colums(self, data_df: pd.DataFrame) -> pd.DataFrame: + @staticmethod + def extract_go_id(go_list: List[str]) -> List[int]: """ - Generates a DataFrame with one-hot encoded columns for each GO term label, - based on the terms provided in `self._terms_df` and the existing labels in `data_df`. + Extracts and parses GO IDs from a list of GO annotations. - This method extracts GO IDs from the `functions` column of `self._terms_df`, - creating a list of all unique GO IDs. It then uses this list to create new - columns in the returned DataFrame, where each row has binary values - (0 or 1) indicating the presence of each GO ID in the corresponding entry of - `data_df['labels']`. + Args: + go_list (List[str]): List of GO annotation strings. + + Returns: + List[int]: List of parsed GO IDs. + """ + return [DeepGO1MigratedData._parse_go_id(go_id_str) for go_id_str in go_list] + + def _get_labels_columns(self, data_df: pd.DataFrame) -> pd.DataFrame: + """ + Generates columns for labels based on provided selected terms. Args: - data_df (pd.DataFrame): DataFrame containing data with a 'labels' column, - which holds lists of GO ID labels for each row. + data_df (pd.DataFrame): DataFrame with GO annotations and labels. Returns: - pd.DataFrame: A DataFrame with the same index as `data_df` and one column - per GO ID, containing binary values indicating label presence. + pd.DataFrame: DataFrame with label columns. """ - print("Generating labels based on terms.pkl file.......") + print("Generating label columns from provided selected terms.") parsed_go_ids: pd.Series = self._terms_df["functions"].apply( - lambda gos: _GOUniProtDataExtractor._parse_go_id(gos) + lambda gos: DeepGO1MigratedData._parse_go_id(gos) ) all_go_ids_list = parsed_go_ids.values.tolist() self._classes = all_go_ids_list @@ -230,21 +233,6 @@ def _get_labels_colums(self, data_df: pd.DataFrame) -> pd.DataFrame: return new_label_columns - @staticmethod - def extract_go_id(go_list: List[str]) -> List[int]: - """ - Extracts and parses GO IDs from a list of GO annotations. - - Args: - go_list (List[str]): List of GO annotation strings. - - Returns: - List[str]: List of parsed GO IDs. - """ - return [ - _GOUniProtDataExtractor._parse_go_id(go_id_str) for go_id_str in go_list - ] - def save_migrated_data( self, data_df: pd.DataFrame, splits_df: pd.DataFrame ) -> None: @@ -255,31 +243,38 @@ def save_migrated_data( data_df (pd.DataFrame): Data with GO labels. splits_df (pd.DataFrame): Split assignment DataFrame. """ - print("Saving transformed data......") - go_class_instance: _GOUniProtDataExtractor = self._CORRESPONDING_GO_CLASSES[ - self._go_branch - ](go_branch=self._go_branch.upper(), max_sequence_length=self._MAXLEN) + print("Saving transformed data files.") - go_class_instance.save_processed( - data_df, go_class_instance.processed_main_file_names_dict["data"] + deepgo_migr_inst: DeepGO1MigratedData = DeepGO1MigratedData( + go_branch=DeepGO1MigratedData.GO_BRANCH_MAPPING[self._go_branch], + max_sequence_length=self._MAXLEN, + ) + + # Save data file + deepgo_migr_inst.save_processed( + data_df, deepgo_migr_inst.processed_main_file_names_dict["data"] ) print( - f"{go_class_instance.processed_main_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}" + f"{deepgo_migr_inst.processed_main_file_names_dict['data']} saved to {deepgo_migr_inst.processed_dir_main}" ) + # Save splits file splits_df.to_csv( - os.path.join(go_class_instance.processed_dir_main, "splits.csv"), + os.path.join(deepgo_migr_inst.processed_dir_main, "splits_deep_go1.csv"), index=False, ) - print(f"splits.csv saved to {go_class_instance.processed_dir_main}") + print(f"splits_deep_go1.csv saved to {deepgo_migr_inst.processed_dir_main}") + # Save classes file classes = sorted(self._classes) with open( - os.path.join(go_class_instance.processed_dir_main, "classes.txt"), "wt" + os.path.join(deepgo_migr_inst.processed_dir_main, "classes_deep_go1.txt"), + "wt", ) as fout: fout.writelines(str(node) + "\n" for node in classes) - print(f"classes.txt saved to {go_class_instance.processed_dir_main}") - print("Migration completed!") + print(f"classes_deep_go1.txt saved to {deepgo_migr_inst.processed_dir_main}") + + print("Migration process completed!") class Main: From 093be281a3784972a80abd647dff7f79ceaca553 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 13 Nov 2024 22:56:44 +0100 Subject: [PATCH 15/71] migration script update --- .../deep_go/migrate_deep_go_1_data.py | 2 +- .../deep_go/migrate_deep_go_2_data.py | 163 ++++++++---------- 2 files changed, 71 insertions(+), 94 deletions(-) diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py index 48188cd7..ad8ae322 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py @@ -14,7 +14,7 @@ class DeepGo1DataMigration: A class to handle data migration and processing for the DeepGO project. It migrates the DeepGO data to our data structure followed for GO-UniProt data. - This class handles data from the DeepGO model as described in: + This class handles migration of data from the DeepGO paper below: Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf, DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier, Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668 diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py index 0d5266ef..3d4109e1 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py @@ -5,11 +5,7 @@ import pandas as pd from jsonargparse import CLI -from chebai.preprocessing.datasets.go_uniprot import ( - GOUniProtOver50, - GOUniProtOver250, - _GOUniProtDataExtractor, -) +from chebai.preprocessing.datasets.go_uniprot import DeepGO2MigratedData class DeepGo2DataMigration: @@ -17,39 +13,16 @@ class DeepGo2DataMigration: A class to handle data migration and processing for the DeepGO project. It migrates the data from the DeepGO-SE data structure to our data structure followed for GO-UniProt data. - It migrates the data of DeepGO model of the below research paper: + This class handles migration of data from the DeepGO paper below: Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf, DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier, Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668 - (https://doi.org/10.1093/bioinformatics/btx624), - - Attributes: - _CORRESPONDING_GO_CLASSES (dict): Mapping of GO branches to specific data extractor classes. - _MAXLEN (int): Maximum sequence length for sequences. - _LABELS_START_IDX (int): Starting index for labels in the dataset. - - Methods: - __init__(data_dir, go_branch): Initializes the data directory and GO branch. - _load_data(): Loads train, validation, test, and terms data from the specified directory. - _record_splits(): Creates a DataFrame with IDs and their corresponding split. - migrate(): Executes the migration process including data loading, processing, and saving. - _extract_required_data_from_splits(): Extracts required columns from the splits data. - _generate_labels(data_df): Generates label columns for the data based on GO terms. - extract_go_id(go_list): Extracts GO IDs from a list. - save_migrated_data(data_df, splits_df): Saves the processed data and splits. + (https://doi.org/10.1093/bioinformatics/btx624) """ - # Link for the namespaces convention used for GO branch - # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/utils.py#L18-L22 - _CORRESPONDING_GO_CLASSES = { - "cc": GOUniProtOver50, - "mf": GOUniProtOver50, - "bp": GOUniProtOver250, - } - # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 _MAXLEN = 1000 - _LABELS_START_IDX = _GOUniProtDataExtractor._LABELS_START_IDX + _LABELS_START_IDX = DeepGO2MigratedData._LABELS_START_IDX def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): """ @@ -57,9 +30,9 @@ def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): Args: data_dir (str): Directory containing the data files. - go_branch (Literal["cc", "mf", "bp"]): GO branch to use (cellular_component, molecular_function, or biological_process). + go_branch (Literal["cc", "mf", "bp"]): GO branch to use. """ - valid_go_branches = list(self._CORRESPONDING_GO_CLASSES.keys()) + valid_go_branches = list(DeepGO2MigratedData.GO_BRANCH_MAPPING.keys()) if go_branch not in valid_go_branches: raise ValueError(f"go_branch must be one of {valid_go_branches}") self._go_branch = go_branch @@ -71,13 +44,45 @@ def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): self._terms_df: Optional[pd.DataFrame] = None self._classes: Optional[List[str]] = None + def migrate(self) -> None: + """ + Executes the data migration by loading, processing, and saving the data. + """ + print("Starting the migration process...") + self._load_data() + if not all( + df is not None + for df in [ + self._train_df, + self._validation_df, + self._test_df, + self._terms_df, + ] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + splits_df = self._record_splits() + + data_df = self._extract_required_data_from_splits() + data_with_labels_df = self._generate_labels(data_df) + + if not all( + var is not None for var in [data_with_labels_df, splits_df, self._classes] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + + self.save_migrated_data(data_df, splits_df) + def _load_data(self) -> None: """ Loads the test, train, validation, and terms data from the pickled files in the data directory. """ try: - print(f"Loading data from {self._data_dir}......") + print(f"Loading data from directory: {self._data_dir}......") self._test_df = pd.DataFrame( pd.read_pickle(os.path.join(self._data_dir, "test_data.pkl")) ) @@ -100,7 +105,7 @@ def _record_splits(self) -> pd.DataFrame: Returns: pd.DataFrame: A combined DataFrame containing split assignments. """ - print("Recording splits...") + print("Recording data splits for train, validation, and test sets.") split_assignment_list: List[pd.DataFrame] = [ pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}), pd.DataFrame( @@ -112,38 +117,6 @@ def _record_splits(self) -> pd.DataFrame: combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) return combined_split_assignment - def migrate(self) -> None: - """ - Executes the data migration by loading, processing, and saving the data. - """ - print("Migration started......") - self._load_data() - if not all( - df is not None - for df in [ - self._train_df, - self._validation_df, - self._test_df, - self._terms_df, - ] - ): - raise Exception( - "Data splits or terms data is not available in instance variables." - ) - splits_df = self._record_splits() - - data_df = self._extract_required_data_from_splits() - data_with_labels_df = self._generate_labels(data_df) - - if not all( - var is not None for var in [data_with_labels_df, splits_df, self._classes] - ): - raise Exception( - "Data splits or terms data is not available in instance variables." - ) - - self.save_migrated_data(data_df, splits_df) - def _extract_required_data_from_splits(self) -> pd.DataFrame: """ Extracts required columns from the combined data splits. @@ -186,6 +159,19 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame: ) return data_df + @staticmethod + def extract_go_id(go_list: List[str]) -> List[int]: + """ + Extracts and parses GO IDs from a list of GO annotations. + + Args: + go_list (List[str]): List of GO annotation strings. + + Returns: + List[str]: List of parsed GO IDs. + """ + return [DeepGO2MigratedData._parse_go_id(go_id_str) for go_id_str in go_list] + def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame: """ Generates label columns for each GO term in the dataset. @@ -198,7 +184,7 @@ def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame: """ print("Generating labels based on terms.pkl file.......") parsed_go_ids: pd.Series = self._terms_df["gos"].apply( - lambda gos: _GOUniProtDataExtractor._parse_go_id(gos) + lambda gos: DeepGO2MigratedData._parse_go_id(gos) ) all_go_ids_list = parsed_go_ids.values.tolist() self._classes = all_go_ids_list @@ -215,21 +201,6 @@ def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame: data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)] return data_df - @staticmethod - def extract_go_id(go_list: List[str]) -> List[int]: - """ - Extracts and parses GO IDs from a list of GO annotations. - - Args: - go_list (List[str]): List of GO annotation strings. - - Returns: - List[str]: List of parsed GO IDs. - """ - return [ - _GOUniProtDataExtractor._parse_go_id(go_id_str) for go_id_str in go_list - ] - def save_migrated_data( self, data_df: pd.DataFrame, splits_df: pd.DataFrame ) -> None: @@ -241,29 +212,35 @@ def save_migrated_data( splits_df (pd.DataFrame): Split assignment DataFrame. """ print("Saving transformed data......") - go_class_instance: _GOUniProtDataExtractor = self._CORRESPONDING_GO_CLASSES[ - self._go_branch - ](go_branch=self._go_branch.upper(), max_sequence_length=self._MAXLEN) + deepgo_migr_inst: DeepGO2MigratedData = DeepGO2MigratedData( + go_branch=DeepGO2MigratedData.GO_BRANCH_MAPPING[self._go_branch], + max_sequence_length=self._MAXLEN, + ) - go_class_instance.save_processed( - data_df, go_class_instance.processed_main_file_names_dict["data"] + # Save data file + deepgo_migr_inst.save_processed( + data_df, deepgo_migr_inst.processed_main_file_names_dict["data"] ) print( - f"{go_class_instance.processed_main_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}" + f"{deepgo_migr_inst.processed_main_file_names_dict['data']} saved to {deepgo_migr_inst.processed_dir_main}" ) + # Save split file splits_df.to_csv( - os.path.join(go_class_instance.processed_dir_main, "splits.csv"), + os.path.join(deepgo_migr_inst.processed_dir_main, "splits_deep_go2.csv"), index=False, ) - print(f"splits.csv saved to {go_class_instance.processed_dir_main}") + print(f"splits_deep_go2.csv saved to {deepgo_migr_inst.processed_dir_main}") + # Save classes.txt file classes = sorted(self._classes) with open( - os.path.join(go_class_instance.processed_dir_main, "classes.txt"), "wt" + os.path.join(deepgo_migr_inst.processed_dir_main, "classes_deep_go2.txt"), + "wt", ) as fout: fout.writelines(str(node) + "\n" for node in classes) - print(f"classes.txt saved to {go_class_instance.processed_dir_main}") + print(f"classes_deep_go2.txt saved to {deepgo_migr_inst.processed_dir_main}") + print("Migration completed!") From 14db9d641a8b627dd2a878eee736158abdeddbcc Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 13 Nov 2024 23:00:18 +0100 Subject: [PATCH 16/71] add classes to use migrated deepgo data --- chebai/preprocessing/datasets/go_uniprot.py | 186 ++++++++++++++++++++ 1 file changed, 186 insertions(+) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index 7b1c16e3..16bd6a31 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -22,6 +22,8 @@ "GOUniProtOver50", "EXPERIMENTAL_EVIDENCE_CODES", "AMBIGUOUS_AMINO_ACIDS", + "DeepGO1MigratedData", + "DeepGO2MigratedData", ] import gzip @@ -731,3 +733,187 @@ class GOUniProtOver50(_GOUniProtOverX): """ THRESHOLD: int = 50 + + +class _DeepGOMigratedData(_GOUniProtDataExtractor, ABC): + """ + Base class for use of the migrated DeepGO data with common properties, name formatting, and file paths. + + Attributes: + READER (dr.ProteinDataReader): Protein data reader class. + THRESHOLD (Optional[int]): Threshold value for GO class selection, + determined by the GO branch type in derived classes. + """ + + READER: dr.ProteinDataReader = dr.ProteinDataReader + THRESHOLD: Optional[int] = None + + # Mapping from GO branch conventions used in DeepGO to our conventions + GO_BRANCH_MAPPING: dict = { + "cc": "CC", + "mf": "MF", + "bp": "BP", + } + + @property + def _name(self) -> str: + """ + Generates a unique identifier for the migrated data based on the GO + branch and max sequence length, optionally including a threshold. + + Returns: + str: A formatted name string for the data. + """ + threshold_part = f"GO{self.THRESHOLD}_" if self.THRESHOLD is not None else "" + + if self.go_branch != self._ALL_GO_BRANCHES: + return f"{threshold_part}{self.go_branch}_{self.max_sequence_length}" + + return f"{threshold_part}{self.max_sequence_length}" + + # ------------------------------ Phase: Prepare data ----------------------------------- + def prepare_data(self, *args: Any, **kwargs: Any) -> None: + """ + Checks for the existence of migrated DeepGO data in the specified directory. + Raises an error if the required data file is not found, prompting + migration from DeepGO to this data structure. + + Args: + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. + + Raises: + FileNotFoundError: If the processed data file does not exist. + """ + print("Checking for processed data in", self.processed_dir_main) + + processed_name = self.processed_main_file_names_dict["data"] + if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): + raise FileNotFoundError( + f"File {processed_name} not found.\n" + f"You must run the appropriate DeepGO migration script " + f"(chebai/preprocessing/migration/deep_go) before executing this configuration " + f"to migrate data from DeepGO to this data structure." + ) + + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + # Selection of GO classes not needed for migrated data + pass + + # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + @abstractmethod + def processed_main_file_names_dict(self) -> Dict[str, str]: + """ + Abstract property for defining main processed file names. + These files are stored in the same directory as the generated data files + but have distinct names to differentiate them during training. + + Returns: + dict: A dictionary with key-value pairs for main processed file names. + """ + pass + + @property + @abstractmethod + def processed_file_names_dict(self) -> Dict[str, str]: + """ + Abstract property for defining additional processed file names. + These files are stored in the same directory as the generated data files + but have distinct names to differentiate them during training. + + Returns: + dict: A dictionary with key-value pairs for processed file names. + """ + pass + + +class DeepGO1MigratedData(_DeepGOMigratedData): + """ + Migrated data class specific to DeepGO1. Sets threshold values according + to the research paper based on the GO branch. + + Note: + Refer reference number 1 at the top of this file for the corresponding research paper. + + Args: + **kwargs: Arbitrary keyword arguments passed to the superclass. + + Raises: + ValueError: If an unsupported GO branch is provided. + """ + + def __init__(self, **kwargs): + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 + assert int(kwargs.get("max_sequence_length")) == 1002 + + # Set threshold based on GO branch, as per DeepGO1 paper and its data. + if kwargs.get("go_branch") in ["CC", "MF"]: + self.THRESHOLD = 50 + elif kwargs.get("go_branch") == "BP": + self.THRESHOLD = 250 + else: + raise ValueError( + f"DeepGO1 paper has no defined threshold for branch {self.go_branch}" + ) + + super(_DeepGOMigratedData, self).__init__(**kwargs) + + @property + def processed_main_file_names_dict(self) -> Dict[str, str]: + """ + Returns main processed file names specific to DeepGO1. + + Returns: + dict: Dictionary with the main data file name for DeepGO1. + """ + return {"data": "data_deep_go1.pkl"} + + @property + def processed_file_names_dict(self) -> Dict[str, str]: + """ + Returns processed file names specific to DeepGO1. + + Returns: + dict: Dictionary with data file name for DeepGO1. + """ + return {"data": "data_deep_go1.pt"} + + +class DeepGO2MigratedData(_DeepGOMigratedData): + """ + Migrated data class specific to DeepGO2, inheriting from DeepGO1MigratedData + with different processed file names. + + Note: + Refer reference number 3 at the top of this file for the corresponding research paper. + + Returns: + dict: Dictionary with file names specific to DeepGO2. + """ + + def __init__(self, **kwargs): + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 + assert int(kwargs.get("max_sequence_length")) == 1000 + + super(_DeepGOMigratedData, self).__init__(**kwargs) + + @property + def processed_main_file_names_dict(self) -> Dict[str, str]: + """ + Returns main processed file names specific to DeepGO2. + + Returns: + dict: Dictionary with the main data file name for DeepGO2. + """ + return {"data": "data_deep_go2.pkl"} + + @property + def processed_file_names_dict(self) -> Dict[str, str]: + """ + Returns processed file names specific to DeepGO2. + + Returns: + dict: Dictionary with data file name for DeepGO2. + """ + return {"data": "data_deep_go2.pt"} From 8922d4dc9c403648f6a039ac1144091383703f68 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 13 Nov 2024 23:24:07 +0100 Subject: [PATCH 17/71] deepgo: minor code change --- chebai/preprocessing/datasets/go_uniprot.py | 2 +- .../migration/deep_go/migrate_deep_go_1_data.py | 5 ++++- .../migration/deep_go/migrate_deep_go_2_data.py | 5 ++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index 16bd6a31..22d13e3f 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -764,7 +764,7 @@ def _name(self) -> str: Returns: str: A formatted name string for the data. """ - threshold_part = f"GO{self.THRESHOLD}_" if self.THRESHOLD is not None else "" + threshold_part = f"GO{self.THRESHOLD}_" if self.THRESHOLD is not None else "GO_" if self.go_branch != self._ALL_GO_BRANCHES: return f"{threshold_part}{self.go_branch}_{self.max_sequence_length}" diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py index ad8ae322..d9122c75 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py @@ -104,7 +104,10 @@ def _load_data(self) -> None: ) except FileNotFoundError as e: - print(f"Error loading data: {e}") + raise FileNotFoundError( + f"Data file not found in directory: {e}. " + "Please ensure all required files are available in the specified directory." + ) @staticmethod def _get_train_val_split( diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py index 3d4109e1..b24b3cfb 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py @@ -96,7 +96,10 @@ def _load_data(self) -> None: pd.read_pickle(os.path.join(self._data_dir, "terms.pkl")) ) except FileNotFoundError as e: - print(f"Error loading data: {e}") + raise FileNotFoundError( + f"Data file not found in directory: {e}. " + "Please ensure all required files are available in the specified directory." + ) def _record_splits(self) -> pd.DataFrame: """ From 796356cc3253e40eabfcc5a3d884c8bec089e086 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 13 Nov 2024 23:42:11 +0100 Subject: [PATCH 18/71] modify prints to display actual file name --- chebai/preprocessing/datasets/base.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index dfa0f999..f382f050 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -728,7 +728,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: processed_name = self.processed_main_file_names_dict["data"] if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): - print("Missing processed data file (`data.pkl` file)") + print(f"Missing processed data file (`{processed_name}` file)") os.makedirs(self.processed_dir_main, exist_ok=True) data_path = self._download_required_data() g = self._extract_class_hierarchy(data_path) @@ -812,12 +812,15 @@ def setup_processed(self) -> None: None """ os.makedirs(self.processed_dir, exist_ok=True) - print("Missing transformed data (`data.pt` file). Transforming data.... ") + processed_main_file_name = self.processed_main_file_names_dict["data"] + print( + f"Missing transformed data (`{processed_main_file_name}` file). Transforming data.... " + ) torch.save( self._load_data_from_file( os.path.join( self.processed_dir_main, - self.processed_main_file_names_dict["data"], + processed_main_file_name, ) ), os.path.join(self.processed_dir, self.processed_file_names_dict["data"]), From 3c11a690718ca743ac28d75438fa9bf9996adf84 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 17 Nov 2024 23:42:20 +0100 Subject: [PATCH 19/71] create sub dir for deego dataset and move rel files --- chebai/preprocessing/datasets/deepGO/__init__.py | 0 chebai/preprocessing/datasets/{ => deepGO}/go_uniprot.py | 0 chebai/preprocessing/datasets/{ => deepGO}/protein_pretraining.py | 0 .../preprocessing/datasets/deepGO/protein_protein_interactions.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 chebai/preprocessing/datasets/deepGO/__init__.py rename chebai/preprocessing/datasets/{ => deepGO}/go_uniprot.py (100%) rename chebai/preprocessing/datasets/{ => deepGO}/protein_pretraining.py (100%) create mode 100644 chebai/preprocessing/datasets/deepGO/protein_protein_interactions.py diff --git a/chebai/preprocessing/datasets/deepGO/__init__.py b/chebai/preprocessing/datasets/deepGO/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/deepGO/go_uniprot.py similarity index 100% rename from chebai/preprocessing/datasets/go_uniprot.py rename to chebai/preprocessing/datasets/deepGO/go_uniprot.py diff --git a/chebai/preprocessing/datasets/protein_pretraining.py b/chebai/preprocessing/datasets/deepGO/protein_pretraining.py similarity index 100% rename from chebai/preprocessing/datasets/protein_pretraining.py rename to chebai/preprocessing/datasets/deepGO/protein_pretraining.py diff --git a/chebai/preprocessing/datasets/deepGO/protein_protein_interactions.py b/chebai/preprocessing/datasets/deepGO/protein_protein_interactions.py new file mode 100644 index 00000000..e69de29b From 2b571c5f3b3d30fadc2ec77329ce1d16b70a99d1 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 17 Nov 2024 23:51:14 +0100 Subject: [PATCH 20/71] update imports as per new deepGO dir --- chebai/preprocessing/datasets/deepGO/protein_pretraining.py | 2 +- .../preprocessing/migration/deep_go/migrate_deep_go_1_data.py | 2 +- .../preprocessing/migration/deep_go/migrate_deep_go_2_data.py | 2 +- tests/unit/dataset_classes/testGOUniProDataExtractor.py | 2 +- tests/unit/dataset_classes/testGoUniProtOverX.py | 2 +- tutorials/data_exploration_go.ipynb | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/chebai/preprocessing/datasets/deepGO/protein_pretraining.py b/chebai/preprocessing/datasets/deepGO/protein_pretraining.py index d2a2b6db..8f7e9c4d 100644 --- a/chebai/preprocessing/datasets/deepGO/protein_pretraining.py +++ b/chebai/preprocessing/datasets/deepGO/protein_pretraining.py @@ -12,7 +12,7 @@ from sklearn.model_selection import train_test_split from chebai.preprocessing.datasets.base import _DynamicDataset -from chebai.preprocessing.datasets.go_uniprot import ( +from chebai.preprocessing.datasets.deepGO.go_uniprot import ( AMBIGUOUS_AMINO_ACIDS, EXPERIMENTAL_EVIDENCE_CODES, GOUniProtOver250, diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py index d9122c75..7d59c699 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py @@ -6,7 +6,7 @@ from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit from jsonargparse import CLI -from chebai.preprocessing.datasets.go_uniprot import DeepGO1MigratedData +from chebai.preprocessing.datasets.deepGO.go_uniprot import DeepGO1MigratedData class DeepGo1DataMigration: diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py index b24b3cfb..d63bcad3 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py @@ -5,7 +5,7 @@ import pandas as pd from jsonargparse import CLI -from chebai.preprocessing.datasets.go_uniprot import DeepGO2MigratedData +from chebai.preprocessing.datasets.deepGO.go_uniprot import DeepGO2MigratedData class DeepGo2DataMigration: diff --git a/tests/unit/dataset_classes/testGOUniProDataExtractor.py b/tests/unit/dataset_classes/testGOUniProDataExtractor.py index 9da48bee..96ff9a3a 100644 --- a/tests/unit/dataset_classes/testGOUniProDataExtractor.py +++ b/tests/unit/dataset_classes/testGOUniProDataExtractor.py @@ -6,7 +6,7 @@ import networkx as nx import pandas as pd -from chebai.preprocessing.datasets.go_uniprot import _GOUniProtDataExtractor +from chebai.preprocessing.datasets.deepGO.go_uniprot import _GOUniProtDataExtractor from chebai.preprocessing.reader import ProteinDataReader from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData diff --git a/tests/unit/dataset_classes/testGoUniProtOverX.py b/tests/unit/dataset_classes/testGoUniProtOverX.py index d4157770..3f329c56 100644 --- a/tests/unit/dataset_classes/testGoUniProtOverX.py +++ b/tests/unit/dataset_classes/testGoUniProtOverX.py @@ -5,7 +5,7 @@ import networkx as nx import pandas as pd -from chebai.preprocessing.datasets.go_uniprot import _GOUniProtOverX +from chebai.preprocessing.datasets.deepGO.go_uniprot import _GOUniProtOverX from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData diff --git a/tutorials/data_exploration_go.ipynb b/tutorials/data_exploration_go.ipynb index 6f67c82b..1a205e37 100644 --- a/tutorials/data_exploration_go.ipynb +++ b/tutorials/data_exploration_go.ipynb @@ -70,7 +70,7 @@ } }, "outputs": [], - "source": "from chebai.preprocessing.datasets.go_uniprot import GOUniProtOver250" + "source": "from chebai.preprocessing.datasets.deepGO.go_uniprot import GOUniProtOver250" }, { "cell_type": "code", From f75e30bcbbc3c3a7d916fa30ddee8fa34af1c486 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 17 Nov 2024 23:54:53 +0100 Subject: [PATCH 21/71] update import dir for pretrain test --- tests/unit/dataset_classes/testProteinPretrainingData.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/dataset_classes/testProteinPretrainingData.py b/tests/unit/dataset_classes/testProteinPretrainingData.py index cb6b0688..caac3eac 100644 --- a/tests/unit/dataset_classes/testProteinPretrainingData.py +++ b/tests/unit/dataset_classes/testProteinPretrainingData.py @@ -1,7 +1,9 @@ import unittest from unittest.mock import PropertyMock, mock_open, patch -from chebai.preprocessing.datasets.protein_pretraining import _ProteinPretrainingData +from chebai.preprocessing.datasets.deepGO.protein_pretraining import ( + _ProteinPretrainingData, +) from chebai.preprocessing.reader import ProteinDataReader from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData From 1b8b270c4b4ec99d81739c80ca658c9f7696da10 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 4 Dec 2024 12:06:11 +0100 Subject: [PATCH 22/71] migration fix : truncate seq and save data with labels --- .../deep_go/migrate_deep_go_2_data.py | 62 +++++++++++++++---- 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py index d63bcad3..1edec52b 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py @@ -20,17 +20,19 @@ class DeepGo2DataMigration: (https://doi.org/10.1093/bioinformatics/btx624) """ - # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 - _MAXLEN = 1000 _LABELS_START_IDX = DeepGO2MigratedData._LABELS_START_IDX - def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): + def __init__( + self, data_dir: str, go_branch: Literal["cc", "mf", "bp"], max_len: int = 1000 + ): """ Initializes the data migration object with a data directory and GO branch. Args: data_dir (str): Directory containing the data files. go_branch (Literal["cc", "mf", "bp"]): GO branch to use. + max_len (int): Used to truncate the sequence to this length. Default is 1000. + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 """ valid_go_branches = list(DeepGO2MigratedData.GO_BRANCH_MAPPING.keys()) if go_branch not in valid_go_branches: @@ -38,6 +40,8 @@ def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): self._go_branch = go_branch self._data_dir: str = os.path.join(rf"{data_dir}", go_branch) + self._max_len: int = max_len + self._train_df: Optional[pd.DataFrame] = None self._test_df: Optional[pd.DataFrame] = None self._validation_df: Optional[pd.DataFrame] = None @@ -74,33 +78,61 @@ def migrate(self) -> None: "Data splits or terms data is not available in instance variables." ) - self.save_migrated_data(data_df, splits_df) + self.save_migrated_data(data_with_labels_df, splits_df) def _load_data(self) -> None: """ Loads the test, train, validation, and terms data from the pickled files in the data directory. """ + try: print(f"Loading data from directory: {self._data_dir}......") - self._test_df = pd.DataFrame( - pd.read_pickle(os.path.join(self._data_dir, "test_data.pkl")) + self._test_df = self._truncate_sequences( + pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "test_data.pkl")) + ) ) - self._train_df = pd.DataFrame( - pd.read_pickle(os.path.join(self._data_dir, "train_data.pkl")) + self._train_df = self._truncate_sequences( + pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "train_data.pkl")) + ) ) - self._validation_df = pd.DataFrame( - pd.read_pickle(os.path.join(self._data_dir, "valid_data.pkl")) + self._validation_df = self._truncate_sequences( + pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "valid_data.pkl")) + ) ) + self._terms_df = pd.DataFrame( pd.read_pickle(os.path.join(self._data_dir, "terms.pkl")) ) + except FileNotFoundError as e: raise FileNotFoundError( f"Data file not found in directory: {e}. " "Please ensure all required files are available in the specified directory." ) + def _truncate_sequences( + self, df: pd.DataFrame, column: str = "sequences" + ) -> pd.DataFrame: + """ + Truncate sequences in a specified column of a dataframe to the maximum length. + + https://github.com/bio-ontology-research-group/deepgo2/blob/main/train_cnn.py#L206-L217 + + Args: + df (pd.DataFrame): The input dataframe containing the data to be processed. + column (str, optional): The column containing sequences to truncate. + Defaults to "sequences". + + Returns: + pd.DataFrame: The dataframe with sequences truncated to `self._max_len`. + """ + df[column] = df[column].apply(lambda x: x[: self._max_len]) + return df + def _record_splits(self) -> pd.DataFrame: """ Creates a DataFrame that stores the IDs and their corresponding data splits. @@ -217,7 +249,7 @@ def save_migrated_data( print("Saving transformed data......") deepgo_migr_inst: DeepGO2MigratedData = DeepGO2MigratedData( go_branch=DeepGO2MigratedData.GO_BRANCH_MAPPING[self._go_branch], - max_sequence_length=self._MAXLEN, + max_sequence_length=self._max_len, ) # Save data file @@ -257,7 +289,9 @@ class Main: """ @staticmethod - def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None: + def migrate( + data_dir: str, go_branch: Literal["cc", "mf", "bp"], max_len: int = 1000 + ) -> None: """ Initiates the migration process by creating a DeepGoDataMigration instance and invoking its migrate method. @@ -268,8 +302,10 @@ def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None: ("cc" for cellular_component, "mf" for molecular_function, or "bp" for biological_process). + max_len (int): Used to truncate the sequence to this length. Default is 1000. + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 """ - DeepGo2DataMigration(data_dir, go_branch).migrate() + DeepGo2DataMigration(data_dir, go_branch, max_len).migrate() if __name__ == "__main__": From bcda11ca7517c4e60456303bc98418f345ca6f08 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 4 Dec 2024 12:37:56 +0100 Subject: [PATCH 23/71] Delete protein_protein_interactions.py --- .../preprocessing/datasets/deepGO/protein_protein_interactions.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 chebai/preprocessing/datasets/deepGO/protein_protein_interactions.py diff --git a/chebai/preprocessing/datasets/deepGO/protein_protein_interactions.py b/chebai/preprocessing/datasets/deepGO/protein_protein_interactions.py deleted file mode 100644 index e69de29b..00000000 From 85c47a05aa36a2bde9f07fca71f8838fe8fd5e96 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 4 Dec 2024 15:55:57 +0100 Subject: [PATCH 24/71] migration: replace invalid amino acid with "X" notation - https://github.com/ChEB-AI/python-chebai/pull/64#issuecomment-2517067073 --- .../deep_go/migrate_deep_go_2_data.py | 55 ++++++++++++++++++- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py index 1edec52b..0bb07914 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py @@ -1,4 +1,5 @@ import os +import re from collections import OrderedDict from typing import List, Literal, Optional @@ -6,6 +7,7 @@ from jsonargparse import CLI from chebai.preprocessing.datasets.deepGO.go_uniprot import DeepGO2MigratedData +from chebai.preprocessing.reader import ProteinDataReader class DeepGo2DataMigration: @@ -88,17 +90,25 @@ def _load_data(self) -> None: try: print(f"Loading data from directory: {self._data_dir}......") - self._test_df = self._truncate_sequences( + + print( + "Pre-processing the data before loading them into instance variables\n" + f"2-Steps preprocessing: \n" + f"\t 1: Truncating every sequence to {self._max_len}\n" + f"\t 2: Replacing every amino acid which is not in {ProteinDataReader.AA_LETTER}" + ) + + self._test_df = self._pre_process_data( pd.DataFrame( pd.read_pickle(os.path.join(self._data_dir, "test_data.pkl")) ) ) - self._train_df = self._truncate_sequences( + self._train_df = self._pre_process_data( pd.DataFrame( pd.read_pickle(os.path.join(self._data_dir, "train_data.pkl")) ) ) - self._validation_df = self._truncate_sequences( + self._validation_df = self._pre_process_data( pd.DataFrame( pd.read_pickle(os.path.join(self._data_dir, "valid_data.pkl")) ) @@ -114,6 +124,21 @@ def _load_data(self) -> None: "Please ensure all required files are available in the specified directory." ) + def _pre_process_data(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Pre-processes the input dataframe by truncating sequences to the maximum + length and replacing invalid amino acids with 'X'. + + Args: + df (pd.DataFrame): The dataframe to preprocess. + + Returns: + pd.DataFrame: The processed dataframe. + """ + df = self._truncate_sequences(df) + df = self._replace_invalid_amino_acids(df) + return df + def _truncate_sequences( self, df: pd.DataFrame, column: str = "sequences" ) -> pd.DataFrame: @@ -133,6 +158,30 @@ def _truncate_sequences( df[column] = df[column].apply(lambda x: x[: self._max_len]) return df + @staticmethod + def _replace_invalid_amino_acids( + df: pd.DataFrame, column: str = "sequences" + ) -> pd.DataFrame: + """ + Replaces invalid amino acids in a sequence with 'X' using regex. + + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L26-L33 + https://github.com/ChEB-AI/python-chebai/pull/64#issuecomment-2517067073 + + Args: + df (pd.DataFrame): The dataframe containing the sequences to be processed. + column (str, optional): The column containing the sequences. Defaults to "sequences". + + Returns: + pd.DataFrame: The dataframe with invalid amino acids replaced by 'X'. + """ + valid_amino_acids = "".join(ProteinDataReader.AA_LETTER) + # Replace any character not in the valid set with 'X' + df[column] = df[column].apply( + lambda x: re.sub(f"[^{valid_amino_acids}]", "X", x) + ) + return df + def _record_splits(self) -> pd.DataFrame: """ Creates a DataFrame that stores the IDs and their corresponding data splits. From fbb5c58064e2171964290d0a7d6d7f1d3da35173 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 4 Dec 2024 16:41:22 +0100 Subject: [PATCH 25/71] update deepgo configs --- configs/data/deepGO/deepgo_1_migrated_data.yml | 4 ++++ configs/data/deepGO/deepgo_2_migrated_data.yml | 4 ++++ configs/data/deepGO/go250.yml | 3 +++ configs/data/deepGO/go50.yml | 1 + configs/data/go250.yml | 3 --- configs/data/go50.yml | 1 - 6 files changed, 12 insertions(+), 4 deletions(-) create mode 100644 configs/data/deepGO/deepgo_1_migrated_data.yml create mode 100644 configs/data/deepGO/deepgo_2_migrated_data.yml create mode 100644 configs/data/deepGO/go250.yml create mode 100644 configs/data/deepGO/go50.yml delete mode 100644 configs/data/go250.yml delete mode 100644 configs/data/go50.yml diff --git a/configs/data/deepGO/deepgo_1_migrated_data.yml b/configs/data/deepGO/deepgo_1_migrated_data.yml new file mode 100644 index 00000000..0924e023 --- /dev/null +++ b/configs/data/deepGO/deepgo_1_migrated_data.yml @@ -0,0 +1,4 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO1MigratedData +init_args: + go_branch: "MF" + max_sequence_length: 1002 diff --git a/configs/data/deepGO/deepgo_2_migrated_data.yml b/configs/data/deepGO/deepgo_2_migrated_data.yml new file mode 100644 index 00000000..1ed2ad09 --- /dev/null +++ b/configs/data/deepGO/deepgo_2_migrated_data.yml @@ -0,0 +1,4 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO2MigratedData +init_args: + go_branch: "MF" + max_sequence_length: 1000 diff --git a/configs/data/deepGO/go250.yml b/configs/data/deepGO/go250.yml new file mode 100644 index 00000000..01e34aa4 --- /dev/null +++ b/configs/data/deepGO/go250.yml @@ -0,0 +1,3 @@ +class_path: chebai.preprocessing.datasets.go_uniprot.deepGO.GOUniProtOver250 +init_args: + go_branch: "BP" diff --git a/configs/data/deepGO/go50.yml b/configs/data/deepGO/go50.yml new file mode 100644 index 00000000..bee43773 --- /dev/null +++ b/configs/data/deepGO/go50.yml @@ -0,0 +1 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.GOUniProtOver50 diff --git a/configs/data/go250.yml b/configs/data/go250.yml deleted file mode 100644 index 5598495c..00000000 --- a/configs/data/go250.yml +++ /dev/null @@ -1,3 +0,0 @@ -class_path: chebai.preprocessing.datasets.go_uniprot.GOUniProtOver250 -init_args: - go_branch: "BP" diff --git a/configs/data/go50.yml b/configs/data/go50.yml deleted file mode 100644 index 2ed4d14c..00000000 --- a/configs/data/go50.yml +++ /dev/null @@ -1 +0,0 @@ -class_path: chebai.preprocessing.datasets.go_uniprot.GOUniProtOver50 From 272446db7a5dd0f2aa08de6e96fd9a6d11a0e3d2 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 9 Dec 2024 13:04:30 +0100 Subject: [PATCH 26/71] add esm2 reader for deepGO --- chebai/preprocessing/reader.py | 257 ++++++++++++++++++++++++++++++++- setup.py | 1 + 2 files changed, 257 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index a08a3f91..dff2ff51 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -1,8 +1,18 @@ import os -from typing import Any, Dict, List, Optional +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple +from urllib.error import HTTPError import deepsmiles import selfies as sf +import torch +from esm import Alphabet +from esm.model.esm2 import ESM2 +from esm.pretrained import ( + _has_regression_weights, + load_model_and_alphabet_core, + load_model_and_alphabet_local, +) from pysmiles.read_smiles import _tokenize from transformers import RobertaTokenizerFast @@ -471,3 +481,248 @@ def on_finish(self) -> None: print(f"Saving {len(self.cache)} tokens to {self.token_path}...") print(f"First 10 tokens: {self.cache[:10]}") pk.writelines([f"{c}\n" for c in self.cache]) + + +class ESM2EmbeddingReader(DataReader): + """ + A data reader to process protein sequences using the ESM2 model for embeddings. + + References: + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/extract_esm.py + + Note: + For layer availability by model, Please check below link: + https://github.com/facebookresearch/esm?tab=readme-ov-file#pre-trained-models- + + To test this reader, try lighter models: + esm2_t6_8M_UR50D: 6 layers (valid layers: 1–6), (~28 Mb) - A tiny 8M parameter model. + esm2_t12_35M_UR50D: 12 layers (valid layers: 1–12), (~128 Mb) - A slightly larger, 35M parameter model. + These smaller models are good for testing and debugging purposes. + + """ + + # https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L53 + _MODELS_URL = "https://dl.fbaipublicfiles.com/fair-esm/models/{}.pt" + _REGRESSION_URL = ( + "https://dl.fbaipublicfiles.com/fair-esm/regression/{}-contact-regression.pt" + ) + + def __init__( + self, + save_model_dir: str, + model_name: str = "esm2_t36_3B_UR50D", + device: Optional[torch.device] = None, + truncation_length: int = 1022, + toks_per_batch: int = 4096, + return_contacts: bool = False, + repr_layer: int = 36, + *args, + **kwargs, + ): + """ + Initialize the ESM2EmbeddingReader class. + + Args: + save_model_dir (str): Directory to save/load the pretrained ESM model. + model_name (str): Name of the pretrained model. Defaults to "esm2_t36_3B_UR50D". + device (torch.device or str, optional): Device for computation (e.g., 'cpu', 'cuda'). + truncation_length (int): Maximum sequence length for truncation. Defaults to 1022. + toks_per_batch (int): Tokens per batch for data processing. Defaults to 4096. + return_contacts (bool): Whether to return contact maps. Defaults to False. + repr_layers (int): Layer number to extract representations from. Defaults to 36. + """ + self.save_model_dir = save_model_dir + if not os.path.exists(self.save_model_dir): + os.makedirs((os.path.dirname(self.save_model_dir)), exist_ok=True) + self.model_name = model_name + self.device = device + self.truncation_length = truncation_length + self.toks_per_batch = toks_per_batch + self.return_contacts = return_contacts + self.repr_layer = repr_layer + + self._model: Optional[ESM2] = None + self._alphabet: Optional[Alphabet] = None + + self._model, self._alphabet = self.load_model_and_alphabet() + self._model.eval() + + if self.device: + self._model = self._model.to(device) + + super().__init__(*args, **kwargs) + + def load_model_and_alphabet(self) -> Tuple[ESM2, Alphabet]: + """ + Load the ESM2 model and its alphabet. + + References: + https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L24-L28 + + Returns: + Tuple[ESM2, Alphabet]: Loaded model and alphabet. + """ + model_location = os.path.join(self.save_model_dir, f"{self.model_name}.pt") + if os.path.exists(model_location): + return load_model_and_alphabet_local(model_location) + else: + return self.load_model_and_alphabet_hub() + + def load_model_and_alphabet_hub(self) -> Tuple[ESM2, Alphabet]: + """ + Load the model and alphabet from the hub URL. + + References: + https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L62-L64 + + Returns: + Tuple[ESM2, Alphabet]: Loaded model and alphabet. + """ + model_url = self._MODELS_URL.format(self.model_name) + model_data = self.load_hub_workaround(model_url) + regression_data = None + if _has_regression_weights(self.model_name): + regression_url = self._REGRESSION_URL.format(self.model_name) + regression_data = self.load_hub_workaround(regression_url) + return load_model_and_alphabet_core( + self.model_name, model_data, regression_data + ) + + def load_hub_workaround(self, url) -> torch.Tensor: + """ + Workaround to load models from the PyTorch Hub. + + References: + https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L31-L43 + + Returns: + torch.Tensor: Loaded model state dictionary. + """ + try: + data = torch.hub.load_state_dict_from_url( + url, self.save_model_dir, progress=True, map_location=self.device + ) + + except RuntimeError: + # Handle PyTorch version issues + fn = Path(url).name + data = torch.load( + f"{torch.hub.get_dir()}/checkpoints/{fn}", + map_location="cpu", + ) + except HTTPError as e: + raise Exception( + f"Could not load {url}. Did you specify the correct model name?" + ) + return data + + def name(self) -> None: + """ + Returns the name of the data reader. This method identifies the specific type of data reader. + + Returns: + str: The name of the data reader, which is "protein_token". + """ + return "esm2_embedding" + + @property + def token_path(self) -> None: + """ + Not used as no token file is not created for this reader. + + Returns: + str: Empty string since this method is not implemented. + """ + return + + def _read_data(self, raw_data: str) -> List[int]: + """ + Reads protein sequence data and generates embeddings. + + Args: + raw_data (str): The protein sequence. + + Returns: + List[int]: Embeddings generated for the sequence. + """ + alp_tokens_idx = self._sequence_to_alphabet_tokens_idx(raw_data) + return self._alphabet_tokens_to_esm_embedding(alp_tokens_idx).tolist() + + def _sequence_to_alphabet_tokens_idx(self, sequence: str) -> torch.Tensor: + """ + Converts a protein sequence into ESM alphabet token indices. + + Args: + sequence (str): Protein sequence. + + References: + https://github.com/facebookresearch/esm/blob/2b369911bb5b4b0dda914521b9475cad1656b2ac/esm/data.py#L249-L250 + https://github.com/facebookresearch/esm/blob/2b369911bb5b4b0dda914521b9475cad1656b2ac/esm/data.py#L262-L297 + + Returns: + torch.Tensor: Tokenized sequence with special tokens (BOS/EOS) included. + """ + seq_encoded = self._alphabet.encode(sequence) + tokens = [] + + # Add BOS token if configured + if self._alphabet.prepend_bos: + tokens.append(self._alphabet.cls_idx) + + # Add the main sequence + tokens.extend(seq_encoded) + + # Add EOS token if configured + if self._alphabet.append_eos: + tokens.append(self._alphabet.eos_idx) + + # Convert to PyTorch tensor and return + return torch.tensor([tokens], dtype=torch.int64) + + def _alphabet_tokens_to_esm_embedding(self, tokens: torch.Tensor) -> torch.Tensor: + """ + Converts alphabet tokens into ESM embeddings. + + Args: + tokens (torch.Tensor): Tokenized protein sequences. + + References: + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/extract_esm.py#L82-L107 + + Returns: + torch.Tensor: Protein embedding from the specified representation layer. + """ + if self.device: + tokens = tokens.to(self.device, non_blocking=True) + + with torch.no_grad(): + out = self._model( + tokens, + repr_layers=[ + self.repr_layer, + ], + return_contacts=self.return_contacts, + ) + + # Extract representations and compute the mean embedding for each layer + representations = { + layer: t.to(self.device) for layer, t in out["representations"].items() + } + truncate_len = min(self.truncation_length, tokens.size(1)) + + result = { + "mean_representations": { + layer: t[0, 1 : truncate_len + 1].mean(0).clone() + for layer, t in representations.items() + } + } + return result["mean_representations"][self.repr_layer] + + def on_finish(self) -> None: + """ + Not used here as no token file exists for this reader. + + Returns: + None + """ + pass diff --git a/setup.py b/setup.py index 58bfc75b..ba134e41 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ "pyyaml", "torchmetrics", "biopython", + "fair-esm", ], extras_require={"dev": ["black", "isort", "pre-commit"]}, ) From a12354b527f670da28ac6b8f200b659d4d67ab43 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 9 Dec 2024 15:03:03 +0100 Subject: [PATCH 27/71] increase electra vocab size --- chebai/models/electra.py | 2 +- configs/model/electra.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 7009406d..dc6c719b 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -329,7 +329,7 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: except RuntimeError as e: print(f"RuntimeError at forward: {e}") print(f'data[features]: {data["features"]}') - raise Exception + raise e inp = self.word_dropout(inp) electra = self.electra(inputs_embeds=inp, **kwargs) d = electra.last_hidden_state[:, 0, :] diff --git a/configs/model/electra.yml b/configs/model/electra.yml index c3cf2fdf..ade89acd 100644 --- a/configs/model/electra.yml +++ b/configs/model/electra.yml @@ -3,7 +3,7 @@ init_args: optimizer_kwargs: lr: 1e-3 config: - vocab_size: 1400 + vocab_size: 8500 max_position_embeddings: 1800 num_attention_heads: 8 num_hidden_layers: 6 From 66732a7cf5e9e8f0f2f338848a00333bb0375ec4 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 9 Dec 2024 21:53:06 +0100 Subject: [PATCH 28/71] fix: print right name of missing file --- chebai/preprocessing/datasets/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index f382f050..fc64c808 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -812,18 +812,18 @@ def setup_processed(self) -> None: None """ os.makedirs(self.processed_dir, exist_ok=True) - processed_main_file_name = self.processed_main_file_names_dict["data"] + transformed_file_name = self.processed_file_names_dict["data"] print( - f"Missing transformed data (`{processed_main_file_name}` file). Transforming data.... " + f"Missing transformed data (`{transformed_file_name}` file). Transforming data.... " ) torch.save( self._load_data_from_file( os.path.join( self.processed_dir_main, - processed_main_file_name, + self.processed_main_file_names_dict["data"], ) ), - os.path.join(self.processed_dir, self.processed_file_names_dict["data"]), + os.path.join(self.processed_dir, transformed_file_name), ) @staticmethod From e7b3d800da1f3ae2aeb17e9202f1a2d45e1a5083 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 9 Dec 2024 21:56:35 +0100 Subject: [PATCH 29/71] migration : add esm2 embeddings - modify deepgo2 migration script to migrate the esm2 embeddings too - modify migration class to use esm2 embeddings or reader features, based on input --- .../datasets/deepGO/go_uniprot.py | 95 ++++++++++++++++++- .../deep_go/migrate_deep_go_2_data.py | 2 + chebai/preprocessing/reader.py | 3 +- 3 files changed, 97 insertions(+), 3 deletions(-) diff --git a/chebai/preprocessing/datasets/deepGO/go_uniprot.py b/chebai/preprocessing/datasets/deepGO/go_uniprot.py index 22d13e3f..3c957e6c 100644 --- a/chebai/preprocessing/datasets/deepGO/go_uniprot.py +++ b/chebai/preprocessing/datasets/deepGO/go_uniprot.py @@ -40,6 +40,7 @@ import pandas as pd import requests import torch +import tqdm from Bio import SwissProt from chebai.preprocessing import reader as dr @@ -892,12 +893,95 @@ class DeepGO2MigratedData(_DeepGOMigratedData): dict: Dictionary with file names specific to DeepGO2. """ - def __init__(self, **kwargs): + _LABELS_START_IDX: int = 5 # additional esm2_embeddings column in the dataframe + _ESM_EMBEDDINGS_COL_IDX: int = 4 + + def __init__(self, use_esm2_embeddings=False, **kwargs): # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 assert int(kwargs.get("max_sequence_length")) == 1000 - + self.use_esm2_embeddings: bool = use_esm2_embeddings super(_DeepGOMigratedData, self).__init__(**kwargs) + # ------------------------------ Phase: Setup data ----------------------------------- + def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: + """ + Load and process data from a file into a list of dictionaries containing features and labels. + + This method processes data differently based on the `use_esm2_embeddings` flag: + - If `use_esm2_embeddings` is True, raw dictionaries from `_load_dict` are returned, _load_dict already returns + the numerical features (esm2 embeddings) from the data file, hence no reader is required. + - Otherwise, a reader is used to process the data (generate numerical features). + + Args: + path (str): The path to the input file. + + Returns: + List[Dict[str, Any]]: A list of dictionaries with the following keys: + - `features`: Sequence or embedding data, depending on the context. + - `labels`: A boolean array of labels. + - `ident`: The identifier for the sequence. + """ + lines = self._get_data_size(path) + print(f"Processing {lines} lines...") + + if self.use_esm2_embeddings: + data = [ + d + for d in tqdm.tqdm(self._load_dict(path), total=lines) + if d["features"] is not None + ] + else: + data = [ + self.reader.to_data(d) + for d in tqdm.tqdm(self._load_dict(path), total=lines) + if d["features"] is not None + ] + + # filter for missing features in resulting data + data = [val for val in data if val["features"] is not None] + + return data + + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: + """ + Loads data from a pickled file and yields individual dictionaries for each row. + + The pickled file is expected to contain rows with the following structure: + - Data at row index `self._ID_IDX`: ID of go data instance + - Data at row index `self._DATA_REPRESENTATION_IDX`: Sequence representation of protein + - Data at row index `self._ESM2_EMBEDDINGS_COL_IDX`: ESM2 embeddings of the protein + - Data from row index `self._LABELS_START_IDX` onwards: Labels + + The method adapts based on the `use_esm2_embeddings` flag: + - If `use_esm2_embeddings` is True, features are loaded from the column specified by `self._ESM_EMBEDDINGS_COL_IDX`. + - Otherwise, features are loaded from the column specified by `self._DATA_REPRESENTATION_IDX`. + + Args: + input_file_path (str): The path to the pickled input file. + + Yields: + Dict[str, Any]: A dictionary containing: + - `features` (Any): Sequence or embedding data for the instance. + - `labels` (np.ndarray): A boolean array of labels starting from row index 4. + - `ident` (Any): The identifier from row index 0. + """ + with open(input_file_path, "rb") as input_file: + df = pd.read_pickle(input_file) + + if self.use_esm2_embeddings: + features_idx = self._ESM_EMBEDDINGS_COL_IDX + else: + features_idx = self._DATA_REPRESENTATION_IDX + + for row in df.values: + labels = row[self._LABELS_START_IDX :].astype(bool) + yield dict( + features=row[features_idx], + labels=labels, + ident=row[self._ID_IDX], + ) + + # ------------------------------ Phase: Raw Properties ----------------------------------- @property def processed_main_file_names_dict(self) -> Dict[str, str]: """ @@ -917,3 +1001,10 @@ def processed_file_names_dict(self) -> Dict[str, str]: dict: Dictionary with data file name for DeepGO2. """ return {"data": "data_deep_go2.pt"} + + @property + def identifier(self) -> tuple: + """Identifier for the dataset.""" + if self.use_esm2_embeddings: + return (dr.ESM2EmbeddingReader.name(),) + return (self.reader.name(),) diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py index 0bb07914..68d7dc78 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py @@ -217,6 +217,7 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame: "exp_annotations", # Directly associated GO ids # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L60-L69 "prop_annotations", # Transitively associated GO ids + "esm2", ] new_df = pd.concat( @@ -239,6 +240,7 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame: accession=new_df["accessions"], go_ids=new_df["go_ids"], sequence=new_df["sequences"], + esm2_embeddings=new_df["esm2"], ) ) return data_df diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index dff2ff51..88e4fedd 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -616,7 +616,8 @@ def load_hub_workaround(self, url) -> torch.Tensor: ) return data - def name(self) -> None: + @staticmethod + def name() -> None: """ Returns the name of the data reader. This method identifies the specific type of data reader. From 862c8ef5743f3711c92c9b922cd269203940936e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 5 Jan 2025 17:07:27 +0100 Subject: [PATCH 30/71] scope dataset: add scope abstract code --- .../preprocessing/datasets/scope/__init__.py | 0 chebai/preprocessing/datasets/scope/scope.py | 381 ++++++++++++++++++ 2 files changed, 381 insertions(+) create mode 100644 chebai/preprocessing/datasets/scope/__init__.py create mode 100644 chebai/preprocessing/datasets/scope/scope.py diff --git a/chebai/preprocessing/datasets/scope/__init__.py b/chebai/preprocessing/datasets/scope/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py new file mode 100644 index 00000000..a987f53d --- /dev/null +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -0,0 +1,381 @@ +import gzip +import itertools +import os +import pickle +import shutil +from abc import ABC +from collections import OrderedDict +from tempfile import NamedTemporaryFile +from typing import Any, Dict, Generator, List, Optional, Tuple, Union + +import fastobo +import networkx as nx +import pandas as pd +import requests +import torch +from Bio import SeqIO +from Bio.Seq import Seq + +from chebai.preprocessing.datasets.base import _DynamicDataset +from chebai.preprocessing.reader import ProteinDataReader + + +class _SCOPeDataExtractor(_DynamicDataset, ABC): + """ + A class for extracting and processing data from the Gene Ontology (GO) dataset and the Swiss UniProt dataset. + + Args: + dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. + splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. + max_sequence_length (int, optional): Specifies the maximum allowed sequence length for a protein, with a + default of 1002. During data preprocessing, any proteins exceeding this length will be excluded from further + processing. + **kwargs: Additional keyword arguments passed to DynamicDataset and XYBaseDataModule. + """ + + _GO_DATA_INIT = "GO" + _SWISS_DATA_INIT = "SWISS" + + # -- Index for columns of processed `data.pkl` (derived from `_get_swiss_to_go_mapping` & `_graph_to_raw_dataset` + # "swiss_id" at row index 0 + # "accession" at row index 1 + # "go_ids" at row index 2 + # "sequence" at row index 3 + # labels starting from row index 4 + _ID_IDX: int = 0 + _DATA_REPRESENTATION_IDX: int = 3 # here `sequence` column + _LABELS_START_IDX: int = 4 + + _SCOPE_GENERAL_URL = "https://scop.berkeley.edu/downloads/parse/dir.{data_type}.scope.{version_number}-stable.txt" + _PDB_SEQUENCE_DATA_URL = ( + "https://files.rcsb.org/pub/pdb/derived_data/pdb_seqres.txt.gz" + ) + + def __init__( + self, + scope_version: float, + scope_version_train: Optional[float] = None, + **kwargs, + ): + + self.scope_version: float = scope_version + self.scope_version_train: float = scope_version_train + + super(_SCOPeDataExtractor, self).__init__(**kwargs) + + if self.scope_version_train is not None: + # Instantiate another same class with "scope_version" as "scope_version_train", if train_version is given + # This is to get the data from respective directory related to "scope_version_train" + _init_kwargs = kwargs + _init_kwargs["chebi_version"] = self.scope_version_train + self._scope_version_train_obj = self.__class__( + **_init_kwargs, + ) + + @staticmethod + def _get_scope_url(data_type: str, version_number: float) -> str: + """ + Generates the URL for downloading SCOPe files. + + Args: + data_type (str): The type of data (e.g., 'cla', 'hie', 'des'). + version_number (str): The version of the SCOPe file. + + Returns: + str: The formatted SCOPe file URL. + """ + return _SCOPeDataExtractor._SCOPE_GENERAL_URL.format( + data_type=data_type, version_number=version_number + ) + + # ------------------------------ Phase: Prepare data ----------------------------------- + def _download_required_data(self) -> str: + """ + Downloads the required raw data related to Gene Ontology (GO) and Swiss-UniProt dataset. + + Returns: + str: Path to the downloaded data. + """ + self._download_pdb_sequence_data() + return self._download_scope_raw_data() + + def _download_pdb_sequence_data(self) -> None: + pdb_seq_file_path = os.path.join(self.raw_dir, self.raw_file_names_dict["PDB"]) + os.makedirs(os.path.dirname(pdb_seq_file_path), exist_ok=True) + + if not os.path.isfile(pdb_seq_file_path): + print(f"Downloading PDB sequence data....") + + # Create a temporary file + with NamedTemporaryFile(delete=False) as tf: + temp_filename = tf.name + print(f"Downloading to temporary file {temp_filename}") + + # Download the file + response = requests.get(self._PDB_SEQUENCE_DATA_URL, stream=True) + with open(temp_filename, "wb") as temp_file: + shutil.copyfileobj(response.raw, temp_file) + + print(f"Downloaded to {temp_filename}") + + # Unpack the gzipped file + try: + print(f"Unzipping the file....") + with gzip.open(temp_filename, "rb") as f_in: + output_file_path = pdb_seq_file_path + with open(output_file_path, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + print(f"Unpacked and saved to {output_file_path}") + + except Exception as e: + print(f"Failed to unpack the file: {e}") + finally: + # Clean up the temporary file + os.remove(temp_filename) + print(f"Removed temporary file {temp_filename}") + + def _download_scope_raw_data(self) -> str: + os.makedirs(self.raw_dir, exist_ok=True) + for data_type in ["CLA", "COM", "HIE", "DES"]: + data_file_name = self.raw_file_names_dict[data_type] + scope_path = os.path.join(self.raw_dir, data_file_name) + if not os.path.isfile(scope_path): + print(f"Missing Scope: {data_file_name} raw data, Downloading...") + r = requests.get( + self._get_scope_url(data_type.lower(), self.scope_version), + allow_redirects=False, + verify=False, # Disable SSL verification + ) + r.raise_for_status() # Check if the request was successful + open(scope_path, "wb").write(r.content) + return "dummy/path" + + def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]: + pdb_chain_seq_mapping: Dict[str, Dict[str, str]] = {} + for record in SeqIO.parse( + os.path.join(self.raw_dir, self.raw_file_names_dict["PDB"]), "fasta" + ): + pdb_id, chain = record.id.split("_") + pdb_chain_seq_mapping.setdefault(pdb_id, {})[chain] = str(record.seq) + return pdb_chain_seq_mapping + + def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + print("Extracting class hierarchy...") + + # Load and preprocess CLA file + df_cla = pd.read_csv( + os.path.join(self.raw_dir, self.raw_file_names_dict["CLA"]), + sep="\t", + header=None, + comment="#", + ) + df_cla.columns = [ + "sid", + "PDB_ID", + "description", + "sccs", + "sunid", + "ancestor_nodes", + ] + df_cla["sunid"] = pd.to_numeric( + df_cla["sunid"], errors="coerce", downcast="integer" + ) + df_cla["ancestor_nodes"] = df_cla["ancestor_nodes"].apply( + lambda x: {k: int(v) for k, v in (item.split("=") for item in x.split(","))} + ) + df_cla.set_index("sunid", inplace=True) + + # Load and preprocess HIE file + df_hie = pd.read_csv( + os.path.join(self.raw_dir, self.raw_file_names_dict["HIE"]), + sep="\t", + header=None, + comment="#", + ) + df_hie.columns = ["sunid", "parent_sunid", "children_sunids"] + df_hie["sunid"] = pd.to_numeric( + df_hie["sunid"], errors="coerce", downcast="integer" + ) + df_hie["parent_sunid"] = df_hie["parent_sunid"].replace("-", -1).astype(int) + df_hie["children_sunids"] = df_hie["children_sunids"].apply( + lambda x: list(map(int, x.split(","))) if x != "-" else [] + ) + + # Initialize directed graph + g = nx.DiGraph() + + # Add nodes and edges efficiently + g.add_edges_from( + df_hie[df_hie["parent_sunid"] != -1].apply( + lambda row: (row["parent_sunid"], row["sunid"]), axis=1 + ) + ) + g.add_edges_from( + df_hie.explode("children_sunids") + .dropna() + .apply(lambda row: (row["sunid"], row["children_sunids"]), axis=1) + ) + + pdb_chain_seq_mapping = self._parse_pdb_sequence_file() + + node_to_pdb_id = df_cla["PDB_ID"].to_dict() + + for node in g.nodes(): + pdb_id = node_to_pdb_id[node] + chain_mapping = pdb_chain_seq_mapping.get(pdb_id, {}) + + # Add nodes and edges for chains in the mapping + for chain, sequence in chain_mapping.items(): + chain_node = f"{pdb_id}_{chain}" + g.add_node(chain_node, sequence=sequence) + g.add_edge(node, chain_node) + + print("Compute transitive closure...") + return nx.transitive_closure_dag(g) + + def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: + """ + Processes a directed acyclic graph (DAG) to create a raw dataset in DataFrame format. The dataset includes + Swiss-Prot protein data and their associations with Gene Ontology (GO) terms. + + Note: + - GO classes are used as labels in the dataset. Each GO term is represented as a column, and its value + indicates whether a Swiss-Prot protein is associated with that GO term. + - Swiss-Prot proteins serve as samples. There is no 1-to-1 correspondence between Swiss-Prot proteins + and GO terms. + + Data Format: pd.DataFrame + - Column 0 : swiss_id (Identifier for SwissProt protein) + - Column 1 : Accession of the protein + - Column 2 : GO IDs (associated GO terms) + - Column 3 : Sequence of the protein + - Column 4 to Column "n": Each column corresponding to a class with value True/False indicating whether the + protein is associated with this GO term. + + Args: + g (nx.DiGraph): The class hierarchy graph. + + Returns: + pd.DataFrame: The raw dataset created from the graph. + """ + print(f"Processing graph") + + data_df = self._get_swiss_to_go_mapping() + # add ancestors to go ids + data_df["go_ids"] = data_df["go_ids"].apply( + lambda go_ids: sorted( + set( + itertools.chain.from_iterable( + [ + [go_id] + list(g.predecessors(go_id)) + for go_id in go_ids + if go_id in g.nodes + ] + ) + ) + ) + ) + # Initialize the GO term labels/columns to False + selected_classes = self.select_classes(g, data_df=data_df) + new_label_columns = pd.DataFrame( + False, index=data_df.index, columns=selected_classes + ) + data_df = pd.concat([data_df, new_label_columns], axis=1) + + # Set True for the corresponding GO IDs in the DataFrame go labels/columns + for index, row in data_df.iterrows(): + for go_id in row["go_ids"]: + if go_id in data_df.columns: + data_df.at[index, go_id] = True + + # This filters the DataFrame to include only the rows where at least one value in the row from 5th column + # onwards is True/non-zero. + # Quote from DeepGo Paper: `For training and testing, we use proteins which have been annotated with at least + # one GO term from the set of the GO terms for the model` + data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)] + return data_df + + # ------------------------------ Phase: Setup data ----------------------------------- + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: + with open(input_file_path, "rb") as input_file: + df = pd.read_pickle(input_file) + for row in df.values: + labels = row[self._LABELS_START_IDX :].astype(bool) + # chebai.preprocessing.reader.DataReader only needs features, labels, ident, group + # "group" set to None, by default as no such entity for this data + yield dict( + features=row[self._DATA_REPRESENTATION_IDX], + labels=labels, + ident=row[self._ID_IDX], + ) + + # ------------------------------ Phase: Dynamic Splits ----------------------------------- + def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + try: + filename = self.processed_file_names_dict["data"] + data_go = torch.load( + os.path.join(self.processed_dir, filename), weights_only=False + ) + except FileNotFoundError: + raise FileNotFoundError( + f"File data.pt doesn't exists. " + f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" + ) + + df_go_data = pd.DataFrame(data_go) + train_df_go, df_test = self.get_test_split( + df_go_data, seed=self.dynamic_data_split_seed + ) + + # Get all splits + df_train, df_val = self.get_train_val_splits_given_test( + train_df_go, + df_test, + seed=self.dynamic_data_split_seed, + ) + + return df_train, df_val, df_test + + # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + def base_dir(self) -> str: + """ + Returns the base directory path for storing GO-Uniprot data. + + Returns: + str: The path to the base directory, which is "data/GO_UniProt". + """ + return os.path.join("data", "SCOPe", f"version_{self.scope_version}") + + @property + def raw_file_names_dict(self) -> dict: + """ + Returns a dictionary of raw file names used in data processing. + + Returns: + dict: A dictionary mapping dataset names to their respective file names. + For example, {"GO": "go-basic.obo", "SwissUniProt": "uniprot_sprot.dat"}. + """ + return { + "CLA": "cla.txt", + "DES": "des.txt", + "HIE": "hie.txt", + "COM": "com.txt", + "PDB": "pdb_sequences.txt", + } + + +class SCOPE(_SCOPeDataExtractor): + READER = ProteinDataReader + + @property + def _name(self) -> str: + return "test" + + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + pass + + +if __name__ == "__main__": + scope = SCOPE(scope_version=2.08) + scope._parse_pdb_sequence_file() From 7da8963c169c3e59ac9eb512c65e73313f7370cd Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 5 Jan 2025 17:17:02 +0100 Subject: [PATCH 31/71] base: make _name property abstract method - this will help to identify methods that needs to be implemented during coding and not during runtime --- chebai/preprocessing/datasets/base.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index fc64c808..6158b9dc 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -155,8 +155,19 @@ def fold_dir(self) -> str: return f"cv_{self.inner_k_folds}_fold" @property + @abstractmethod def _name(self) -> str: - raise NotImplementedError + """ + Abstract property representing the name of the data module. + + This property should be implemented in subclasses to provide a unique name for the data module. + The name is used to create subdirectories within the base directory or `processed_dir_main` + for storing relevant data associated with this module. + + Returns: + str: The name of the data module. + """ + pass def _filter_labels(self, row: dict) -> dict: """ From 976f2b895e3ee8fce4a9bcbde6ace30539e7845a Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 10 Jan 2025 13:55:50 +0100 Subject: [PATCH 32/71] add simple Feed-forward network (for ESM2->chebi task) --- chebai/models/ffn.py | 55 ++++++++++++++++++++++++++++ configs/data/deepGO/deepgo2_esm2.yml | 5 +++ configs/model/ffn.yml | 7 ++++ 3 files changed, 67 insertions(+) create mode 100644 chebai/models/ffn.py create mode 100644 configs/data/deepGO/deepgo2_esm2.yml create mode 100644 configs/model/ffn.yml diff --git a/chebai/models/ffn.py b/chebai/models/ffn.py new file mode 100644 index 00000000..77046ae6 --- /dev/null +++ b/chebai/models/ffn.py @@ -0,0 +1,55 @@ +from typing import Dict, Any, Tuple + +from chebai.models import ChebaiBaseNet +import torch +from torch import Tensor + +class FFN(ChebaiBaseNet): + + NAME = "FFN" + + def __init__(self, input_size: int = 1000, num_hidden_layers: int = 3, hidden_size: int = 128, **kwargs): + super().__init__(**kwargs) + + self.layers = torch.nn.ModuleList() + self.layers.append(torch.nn.Linear(input_size, hidden_size)) + for _ in range(num_hidden_layers): + self.layers.append(torch.nn.Linear(hidden_size, hidden_size)) + self.layers.append(torch.nn.Linear(hidden_size, self.out_dim)) + + def _get_prediction_and_labels(self, data, labels, model_output): + d = model_output["logits"] + loss_kwargs = data.get("loss_kwargs", dict()) + if "non_null_labels" in loss_kwargs: + n = loss_kwargs["non_null_labels"] + d = data[n] + return torch.sigmoid(d), labels.int() if labels is not None else None + + def _process_for_loss( + self, + model_output: Dict[str, Tensor], + labels: Tensor, + loss_kwargs: Dict[str, Any], + ) -> Tuple[Tensor, Tensor, Dict[str, Any]]: + """ + Process the model output for calculating the loss. + + Args: + model_output (Dict[str, Tensor]): The output of the model. + labels (Tensor): The target labels. + loss_kwargs (Dict[str, Any]): Additional loss arguments. + + Returns: + tuple: A tuple containing the processed model output, labels, and loss arguments. + """ + kwargs_copy = dict(loss_kwargs) + if labels is not None: + labels = labels.float() + return model_output["logits"], labels, kwargs_copy + + def forward(self, data, **kwargs): + x = data["features"] + for layer in self.layers: + x = torch.relu(layer(x)) + return {"logits": x} + diff --git a/configs/data/deepGO/deepgo2_esm2.yml b/configs/data/deepGO/deepgo2_esm2.yml new file mode 100644 index 00000000..4b3ae3b1 --- /dev/null +++ b/configs/data/deepGO/deepgo2_esm2.yml @@ -0,0 +1,5 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO2MigratedData +init_args: + go_branch: "MF" + max_sequence_length: 1000 + use_esm2_embeddings: True \ No newline at end of file diff --git a/configs/model/ffn.yml b/configs/model/ffn.yml new file mode 100644 index 00000000..193c6f64 --- /dev/null +++ b/configs/model/ffn.yml @@ -0,0 +1,7 @@ +class_path: chebai.models.ffn.FFN +init_args: + optimizer_kwargs: + lr: 1e-3 + hidden_size: 128 + num_hidden_layers: 3 + input_size: 2560 From 3b174875ecdcc981a3c0a245e535d83bdd5811e3 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 10 Jan 2025 14:11:51 +0100 Subject: [PATCH 33/71] reformat using Black --- chebai/models/ffn.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/chebai/models/ffn.py b/chebai/models/ffn.py index 77046ae6..ca1f6f22 100644 --- a/chebai/models/ffn.py +++ b/chebai/models/ffn.py @@ -4,11 +4,18 @@ import torch from torch import Tensor + class FFN(ChebaiBaseNet): NAME = "FFN" - def __init__(self, input_size: int = 1000, num_hidden_layers: int = 3, hidden_size: int = 128, **kwargs): + def __init__( + self, + input_size: int = 1000, + num_hidden_layers: int = 3, + hidden_size: int = 128, + **kwargs + ): super().__init__(**kwargs) self.layers = torch.nn.ModuleList() @@ -52,4 +59,3 @@ def forward(self, data, **kwargs): for layer in self.layers: x = torch.relu(layer(x)) return {"logits": x} - From f4d1d74be2e24a74006de78ed53df6ab5d6cf82b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 22 Jan 2025 13:38:30 +0100 Subject: [PATCH 34/71] scope: data preparation code --- chebai/preprocessing/datasets/scope/scope.py | 295 ++++++++++++------- 1 file changed, 196 insertions(+), 99 deletions(-) diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py index a987f53d..fe41ba51 100644 --- a/chebai/preprocessing/datasets/scope/scope.py +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -51,13 +51,28 @@ class _SCOPeDataExtractor(_DynamicDataset, ABC): "https://files.rcsb.org/pub/pdb/derived_data/pdb_seqres.txt.gz" ) + SCOPE_HIERARCHY: Dict[str, str] = { + "cl": "class", + "cf": "fold", + "sf": "superfamily", + "fa": "family", + "dm": "protein", + "sp": "species", + "px": "domain", + } + def __init__( self, scope_version: float, scope_version_train: Optional[float] = None, + scope_hierarchy_level: str = "cl", **kwargs, ): + assert ( + scope_hierarchy_level in self.SCOPE_HIERARCHY.keys() + ), f"level can contain only one of the following values {self.SCOPE_HIERARCHY.keys()}" + self.scope_hierarchy_level = scope_hierarchy_level self.scope_version: float = scope_version self.scope_version_train: float = scope_version_train @@ -67,7 +82,8 @@ def __init__( # Instantiate another same class with "scope_version" as "scope_version_train", if train_version is given # This is to get the data from respective directory related to "scope_version_train" _init_kwargs = kwargs - _init_kwargs["chebi_version"] = self.scope_version_train + _init_kwargs["scope_version"] = self.scope_version_train + _init_kwargs["scope_hierarchy_level"] = self.scope_hierarchy_level self._scope_version_train_obj = self.__class__( **_init_kwargs, ) @@ -150,18 +166,40 @@ def _download_scope_raw_data(self) -> str: open(scope_path, "wb").write(r.content) return "dummy/path" - def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]: - pdb_chain_seq_mapping: Dict[str, Dict[str, str]] = {} - for record in SeqIO.parse( - os.path.join(self.raw_dir, self.raw_file_names_dict["PDB"]), "fasta" - ): - pdb_id, chain = record.id.split("_") - pdb_chain_seq_mapping.setdefault(pdb_id, {})[chain] = str(record.seq) - return pdb_chain_seq_mapping - def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: print("Extracting class hierarchy...") + df_scope = self._get_scope_data() + + g = nx.DiGraph() + + egdes = [] + for _, row in df_scope.iterrows(): + g.add_node(row["sunid"], **{"sid": row["sid"], "level": row["level"]}) + if row["parent_sunid"] != -1: + egdes.append((row["parent_sunid"], row["sunid"])) + + for children_id in row["children_sunids"]: + egdes.append((row["sunid"], children_id)) + + g.add_edges_from(egdes) + + print("Computing transitive closure") + return nx.transitive_closure_dag(g) + + def _get_scope_data(self) -> pd.DataFrame: + df_cla = self._get_classification_data() + df_hie = self._get_hierarchy_data() + df_des = self._get_node_description_data() + df_hie_with_cla = pd.merge(df_hie, df_cla, how="left", on="sunid") + df_all = pd.merge( + df_hie_with_cla, + df_des.drop(columns=["sid"], axis=1), + how="left", + on="sunid", + ) + return df_all + def _get_classification_data(self) -> pd.DataFrame: # Load and preprocess CLA file df_cla = pd.read_csv( os.path.join(self.raw_dir, self.raw_file_names_dict["CLA"]), @@ -175,125 +213,166 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: "description", "sccs", "sunid", - "ancestor_nodes", + "hie_levels", ] - df_cla["sunid"] = pd.to_numeric( - df_cla["sunid"], errors="coerce", downcast="integer" - ) - df_cla["ancestor_nodes"] = df_cla["ancestor_nodes"].apply( + + # Convert to dict - {cl:46456, cf:46457, sf:46458, fa:46459, dm:46460, sp:116748, px:113449} + df_cla["hie_levels"] = df_cla["hie_levels"].apply( lambda x: {k: int(v) for k, v in (item.split("=") for item in x.split(","))} ) - df_cla.set_index("sunid", inplace=True) + # Split ancestor_nodes into separate columns and assign values + for key in self.SCOPE_HIERARCHY.keys(): + df_cla[self.SCOPE_HIERARCHY[key]] = df_cla["hie_levels"].apply( + lambda x: x[key] + ) + + df_cla["sunid"] = df_cla["sunid"].astype("int64") + + return df_cla + + def _get_hierarchy_data(self) -> pd.DataFrame: # Load and preprocess HIE file df_hie = pd.read_csv( os.path.join(self.raw_dir, self.raw_file_names_dict["HIE"]), sep="\t", header=None, comment="#", + low_memory=False, ) df_hie.columns = ["sunid", "parent_sunid", "children_sunids"] - df_hie["sunid"] = pd.to_numeric( - df_hie["sunid"], errors="coerce", downcast="integer" - ) + + # if not parent id, then insert -1 df_hie["parent_sunid"] = df_hie["parent_sunid"].replace("-", -1).astype(int) + # convert children ids to list of ids df_hie["children_sunids"] = df_hie["children_sunids"].apply( lambda x: list(map(int, x.split(","))) if x != "-" else [] ) - # Initialize directed graph - g = nx.DiGraph() + # Ensure the 'sunid' column in both DataFrames has the same type + df_hie["sunid"] = df_hie["sunid"].astype("int64") + return df_hie - # Add nodes and edges efficiently - g.add_edges_from( - df_hie[df_hie["parent_sunid"] != -1].apply( - lambda row: (row["parent_sunid"], row["sunid"]), axis=1 - ) - ) - g.add_edges_from( - df_hie.explode("children_sunids") - .dropna() - .apply(lambda row: (row["sunid"], row["children_sunids"]), axis=1) + def _get_node_description_data(self): + # Load and preprocess HIE file + df_des = pd.read_csv( + os.path.join(self.raw_dir, self.raw_file_names_dict["DES"]), + sep="\t", + header=None, + comment="#", + low_memory=False, ) + df_des.columns = ["sunid", "level", "scss", "sid", "description"] + df_des.loc[len(df_des)] = {"sunid": 0, "level": "root"} - pdb_chain_seq_mapping = self._parse_pdb_sequence_file() + # Ensure the 'sunid' column in both DataFrames has the same type + df_des["sunid"] = df_des["sunid"].astype("int64") + return df_des - node_to_pdb_id = df_cla["PDB_ID"].to_dict() + def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: + print(f"Process graph") - for node in g.nodes(): - pdb_id = node_to_pdb_id[node] - chain_mapping = pdb_chain_seq_mapping.get(pdb_id, {}) + sids = nx.get_node_attributes(graph, "sid") + levels = nx.get_node_attributes(graph, "level") - # Add nodes and edges for chains in the mapping - for chain, sequence in chain_mapping.items(): - chain_node = f"{pdb_id}_{chain}" - g.add_node(chain_node, sequence=sequence) - g.add_edge(node, chain_node) + sun_ids = [] + sids_list = [] - print("Compute transitive closure...") - return nx.transitive_closure_dag(g) + selected_sids_dict = self.select_classes(graph) - def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: - """ - Processes a directed acyclic graph (DAG) to create a raw dataset in DataFrame format. The dataset includes - Swiss-Prot protein data and their associations with Gene Ontology (GO) terms. - - Note: - - GO classes are used as labels in the dataset. Each GO term is represented as a column, and its value - indicates whether a Swiss-Prot protein is associated with that GO term. - - Swiss-Prot proteins serve as samples. There is no 1-to-1 correspondence between Swiss-Prot proteins - and GO terms. - - Data Format: pd.DataFrame - - Column 0 : swiss_id (Identifier for SwissProt protein) - - Column 1 : Accession of the protein - - Column 2 : GO IDs (associated GO terms) - - Column 3 : Sequence of the protein - - Column 4 to Column "n": Each column corresponding to a class with value True/False indicating whether the - protein is associated with this GO term. + for sun_id, level in levels.items(): + if level == self.scope_hierarchy_level and sun_id in selected_sids_dict: + sun_ids.append(sun_id) + sids_list.append(sids.get(sun_id)) - Args: - g (nx.DiGraph): The class hierarchy graph. + # data_df = pd.DataFrame(OrderedDict(sun_id=sun_ids, sids=sids_list)) + df_cla = self._get_classification_data() + target_col_name = self.SCOPE_HIERARCHY[self.scope_hierarchy_level] + df_cla = df_cla[df_cla[target_col_name].isin(sun_ids)] + df_cla = df_cla[["sid", target_col_name]] - Returns: - pd.DataFrame: The raw dataset created from the graph. - """ - print(f"Processing graph") - - data_df = self._get_swiss_to_go_mapping() - # add ancestors to go ids - data_df["go_ids"] = data_df["go_ids"].apply( - lambda go_ids: sorted( - set( - itertools.chain.from_iterable( - [ - [go_id] + list(g.predecessors(go_id)) - for go_id in go_ids - if go_id in g.nodes - ] - ) - ) - ) + assert ( + len(df_cla) > 1 + ), "dataframe should have more than one instance for `pd.get_dummies` to work as expected" + df_encoded = pd.get_dummies( + df_cla, columns=[target_col_name], drop_first=False, sparse=True ) - # Initialize the GO term labels/columns to False - selected_classes = self.select_classes(g, data_df=data_df) - new_label_columns = pd.DataFrame( - False, index=data_df.index, columns=selected_classes + + pdb_chain_seq_mapping = self._parse_pdb_sequence_file() + + sequence_hierarchy_df = pd.DataFrame( + columns=list(df_encoded.columns) + ["sids"] ) - data_df = pd.concat([data_df, new_label_columns], axis=1) - # Set True for the corresponding GO IDs in the DataFrame go labels/columns - for index, row in data_df.iterrows(): - for go_id in row["go_ids"]: - if go_id in data_df.columns: - data_df.at[index, go_id] = True + for _, row in df_encoded.iterrows(): + assert sum(row.iloc[1:].tolist()) == 1 + sid = row["sid"] + # SID: 7-char identifier ("d" + 4-char PDB ID + chain ID ('_' for none, '.' for multiple) + # + domain specifier ('_' if not needed)) + assert len(sid) == 7, "sid should have 7 characters" + pdb_id, chain_id = sid[1:5], sid[5] + + pdb_to_chain_mapping = pdb_chain_seq_mapping.get(pdb_id, None) + if not pdb_to_chain_mapping: + continue + + if chain_id != "_": + chain_sequence = pdb_to_chain_mapping.get(chain_id, None) + if chain_sequence: + self._update_or_add_sequence( + chain_sequence, row, sequence_hierarchy_df + ) + + else: + # Add nodes and edges for chains in the mapping + for chain, chain_sequence in pdb_to_chain_mapping.items(): + self._update_or_add_sequence( + chain_sequence, row, sequence_hierarchy_df + ) + + sequence_hierarchy_df.drop(columns=["sid"], axis=1, inplace=True) + sequence_hierarchy_df = sequence_hierarchy_df[ + ["sids"] + [col for col in sequence_hierarchy_df.columns if col != "sids"] + ] # This filters the DataFrame to include only the rows where at least one value in the row from 5th column # onwards is True/non-zero. - # Quote from DeepGo Paper: `For training and testing, we use proteins which have been annotated with at least - # one GO term from the set of the GO terms for the model` - data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)] - return data_df + sequence_hierarchy_df = sequence_hierarchy_df[ + sequence_hierarchy_df.iloc[:, self._LABELS_START_IDX :].any(axis=1) + ] + return sequence_hierarchy_df + + def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]: + pdb_chain_seq_mapping: Dict[str, Dict[str, str]] = {} + for record in SeqIO.parse( + os.path.join(self.raw_dir, self.raw_file_names_dict["PDB"]), "fasta" + ): + pdb_id, chain = record.id.split("_") + if str(record.seq): + pdb_chain_seq_mapping.setdefault(pdb_id.lower(), {})[chain.lower()] = ( + str(record.seq) + ) + return pdb_chain_seq_mapping + + @staticmethod + def _update_or_add_sequence(sequence, row, sequence_hierarchy_df): + # Check if sequence already exists as an index + # Slice the series starting from column 2 + sliced_data = row.iloc[1:] # Slice starting from the second column (index 1) + + # Get the column name with the True value + true_column = sliced_data.idxmax() if sliced_data.any() else None + + if sequence in sequence_hierarchy_df.index: + # Update encoded columns only if they are True + if row[true_column] is True: + sequence_hierarchy_df.loc[sequence, true_column] = True + sequence_hierarchy_df.loc[sequence, "sids"].append(row["sid"]) + else: + # Add new row with sequence as the index and hierarchy data + new_row = row + new_row["sids"] = [row["sid"]] + sequence_hierarchy_df.loc[sequence] = new_row # ------------------------------ Phase: Setup data ----------------------------------- def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: @@ -367,15 +446,33 @@ def raw_file_names_dict(self) -> dict: class SCOPE(_SCOPeDataExtractor): READER = ProteinDataReader + THRESHOLD = 1 @property def _name(self) -> str: return "test" - def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: - pass + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict: + # Filter nodes and create a dictionary of node and out-degree + sun_ids_dict = { + node: g.out_degree(node) # Store node and its out-degree + for node in g.nodes + if g.out_degree(node) >= self.THRESHOLD + } + + # Return a sorted dictionary (by out-degree or node id) + sorted_dict = dict( + sorted(sun_ids_dict.items(), key=lambda item: item[0], reverse=False) + ) + + filename = "classes.txt" + with open(os.path.join(self.processed_dir_main, filename), "wt") as fout: + fout.writelines(str(sun_id) + "\n" for sun_id in sorted_dict.keys()) + + return sorted_dict if __name__ == "__main__": scope = SCOPE(scope_version=2.08) - scope._parse_pdb_sequence_file() + g = scope._extract_class_hierarchy("d") + scope._graph_to_raw_dataset(g) From 431da47fe679ff852cf5440c5d5069e777552883 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 23 Jan 2025 17:35:10 +0100 Subject: [PATCH 35/71] scope: include all levels --- chebai/preprocessing/datasets/scope/scope.py | 85 +++++++++++--------- 1 file changed, 48 insertions(+), 37 deletions(-) diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py index fe41ba51..e935773a 100644 --- a/chebai/preprocessing/datasets/scope/scope.py +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -33,9 +33,6 @@ class _SCOPeDataExtractor(_DynamicDataset, ABC): **kwargs: Additional keyword arguments passed to DynamicDataset and XYBaseDataModule. """ - _GO_DATA_INIT = "GO" - _SWISS_DATA_INIT = "SWISS" - # -- Index for columns of processed `data.pkl` (derived from `_get_swiss_to_go_mapping` & `_graph_to_raw_dataset` # "swiss_id" at row index 0 # "accession" at row index 1 @@ -43,8 +40,8 @@ class _SCOPeDataExtractor(_DynamicDataset, ABC): # "sequence" at row index 3 # labels starting from row index 4 _ID_IDX: int = 0 - _DATA_REPRESENTATION_IDX: int = 3 # here `sequence` column - _LABELS_START_IDX: int = 4 + _DATA_REPRESENTATION_IDX: int = 2 # here `sequence` column + _LABELS_START_IDX: int = 3 _SCOPE_GENERAL_URL = "https://scop.berkeley.edu/downloads/parse/dir.{data_type}.scope.{version_number}-stable.txt" _PDB_SEQUENCE_DATA_URL = ( @@ -65,14 +62,8 @@ def __init__( self, scope_version: float, scope_version_train: Optional[float] = None, - scope_hierarchy_level: str = "cl", **kwargs, ): - - assert ( - scope_hierarchy_level in self.SCOPE_HIERARCHY.keys() - ), f"level can contain only one of the following values {self.SCOPE_HIERARCHY.keys()}" - self.scope_hierarchy_level = scope_hierarchy_level self.scope_version: float = scope_version self.scope_version_train: float = scope_version_train @@ -83,7 +74,6 @@ def __init__( # This is to get the data from respective directory related to "scope_version_train" _init_kwargs = kwargs _init_kwargs["scope_version"] = self.scope_version_train - _init_kwargs["scope_hierarchy_level"] = self.scope_hierarchy_level self._scope_version_train_obj = self.__class__( **_init_kwargs, ) @@ -275,37 +265,50 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: sids = nx.get_node_attributes(graph, "sid") levels = nx.get_node_attributes(graph, "level") - sun_ids = [] + sun_ids = {} sids_list = [] selected_sids_dict = self.select_classes(graph) for sun_id, level in levels.items(): - if level == self.scope_hierarchy_level and sun_id in selected_sids_dict: - sun_ids.append(sun_id) + if sun_id in selected_sids_dict: + sun_ids.setdefault(level, []).append(sun_id) sids_list.append(sids.get(sun_id)) + # Remove root node, as it will True for all instances + sun_ids.pop("root", None) + # data_df = pd.DataFrame(OrderedDict(sun_id=sun_ids, sids=sids_list)) df_cla = self._get_classification_data() - target_col_name = self.SCOPE_HIERARCHY[self.scope_hierarchy_level] - df_cla = df_cla[df_cla[target_col_name].isin(sun_ids)] - df_cla = df_cla[["sid", target_col_name]] + + for level, selected_sun_ids in sun_ids.items(): + df_cla = df_cla[df_cla[self.SCOPE_HIERARCHY[level]].isin(selected_sun_ids)] assert ( len(df_cla) > 1 ), "dataframe should have more than one instance for `pd.get_dummies` to work as expected" df_encoded = pd.get_dummies( - df_cla, columns=[target_col_name], drop_first=False, sparse=True + df_cla, + columns=list(self.SCOPE_HIERARCHY.values()), + drop_first=False, + sparse=True, ) pdb_chain_seq_mapping = self._parse_pdb_sequence_file() - sequence_hierarchy_df = pd.DataFrame( - columns=list(df_encoded.columns) + ["sids"] - ) + encoded_target_cols = {} + for col in self.SCOPE_HIERARCHY.values(): + encoded_target_cols[col] = [ + t_col for t_col in df_encoded.columns if t_col.startswith(col) + ] + + encoded_target_columns = [] + for level in self.SCOPE_HIERARCHY.values(): + encoded_target_columns.extend(encoded_target_cols[level]) + + sequence_hierarchy_df = pd.DataFrame(columns=["sids"] + encoded_target_columns) for _, row in df_encoded.iterrows(): - assert sum(row.iloc[1:].tolist()) == 1 sid = row["sid"] # SID: 7-char identifier ("d" + 4-char PDB ID + chain ID ('_' for none, '.' for multiple) # + domain specifier ('_' if not needed)) @@ -320,19 +323,22 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: chain_sequence = pdb_to_chain_mapping.get(chain_id, None) if chain_sequence: self._update_or_add_sequence( - chain_sequence, row, sequence_hierarchy_df + chain_sequence, row, sequence_hierarchy_df, encoded_target_cols ) else: # Add nodes and edges for chains in the mapping for chain, chain_sequence in pdb_to_chain_mapping.items(): self._update_or_add_sequence( - chain_sequence, row, sequence_hierarchy_df + chain_sequence, row, sequence_hierarchy_df, encoded_target_cols ) sequence_hierarchy_df.drop(columns=["sid"], axis=1, inplace=True) + sequence_hierarchy_df.reset_index(inplace=True) + sequence_hierarchy_df["id"] = range(1, len(sequence_hierarchy_df) + 1) + sequence_hierarchy_df = sequence_hierarchy_df[ - ["sids"] + [col for col in sequence_hierarchy_df.columns if col != "sids"] + ["id", "sids", "sequence"] + encoded_target_columns ] # This filters the DataFrame to include only the rows where at least one value in the row from 5th column @@ -355,19 +361,24 @@ def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]: return pdb_chain_seq_mapping @staticmethod - def _update_or_add_sequence(sequence, row, sequence_hierarchy_df): - # Check if sequence already exists as an index - # Slice the series starting from column 2 - sliced_data = row.iloc[1:] # Slice starting from the second column (index 1) - - # Get the column name with the True value - true_column = sliced_data.idxmax() if sliced_data.any() else None - + def _update_or_add_sequence( + sequence, row, sequence_hierarchy_df, encoded_col_names + ): if sequence in sequence_hierarchy_df.index: # Update encoded columns only if they are True - if row[true_column] is True: + for col in encoded_col_names: + assert ( + sum(row[encoded_col_names[col]].tolist()) == 1 + ), "A instance can belong to only one hierarchy level" + sliced_data = row[ + encoded_col_names[col] + ] # Slice starting from the second column (index 1) + # Get the column name with the True value + true_column = sliced_data.idxmax() if sliced_data.any() else None sequence_hierarchy_df.loc[sequence, true_column] = True - sequence_hierarchy_df.loc[sequence, "sids"].append(row["sid"]) + + sequence_hierarchy_df.loc[sequence, "sids"].append(row["sid"]) + else: # Add new row with sequence as the index and hierarchy data new_row = row @@ -446,7 +457,7 @@ def raw_file_names_dict(self) -> dict: class SCOPE(_SCOPeDataExtractor): READER = ProteinDataReader - THRESHOLD = 1 + THRESHOLD = 10000 @property def _name(self) -> str: From 43d4550786ddddd51cad16809fc7328fa00a985c Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 24 Jan 2025 13:43:56 +0100 Subject: [PATCH 36/71] scope: remove domain level from one hot encoding --- chebai/preprocessing/datasets/scope/scope.py | 23 +++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py index e935773a..993fa501 100644 --- a/chebai/preprocessing/datasets/scope/scope.py +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -279,17 +279,28 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: sun_ids.pop("root", None) # data_df = pd.DataFrame(OrderedDict(sun_id=sun_ids, sids=sids_list)) + if not sun_ids: + raise RuntimeError("No sunid selected.") + df_cla = self._get_classification_data() + hierarchy_levels = list(self.SCOPE_HIERARCHY.values()) + hierarchy_levels.remove("domain") + + df_cla = df_cla[["sid", "sunid"] + hierarchy_levels] for level, selected_sun_ids in sun_ids.items(): - df_cla = df_cla[df_cla[self.SCOPE_HIERARCHY[level]].isin(selected_sun_ids)] + if selected_sun_ids: + df_cla = df_cla[ + df_cla[self.SCOPE_HIERARCHY[level]].isin(selected_sun_ids) + ] assert ( len(df_cla) > 1 ), "dataframe should have more than one instance for `pd.get_dummies` to work as expected" + df_encoded = pd.get_dummies( df_cla, - columns=list(self.SCOPE_HIERARCHY.values()), + columns=hierarchy_levels, drop_first=False, sparse=True, ) @@ -297,13 +308,13 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: pdb_chain_seq_mapping = self._parse_pdb_sequence_file() encoded_target_cols = {} - for col in self.SCOPE_HIERARCHY.values(): + for col in hierarchy_levels: encoded_target_cols[col] = [ t_col for t_col in df_encoded.columns if t_col.startswith(col) ] encoded_target_columns = [] - for level in self.SCOPE_HIERARCHY.values(): + for level in hierarchy_levels: encoded_target_columns.extend(encoded_target_cols[level]) sequence_hierarchy_df = pd.DataFrame(columns=["sids"] + encoded_target_columns) @@ -333,8 +344,8 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: chain_sequence, row, sequence_hierarchy_df, encoded_target_cols ) - sequence_hierarchy_df.drop(columns=["sid"], axis=1, inplace=True) sequence_hierarchy_df.reset_index(inplace=True) + sequence_hierarchy_df.rename(columns={"index": "sequence"}, inplace=True) sequence_hierarchy_df["id"] = range(1, len(sequence_hierarchy_df) + 1) sequence_hierarchy_df = sequence_hierarchy_df[ @@ -457,7 +468,7 @@ def raw_file_names_dict(self) -> dict: class SCOPE(_SCOPeDataExtractor): READER = ProteinDataReader - THRESHOLD = 10000 + THRESHOLD = 2143 @property def _name(self) -> str: From 6735e4185b6df2f1c843f8a1eba938334919fab3 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 24 Jan 2025 15:21:23 +0100 Subject: [PATCH 37/71] scope: add documentation --- chebai/preprocessing/datasets/scope/scope.py | 331 ++++++++++++++++--- 1 file changed, 291 insertions(+), 40 deletions(-) diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py index 993fa501..3ac8e6aa 100644 --- a/chebai/preprocessing/datasets/scope/scope.py +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -1,20 +1,27 @@ +# References for this file : + +# Reference 1: +# John-Marc Chandonia, Naomi K Fox, Steven E Brenner, SCOPe: classification of large macromolecular structures +# in the structural classification of proteins—extended database, Nucleic Acids Research, Volume 47, +# Issue D1, 08 January 2019, Pages D475–D481, https://doi.org/10.1093/nar/gky1134 +# https://scop.berkeley.edu/about/ver=2.08 + +# Reference 2: +# Murzin AG, Brenner SE, Hubbard TJP, Chothia C. 1995. SCOP: a structural classification of proteins database for +# the investigation of sequences and structures. Journal of Molecular Biology 247:536-540 + import gzip -import itertools import os -import pickle import shutil from abc import ABC -from collections import OrderedDict from tempfile import NamedTemporaryFile -from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, Optional, Tuple -import fastobo import networkx as nx import pandas as pd import requests import torch from Bio import SeqIO -from Bio.Seq import Seq from chebai.preprocessing.datasets.base import _DynamicDataset from chebai.preprocessing.reader import ProteinDataReader @@ -22,23 +29,25 @@ class _SCOPeDataExtractor(_DynamicDataset, ABC): """ - A class for extracting and processing data from the Gene Ontology (GO) dataset and the Swiss UniProt dataset. + A class for extracting and processing data from the SCOPe (Structural Classification of Proteins - extended) dataset. + + This class is designed to handle the parsing, preprocessing, and hierarchical structure extraction from various + SCOPe dataset files, such as classification (CLA), hierarchy (HIE), and description (DES) files. + Additionally, it supports downloading related data like PDB sequence files. Args: + scope_version (float): The SCOPe version to use. + scope_version_train (Optional[float]): The training SCOPe version, if different. dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. - max_sequence_length (int, optional): Specifies the maximum allowed sequence length for a protein, with a - default of 1002. During data preprocessing, any proteins exceeding this length will be excluded from further - processing. **kwargs: Additional keyword arguments passed to DynamicDataset and XYBaseDataModule. """ - # -- Index for columns of processed `data.pkl` (derived from `_get_swiss_to_go_mapping` & `_graph_to_raw_dataset` - # "swiss_id" at row index 0 - # "accession" at row index 1 - # "go_ids" at row index 2 - # "sequence" at row index 3 - # labels starting from row index 4 + # -- Index for columns of processed `data.pkl` (derived from `_graph_to_raw_dataset`) + # "id" at row index 0 + # "sids" at row index 1 + # "sequence" at row index 2 + # labels starting from row index 3 _ID_IDX: int = 0 _DATA_REPRESENTATION_IDX: int = 2 # here `sequence` column _LABELS_START_IDX: int = 3 @@ -97,7 +106,7 @@ def _get_scope_url(data_type: str, version_number: float) -> str: # ------------------------------ Phase: Prepare data ----------------------------------- def _download_required_data(self) -> str: """ - Downloads the required raw data related to Gene Ontology (GO) and Swiss-UniProt dataset. + Downloads the required raw data for SCOPe and PDB sequence datasets. Returns: str: Path to the downloaded data. @@ -106,6 +115,12 @@ def _download_required_data(self) -> str: return self._download_scope_raw_data() def _download_pdb_sequence_data(self) -> None: + """ + Downloads and unzips the PDB sequence dataset from the RCSB PDB repository. + + The file is downloaded as a temporary gzip file, which is then extracted to the + specified directory. + """ pdb_seq_file_path = os.path.join(self.raw_dir, self.raw_file_names_dict["PDB"]) os.makedirs(os.path.dirname(pdb_seq_file_path), exist_ok=True) @@ -141,8 +156,17 @@ def _download_pdb_sequence_data(self) -> None: print(f"Removed temporary file {temp_filename}") def _download_scope_raw_data(self) -> str: + """ + Downloads the raw SCOPe dataset files (CLA, HIE, DES, and COM). + + Each file is downloaded from the SCOPe repository and saved to the specified directory. + Files are only downloaded if they do not already exist. + + Returns: + str: A dummy path to indicate completion (can be extended for custom behavior). + """ os.makedirs(self.raw_dir, exist_ok=True) - for data_type in ["CLA", "COM", "HIE", "DES"]: + for data_type in ["CLA", "HIE", "DES"]: data_file_name = self.raw_file_names_dict[data_type] scope_path = os.path.join(self.raw_dir, data_file_name) if not os.path.isfile(scope_path): @@ -157,6 +181,15 @@ def _download_scope_raw_data(self) -> str: return "dummy/path" def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + """ + Extracts the class hierarchy from SCOPe data and computes its transitive closure. + + Args: + data_path (str): Path to the processed SCOPe dataset. + + Returns: + nx.DiGraph: A directed acyclic graph representing the SCOPe class hierarchy. + """ print("Extracting class hierarchy...") df_scope = self._get_scope_data() @@ -177,6 +210,12 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: return nx.transitive_closure_dag(g) def _get_scope_data(self) -> pd.DataFrame: + """ + Merges and preprocesses the SCOPe classification, hierarchy, and description files into a unified DataFrame. + + Returns: + pd.DataFrame: A DataFrame containing combined SCOPe data with classification and hierarchy details. + """ df_cla = self._get_classification_data() df_hie = self._get_hierarchy_data() df_des = self._get_node_description_data() @@ -190,7 +229,12 @@ def _get_scope_data(self) -> pd.DataFrame: return df_all def _get_classification_data(self) -> pd.DataFrame: - # Load and preprocess CLA file + """ + Parses and processes the SCOPe CLA (classification) file. + + Returns: + pd.DataFrame: A DataFrame containing classification details, including hierarchy levels. + """ df_cla = pd.read_csv( os.path.join(self.raw_dir, self.raw_file_names_dict["CLA"]), sep="\t", @@ -222,7 +266,12 @@ def _get_classification_data(self) -> pd.DataFrame: return df_cla def _get_hierarchy_data(self) -> pd.DataFrame: - # Load and preprocess HIE file + """ + Parses and processes the SCOPe HIE (hierarchy) file. + + Returns: + pd.DataFrame: A DataFrame containing hierarchy details, including parent-child relationships. + """ df_hie = pd.read_csv( os.path.join(self.raw_dir, self.raw_file_names_dict["HIE"]), sep="\t", @@ -243,8 +292,13 @@ def _get_hierarchy_data(self) -> pd.DataFrame: df_hie["sunid"] = df_hie["sunid"].astype("int64") return df_hie - def _get_node_description_data(self): - # Load and preprocess HIE file + def _get_node_description_data(self) -> pd.DataFrame: + """ + Parses and processes the SCOPe DES (description) file. + + Returns: + pd.DataFrame: A DataFrame containing node-level descriptions from the SCOPe dataset. + """ df_des = pd.read_csv( os.path.join(self.raw_dir, self.raw_file_names_dict["DES"]), sep="\t", @@ -260,6 +314,38 @@ def _get_node_description_data(self): return df_des def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: + """ + Processes a directed acyclic graph (DAG) to generate a raw dataset in DataFrame format. This dataset includes + chain-level sequences and their corresponding labels based on the hierarchical structure of the associated domains. + + The process: + - Extracts SCOPe domain identifiers (sids) from the graph. + - Retrieves class labels for each domain based on all applicable taxonomy levels. + - Fetches the chain-level sequences from the Protein Data Bank (PDB) for each domain. + - For each sequence, identifies all domains associated with the same chain and assigns their corresponding labels. + + Notes: + - SCOPe hierarchy levels are used as labels, with each level represented by a column. The value in each column + indicates whether a PDB chain is associated with that particular hierarchy level. + - PDB chains are treated as samples. The method considers only domains that are mapped to the selected hierarchy levels. + + Data Format: pd.DataFrame + - Column 0 : id (Unique identifier for each sequence entry) + - Column 1 : sids (List of domain identifiers associated with the sequence) + - Column 2 : sequence (Amino acid sequence of the chain) + - Column 3 to Column "n": Each column corresponds to a SCOPe class hierarchy level with a value + of True/False indicating whether the chain is associated with the corresponding level. + + Args: + graph (nx.DiGraph): The class hierarchy graph. + + Returns: + pd.DataFrame: The raw dataset created from the graph. + + Raises: + RuntimeError: If no sunids are selected. + AssertionError: If the input data is insufficient for encoding or validation fails. + """ print(f"Process graph") sids = nx.get_node_attributes(graph, "sid") @@ -278,7 +364,6 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: # Remove root node, as it will True for all instances sun_ids.pop("root", None) - # data_df = pd.DataFrame(OrderedDict(sun_id=sun_ids, sids=sids_list)) if not sun_ids: raise RuntimeError("No sunid selected.") @@ -288,6 +373,8 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: df_cla = df_cla[["sid", "sunid"] + hierarchy_levels] + # This filtering make sures to consider only domains that belongs to each `selected` hierarchy level + # So, that our data has domains that maps to all levels of the taxonomy for level, selected_sun_ids in sun_ids.items(): if selected_sun_ids: df_cla = df_cla[ @@ -351,15 +438,16 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: sequence_hierarchy_df = sequence_hierarchy_df[ ["id", "sids", "sequence"] + encoded_target_columns ] - - # This filters the DataFrame to include only the rows where at least one value in the row from 5th column - # onwards is True/non-zero. - sequence_hierarchy_df = sequence_hierarchy_df[ - sequence_hierarchy_df.iloc[:, self._LABELS_START_IDX :].any(axis=1) - ] return sequence_hierarchy_df def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]: + """ + Parses the PDB sequence file to create a mapping of PDB IDs and chain sequences. + + Returns: + Dict[str, Dict[str, str]]: A nested dictionary where keys are PDB IDs (lowercase), + and values are dictionaries mapping chain IDs (lowercase) to their corresponding sequences. + """ pdb_chain_seq_mapping: Dict[str, Dict[str, str]] = {} for record in SeqIO.parse( os.path.join(self.raw_dir, self.raw_file_names_dict["PDB"]), "fasta" @@ -375,6 +463,18 @@ def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]: def _update_or_add_sequence( sequence, row, sequence_hierarchy_df, encoded_col_names ): + """ + Updates an existing sequence entry or adds a new one to the DataFrame. + + Args: + sequence (str): Amino acid sequence of the chain. + row (pd.Series): Row data containing SCOPe hierarchy levels and associated values. + sequence_hierarchy_df (pd.DataFrame): DataFrame storing sequences and their hierarchy labels. + encoded_col_names (Dict[str, List[str]]): Mapping of hierarchy levels to encoded column names. + + Raises: + AssertionError: If a sequence instance belongs to more than one hierarchy level. + """ if sequence in sequence_hierarchy_df.index: # Update encoded columns only if they are True for col in encoded_col_names: @@ -397,7 +497,53 @@ def _update_or_add_sequence( sequence_hierarchy_df.loc[sequence] = new_row # ------------------------------ Phase: Setup data ----------------------------------- + def setup_processed(self) -> None: + """ + Transform and prepare processed data for the SCOPe dataset. + + Main function of this method is to transform `data.pkl` into a model input data format (`data.pt`), + ensuring that the data is in a format compatible for input to the model. + The transformed data must contain the following keys: `ident`, `features`, `labels`, and `group`. + This method uses a subclass of Data Reader to perform the transformation. + + It will transform the data related to `scope_version_train`, if specified. + """ + super().setup_processed() + + # Transform the data related to "scope_version_train" to encoded data, if it doesn't exist + if self.scope_version_train is not None and not os.path.isfile( + os.path.join( + self._scope_version_train_obj.processed_dir, + self._scope_version_train_obj.processed_file_names_dict["data"], + ) + ): + print( + f"Missing encoded data related to train version: {self.scope_version_train}" + ) + print("Calling the setup method related to it") + self._scope_version_train_obj.setup() + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: + """ + Loads data from a pickled file and yields individual dictionaries for each row. + + The pickled file is expected to contain rows with the following structure: + - Data at row index `self._ID_IDX`: ID of go data instance + - Data at row index `self._DATA_REPRESENTATION_IDX`: Sequence representation of protein + - Data from row index `self._LABELS_START_IDX` onwards: Labels + + This method is used by `_load_data_from_file` to generate dictionaries that are then + processed and converted into a list of dictionaries containing the features and labels. + + Args: + input_file_path (str): The path to the pickled input file. + + Yields: + Dict[str, Any]: A dictionary containing: + - `features` (str): The sequence data from the file. + - `labels` (np.ndarray): A boolean array of labels starting from row index 4. + - `ident` (Any): The identifier from row index 0. + """ with open(input_file_path, "rb") as input_file: df = pd.read_pickle(input_file) for row in df.values: @@ -412,9 +558,33 @@ def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, No # ------------------------------ Phase: Dynamic Splits ----------------------------------- def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """ + Loads encoded/transformed data and generates training, validation, and test splits. + + This method first loads encoded data from a file named `data.pt`, which is derived from either + `scope_version` or `scope_version_train`. It then splits the data into training, validation, and test sets. + + If `scope_version_train` is provided: + - Loads additional encoded data from `scope_version_train`. + - Splits this data into training and validation sets, while using the test set from `scope_version`. + - Prunes the test set from `scope_version` to include only labels that exist in `scope_version_train`. + + If `scope_version_train` is not provided: + - Splits the data from `scope_version` into training, validation, and test sets without modification. + + Raises: + FileNotFoundError: If the required `data.pt` file(s) do not exist. Ensure that `prepare_data` + and/or `setup` methods have been called to generate the dataset files. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: A tuple containing three DataFrames: + - Training set + - Validation set + - Test set + """ try: filename = self.processed_file_names_dict["data"] - data_go = torch.load( + data_scope_version = torch.load( os.path.join(self.processed_dir, filename), weights_only=False ) except FileNotFoundError: @@ -423,20 +593,102 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" ) - df_go_data = pd.DataFrame(data_go) - train_df_go, df_test = self.get_test_split( - df_go_data, seed=self.dynamic_data_split_seed + df_scope_version = pd.DataFrame(data_scope_version) + train_df_scope_ver, df_test_scope_ver = self.get_test_split( + df_scope_version, seed=self.dynamic_data_split_seed ) - # Get all splits - df_train, df_val = self.get_train_val_splits_given_test( - train_df_go, - df_test, - seed=self.dynamic_data_split_seed, - ) + if self.scope_version_train is not None: + # Load encoded data derived from "scope_version_train" + try: + filename_train = ( + self._scope_version_train_obj.processed_file_names_dict["data"] + ) + data_scope_train_version = torch.load( + os.path.join( + self._scope_version_train_obj.processed_dir, filename_train + ), + weights_only=False, + ) + except FileNotFoundError: + raise FileNotFoundError( + f"File data.pt doesn't exists related to scope_version_train {self.scope_version_train}." + f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" + ) + + df_scope_train_version = pd.DataFrame(data_scope_train_version) + # Get train/val split of data based on "scope_version_train", but + # using test set from "scope_version" + df_train, df_val = self.get_train_val_splits_given_test( + df_scope_train_version, + df_test_scope_ver, + seed=self.dynamic_data_split_seed, + ) + # Modify test set from "scope_version" to only include the labels that + # exists in "scope_version_train", all other entries remains same. + df_test = self._setup_pruned_test_set(df_test_scope_ver) + else: + # Get all splits based on "scope_version" + df_train, df_val = self.get_train_val_splits_given_test( + train_df_scope_ver, + df_test_scope_ver, + seed=self.dynamic_data_split_seed, + ) + df_test = df_test_scope_ver return df_train, df_val, df_test + def _setup_pruned_test_set( + self, df_test_scope_version: pd.DataFrame + ) -> pd.DataFrame: + """ + Create a test set with the same leaf nodes, but use only classes that appear in the training set. + + Args: + df_test_scope_version (pd.DataFrame): The test dataset. + + Returns: + pd.DataFrame: The pruned test dataset. + """ + # TODO: find a more efficient way to do this + filename_old = "classes.txt" + # filename_new = f"classes_v{self.scope_version_train}.txt" + # dataset = torch.load(os.path.join(self.processed_dir, "test.pt")) + + # Load original classes (from the current ChEBI version - scope_version) + with open(os.path.join(self.processed_dir_main, filename_old), "r") as file: + orig_classes = file.readlines() + + # Load new classes (from the training ChEBI version - scope_version_train) + with open( + os.path.join( + self._scope_version_train_obj.processed_dir_main, filename_old + ), + "r", + ) as file: + new_classes = file.readlines() + + # Create a mapping which give index of a class from scope_version, if the corresponding + # class exists in scope_version_train, Size = Number of classes in scope_version + mapping = [ + None if or_class not in new_classes else new_classes.index(or_class) + for or_class in orig_classes + ] + + # Iterate over each data instance in the test set which is derived from scope_version + for _, row in df_test_scope_version.iterrows(): + # Size = Number of classes in scope_version_train + new_labels = [False for _ in new_classes] + for ind, label in enumerate(row["labels"]): + # If the scope_version class exists in the scope_version_train and has a True label, + # set the corresponding label in new_labels to True + if mapping[ind] is not None and label: + new_labels[mapping[ind]] = label + # Update the labels from test instance from scope_version to the new labels, which are compatible to both versions + row["labels"] = new_labels + + return df_test_scope_version + # ------------------------------ Phase: Raw Properties ----------------------------------- @property def base_dir(self) -> str: @@ -461,7 +713,6 @@ def raw_file_names_dict(self) -> dict: "CLA": "cla.txt", "DES": "des.txt", "HIE": "hie.txt", - "COM": "com.txt", "PDB": "pdb_sequences.txt", } From c3ba8da7d46c27ef2c96ed0eae6d3d0500d9b541 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 24 Jan 2025 15:53:13 +0100 Subject: [PATCH 38/71] scope: add OverX classes and their derivaties --- chebai/preprocessing/datasets/scope/scope.py | 134 +++++++++++++++++-- 1 file changed, 124 insertions(+), 10 deletions(-) diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py index 3ac8e6aa..c0a790e6 100644 --- a/chebai/preprocessing/datasets/scope/scope.py +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -121,7 +121,9 @@ def _download_pdb_sequence_data(self) -> None: The file is downloaded as a temporary gzip file, which is then extracted to the specified directory. """ - pdb_seq_file_path = os.path.join(self.raw_dir, self.raw_file_names_dict["PDB"]) + pdb_seq_file_path = os.path.join( + self.scope_root_dir, self.raw_file_names_dict["PDB"] + ) os.makedirs(os.path.dirname(pdb_seq_file_path), exist_ok=True) if not os.path.isfile(pdb_seq_file_path): @@ -450,7 +452,7 @@ def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]: """ pdb_chain_seq_mapping: Dict[str, Dict[str, str]] = {} for record in SeqIO.parse( - os.path.join(self.raw_dir, self.raw_file_names_dict["PDB"]), "fasta" + os.path.join(self.scope_root_dir, self.raw_file_names_dict["PDB"]), "fasta" ): pdb_id, chain = record.id.split("_") if str(record.seq): @@ -655,11 +657,11 @@ def _setup_pruned_test_set( # filename_new = f"classes_v{self.scope_version_train}.txt" # dataset = torch.load(os.path.join(self.processed_dir, "test.pt")) - # Load original classes (from the current ChEBI version - scope_version) + # Load original classes (from the current SCOPe version - scope_version) with open(os.path.join(self.processed_dir_main, filename_old), "r") as file: orig_classes = file.readlines() - # Load new classes (from the training ChEBI version - scope_version_train) + # Load new classes (from the training SCOPe version - scope_version_train) with open( os.path.join( self._scope_version_train_obj.processed_dir_main, filename_old @@ -690,15 +692,25 @@ def _setup_pruned_test_set( return df_test_scope_version # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + def scope_root_dir(self) -> str: + """ + Returns the root directory of scope data. + + Returns: + str: The path to the base directory, which is "data/GO_UniProt". + """ + return os.path.join("data", "SCOPe") + @property def base_dir(self) -> str: """ - Returns the base directory path for storing GO-Uniprot data. + Returns the base directory path for storing SCOPe data. Returns: str: The path to the base directory, which is "data/GO_UniProt". """ - return os.path.join("data", "SCOPe", f"version_{self.scope_version}") + return os.path.join(self.scope_root_dir, f"version_{self.scope_version}") @property def raw_file_names_dict(self) -> dict: @@ -707,7 +719,6 @@ def raw_file_names_dict(self) -> dict: Returns: dict: A dictionary mapping dataset names to their respective file names. - For example, {"GO": "go-basic.obo", "SwissUniProt": "uniprot_sprot.dat"}. """ return { "CLA": "cla.txt", @@ -717,13 +728,32 @@ def raw_file_names_dict(self) -> dict: } -class SCOPE(_SCOPeDataExtractor): +class _SCOPeOverX(_SCOPeDataExtractor, ABC): + """ + A class for extracting data from the SCOPe dataset with a threshold for selecting classes/labels based on + the number of subclasses. + + This class is designed to filter SCOPe classes/labels based on a specified threshold, selecting only those classes + which have a certain number of subclasses in the hierarchy. + + Attributes: + READER (dr.ProteinDataReader): The reader used for reading the dataset. + THRESHOLD (int): The threshold for selecting classes/labels based on the number of subclasses. + + """ + READER = ProteinDataReader - THRESHOLD = 2143 + THRESHOLD: int = None @property def _name(self) -> str: - return "test" + """ + Returns the name of the dataset. + + Returns: + str: The dataset name, formatted with the current threshold. + """ + return f"SCOPe{self.THRESHOLD}" def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict: # Filter nodes and create a dictionary of node and out-degree @@ -745,6 +775,90 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict: return sorted_dict +class _SCOPeOverXPartial(_SCOPeOverX, ABC): + """ + Dataset that doesn't use the full SCOPe dataset, but extracts a part of SCOPe (subclasses of a given top class) + + Attributes: + top_class_sunid (int): The Sun-ID of the top class from which to extract subclasses. + """ + + def __init__(self, top_class_sunid: int, **kwargs): + """ + Initializes the _SCOPeOverXPartial dataset. + + Args: + top_class_sunid (int): The Sun-ID of the top class from which to extract subclasses. + **kwargs: Additional keyword arguments passed to the superclass initializer. + """ + if "top_class_sunid" not in kwargs: + kwargs["top_class_sunid"] = top_class_sunid + + self.top_class_sunid: int = top_class_sunid + super().__init__(**kwargs) + + @property + def processed_dir_main(self) -> str: + """ + Returns the main processed data directory specific to the top class. + + Returns: + str: The processed data directory path. + """ + return os.path.join( + self.base_dir, + self._name, + f"partial_{self.top_class_sunid}", + "processed", + ) + + def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + """ + Extracts a subset of SCOPe based on subclasses of the top class ID. + + This method calls the superclass method to extract the full class hierarchy, + then extracts the subgraph containing only the descendants of the top class ID, including itself. + + Args: + data_path (str): The file path to the SCOPe ontology file. + + Returns: + nx.DiGraph: The extracted class hierarchy as a directed graph, limited to the + descendants of the top class ID. + """ + g = super()._extract_class_hierarchy(data_path) + g = g.subgraph( + list(g.successors(self.top_class_sunid)) + [self.top_class_sunid] + ) + return g + + +class SCOPeOver2000(_SCOPeOverX): + """ + A class for extracting data from the SCOPe dataset with a threshold of 2000 for selecting classes. + + Inherits from `_SCOPeOverX` and sets the threshold for selecting classes to 2000. + + Attributes: + THRESHOLD (int): The threshold for selecting classes (2000). + """ + + THRESHOLD: int = 2000 + + +class SCOPeOverPartial2000(_SCOPeOverXPartial): + """ + A class for extracting data from the SCOPe dataset with a threshold of 2000 for selecting classes. + + Inherits from `_SCOPeOverXPartial` and sets the threshold for selecting classes to 2000. + + Attributes: + THRESHOLD (int): The threshold for selecting classes (2000). + """ + + THRESHOLD: int = 2000 + + if __name__ == "__main__": scope = SCOPE(scope_version=2.08) g = scope._extract_class_hierarchy("d") From 764b812fcf110058042f5e047be9ac7ee9f8a972 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 24 Jan 2025 22:29:50 +0100 Subject: [PATCH 39/71] scope: modify select classes and labels save operation --- chebai/preprocessing/datasets/scope/scope.py | 75 +++++++++++--------- 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py index c0a790e6..7108170a 100644 --- a/chebai/preprocessing/datasets/scope/scope.py +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -13,9 +13,9 @@ import gzip import os import shutil -from abc import ABC +from abc import ABC, abstractmethod from tempfile import NamedTemporaryFile -from typing import Any, Dict, Generator, Optional, Tuple +from typing import Any, Dict, Generator, List, Optional, Tuple import networkx as nx import pandas as pd @@ -350,21 +350,7 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: """ print(f"Process graph") - sids = nx.get_node_attributes(graph, "sid") - levels = nx.get_node_attributes(graph, "level") - - sun_ids = {} - sids_list = [] - - selected_sids_dict = self.select_classes(graph) - - for sun_id, level in levels.items(): - if sun_id in selected_sids_dict: - sun_ids.setdefault(level, []).append(sun_id) - sids_list.append(sids.get(sun_id)) - - # Remove root node, as it will True for all instances - sun_ids.pop("root", None) + sun_ids = self.select_classes(graph) if not sun_ids: raise RuntimeError("No sunid selected.") @@ -440,6 +426,10 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: sequence_hierarchy_df = sequence_hierarchy_df[ ["id", "sids", "sequence"] + encoded_target_columns ] + + with open(os.path.join(self.processed_dir_main, "classes.txt"), "wt") as fout: + fout.writelines(str(sun_id) + "\n" for sun_id in encoded_target_columns) + return sequence_hierarchy_df def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]: @@ -498,6 +488,11 @@ def _update_or_add_sequence( new_row["sids"] = [row["sid"]] sequence_hierarchy_df.loc[sequence] = new_row + @abstractmethod + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict[str, List[int]]: + # Override the return type of the method from superclass + pass + # ------------------------------ Phase: Setup data ----------------------------------- def setup_processed(self) -> None: """ @@ -755,24 +750,36 @@ def _name(self) -> str: """ return f"SCOPe{self.THRESHOLD}" - def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict: - # Filter nodes and create a dictionary of node and out-degree - sun_ids_dict = { - node: g.out_degree(node) # Store node and its out-degree - for node in g.nodes - if g.out_degree(node) >= self.THRESHOLD - } + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict[str, List[int]]: + """ + Selects classes from the SCOPe dataset based on the number of successors meeting a specified threshold. - # Return a sorted dictionary (by out-degree or node id) - sorted_dict = dict( - sorted(sun_ids_dict.items(), key=lambda item: item[0], reverse=False) - ) + This method iterates over the nodes in the graph, counting the number of successors for each node. + Nodes with a number of successors greater than or equal to the defined threshold are selected. + + Note: + The input graph must be transitive closure of a directed acyclic graph. - filename = "classes.txt" - with open(os.path.join(self.processed_dir_main, filename), "wt") as fout: - fout.writelines(str(sun_id) + "\n" for sun_id in sorted_dict.keys()) + Args: + g (nx.Graph): The graph representing the dataset. + *args: Additional positional arguments (not used). + **kwargs: Additional keyword arguments (not used). + + Returns: + Dict: A dict containing selected nodes at each hierarchy level. - return sorted_dict + Notes: + - The `THRESHOLD` attribute should be defined in the subclass of this class. + """ + selected_sunids_for_level = {} + for node, attr_dict in g.nodes(data=True): + if g.out_degree(node) >= self.THRESHOLD: + selected_sunids_for_level.setdefault(attr_dict["level"], []).append( + node + ) + # Remove root node, as it will True for all instances + selected_sunids_for_level.pop("root", None) + return selected_sunids_for_level class _SCOPeOverXPartial(_SCOPeOverX, ABC): @@ -860,6 +867,6 @@ class SCOPeOverPartial2000(_SCOPeOverXPartial): if __name__ == "__main__": - scope = SCOPE(scope_version=2.08) - g = scope._extract_class_hierarchy("d") + scope = SCOPeOver2000(scope_version=2.08) + g = scope._extract_class_hierarchy("dummy/path") scope._graph_to_raw_dataset(g) From b23f1f6c552a0baf1846bedf97229eacf34b86cc Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 24 Jan 2025 22:30:19 +0100 Subject: [PATCH 40/71] scope: data config --- configs/data/scope/scope2000.yml | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 configs/data/scope/scope2000.yml diff --git a/configs/data/scope/scope2000.yml b/configs/data/scope/scope2000.yml new file mode 100644 index 00000000..92dbabde --- /dev/null +++ b/configs/data/scope/scope2000.yml @@ -0,0 +1,3 @@ +class_path: chebai.preprocessing.datasets.scope.scope.SCOPeOver2000 +init_args: + scope_version: 2.08 From 63705721533b13a538f8fafab1601ddd385b3b66 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 24 Jan 2025 22:33:20 +0100 Subject: [PATCH 41/71] deepgo: remove label_number from docstring --- chebai/preprocessing/datasets/deepGO/go_uniprot.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/chebai/preprocessing/datasets/deepGO/go_uniprot.py b/chebai/preprocessing/datasets/deepGO/go_uniprot.py index 3c957e6c..1b0eb2aa 100644 --- a/chebai/preprocessing/datasets/deepGO/go_uniprot.py +++ b/chebai/preprocessing/datasets/deepGO/go_uniprot.py @@ -616,9 +616,6 @@ class _GOUniProtOverX(_GOUniProtDataExtractor, ABC): Attributes: READER (dr.ProteinDataReader): The reader used for reading the dataset. THRESHOLD (int): The threshold for selecting classes based on the number of subclasses. - - Property: - label_number (int): The number of labels in the dataset. This property must be implemented by subclasses. """ READER: dr.ProteinDataReader = dr.ProteinDataReader From 191c979ab514c0b59e5555530e70664d13605cd5 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 26 Jan 2025 23:27:37 +0100 Subject: [PATCH 42/71] ffn: update for as per deepgo2 mlp architecture --- chebai/models/ffn.py | 120 +++++++++++++++++++++++++++++++++++++----- configs/model/ffn.yml | 2 - 2 files changed, 106 insertions(+), 16 deletions(-) diff --git a/chebai/models/ffn.py b/chebai/models/ffn.py index ca1f6f22..7ab3d5f9 100644 --- a/chebai/models/ffn.py +++ b/chebai/models/ffn.py @@ -1,28 +1,36 @@ -from typing import Dict, Any, Tuple +from typing import Any, Dict, List, Optional, Tuple -from chebai.models import ChebaiBaseNet import torch -from torch import Tensor +from torch import Tensor, nn + +from chebai.models import ChebaiBaseNet class FFN(ChebaiBaseNet): + # Reference: https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/models.py#L121-L139 NAME = "FFN" def __init__( self, - input_size: int = 1000, - num_hidden_layers: int = 3, - hidden_size: int = 128, + input_size: int, + hidden_layers: List[int] = [ + 1024, + ], **kwargs ): super().__init__(**kwargs) - self.layers = torch.nn.ModuleList() - self.layers.append(torch.nn.Linear(input_size, hidden_size)) - for _ in range(num_hidden_layers): - self.layers.append(torch.nn.Linear(hidden_size, hidden_size)) - self.layers.append(torch.nn.Linear(hidden_size, self.out_dim)) + layers = [] + current_layer_input_size = input_size + for hidden_dim in hidden_layers: + layers.append(MLPBlock(current_layer_input_size, hidden_dim)) + layers.append(Residual(MLPBlock(current_layer_input_size, hidden_dim))) + current_layer_input_size = hidden_dim + + layers.append(torch.nn.Linear(current_layer_input_size, self.out_dim)) + layers.append(nn.Sigmoid()) + self.model = nn.Sequential(*layers) def _get_prediction_and_labels(self, data, labels, model_output): d = model_output["logits"] @@ -56,6 +64,90 @@ def _process_for_loss( def forward(self, data, **kwargs): x = data["features"] - for layer in self.layers: - x = torch.relu(layer(x)) - return {"logits": x} + return {"logits": self.model(x)} + + +class Residual(nn.Module): + """ + A residual layer that adds the output of a function to its input. + + Args: + fn (nn.Module): The function to be applied to the input. + + References: + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/base.py#L6-L35 + """ + + def __init__(self, fn): + """ + Initialize the Residual layer with a given function. + + Args: + fn (nn.Module): The function to be applied to the input. + """ + super().__init__() + self.fn = fn + + def forward(self, x): + """ + Forward pass of the Residual layer. + + Args: + x: Input tensor. + + Returns: + torch.Tensor: The input tensor added to the result of applying the function `fn` to it. + """ + return x + self.fn(x) + + +class MLPBlock(nn.Module): + """ + A basic Multi-Layer Perceptron (MLP) block with one fully connected layer. + + Args: + in_features (int): The number of input features. + output_size (int): The number of output features. + bias (boolean): Add bias to the linear layer + layer_norm (boolean): Apply layer normalization + dropout (float): The dropout value + activation (nn.Module): The activation function to be applied after each fully connected layer. + + References: + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/base.py#L38-L73 + + Example: + ```python + # Create an MLP block with 2 hidden layers and ReLU activation + mlp_block = MLPBlock(input_size=64, output_size=10, activation=nn.ReLU()) + + # Apply the MLP block to an input tensor + input_tensor = torch.randn(32, 64) + output = mlp_block(input_tensor) + ``` + """ + + def __init__( + self, + in_features, + out_features, + bias=True, + layer_norm=True, + dropout=0.1, + activation=nn.ReLU, + ): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias) + self.activation = activation() + self.layer_norm: Optional[nn.LayerNorm] = ( + nn.LayerNorm(out_features) if layer_norm else None + ) + self.dropout: Optional[nn.Dropout] = nn.Dropout(dropout) if dropout else None + + def forward(self, x): + x = self.activation(self.linear(x)) + if self.layer_norm: + x = self.layer_norm(x) + if self.dropout: + x = self.dropout(x) + return x diff --git a/configs/model/ffn.yml b/configs/model/ffn.yml index 193c6f64..ba94a43e 100644 --- a/configs/model/ffn.yml +++ b/configs/model/ffn.yml @@ -2,6 +2,4 @@ class_path: chebai.models.ffn.FFN init_args: optimizer_kwargs: lr: 1e-3 - hidden_size: 128 - num_hidden_layers: 3 input_size: 2560 From b7ca0e54cdd31405150aa7e75dd3b00aeda4f1bc Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 27 Jan 2025 11:24:33 +0100 Subject: [PATCH 43/71] scope: map invalid amino acids to "X" --- chebai/preprocessing/datasets/scope/scope.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py index 7108170a..99840448 100644 --- a/chebai/preprocessing/datasets/scope/scope.py +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -12,6 +12,7 @@ import gzip import os +import re import shutil from abc import ABC, abstractmethod from tempfile import NamedTemporaryFile @@ -441,14 +442,18 @@ def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]: and values are dictionaries mapping chain IDs (lowercase) to their corresponding sequences. """ pdb_chain_seq_mapping: Dict[str, Dict[str, str]] = {} + valid_amino_acids = "".join(ProteinDataReader.AA_LETTER) + for record in SeqIO.parse( os.path.join(self.scope_root_dir, self.raw_file_names_dict["PDB"]), "fasta" ): pdb_id, chain = record.id.split("_") if str(record.seq): - pdb_chain_seq_mapping.setdefault(pdb_id.lower(), {})[chain.lower()] = ( - str(record.seq) - ) + sequence = re.sub(f"[^{valid_amino_acids}]", "X", str(record.seq)) + + pdb_chain_seq_mapping.setdefault(pdb_id.lower(), {})[ + chain.lower() + ] = sequence return pdb_chain_seq_mapping @staticmethod From ad24fa71ad357faf65b8aaf2dd27f7f8c74eee42 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 4 Feb 2025 09:40:56 +0100 Subject: [PATCH 44/71] fix MLPBlock hidden_dim --- chebai/models/ffn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/models/ffn.py b/chebai/models/ffn.py index 7ab3d5f9..cd32086e 100644 --- a/chebai/models/ffn.py +++ b/chebai/models/ffn.py @@ -25,7 +25,7 @@ def __init__( current_layer_input_size = input_size for hidden_dim in hidden_layers: layers.append(MLPBlock(current_layer_input_size, hidden_dim)) - layers.append(Residual(MLPBlock(current_layer_input_size, hidden_dim))) + layers.append(Residual(MLPBlock(hidden_dim, hidden_dim))) current_layer_input_size = hidden_dim layers.append(torch.nn.Linear(current_layer_input_size, self.out_dim)) From 357752fa25ecd7e40221ea2e762d4d01bb3c9098 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 7 Feb 2025 13:20:41 +0100 Subject: [PATCH 45/71] esm2 reader: save reader to default global data dir --- chebai/preprocessing/reader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 88e4fedd..7e943eb5 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -509,7 +509,7 @@ class ESM2EmbeddingReader(DataReader): def __init__( self, - save_model_dir: str, + save_model_dir: str = os.path.join("data", "esm2_reader"), model_name: str = "esm2_t36_3B_UR50D", device: Optional[torch.device] = None, truncation_length: int = 1022, @@ -617,7 +617,7 @@ def load_hub_workaround(self, url) -> torch.Tensor: return data @staticmethod - def name() -> None: + def name() -> str: """ Returns the name of the data reader. This method identifies the specific type of data reader. From a52d8dec09ba58a539df54cbf384c1d514963314 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 7 Feb 2025 13:24:28 +0100 Subject: [PATCH 46/71] configs: move configs to data specific sub-dir --- configs/data/{ => chebi}/chebi100.yml | 0 configs/data/{ => chebi}/chebi100_SELFIES.yml | 0 configs/data/{ => chebi}/chebi100_deepSMILES.yml | 0 configs/data/{ => chebi}/chebi100_mixed.yml | 0 configs/data/{ => chebi}/chebi50.yml | 0 configs/data/{ => chebi}/chebi50_mixed.yml | 0 configs/data/{ => chebi}/chebi50_partial.yml | 0 configs/data/{ => pubchem}/pubchem_SELFIES.yml | 0 configs/data/{ => pubchem}/pubchem_deepSMILES.yml | 0 configs/data/{ => pubchem}/pubchem_dissimilar.yml | 0 configs/data/{ => tox21}/tox21_moleculenet.yml | 0 11 files changed, 0 insertions(+), 0 deletions(-) rename configs/data/{ => chebi}/chebi100.yml (100%) rename configs/data/{ => chebi}/chebi100_SELFIES.yml (100%) rename configs/data/{ => chebi}/chebi100_deepSMILES.yml (100%) rename configs/data/{ => chebi}/chebi100_mixed.yml (100%) rename configs/data/{ => chebi}/chebi50.yml (100%) rename configs/data/{ => chebi}/chebi50_mixed.yml (100%) rename configs/data/{ => chebi}/chebi50_partial.yml (100%) rename configs/data/{ => pubchem}/pubchem_SELFIES.yml (100%) rename configs/data/{ => pubchem}/pubchem_deepSMILES.yml (100%) rename configs/data/{ => pubchem}/pubchem_dissimilar.yml (100%) rename configs/data/{ => tox21}/tox21_moleculenet.yml (100%) diff --git a/configs/data/chebi100.yml b/configs/data/chebi/chebi100.yml similarity index 100% rename from configs/data/chebi100.yml rename to configs/data/chebi/chebi100.yml diff --git a/configs/data/chebi100_SELFIES.yml b/configs/data/chebi/chebi100_SELFIES.yml similarity index 100% rename from configs/data/chebi100_SELFIES.yml rename to configs/data/chebi/chebi100_SELFIES.yml diff --git a/configs/data/chebi100_deepSMILES.yml b/configs/data/chebi/chebi100_deepSMILES.yml similarity index 100% rename from configs/data/chebi100_deepSMILES.yml rename to configs/data/chebi/chebi100_deepSMILES.yml diff --git a/configs/data/chebi100_mixed.yml b/configs/data/chebi/chebi100_mixed.yml similarity index 100% rename from configs/data/chebi100_mixed.yml rename to configs/data/chebi/chebi100_mixed.yml diff --git a/configs/data/chebi50.yml b/configs/data/chebi/chebi50.yml similarity index 100% rename from configs/data/chebi50.yml rename to configs/data/chebi/chebi50.yml diff --git a/configs/data/chebi50_mixed.yml b/configs/data/chebi/chebi50_mixed.yml similarity index 100% rename from configs/data/chebi50_mixed.yml rename to configs/data/chebi/chebi50_mixed.yml diff --git a/configs/data/chebi50_partial.yml b/configs/data/chebi/chebi50_partial.yml similarity index 100% rename from configs/data/chebi50_partial.yml rename to configs/data/chebi/chebi50_partial.yml diff --git a/configs/data/pubchem_SELFIES.yml b/configs/data/pubchem/pubchem_SELFIES.yml similarity index 100% rename from configs/data/pubchem_SELFIES.yml rename to configs/data/pubchem/pubchem_SELFIES.yml diff --git a/configs/data/pubchem_deepSMILES.yml b/configs/data/pubchem/pubchem_deepSMILES.yml similarity index 100% rename from configs/data/pubchem_deepSMILES.yml rename to configs/data/pubchem/pubchem_deepSMILES.yml diff --git a/configs/data/pubchem_dissimilar.yml b/configs/data/pubchem/pubchem_dissimilar.yml similarity index 100% rename from configs/data/pubchem_dissimilar.yml rename to configs/data/pubchem/pubchem_dissimilar.yml diff --git a/configs/data/tox21_moleculenet.yml b/configs/data/tox21/tox21_moleculenet.yml similarity index 100% rename from configs/data/tox21_moleculenet.yml rename to configs/data/tox21/tox21_moleculenet.yml From b2f51f9d4de0ff1525589ca7e186c42bde12cd58 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 10 Feb 2025 09:43:43 +0100 Subject: [PATCH 47/71] adding SCOPe50 dataset --- chebai/preprocessing/datasets/scope/scope.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py index 99840448..36bada01 100644 --- a/chebai/preprocessing/datasets/scope/scope.py +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -858,6 +858,11 @@ class SCOPeOver2000(_SCOPeOverX): THRESHOLD: int = 2000 +class SCOPeOver50(_SCOPeOverX): + + THRESHOLD = 50 + + class SCOPeOverPartial2000(_SCOPeOverXPartial): """ A class for extracting data from the SCOPe dataset with a threshold of 2000 for selecting classes. From e0bbb0e7ec3b7ab2d02e6329cf1e6845cd23fbe2 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 10 Feb 2025 09:49:48 +0100 Subject: [PATCH 48/71] add scope50 config --- configs/data/scope/scope50.yml | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 configs/data/scope/scope50.yml diff --git a/configs/data/scope/scope50.yml b/configs/data/scope/scope50.yml new file mode 100644 index 00000000..8e64ccf5 --- /dev/null +++ b/configs/data/scope/scope50.yml @@ -0,0 +1,3 @@ +class_path: chebai.preprocessing.datasets.scope.scope.SCOPeOver50 +init_args: + scope_version: 2.08 From 22aa985ddee7f18940fca8a2b35c1aa42ce06d52 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 15 Feb 2025 12:17:06 +0100 Subject: [PATCH 49/71] scope: version number should str not float - as we are not doing any processing for number so its compatible to pass it as str --- chebai/preprocessing/datasets/scope/scope.py | 16 ++++++++-------- configs/data/scope/scope2000.yml | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py index 99840448..a62fabb7 100644 --- a/chebai/preprocessing/datasets/scope/scope.py +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -37,8 +37,8 @@ class _SCOPeDataExtractor(_DynamicDataset, ABC): Additionally, it supports downloading related data like PDB sequence files. Args: - scope_version (float): The SCOPe version to use. - scope_version_train (Optional[float]): The training SCOPe version, if different. + scope_version (str): The SCOPe version to use. + scope_version_train (Optional[str]): The training SCOPe version, if different. dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. **kwargs: Additional keyword arguments passed to DynamicDataset and XYBaseDataModule. @@ -70,12 +70,12 @@ class _SCOPeDataExtractor(_DynamicDataset, ABC): def __init__( self, - scope_version: float, - scope_version_train: Optional[float] = None, + scope_version: str, + scope_version_train: Optional[str] = None, **kwargs, ): - self.scope_version: float = scope_version - self.scope_version_train: float = scope_version_train + self.scope_version: str = scope_version + self.scope_version_train: str = scope_version_train super(_SCOPeDataExtractor, self).__init__(**kwargs) @@ -89,7 +89,7 @@ def __init__( ) @staticmethod - def _get_scope_url(data_type: str, version_number: float) -> str: + def _get_scope_url(data_type: str, version_number: str) -> str: """ Generates the URL for downloading SCOPe files. @@ -872,6 +872,6 @@ class SCOPeOverPartial2000(_SCOPeOverXPartial): if __name__ == "__main__": - scope = SCOPeOver2000(scope_version=2.08) + scope = SCOPeOver2000(scope_version="2.08") g = scope._extract_class_hierarchy("dummy/path") scope._graph_to_raw_dataset(g) diff --git a/configs/data/scope/scope2000.yml b/configs/data/scope/scope2000.yml index 92dbabde..d75c807f 100644 --- a/configs/data/scope/scope2000.yml +++ b/configs/data/scope/scope2000.yml @@ -1,3 +1,3 @@ class_path: chebai.preprocessing.datasets.scope.scope.SCOPeOver2000 init_args: - scope_version: 2.08 + scope_version: "2.08" From d3fd0f28829e832a19c887a76cfd100a4690537b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 16 Feb 2025 00:47:57 +0100 Subject: [PATCH 50/71] scope: data filtering update - consider proteins domain in the dataset which maps to any selected node irrespective of the hierarchy level --- chebai/preprocessing/datasets/scope/scope.py | 113 +++++++++++-------- 1 file changed, 65 insertions(+), 48 deletions(-) diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py index b949fa0d..36db85b9 100644 --- a/chebai/preprocessing/datasets/scope/scope.py +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -351,9 +351,9 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: """ print(f"Process graph") - sun_ids = self.select_classes(graph) + selected_sun_ids_per_lvl = self.select_classes(graph) - if not sun_ids: + if not selected_sun_ids_per_lvl: raise RuntimeError("No sunid selected.") df_cla = self._get_classification_data() @@ -362,38 +362,35 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: df_cla = df_cla[["sid", "sunid"] + hierarchy_levels] - # This filtering make sures to consider only domains that belongs to each `selected` hierarchy level - # So, that our data has domains that maps to all levels of the taxonomy - for level, selected_sun_ids in sun_ids.items(): - if selected_sun_ids: - df_cla = df_cla[ - df_cla[self.SCOPE_HIERARCHY[level]].isin(selected_sun_ids) - ] - - assert ( - len(df_cla) > 1 - ), "dataframe should have more than one instance for `pd.get_dummies` to work as expected" - - df_encoded = pd.get_dummies( - df_cla, - columns=hierarchy_levels, - drop_first=False, - sparse=True, - ) + # Initialize selected target columns + df_encoded = df_cla[["sid", "sunid"]].copy() + + lvl_to_target_cols_mapping = {} + # Iterate over only the selected sun_ids (nodes) to one-hot encode them + for level, selected_sun_ids in selected_sun_ids_per_lvl.items(): + level_column = self.SCOPE_HIERARCHY[ + level + ] # Get the actual column name in df_cla + if level_column in df_cla.columns: + # Create binary encoding for only relevant sun_ids + for sun_id in selected_sun_ids: + col_name = f"{level_column}_{sun_id}" + df_encoded[col_name] = (df_cla[level_column] == sun_id).astype(bool) + lvl_to_target_cols_mapping.setdefault(level_column, []).append( + col_name + ) - pdb_chain_seq_mapping = self._parse_pdb_sequence_file() + # Filter to select only domains that atleast map to any one selected sunid in any level + df_encoded = df_encoded[df_encoded.iloc[:, 2:].any(axis=1)] - encoded_target_cols = {} - for col in hierarchy_levels: - encoded_target_cols[col] = [ - t_col for t_col in df_encoded.columns if t_col.startswith(col) - ] + pdb_chain_seq_mapping = self._parse_pdb_sequence_file() encoded_target_columns = [] for level in hierarchy_levels: - encoded_target_columns.extend(encoded_target_cols[level]) + encoded_target_columns.extend(lvl_to_target_cols_mapping[level]) sequence_hierarchy_df = pd.DataFrame(columns=["sids"] + encoded_target_columns) + df_encoded = df_encoded[["sid", "sunid"] + encoded_target_columns] for _, row in df_encoded.iterrows(): sid = row["sid"] @@ -410,14 +407,19 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: chain_sequence = pdb_to_chain_mapping.get(chain_id, None) if chain_sequence: self._update_or_add_sequence( - chain_sequence, row, sequence_hierarchy_df, encoded_target_cols + chain_sequence, + row, + sequence_hierarchy_df, + encoded_target_columns, ) - else: # Add nodes and edges for chains in the mapping for chain, chain_sequence in pdb_to_chain_mapping.items(): self._update_or_add_sequence( - chain_sequence, row, sequence_hierarchy_df, encoded_target_cols + chain_sequence, + row, + sequence_hierarchy_df, + encoded_target_columns, ) sequence_hierarchy_df.reset_index(inplace=True) @@ -427,6 +429,10 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: sequence_hierarchy_df = sequence_hierarchy_df[ ["id", "sids", "sequence"] + encoded_target_columns ] + # Ensure atleast one label is true for each protein sequence + sequence_hierarchy_df = sequence_hierarchy_df[ + sequence_hierarchy_df.iloc[:, self._LABELS_START_IDX :].any(axis=1) + ] with open(os.path.join(self.processed_dir_main, "classes.txt"), "wt") as fout: fout.writelines(str(sun_id) + "\n" for sun_id in encoded_target_columns) @@ -458,7 +464,10 @@ def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]: @staticmethod def _update_or_add_sequence( - sequence, row, sequence_hierarchy_df, encoded_col_names + sequence: str, + row: pd.Series, + sequence_hierarchy_df: pd.DataFrame, + encoded_target_columns: List[str], ): """ Updates an existing sequence entry or adds a new one to the DataFrame. @@ -467,29 +476,25 @@ def _update_or_add_sequence( sequence (str): Amino acid sequence of the chain. row (pd.Series): Row data containing SCOPe hierarchy levels and associated values. sequence_hierarchy_df (pd.DataFrame): DataFrame storing sequences and their hierarchy labels. - encoded_col_names (Dict[str, List[str]]): Mapping of hierarchy levels to encoded column names. + encoded_target_columns (List): List of column names which must be in same order in row and sequence_hierarchy_df. Raises: AssertionError: If a sequence instance belongs to more than one hierarchy level. """ if sequence in sequence_hierarchy_df.index: - # Update encoded columns only if they are True - for col in encoded_col_names: - assert ( - sum(row[encoded_col_names[col]].tolist()) == 1 - ), "A instance can belong to only one hierarchy level" - sliced_data = row[ - encoded_col_names[col] - ] # Slice starting from the second column (index 1) - # Get the column name with the True value - true_column = sliced_data.idxmax() if sliced_data.any() else None - sequence_hierarchy_df.loc[sequence, true_column] = True - - sequence_hierarchy_df.loc[sequence, "sids"].append(row["sid"]) + # Update encoded columns using bitwise OR (ensures values remain True if they were previously True) + sequence_hierarchy_df.loc[sequence, encoded_target_columns] = ( + row[encoded_target_columns] + | sequence_hierarchy_df.loc[sequence, encoded_target_columns] + ) + + sequence_hierarchy_df.at[sequence, "sids"] = sequence_hierarchy_df.at[ + sequence, "sids" + ] + [row["sid"]] else: # Add new row with sequence as the index and hierarchy data - new_row = row + new_row = row.to_dict() new_row["sids"] = [row["sid"]] sequence_hierarchy_df.loc[sequence] = new_row @@ -859,7 +864,7 @@ class SCOPeOver2000(_SCOPeOverX): class SCOPeOver50(_SCOPeOverX): - + THRESHOLD = 50 @@ -878,5 +883,17 @@ class SCOPeOverPartial2000(_SCOPeOverXPartial): if __name__ == "__main__": scope = SCOPeOver2000(scope_version="2.08") - g = scope._extract_class_hierarchy("dummy/path") + # g = scope._extract_class_hierarchy("dummy/path") + # # Save graph + # import pickle + # with open("graph.gpickle", "wb") as f: + # pickle.dump(g, f) + + # Load graph + import pickle + + with open("graph.gpickle", "rb") as f: + g = pickle.load(f) + + # print(len([node for node in g.nodes() if g.out_degree(node) > 10000])) scope._graph_to_raw_dataset(g) From c7918931235463233e003ba4a6a1253d895f507c Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 16 Feb 2025 13:35:39 +0100 Subject: [PATCH 51/71] scope: avoid data fragmentation and add progress bar --- chebai/preprocessing/datasets/scope/scope.py | 22 ++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py index 36db85b9..5b0fdd2d 100644 --- a/chebai/preprocessing/datasets/scope/scope.py +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -23,6 +23,7 @@ import requests import torch from Bio import SeqIO +from tqdm import tqdm from chebai.preprocessing.datasets.base import _DynamicDataset from chebai.preprocessing.reader import ProteinDataReader @@ -365,6 +366,9 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: # Initialize selected target columns df_encoded = df_cla[["sid", "sunid"]].copy() + # Collect all new columns in a dictionary first (avoids fragmentation) + encoded_df_columns = {} + lvl_to_target_cols_mapping = {} # Iterate over only the selected sun_ids (nodes) to one-hot encode them for level, selected_sun_ids in selected_sun_ids_per_lvl.items(): @@ -375,11 +379,17 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: # Create binary encoding for only relevant sun_ids for sun_id in selected_sun_ids: col_name = f"{level_column}_{sun_id}" - df_encoded[col_name] = (df_cla[level_column] == sun_id).astype(bool) + encoded_df_columns[col_name] = ( + df_cla[level_column] == sun_id + ).astype(bool) + lvl_to_target_cols_mapping.setdefault(level_column, []).append( col_name ) + # Convert the dictionary into a DataFrame and concatenate at once (prevents fragmentation) + df_encoded = pd.concat([df_encoded, pd.DataFrame(encoded_df_columns)], axis=1) + # Filter to select only domains that atleast map to any one selected sunid in any level df_encoded = df_encoded[df_encoded.iloc[:, 2:].any(axis=1)] @@ -392,7 +402,15 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: sequence_hierarchy_df = pd.DataFrame(columns=["sids"] + encoded_target_columns) df_encoded = df_encoded[["sid", "sunid"] + encoded_target_columns] - for _, row in df_encoded.iterrows(): + print( + f"{len(encoded_target_columns)} labels has been selected for specified threshold, " + f"Max possible size of dataset is {len(df_encoded)} rows x {len(encoded_target_columns) + 1} columns" + ) + print("Constructing data.pkl file .....") + + for _, row in tqdm( + df_encoded.iterrows(), total=len(df_encoded), desc="Processing Rows" + ): sid = row["sid"] # SID: 7-char identifier ("d" + 4-char PDB ID + chain ID ('_' for none, '.' for multiple) # + domain specifier ('_' if not needed)) From aad16d9225651b51f44d982bde44764a07211090 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 16 Feb 2025 20:25:55 +0100 Subject: [PATCH 52/71] scope: vectorized operation instead of df.itterows - https://stackoverflow.com/a/24871316/17626445 --- chebai/preprocessing/datasets/scope/scope.py | 168 ++++++++----------- 1 file changed, 74 insertions(+), 94 deletions(-) diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py index 5b0fdd2d..96258bd2 100644 --- a/chebai/preprocessing/datasets/scope/scope.py +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -23,7 +23,6 @@ import requests import torch from Bio import SeqIO -from tqdm import tqdm from chebai.preprocessing.datasets.base import _DynamicDataset from chebai.preprocessing.reader import ProteinDataReader @@ -372,9 +371,7 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: lvl_to_target_cols_mapping = {} # Iterate over only the selected sun_ids (nodes) to one-hot encode them for level, selected_sun_ids in selected_sun_ids_per_lvl.items(): - level_column = self.SCOPE_HIERARCHY[ - level - ] # Get the actual column name in df_cla + level_column = self.SCOPE_HIERARCHY[level] if level_column in df_cla.columns: # Create binary encoding for only relevant sun_ids for sun_id in selected_sun_ids: @@ -390,63 +387,71 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: # Convert the dictionary into a DataFrame and concatenate at once (prevents fragmentation) df_encoded = pd.concat([df_encoded, pd.DataFrame(encoded_df_columns)], axis=1) - # Filter to select only domains that atleast map to any one selected sunid in any level - df_encoded = df_encoded[df_encoded.iloc[:, 2:].any(axis=1)] - - pdb_chain_seq_mapping = self._parse_pdb_sequence_file() - encoded_target_columns = [] for level in hierarchy_levels: encoded_target_columns.extend(lvl_to_target_cols_mapping[level]) - sequence_hierarchy_df = pd.DataFrame(columns=["sids"] + encoded_target_columns) - df_encoded = df_encoded[["sid", "sunid"] + encoded_target_columns] - print( f"{len(encoded_target_columns)} labels has been selected for specified threshold, " - f"Max possible size of dataset is {len(df_encoded)} rows x {len(encoded_target_columns) + 1} columns" ) print("Constructing data.pkl file .....") - for _, row in tqdm( - df_encoded.iterrows(), total=len(df_encoded), desc="Processing Rows" - ): - sid = row["sid"] - # SID: 7-char identifier ("d" + 4-char PDB ID + chain ID ('_' for none, '.' for multiple) - # + domain specifier ('_' if not needed)) - assert len(sid) == 7, "sid should have 7 characters" - pdb_id, chain_id = sid[1:5], sid[5] - - pdb_to_chain_mapping = pdb_chain_seq_mapping.get(pdb_id, None) - if not pdb_to_chain_mapping: - continue - - if chain_id != "_": - chain_sequence = pdb_to_chain_mapping.get(chain_id, None) - if chain_sequence: - self._update_or_add_sequence( - chain_sequence, - row, - sequence_hierarchy_df, - encoded_target_columns, - ) - else: - # Add nodes and edges for chains in the mapping - for chain, chain_sequence in pdb_to_chain_mapping.items(): - self._update_or_add_sequence( - chain_sequence, - row, - sequence_hierarchy_df, - encoded_target_columns, - ) + df_encoded = df_encoded[["sid", "sunid"] + encoded_target_columns] - sequence_hierarchy_df.reset_index(inplace=True) - sequence_hierarchy_df.rename(columns={"index": "sequence"}, inplace=True) - sequence_hierarchy_df["id"] = range(1, len(sequence_hierarchy_df) + 1) + # Filter to select only domains that atleast map to any one selected sunid in any level + df_encoded = df_encoded[df_encoded.iloc[:, 2:].any(axis=1)] + + df_encoded["pdb_id"] = df_encoded["sid"].str[1:5] + df_encoded["chain_id"] = df_encoded["sid"].str[5] + + pdb_chain_df = self._parse_pdb_sequence_file() + + # Handle `chain_id == "_"` Case** + # Split df_encoded into two: One for specific chains, one for "all chains" ("_") + df_specific_chains = df_encoded[df_encoded["chain_id"] != "_"] + df_all_chains = df_encoded[df_encoded["chain_id"] == "_"].drop( + columns=["chain_id"] + ) + + common_pdb_ids = set(df_specific_chains["pdb_id"]) & set( + df_all_chains["pdb_id"] + ) + if common_pdb_ids: + raise RuntimeError( + f"{len(common_pdb_ids)} PDB chain IDs found in specific-chains df and all-chains df" + ) + + # Merge specific chains normally + merged_specific = df_specific_chains.merge( + pdb_chain_df, on=["pdb_id", "chain_id"], how="left" + ) + + # Merge all chains case -> Join by pdb_id (not chain_id) + merged_all_chains = df_all_chains.merge(pdb_chain_df, on="pdb_id", how="left") + + # Combine both cases + sequence_hierarchy_df = pd.concat( + [merged_specific, merged_all_chains], ignore_index=True + ).dropna(subset=["sequence"]) + + # Vectorized Aggregation Instead of Row-wise Updates + sequence_hierarchy_df = ( + sequence_hierarchy_df.groupby("sequence", as_index=False) + .agg( + { + "sid": list, # Collect all SIDs per sequence + **{ + col: "max" for col in encoded_target_columns + }, # Max works as Bitwise OR for labels + } + ) + .rename(columns={"sid": "sids"}) + ) # Rename for clarity + + sequence_hierarchy_df = sequence_hierarchy_df.assign( + id=range(1, len(sequence_hierarchy_df) + 1) + )[["id", "sids", "sequence"] + encoded_target_columns] - sequence_hierarchy_df = sequence_hierarchy_df[ - ["id", "sids", "sequence"] + encoded_target_columns - ] # Ensure atleast one label is true for each protein sequence sequence_hierarchy_df = sequence_hierarchy_df[ sequence_hierarchy_df.iloc[:, self._LABELS_START_IDX :].any(axis=1) @@ -457,64 +462,39 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: return sequence_hierarchy_df - def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]: + def _parse_pdb_sequence_file(self) -> pd.DataFrame: """ - Parses the PDB sequence file to create a mapping of PDB IDs and chain sequences. + Parses the PDB sequence file and returns a DataFrame containing PDB IDs, chain IDs, and sequences. Returns: - Dict[str, Dict[str, str]]: A nested dictionary where keys are PDB IDs (lowercase), - and values are dictionaries mapping chain IDs (lowercase) to their corresponding sequences. + pd.DataFrame: A DataFrame with columns ["pdb_id", "chain_id", "sequence"]. """ - pdb_chain_seq_mapping: Dict[str, Dict[str, str]] = {} + records = [] valid_amino_acids = "".join(ProteinDataReader.AA_LETTER) for record in SeqIO.parse( os.path.join(self.scope_root_dir, self.raw_file_names_dict["PDB"]), "fasta" ): pdb_id, chain = record.id.split("_") - if str(record.seq): - sequence = re.sub(f"[^{valid_amino_acids}]", "X", str(record.seq)) - - pdb_chain_seq_mapping.setdefault(pdb_id.lower(), {})[ - chain.lower() - ] = sequence - return pdb_chain_seq_mapping - - @staticmethod - def _update_or_add_sequence( - sequence: str, - row: pd.Series, - sequence_hierarchy_df: pd.DataFrame, - encoded_target_columns: List[str], - ): - """ - Updates an existing sequence entry or adds a new one to the DataFrame. - - Args: - sequence (str): Amino acid sequence of the chain. - row (pd.Series): Row data containing SCOPe hierarchy levels and associated values. - sequence_hierarchy_df (pd.DataFrame): DataFrame storing sequences and their hierarchy labels. - encoded_target_columns (List): List of column names which must be in same order in row and sequence_hierarchy_df. + sequence = ( + re.sub(f"[^{valid_amino_acids}]", "X", str(record.seq)) + if record.seq + else "" + ) - Raises: - AssertionError: If a sequence instance belongs to more than one hierarchy level. - """ - if sequence in sequence_hierarchy_df.index: - # Update encoded columns using bitwise OR (ensures values remain True if they were previously True) - sequence_hierarchy_df.loc[sequence, encoded_target_columns] = ( - row[encoded_target_columns] - | sequence_hierarchy_df.loc[sequence, encoded_target_columns] + # Store as a dictionary entry (list of dicts -> DataFrame later) + records.append( + { + "pdb_id": pdb_id.lower(), + "chain_id": chain.lower(), + "sequence": sequence, + } ) - sequence_hierarchy_df.at[sequence, "sids"] = sequence_hierarchy_df.at[ - sequence, "sids" - ] + [row["sid"]] + # Convert list of dictionaries to a DataFrame + pdb_chain_df = pd.DataFrame.from_records(records) - else: - # Add new row with sequence as the index and hierarchy data - new_row = row.to_dict() - new_row["sids"] = [row["sid"]] - sequence_hierarchy_df.loc[sequence] = new_row + return pdb_chain_df @abstractmethod def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict[str, List[int]]: From 13b8795ff4452d145771722dc7e9e09023337b36 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 21 Feb 2025 14:59:57 +0100 Subject: [PATCH 53/71] scope: fix multiple chain filtering --- chebai/preprocessing/datasets/scope/scope.py | 24 ++++++++------------ 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py index 96258bd2..6dff3156 100644 --- a/chebai/preprocessing/datasets/scope/scope.py +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -347,7 +347,6 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: Raises: RuntimeError: If no sunids are selected. - AssertionError: If the input data is insufficient for encoding or validation fails. """ print(f"Process graph") @@ -404,30 +403,27 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: df_encoded["pdb_id"] = df_encoded["sid"].str[1:5] df_encoded["chain_id"] = df_encoded["sid"].str[5] + # "_" (underscore) means it has no chain + df_encoded = df_encoded[df_encoded["chain_id"] != "_"] + pdb_chain_df = self._parse_pdb_sequence_file() - # Handle `chain_id == "_"` Case** - # Split df_encoded into two: One for specific chains, one for "all chains" ("_") - df_specific_chains = df_encoded[df_encoded["chain_id"] != "_"] - df_all_chains = df_encoded[df_encoded["chain_id"] == "_"].drop( + # Handle chain_id == "." - Multiple chain case + # Split df_encoded into two: One for specific chains, one for "multiple chains" (".") + df_specific_chains = df_encoded[df_encoded["chain_id"] != "."] + df_multiple_chains = df_encoded[df_encoded["chain_id"] == "."].drop( columns=["chain_id"] ) - common_pdb_ids = set(df_specific_chains["pdb_id"]) & set( - df_all_chains["pdb_id"] - ) - if common_pdb_ids: - raise RuntimeError( - f"{len(common_pdb_ids)} PDB chain IDs found in specific-chains df and all-chains df" - ) - # Merge specific chains normally merged_specific = df_specific_chains.merge( pdb_chain_df, on=["pdb_id", "chain_id"], how="left" ) # Merge all chains case -> Join by pdb_id (not chain_id) - merged_all_chains = df_all_chains.merge(pdb_chain_df, on="pdb_id", how="left") + merged_all_chains = df_multiple_chains.merge( + pdb_chain_df, on="pdb_id", how="left" + ) # Combine both cases sequence_hierarchy_df = pd.concat( From 4572272c4d93002ae3655553b3358139ec5873d7 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 21 Feb 2025 16:32:08 +0100 Subject: [PATCH 54/71] scope: tutorial for scope data exploration --- tutorials/data_exploration_scope.ipynb | 1021 ++++++++++++++++++++++++ 1 file changed, 1021 insertions(+) create mode 100644 tutorials/data_exploration_scope.ipynb diff --git a/tutorials/data_exploration_scope.ipynb b/tutorials/data_exploration_scope.ipynb new file mode 100644 index 00000000..f8d053e5 --- /dev/null +++ b/tutorials/data_exploration_scope.ipynb @@ -0,0 +1,1021 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0bd757ea-a6a0-43f8-8701-cafb44f20f6b", + "metadata": {}, + "source": [ + "# Introduction\n", + "\n", + "This notebook serves as a guide for new developers using the `chebai` package. If you just want to run the experiments, you can refer to the [README.md](https://github.com/ChEB-AI/python-chebai/blob/dev/README.md) and the [wiki](https://github.com/ChEB-AI/python-chebai/wiki) for the basic commands. This notebook explains what happens under the hood for the SCOPe dataset. It covers\n", + "- how to instantiate a data class and generate data\n", + "- how the data is processed and stored\n", + "- and how to work with different molecule encodings.\n", + "\n", + "The chebai package simplifies the handling of these datasets by **automatically creating** them as needed. This means that you do not have to input any data manually; the package will generate and organize the data files based on the parameters and encodings selected. This feature ensures that the right data is available and formatted properly. You can however provide your own data files, for instance if you want to replicate a specific experiment.\n", + "\n", + "---\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "990cc6f2-6b4a-4fa7-905f-dda183c3ec4c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Changed to project root directory: G:\\github-aditya0by0\\python-chebai\n" + ] + } + ], + "source": [ + "# To run this notebook, you need to change the working directory of the jupyter notebook to root dir of the project.\n", + "import os\n", + "\n", + "# Root directory name of the project\n", + "expected_root_dir = \"python-chebai\"\n", + "\n", + "# Check if the current directory ends with the expected root directory name\n", + "if not os.getcwd().endswith(expected_root_dir):\n", + " os.chdir(\"..\") # Move up one directory level\n", + " if os.getcwd().endswith(expected_root_dir):\n", + " print(\"Changed to project root directory:\", os.getcwd())\n", + " else:\n", + " print(\"Warning: Directory change unsuccessful. Current directory:\", os.getcwd())\n", + "else:\n", + " print(\"Already in the project root directory:\", os.getcwd())" + ] + }, + { + "cell_type": "markdown", + "id": "4550d01fc7af5ae4", + "metadata": {}, + "source": [ + "# 1. Instantiation of a Data Class\n", + "\n", + "To start working with `chebai`, you first need to instantiate a SCOPe data class. This class is responsible for managing, interacting with, and preprocessing the ChEBI chemical data." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "f3a66e07-edc9-4aa2-9cd0-d4ea58914d22", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "G:\\anaconda3\\envs\\env_chebai\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from chebai.preprocessing.datasets.scope.scope import SCOPeOver2000" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a71b7301-6195-4155-a439-f5eb3183d0f3", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:07:26.371796Z", + "start_time": "2024-10-05T21:07:26.058728Z" + } + }, + "outputs": [], + "source": [ + "scope_class = SCOPeOver2000(scope_version=\"2.08\")" + ] + }, + { + "cell_type": "markdown", + "id": "b810d7c9-4f7f-4725-9bc2-452ff2c3a89d", + "metadata": {}, + "source": [ + "\n", + "### Inheritance Hierarchy\n", + "\n", + "SCOPe data classes inherit from [`_DynamicDataset`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/base.py#L598), which in turn inherits from [`XYBaseDataModule`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/base.py#L23). Specifically:\n", + "\n", + "- **`_DynamicDataset`**: This class serves as an intermediate base class that provides additional functionality or customization for datasets that require dynamic behavior. It inherits from `XYBaseDataModule`, which provides the core methods for data loading and processing.\n", + "\n", + "- **`XYBaseDataModule`**: This is the base class for data modules, providing foundational properties and methods for handling and processing datasets, including data splitting, loading, and preprocessing.\n", + "\n", + "In summary, ChEBI data classes are designed to manage and preprocess chemical data effectively by leveraging the capabilities provided by `XYBaseDataModule` through the `_DynamicDataset` intermediary.\n", + "\n", + "\n", + "### Input parameters\n", + "A SCOPe data class can be configured with a range of parameters, including:\n", + "\n", + "- **scope_version (str)**: Specifies the version of the ChEBI database to be used. Specifying a version ensures the reproducibility of your experiments by using a consistent dataset.\n", + "\n", + "- **scope_version_train (str, optional)**: The version of ChEBI to use specifically for training and validation. If not set, the `scope_version` specified will be used for all data splits, including training, validation, and test. Defaults to `None`.\n", + "\n", + "- **splits_file_path (str, optional)**: Path to a CSV file containing data splits. If not provided, the class will handle splits internally. Defaults to `None`.\n", + "\n", + "### Additional Input Parameters\n", + "\n", + "To get more control over various aspects of data loading, processing, and splitting, you can refer to documentation of additional parameters in docstrings of the respective classes: [`_SCOPeDataExtractor`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/scope/scope.py#L31), [`XYBaseDataModule`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/base.py#L22), [`_DynamicDataset`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/base.py#L597), etc.\n" + ] + }, + { + "cell_type": "markdown", + "id": "8578b7aa-1bd9-4e50-9eee-01bfc6d5464a", + "metadata": {}, + "source": [ + "# Available SCOPe Data Classes\n", + "\n", + "__Note__: Check the code implementation of classes [here](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/scope/scope.py):\n", + "\n", + "There is a range of available dataset classes for SCOPe. Usually, you want to use `SCOPeOver2000` or `SCOPeOver50`. The number indicates the threshold for selecting label classes: SCOPe classes which have at least 2000 / 50 subclasses will be used as labels.\n", + "\n", + "Both inherit from `SCOPeOverX`. If you need a different threshold, you can create your own subclass. By default, `SCOPeOverX` uses the Protein encoding (see Section 5).\n", + "\n", + "Finally, `SCOPeOver2000Partial` selects extracts a part of SCOPe based on a given top class, with a threshold of 2000 for selecting labels.\n", + "This class inherits from `SCOPEOverXPartial`.\n" + ] + }, + { + "cell_type": "markdown", + "id": "8456b545-88c5-401d-baa5-47e8ae710f04", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "ed973fb59df11849", + "metadata": {}, + "source": [ + "# 2. Preparation / Setup Methods\n", + "\n", + "Now we have a SCOPe data class with all the relevant parameters. Next, we need to generate the actual dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "11f2208e-fa40-44c9-bfe7-576ca23ad366", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checking for processed data in data\\SCOPe\\version_2.08\\SCOPe2000\\processed\n", + "Missing processed data file (`data.pkl` file)\n", + "Extracting class hierarchy...\n", + "Computing transitive closure\n", + "Process graph\n", + "101 labels has been selected for specified threshold, \n", + "Constructing data.pkl file .....\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Check for processed data in data\\SCOPe\\version_2.08\\SCOPe2000\\processed\\protein_token\n", + "Cross-validation enabled: False\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Missing transformed data (`data.pt` file). Transforming data.... \n", + "Processing 60298 lines...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████| 60298/60298 [00:53<00:00, 1119.10it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saving 21 tokens to G:\\github-aditya0by0\\python-chebai\\chebai\\preprocessing\\bin\\protein_token\\tokens.txt...\n", + "First 10 tokens: ['M', 'S', 'I', 'G', 'A', 'T', 'R', 'L', 'Q', 'N']\n" + ] + } + ], + "source": [ + "scope_class.prepare_data()\n", + "scope_class.setup()" + ] + }, + { + "cell_type": "markdown", + "id": "1655d489-25fe-46de-9feb-eeca5d36936f", + "metadata": {}, + "source": [ + "\n", + "### Automatic Execution: \n", + "These methods are executed automatically when using the training command `chebai fit`. Users do not need to call them explicitly, as the code internally manages the preparation and setup of data, ensuring that it is ready for subsequent use in training and validation processes.\n", + "\n", + "### Why is Preparation Needed?\n", + "\n", + "- **Data Availability**: The preparation step ensures that the required SCOPe data files are downloaded or loaded, which are essential for analysis.\n", + "- **Data Integrity**: It ensures that the data files are transformed into a compatible format required for model input.\n", + "\n", + "### Main Methods for Data Preprocessing\n", + "\n", + "The data preprocessing in a data class involves two main methods:\n", + "\n", + "1. **`prepare_data` Method**:\n", + " - **Purpose**: This method checks for the presence of raw data in the specified directory. If the raw data is missing, it fetches the ontology, creates a dataframe, and saves it to a file (`data.pkl`). The dataframe includes columns such as IDs, data representations, and labels. This step is independent of input encodings.\n", + " - **Documentation**: [PyTorch Lightning - `prepare_data`](https://lightning.ai/docs/pytorch/stable/data/datamodule.html#prepare-data)\n", + "\n", + "2. **`setup` Method**:\n", + " - **Purpose**: This method sets up the data module for training, validation, and testing. It checks for the processed data and, if necessary, performs additional setup to ensure the data is ready for model input. It also handles cross-validation settings if enabled.\n", + " - **Description**: Transforms `data.pkl` into a model input data format (`data.pt`), tokenizing the input according to the specified encoding. The transformed data contains the following keys: `ident`, `features`, `labels`, and `group`. This method uses a subclass of Data Reader to perform the tokenization.\n", + " - **Documentation**: [PyTorch Lightning - `setup`](https://lightning.ai/docs/pytorch/stable/data/datamodule.html#setup)\n", + "\n", + "These methods ensure that the data is correctly prepared and set up for subsequent use in training and validation processes." + ] + }, + { + "cell_type": "markdown", + "id": "f5aaa12d-5f01-4b74-8b59-72562af953bf", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "bb6e9a81554368f7", + "metadata": {}, + "source": [ + "# 3. Overview of the 3 preprocessing stages\n", + "\n", + "The `chebai` library follows a three-stage preprocessing pipeline, which is reflected in its file structure:\n", + "\n", + "1. **Raw Data Stage**:\n", + " - **Files**: `cla.txt`, `des.txt` and `hie.txt`. Please find description of each file [here](https://scop.berkeley.edu/help/ver=2.08#parseablefiles-2.08).\n", + " - **Description**: This stage contains the raw SCOPe data in txt format, serving as the initial input for further processing.\n", + " - **File Path**: `data/SCOPe/version_${scope_version}/raw/${filename}.txt`\n", + "\n", + "2. **Processed Data Stage 1**:\n", + " - **File**: `data.pkl`\n", + " - **Description**: This stage includes the data after initial processing. It contains protein sequence strings, class columns, and metadata but lacks data splits.\n", + " - **File Path**: `data/SCOPe/version_${scope_version}/${dataset_name}/processed/data.pkl`\n", + " - **Additional File**: `classes.txt` - A file listing the relevant SCOPe classes.\n", + "\n", + "3. **Processed Data Stage 2**:\n", + " - **File**: `data.pt`\n", + " - **Description**: This final stage includes the encoded data in a format compatible with PyTorch, ready for model input. This stage also references data splits when available.\n", + " - **File Path**: `data/SCOPe/version_${scope_version}/${dataset_name}/processed/${reader_name}/data.pt`\n", + " - **Additional File**: `splits.csv` - Contains saved splits for reproducibility.\n", + "\n", + "This structured approach to data management ensures that each stage of data processing is well-organized and documented, from raw data acquisition to the preparation of model-ready inputs. It also facilitates reproducibility and traceability across different experiments.\n", + "\n", + "### Data Splits\n", + "\n", + "- **Creation**: Data splits are generated dynamically \"on the fly\" during training and evaluation to ensure flexibility and adaptability to different tasks.\n", + "- **Reproducibility**: To maintain consistency across different runs, splits can be reproduced by comparing hashes with a fixed seed value.\n" + ] + }, + { + "cell_type": "markdown", + "id": "7e172c0d1e8bb93f", + "metadata": {}, + "source": [ + "# 4. Data Files and their structure\n", + "\n", + "`chebai` creates and manages several data files during its operation. These files store various chemical data and metadata essential for different tasks. Let’s explore these files and their content.\n" + ] + }, + { + "cell_type": "markdown", + "id": "43329709-5134-4ce5-88e7-edd2176bf84d", + "metadata": {}, + "source": [ + "## raw files\n", + "- cla.txt, des.txt and hie.txt\n", + "\n", + "For detailed description of raw files and their structures, please refer the official website [here](https://scop.berkeley.edu/help/ver=2.08#parseablefiles-2.08).\n" + ] + }, + { + "cell_type": "markdown", + "id": "558295e5a7ded456", + "metadata": {}, + "source": [ + "## data.pkl File\n", + "\n", + "**Description**: Generated by the `prepare_data` method, this file contains processed data in a dataframe format. It includes the ids, sids which are used to label corresponding sequence, protein-chain sequence, and columns for each label with boolean values." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "fd490270-59b8-4c1c-8b09-204defddf592", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:09:01.622317Z", + "start_time": "2024-10-05T21:09:01.606698Z" + } + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d7d16247-092c-4e8d-96c2-ab23931cf766", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:11:51.296162Z", + "start_time": "2024-10-05T21:11:44.559304Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Size of the data (rows x columns): (60298, 104)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idsidssequenceclass_46456class_48724class_51349class_53931class_56572class_56835class_56992...protein_56252protein_190144protein_310895protein_310894species_187221species_187920species_187294species_56254species_311502species_311501
01[d4oq9a_, d4oq9b_, d4oq9c_, d4oq9d_, d4niaa_, ...AAAAAAAAAAFalseTrueFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
12[d7dxhc_]AAAAAAAAAAAAAAAAAAAAAAAFalseFalseFalseFalseFalseTrueFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
23[d1gkub1, d1gkub2, d1gkub3, d1gkub4]AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASLCLFPEDFLLKEF...FalseFalseTrueFalseTrueFalseFalse...FalseFalseFalseTrueFalseFalseFalseFalseFalseTrue
34[d3c9wa2, d3c9wb2, d3c9wa3, d3c9wb3]AAAAAAGPEMVRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNLNKV...FalseFalseFalseTrueFalseFalseFalse...FalseFalseFalseTrueFalseFalseFalseFalseFalseTrue
45[d1xwaa1, d1xwab_, d1xwac_, d1xwad_, d1xwaa2]AAAAAMVYQVKDKADLDGQLTKASGKLVVLDFFATWCGPCKMISPK...FalseFalseTrueFalseFalseFalseFalse...FalseFalseFalseTrueFalseFalseFalseFalseFalseTrue
\n", + "

5 rows × 104 columns

\n", + "
" + ], + "text/plain": [ + " id sids \\\n", + "0 1 [d4oq9a_, d4oq9b_, d4oq9c_, d4oq9d_, d4niaa_, ... \n", + "1 2 [d7dxhc_] \n", + "2 3 [d1gkub1, d1gkub2, d1gkub3, d1gkub4] \n", + "3 4 [d3c9wa2, d3c9wb2, d3c9wa3, d3c9wb3] \n", + "4 5 [d1xwaa1, d1xwab_, d1xwac_, d1xwad_, d1xwaa2] \n", + "\n", + " sequence class_46456 \\\n", + "0 AAAAAAAAAA False \n", + "1 AAAAAAAAAAAAAAAAAAAAAAA False \n", + "2 AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASLCLFPEDFLLKEF... False \n", + "3 AAAAAAGPEMVRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNLNKV... False \n", + "4 AAAAAMVYQVKDKADLDGQLTKASGKLVVLDFFATWCGPCKMISPK... False \n", + "\n", + " class_48724 class_51349 class_53931 class_56572 class_56835 \\\n", + "0 True False False False False \n", + "1 False False False False True \n", + "2 False True False True False \n", + "3 False False True False False \n", + "4 False True False False False \n", + "\n", + " class_56992 ... protein_56252 protein_190144 protein_310895 \\\n", + "0 False ... False False False \n", + "1 False ... False False False \n", + "2 False ... False False False \n", + "3 False ... False False False \n", + "4 False ... False False False \n", + "\n", + " protein_310894 species_187221 species_187920 species_187294 \\\n", + "0 False False False False \n", + "1 False False False False \n", + "2 True False False False \n", + "3 True False False False \n", + "4 True False False False \n", + "\n", + " species_56254 species_311502 species_311501 \n", + "0 False False False \n", + "1 False False False \n", + "2 False False True \n", + "3 False False True \n", + "4 False False True \n", + "\n", + "[5 rows x 104 columns]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pkl_df = pd.DataFrame(\n", + " pd.read_pickle(\n", + " os.path.join(\n", + " scope_class.processed_dir_main,\n", + " scope_class.processed_main_file_names_dict[\"data\"],\n", + " )\n", + " )\n", + ")\n", + "print(\"Size of the data (rows x columns): \", pkl_df.shape)\n", + "pkl_df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "322bc926-69ff-4b93-9e95-5e8b85869c38", + "metadata": {}, + "source": [ + "**File Path**: `data/SCOPe/version_${scope_version}/${dataset_name}/processed/data.pkl`\n", + "\n", + "\n", + "### Structure of `data.pkl`\n", + "`data.pkl` as following structure: \n", + "- **Column 0**: Contains the ID of eachdata instance.\n", + "- **Column 1**: Contains the `sids` which are associated with corresponding protein-chain sequence.\n", + "- **Column 2**: Contains the protein-chain sequence.\n", + "- **Column 3 and onwards**: Contains the labels, starting from column 3.\n", + "\n", + "This structure ensures that the data is organized and ready for further processing, such as further encoding.\n" + ] + }, + { + "cell_type": "markdown", + "id": "ba019d2d4324bd0b", + "metadata": {}, + "source": [ + "## data.pt File\n", + "\n", + "\n", + "**Description**: Generated by the `setup` method, this file contains encoded data in a format compatible with the PyTorch library, specifically as a list of dictionaries. Each dictionary in this list includes keys such as `ident`, `features`, `labels`, and `group`, ready for model input." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "977ddd83-b469-4b58-ab1a-8574fb8769b4", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:12:49.338943Z", + "start_time": "2024-10-05T21:12:49.323319Z" + } + }, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "3266ade9-efdc-49fe-ae07-ed52b2eb52d0", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:14:12.892845Z", + "start_time": "2024-10-05T21:13:59.859953Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Type of loaded data: \n" + ] + } + ], + "source": [ + "data_pt = torch.load(\n", + " os.path.join(\n", + " scope_class.processed_dir, scope_class.processed_file_names_dict[\"data\"]\n", + " ),\n", + " weights_only=False,\n", + ")\n", + "print(\"Type of loaded data:\", type(data_pt))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "84cfa3e6-f60d-47c0-9f82-db3d5673d1e7", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:14:21.185027Z", + "start_time": "2024-10-05T21:14:21.169358Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'features': [14, 14, 14, 14, 14, 14, 14, 14, 14, 14], 'labels': array([False, True, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, True, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False]), 'ident': 1, 'group': None}\n" + ] + } + ], + "source": [ + "for i in range(1):\n", + " print(data_pt[i])" + ] + }, + { + "cell_type": "markdown", + "id": "0d80ffbb-5f1e-4489-9bc8-d688c9be1d07", + "metadata": {}, + "source": [ + "**File Path**: `data/SCOPe/version_${scope_version}/${dataset_name}/processed/${reader_name}/data.pt`\n", + "\n", + "\n", + "### Structure of `data.pt`\n", + "\n", + "The `data.pt` file is a list where each element is a dictionary with the following keys:\n", + "\n", + "- **`features`**: \n", + " - **Description**: This key holds the input features for the model. The features are typically stored as tensors and represent the attributes used by the model for training and evaluation.\n", + "\n", + "- **`labels`**: \n", + " - **Description**: This key contains the labels or target values associated with each instance. Labels are also stored as tensors and are used by the model to learn and make predictions.\n", + "\n", + "- **`ident`**: \n", + " - **Description**: This key holds identifiers for each data instance. These identifiers help track and reference the individual samples in the dataset.\n" + ] + }, + { + "cell_type": "markdown", + "id": "186ec6f0eed6ecf7", + "metadata": {}, + "source": [ + "## classes.txt File\n", + "\n", + "**Description**: A file containing the list of selected SCOPe classes based on the specified threshold. This file is crucial for ensuring that only relevant classes are included in the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "8d1fbe6c-beb8-4038-93d4-c56bc7628716", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:15:19.146285Z", + "start_time": "2024-10-05T21:15:18.503284Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "class_46456\n", + "class_48724\n", + "class_51349\n", + "class_53931\n", + "class_56572\n", + "class_56835\n", + "class_56992\n", + "class_57942\n", + "class_58117\n", + "class_58231\n", + "class_310555\n", + "fold_46457\n", + "fold_46688\n", + "fold_47239\n", + "fold_47363\n" + ] + } + ], + "source": [ + "with open(os.path.join(scope_class.processed_dir_main, \"classes.txt\"), \"r\") as file:\n", + " for i in range(15):\n", + " line = file.readline()\n", + " print(line.strip())" + ] + }, + { + "cell_type": "markdown", + "id": "861da1c3-0401-49f0-a22f-109814ed95d5", + "metadata": {}, + "source": [ + "\n", + "**File Path**: `data/SCOPe/version_${scope_version}/${dataset_name}/processed/classes.txt`\n", + "\n", + "The `classes.txt` file lists selected SCOPe classes. These classes are chosen based on a specified threshold, which is typically used for filtering or categorizing the dataset. Each line in the file corresponds to a unique SCOPe class ID, identifying specific class withing SCOPe ontology along with the hierarchy level.\n", + "\n", + "This file is essential for organizing the data and ensuring that only relevant classes, as defined by the threshold, are included in subsequent processing and analysis tasks.\n" + ] + }, + { + "cell_type": "markdown", + "id": "fb72be449e52b63f", + "metadata": {}, + "source": [ + "## splits.csv File\n", + "\n", + "**Description**: Contains saved data splits from previous runs. During subsequent runs, this file is used to reconstruct the train, validation, and test splits by filtering the encoded data (`data.pt`) based on the IDs stored in `splits.csv`." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "3ebdcae4-4344-46bd-8fc0-a82ef5d40da5", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:15:54.575116Z", + "start_time": "2024-10-05T21:15:53.945139Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idsplit
01train
12train
24train
35train
46train
\n", + "
" + ], + "text/plain": [ + " id split\n", + "0 1 train\n", + "1 2 train\n", + "2 4 train\n", + "3 5 train\n", + "4 6 train" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "csv_df = pd.read_csv(os.path.join(scope_class.processed_dir_main, \"splits.csv\"))\n", + "csv_df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "b058714f-e434-4367-89b9-74c129ac727f", + "metadata": {}, + "source": [ + "\n", + "\n", + "**File Path**: `data/SCOPe/version_${scope_version}/${dataset_name}/processed/splits.csv`\n", + "\n", + "The `splits.csv` file contains the saved data splits from previous runs, including the train, validation, and test sets. During subsequent runs, this file is used to reconstruct these splits by filtering the encoded data (`data.pt`) based on the IDs stored in `splits.csv`. This ensures consistency and reproducibility in data splitting, allowing for reliable evaluation and comparison of model performance across different run.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "6dc3fd6c-7cf6-47ef-812f-54319a0cdeb9", + "metadata": {}, + "outputs": [], + "source": [ + "# You can specify a literal path for the `splits_file_path`, or if another `scope_class` instance is already defined,\n", + "# you can use its existing `splits_file_path` attribute for consistency.\n", + "scope_class_with_splits = SCOPeOver2000(\n", + " scope_version=\"2.08\",\n", + " # splits_file_path=\"data/chebi_v231/ChEBI50/processed/splits.csv\", # Literal path option\n", + " splits_file_path=scope_class.splits_file_path, # Use path from an existing `chebi_class` instance\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a5eb482c-ce5b-4efc-b2ec-85ac7b1a78ee", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "ab110764-216d-4d52-a9d1-4412c8ac8c9d", + "metadata": {}, + "source": [ + "## 5.1 Protein Representation Using Amino Acid Sequence Notation\n", + "\n", + "Proteins are composed of chains of amino acids, and these sequences can be represented using a one-letter notation for each amino acid. This notation provides a concise way to describe the primary structure of a protein.\n", + "\n", + "### Example Protein Sequence\n", + "\n", + "Protein-Chain: PDB ID:**1cph** Chain ID:**B** mol:protein length:30 INSULIN (PH 10)\n", + "
Refer - [1cph_B](https://www.rcsb.org/sequence/1CPH)\n", + "\n", + "- **Sequence**: `FVNQHLCGSHLVEALYLVCGERGFFYTPKA`\n", + "- **Sequence Length**: 30\n", + "\n", + "In this sequence, each letter corresponds to a specific amino acid. This notation is widely used in bioinformatics and molecular biology to represent protein sequences.\n", + "\n", + "### Tokenization and Encoding\n", + "\n", + "To tokenize and numerically encode this protein sequence, the `ProteinDataReader` class is used. This class allows for n-gram tokenization, where the `n_gram` parameter defines the size of the tokenized units. If `n_gram` is not provided (default is `None`), each amino acid letter is treated as a single token.\n", + "\n", + "For more details, you can explore the implementation of the `ProteinDataReader` class in the source code [here](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/reader.py)." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "da47d47e-4560-46af-b246-235596f27d82", + "metadata": {}, + "outputs": [], + "source": [ + "from chebai.preprocessing.reader import ProteinDataReader" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "8bdbf309-29ec-4aab-a6dc-9e09bc6961a2", + "metadata": {}, + "outputs": [], + "source": [ + "protein_dr_3gram = ProteinDataReader(n_gram=3)\n", + "protein_dr = ProteinDataReader()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "68e5c87c-79c3-4d5f-91e6-635399a84d3d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[25, 28, 19, 18, 29, 17, 24, 13, 11, 29, 17, 28, 27, 14, 17, 22, 17, 28, 24, 13, 27, 16, 13, 25, 25, 22, 15, 23, 21, 14]\n", + "[5023, 2218, 3799, 2290, 6139, 2208, 6917, 4674, 484, 439, 2737, 851, 365, 2624, 3240, 4655, 1904, 3737, 1453, 2659, 5160, 3027, 2355, 7163, 4328, 3115, 6207, 1234]\n" + ] + } + ], + "source": [ + "protein = \"FVNQHLCGSHLVEALYLVCGERGFFYTPKA\"\n", + "print(protein_dr._read_data(protein))\n", + "print(protein_dr_3gram._read_data(protein))" + ] + }, + { + "cell_type": "markdown", + "id": "5b7211ee-2ccc-46d3-8e8f-790f344726ba", + "metadata": {}, + "source": [ + "The numbers mentioned above refer to the index of each individual token from the [`tokens.txt`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/bin/protein_token/tokens.txt) file, which is used by the `ProteinDataReader` class. \n", + "\n", + "Each token in the `tokens.txt` file corresponds to a specific amino-acid letter, and these tokens are referenced by their index. Additionally, the index values are offset by the `EMBEDDING_OFFSET`, ensuring that the token embeddings are adjusted appropriately during processing." + ] + }, + { + "cell_type": "markdown", + "id": "93e328cf-09f9-4694-b175-28320590937d", + "metadata": {}, + "source": [ + "---" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python (env_chebai)", + "language": "python", + "name": "env_chebai" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From eba04175d211f4d9ad005bc452f1e7218f66cbc9 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 21 Feb 2025 17:01:43 +0100 Subject: [PATCH 55/71] scope: update tutorial --- tutorials/data_exploration_scope.ipynb | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tutorials/data_exploration_scope.ipynb b/tutorials/data_exploration_scope.ipynb index f8d053e5..a68697ae 100644 --- a/tutorials/data_exploration_scope.ipynb +++ b/tutorials/data_exploration_scope.ipynb @@ -171,6 +171,29 @@ "text": [ "Checking for processed data in data\\SCOPe\\version_2.08\\SCOPe2000\\processed\n", "Missing processed data file (`data.pkl` file)\n", + "Missing PDB raw data, Downloading PDB sequence data....\n", + "Downloading to temporary file C:\\Users\\HP\\AppData\\Local\\Temp\\tmpsif7r129\n", + "Downloaded to C:\\Users\\HP\\AppData\\Local\\Temp\\tmpsif7r129\n", + "Unzipping the file....\n", + "Unpacked and saved to data\\SCOPe\\pdb_sequences.txt\n", + "Removed temporary file C:\\Users\\HP\\AppData\\Local\\Temp\\tmpsif7r129\n", + "Missing Scope: cla.txt raw data, Downloading...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "G:\\anaconda3\\envs\\env_chebai\\lib\\site-packages\\urllib3\\connectionpool.py:1099: InsecureRequestWarning: Unverified HTTPS request is being made to host 'scop.berkeley.edu'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#tls-warnings\n", + "warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Missing Scope: hie.txt raw data, Downloading...\n", + "Missing Scope: des.txt raw data, Downloading...\n", "Extracting class hierarchy...\n", "Computing transitive closure\n", "Process graph\n", From dad6f76aad685a6ab18a4328be0041842a0d62de Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 21 Feb 2025 18:02:22 +0100 Subject: [PATCH 56/71] scope: add more scope details to tutorial --- tutorials/data_exploration_scope.ipynb | 116 ++++++++++++++++++++++--- 1 file changed, 103 insertions(+), 13 deletions(-) diff --git a/tutorials/data_exploration_scope.ipynb b/tutorials/data_exploration_scope.ipynb index a68697ae..2858e43a 100644 --- a/tutorials/data_exploration_scope.ipynb +++ b/tutorials/data_exploration_scope.ipynb @@ -17,6 +17,96 @@ "---\n" ] }, + { + "cell_type": "markdown", + "id": "f6c25706-251c-438c-9915-e8002647eb94", + "metadata": {}, + "source": [ + "### Understanding [SCOPe](https://scop.berkeley.edu/) and [PDB](https://www.rcsb.org/) \n", + "\n", + "\n", + "1. **Protein domains form chains.** \n", + "2. **Chains form complexes** (protein complexes or structures). \n", + "3. These **complexes are the entries in PDB**, represented by unique identifiers like `\"1A3N\"`. \n", + "\n", + "---\n", + "\n", + "#### **Protein Domain** \n", + "A **protein domain** is a **structural and functional unit** of a protein. \n", + "\n", + "\n", + "##### Key Characteristics:\n", + "- **Domains are part of a protein chain.** \n", + "- A domain can span: \n", + " 1. **The entire chain** (single-domain protein): \n", + " - In this case, the protein domain is equivalent to the chain itself. \n", + " - Example: \n", + " - All chains of the **PDB structure \"1A3N\"** are single-domain proteins. \n", + " - Each chain has a SCOPe domain identifier. \n", + " - For example, Chain **A**: \n", + " - Domain identifier: `d1a3na_` \n", + " - Breakdown of the identifier: \n", + " - `d`: Denotes domain. \n", + " - `1a3n`: Refers to the PDB protein structure identifier. \n", + " - `a`: Specifies the chain within the structure. (`_` for None and `.` for multiple chains)\n", + " - `_`: Indicates the domain spans the entire chain (single-domain protein). \n", + " - Example: [PDB Structure 1A3N - Chain A](https://www.rcsb.org/sequence/1A3N#A)\n", + " 2. **A specific portion of the chain** (multi-domain protein): \n", + " - Here, a single chain contains multiple domains. \n", + " - Example: Chain **A** of the **PDB structure \"1PKN\"** contains three domains: `d1pkna1`, `d1pkna2`, `d1pkna3`. \n", + " - Example: [PDB Structure 1PKN - Chain A](https://www.rcsb.org/annotations/1PKN). \n", + "\n", + "---\n", + "\n", + "#### **Protein Chain** \n", + "A **protein chain** refers to the entire **polypeptide chain** observed in a protein's 3D structure (as described in PDB files). \n", + "\n", + "##### Key Points:\n", + "- A chain can consist of **one or multiple domains**:\n", + " - **Single-domain chain**: The chain and domain are identical. \n", + " - Example: Myoglobin. \n", + " - **Multi-domain chain**: Contains several domains, each with distinct structural and functional roles. \n", + "- Chains assemble to form **protein complexes** or **structures**. \n", + "\n", + "\n", + "---\n", + "\n", + "#### **Key Observations About SCOPe** \n", + "- The **fundamental classification unit** in SCOPe is the **protein domain**, not the entire protein. \n", + "- _**The taxonomy in SCOPe is not for the entire protein (i.e., the full-length amino acid sequence as encoded by a gene) but for protein domains, which are smaller, structurally and functionally distinct regions of the protein.**_\n", + "\n", + "\n", + "--- \n", + "\n", + "**SCOPe 2.08 Data Analysis:**\n", + "\n", + "The current SCOPe version (2.08) includes the following statistics based on analysis for relevant data:\n", + "\n", + "- **Classes**: 12\n", + "- **Folds**: 1485\n", + "- **Superfamilies**: 2368\n", + "- **Families**: 5431\n", + "- **Proteins**: 13,514\n", + "- **Species**: 30,294\n", + "- **Domains**: 344,851\n", + "\n", + "For more detailed statistics, please refer to the official SCOPe website:\n", + "\n", + "- [SCOPe 2.08 Statistics](https://scop.berkeley.edu/statistics/ver=2.08)\n", + "- [SCOPe 2.08 Release](https://scop.berkeley.edu/ver=2.08)\n", + "\n", + "---\n", + "\n", + "## SCOPe Labeling \n", + "\n", + "- Use SCOPe labels for protein domains.\n", + "- Map them back to their **protein-chain** sequences (protein sequence label = sum of all domain labels).\n", + "- Train on protein sequences.\n", + "- This pretraining task would be comparable to GO-based training.\n", + "\n", + "--- " + ] + }, { "cell_type": "code", "execution_count": 1, @@ -171,25 +261,25 @@ "text": [ "Checking for processed data in data\\SCOPe\\version_2.08\\SCOPe2000\\processed\n", "Missing processed data file (`data.pkl` file)\n", - "Missing PDB raw data, Downloading PDB sequence data....\n", - "Downloading to temporary file C:\\Users\\HP\\AppData\\Local\\Temp\\tmpsif7r129\n", - "Downloaded to C:\\Users\\HP\\AppData\\Local\\Temp\\tmpsif7r129\n", - "Unzipping the file....\n", + "Missing PDB raw data, Downloading PDB sequence data....\n", + "Downloading to temporary file C:\\Users\\HP\\AppData\\Local\\Temp\\tmpsif7r129\n", + "Downloaded to C:\\Users\\HP\\AppData\\Local\\Temp\\tmpsif7r129\n", + "Unzipping the file....\n", "Unpacked and saved to data\\SCOPe\\pdb_sequences.txt\n", "Removed temporary file C:\\Users\\HP\\AppData\\Local\\Temp\\tmpsif7r129\n", - "Missing Scope: cla.txt raw data, Downloading...\n" - ] - }, - { - "name": "stderr", + "Missing Scope: cla.txt raw data, Downloading...\n" + ] + }, + { + "name": "stderr", "output_type": "stream", "text": [ "G:\\anaconda3\\envs\\env_chebai\\lib\\site-packages\\urllib3\\connectionpool.py:1099: InsecureRequestWarning: Unverified HTTPS request is being made to host 'scop.berkeley.edu'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#tls-warnings\n", "warnings.warn(\n" - ] - }, - { - "name": "stdout", + ] + }, + { + "name": "stdout", "output_type": "stream", "text": [ "Missing Scope: hie.txt raw data, Downloading...\n", From fd6dd012d205986f253288ca2e794add9ccf6260 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 23 Feb 2025 20:00:47 +0100 Subject: [PATCH 57/71] minor changes: deepgo configs + scope --- chebai/preprocessing/datasets/scope/scope.py | 2 +- configs/data/deepGO/deepgo2_esm2.yml | 2 +- configs/data/deepGO/deepgo_2_migrated_data.yml | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py index 6dff3156..334ff20e 100644 --- a/chebai/preprocessing/datasets/scope/scope.py +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -128,7 +128,7 @@ def _download_pdb_sequence_data(self) -> None: os.makedirs(os.path.dirname(pdb_seq_file_path), exist_ok=True) if not os.path.isfile(pdb_seq_file_path): - print(f"Downloading PDB sequence data....") + print(f"Missing PDB raw data, Downloading PDB sequence data....") # Create a temporary file with NamedTemporaryFile(delete=False) as tf: diff --git a/configs/data/deepGO/deepgo2_esm2.yml b/configs/data/deepGO/deepgo2_esm2.yml index 4b3ae3b1..5a0436e3 100644 --- a/configs/data/deepGO/deepgo2_esm2.yml +++ b/configs/data/deepGO/deepgo2_esm2.yml @@ -2,4 +2,4 @@ class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO2MigratedData init_args: go_branch: "MF" max_sequence_length: 1000 - use_esm2_embeddings: True \ No newline at end of file + use_esm2_embeddings: True diff --git a/configs/data/deepGO/deepgo_2_migrated_data.yml b/configs/data/deepGO/deepgo_2_migrated_data.yml index 1ed2ad09..5a0436e3 100644 --- a/configs/data/deepGO/deepgo_2_migrated_data.yml +++ b/configs/data/deepGO/deepgo_2_migrated_data.yml @@ -2,3 +2,4 @@ class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO2MigratedData init_args: go_branch: "MF" max_sequence_length: 1000 + use_esm2_embeddings: True From 4a8f8214b7144aa35c4d389ae48d1d06d6d475f1 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Feb 2025 18:08:54 +0100 Subject: [PATCH 58/71] deepgo2 migration: exp_annoations not needed - prop annotations has both direct and transitive annotations --- .../migration/deep_go/migrate_deep_go_2_data.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py index 68d7dc78..d23247c0 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py @@ -213,10 +213,8 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame: "proteins", "accessions", "sequences", - # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L45-L58 - "exp_annotations", # Directly associated GO ids # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L60-L69 - "prop_annotations", # Transitively associated GO ids + "prop_annotations", # Direct and Transitively associated GO ids "esm2", ] @@ -228,10 +226,8 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame: ], ignore_index=True, ) - new_df["go_ids"] = new_df.apply( - lambda row: self.extract_go_id(row["exp_annotations"]) - + self.extract_go_id(row["prop_annotations"]), - axis=1, + new_df["go_ids"] = new_df["prop_annotations"].apply( + lambda x: self.extract_go_id(x) ) data_df = pd.DataFrame( @@ -270,7 +266,7 @@ def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame: """ print("Generating labels based on terms.pkl file.......") parsed_go_ids: pd.Series = self._terms_df["gos"].apply( - lambda gos: DeepGO2MigratedData._parse_go_id(gos) + DeepGO2MigratedData._parse_go_id ) all_go_ids_list = parsed_go_ids.values.tolist() self._classes = all_go_ids_list From 1c432de3e953fe012d005c9b42166d2696e71155 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 3 Mar 2025 17:17:24 +0100 Subject: [PATCH 59/71] fix scope version in scope50.yml --- configs/data/scope/scope50.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/data/scope/scope50.yml b/configs/data/scope/scope50.yml index 8e64ccf5..c65028e2 100644 --- a/configs/data/scope/scope50.yml +++ b/configs/data/scope/scope50.yml @@ -1,3 +1,3 @@ class_path: chebai.preprocessing.datasets.scope.scope.SCOPeOver50 init_args: - scope_version: 2.08 + scope_version: "2.08" \ No newline at end of file From f13e9352a8cb4194725fde96e64a77eb829a3908 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 4 Mar 2025 09:51:20 +0100 Subject: [PATCH 60/71] modify notebook introduction --- tutorials/data_exploration_scope.ipynb | 247 ++++++++++++------------- 1 file changed, 116 insertions(+), 131 deletions(-) diff --git a/tutorials/data_exploration_scope.ipynb b/tutorials/data_exploration_scope.ipynb index 2858e43a..a67251cc 100644 --- a/tutorials/data_exploration_scope.ipynb +++ b/tutorials/data_exploration_scope.ipynb @@ -12,7 +12,7 @@ "- how the data is processed and stored\n", "- and how to work with different molecule encodings.\n", "\n", - "The chebai package simplifies the handling of these datasets by **automatically creating** them as needed. This means that you do not have to input any data manually; the package will generate and organize the data files based on the parameters and encodings selected. This feature ensures that the right data is available and formatted properly. You can however provide your own data files, for instance if you want to replicate a specific experiment.\n", + "The `chebai` package simplifies the handling of these datasets by **automatically downloading and processing** them as needed. This means that you do not have to input any data manually; the package will generate and organize the data files based on the parameters and encodings selected. You can however provide your own data files, for instance if you want to replicate a specific experiment.\n", "\n", "---\n" ] @@ -117,7 +117,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Changed to project root directory: G:\\github-aditya0by0\\python-chebai\n" + "Changed to project root directory: c:\\Users\\sifluegel\\PycharmProjects\\python-chebai\n" ] } ], @@ -159,7 +159,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "G:\\anaconda3\\envs\\env_chebai\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "c:\\Users\\sifluegel\\PycharmProjects\\python-chebai\\venv312c\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } @@ -433,7 +433,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "fd490270-59b8-4c1c-8b09-204defddf592", "metadata": { "ExecuteTime": { @@ -449,7 +449,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "id": "d7d16247-092c-4e8d-96c2-ab23931cf766", "metadata": { "ExecuteTime": { @@ -462,7 +462,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Size of the data (rows x columns): (60298, 104)\n" + "Size of the data (rows x columns): (14210, 31)\n" ] }, { @@ -489,105 +489,105 @@ " id\n", " sids\n", " sequence\n", - " class_46456\n", " class_48724\n", - " class_51349\n", " class_53931\n", - " class_56572\n", - " class_56835\n", - " class_56992\n", + " class_310555\n", + " fold_48725\n", + " fold_56111\n", + " fold_56234\n", + " fold_310573\n", " ...\n", - " protein_56252\n", - " protein_190144\n", - " protein_310895\n", + " protein_190417\n", + " protein_190740\n", " protein_310894\n", + " protein_310895\n", + " species_56254\n", " species_187221\n", - " species_187920\n", " species_187294\n", - " species_56254\n", - " species_311502\n", + " species_187920\n", " species_311501\n", + " species_311502\n", " \n", " \n", " \n", " \n", " 0\n", " 1\n", - " [d4oq9a_, d4oq9b_, d4oq9c_, d4oq9d_, d4niaa_, ...\n", - " AAAAAAAAAA\n", - " False\n", + " [d6vi2a2, d6vi2c2, d6vi2a1, d6vi2c1]\n", + " SDIQMTQSPSSLSASVGDRVTITCRASQSVSSAVAWYQQKPGKAPK...\n", " True\n", " False\n", " False\n", + " True\n", " False\n", " False\n", " False\n", " ...\n", " False\n", + " True\n", " False\n", " False\n", " False\n", + " True\n", " False\n", - " False\n", - " False\n", - " False\n", + " True\n", " False\n", " False\n", " \n", " \n", " 1\n", " 2\n", - " [d7dxhc_]\n", - " AAAAAAAAAAAAAAAAAAAAAAA\n", - " False\n", - " False\n", - " False\n", + " [d4nzul2, d4nzul1]\n", + " DIEMTQSPSSLSASTGDKVTITCQASQDIAKFLDWYQQRPGKTPKL...\n", + " True\n", " False\n", " False\n", " True\n", " False\n", - " ...\n", - " False\n", " False\n", " False\n", + " ...\n", " False\n", + " True\n", " False\n", " False\n", " False\n", + " True\n", " False\n", + " True\n", " False\n", " False\n", " \n", " \n", " 2\n", " 3\n", - " [d1gkub1, d1gkub2, d1gkub3, d1gkub4]\n", - " AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASLCLFPEDFLLKEF...\n", + " [d3tv3l2, d6b3dl2, d3tv3l1, d6b3dl1]\n", + " QSALTQPPSASGSPGQSITISCTGTSNNFVSWYQQHAGKAPKLVIY...\n", + " True\n", " False\n", " False\n", " True\n", " False\n", - " True\n", " False\n", " False\n", " ...\n", " False\n", - " False\n", - " False\n", " True\n", " False\n", " False\n", " False\n", - " False\n", + " True\n", " False\n", " True\n", + " False\n", + " False\n", " \n", " \n", " 3\n", " 4\n", - " [d3c9wa2, d3c9wb2, d3c9wa3, d3c9wb3]\n", - " AAAAAAGPEMVRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNLNKV...\n", - " False\n", + " [d2nw2a2, d2nw2a1]\n", + " QNIDQPTEMTATEGAIVQINCTYQTSGFNGLFWYQQHAGEAPTFLS...\n", + " True\n", " False\n", " False\n", " True\n", @@ -596,92 +596,85 @@ " False\n", " ...\n", " False\n", - " False\n", - " False\n", " True\n", " False\n", " False\n", " False\n", - " False\n", + " True\n", " False\n", " True\n", + " False\n", + " False\n", " \n", " \n", " 4\n", " 5\n", - " [d1xwaa1, d1xwab_, d1xwac_, d1xwad_, d1xwaa2]\n", - " AAAAAMVYQVKDKADLDGQLTKASGKLVVLDFFATWCGPCKMISPK...\n", + " [d7k3ql2, d7r6xd_]\n", + " DIVLTQTPGTLSLSPGERATLSCRASQSVSSSYLAWYQQKPGQAPR...\n", + " True\n", " False\n", " False\n", " True\n", " False\n", " False\n", " False\n", - " False\n", " ...\n", " False\n", - " False\n", - " False\n", " True\n", " False\n", " False\n", " False\n", - " False\n", + " True\n", " False\n", " True\n", + " False\n", + " False\n", " \n", " \n", "\n", - "

5 rows × 104 columns

\n", + "

5 rows × 31 columns

\n", "" ], "text/plain": [ - " id sids \\\n", - "0 1 [d4oq9a_, d4oq9b_, d4oq9c_, d4oq9d_, d4niaa_, ... \n", - "1 2 [d7dxhc_] \n", - "2 3 [d1gkub1, d1gkub2, d1gkub3, d1gkub4] \n", - "3 4 [d3c9wa2, d3c9wb2, d3c9wa3, d3c9wb3] \n", - "4 5 [d1xwaa1, d1xwab_, d1xwac_, d1xwad_, d1xwaa2] \n", + " id sids \\\n", + "0 1 [d6vi2a2, d6vi2c2, d6vi2a1, d6vi2c1] \n", + "1 2 [d4nzul2, d4nzul1] \n", + "2 3 [d3tv3l2, d6b3dl2, d3tv3l1, d6b3dl1] \n", + "3 4 [d2nw2a2, d2nw2a1] \n", + "4 5 [d7k3ql2, d7r6xd_] \n", "\n", - " sequence class_46456 \\\n", - "0 AAAAAAAAAA False \n", - "1 AAAAAAAAAAAAAAAAAAAAAAA False \n", - "2 AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASLCLFPEDFLLKEF... False \n", - "3 AAAAAAGPEMVRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNLNKV... False \n", - "4 AAAAAMVYQVKDKADLDGQLTKASGKLVVLDFFATWCGPCKMISPK... False \n", + " sequence class_48724 class_53931 \\\n", + "0 SDIQMTQSPSSLSASVGDRVTITCRASQSVSSAVAWYQQKPGKAPK... True False \n", + "1 DIEMTQSPSSLSASTGDKVTITCQASQDIAKFLDWYQQRPGKTPKL... True False \n", + "2 QSALTQPPSASGSPGQSITISCTGTSNNFVSWYQQHAGKAPKLVIY... True False \n", + "3 QNIDQPTEMTATEGAIVQINCTYQTSGFNGLFWYQQHAGEAPTFLS... True False \n", + "4 DIVLTQTPGTLSLSPGERATLSCRASQSVSSSYLAWYQQKPGQAPR... True False \n", "\n", - " class_48724 class_51349 class_53931 class_56572 class_56835 \\\n", - "0 True False False False False \n", - "1 False False False False True \n", - "2 False True False True False \n", - "3 False False True False False \n", - "4 False True False False False \n", + " class_310555 fold_48725 fold_56111 fold_56234 fold_310573 ... \\\n", + "0 False True False False False ... \n", + "1 False True False False False ... \n", + "2 False True False False False ... \n", + "3 False True False False False ... \n", + "4 False True False False False ... \n", "\n", - " class_56992 ... protein_56252 protein_190144 protein_310895 \\\n", - "0 False ... False False False \n", - "1 False ... False False False \n", - "2 False ... False False False \n", - "3 False ... False False False \n", - "4 False ... False False False \n", + " protein_190417 protein_190740 protein_310894 protein_310895 species_56254 \\\n", + "0 False True False False False \n", + "1 False True False False False \n", + "2 False True False False False \n", + "3 False True False False False \n", + "4 False True False False False \n", "\n", - " protein_310894 species_187221 species_187920 species_187294 \\\n", - "0 False False False False \n", - "1 False False False False \n", - "2 True False False False \n", - "3 True False False False \n", - "4 True False False False \n", + " species_187221 species_187294 species_187920 species_311501 species_311502 \n", + "0 True False True False False \n", + "1 True False True False False \n", + "2 True False True False False \n", + "3 True False True False False \n", + "4 True False True False False \n", "\n", - " species_56254 species_311502 species_311501 \n", - "0 False False False \n", - "1 False False False \n", - "2 False False True \n", - "3 False False True \n", - "4 False False True \n", - "\n", - "[5 rows x 104 columns]" + "[5 rows x 31 columns]" ] }, - "execution_count": 7, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -730,7 +723,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 6, "id": "977ddd83-b469-4b58-ab1a-8574fb8769b4", "metadata": { "ExecuteTime": { @@ -745,7 +738,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 7, "id": "3266ade9-efdc-49fe-ae07-ed52b2eb52d0", "metadata": { "ExecuteTime": { @@ -774,7 +767,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 8, "id": "84cfa3e6-f60d-47c0-9f82-db3d5673d1e7", "metadata": { "ExecuteTime": { @@ -787,18 +780,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "{'features': [14, 14, 14, 14, 14, 14, 14, 14, 14, 14], 'labels': array([False, True, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, False, False,\n", - " False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False,\n", - " False, False]), 'ident': 1, 'group': None}\n" + "{'features': [11, 20, 12, 18, 10, 15, 18, 11, 23, 11, 11, 17, 11, 14, 11, 28, 13, 20, 16, 28, 15, 12, 15, 24, 16, 14, 11, 18, 11, 28, 11, 11, 14, 28, 14, 26, 22, 18, 18, 21, 23, 13, 21, 14, 23, 21, 17, 17, 12, 22, 11, 14, 11, 11, 17, 22, 11, 13, 28, 23, 11, 16, 25, 11, 13, 11, 16, 11, 13, 15, 20, 25, 15, 17, 15, 12, 11, 11, 17, 18, 23, 27, 20, 25, 14, 15, 22, 22, 24, 18, 18, 11, 28, 11, 22, 10, 13, 23, 17, 15, 25, 13, 18, 13, 15, 21, 28, 27, 12, 21, 16, 15, 28, 14, 14, 23, 11, 28, 25, 12, 25, 23, 23, 11, 20, 11, 18, 17, 21, 11, 13, 15, 14, 11, 28, 28, 24, 17, 17, 19, 19, 25, 22, 23, 16, 27, 14, 21, 28, 18, 26, 21, 28, 20, 19, 14, 17, 18, 11, 13, 19, 11, 18, 27, 11, 28, 15, 27, 18, 20, 11, 21, 20, 11, 15, 22, 11, 17, 11, 11, 15, 17, 15, 17, 11, 21, 14, 20, 22, 27, 21, 29, 21, 28, 22, 14, 24, 27, 28, 15, 29, 18, 13, 17, 11, 11, 23, 28, 15, 21, 11, 25, 19, 16, 13], 'labels': array([ True, False, False, True, False, False, False, True, False,\n", + " False, False, True, False, False, True, False, False, True,\n", + " False, True, False, False, False, True, False, True, False,\n", + " False]), 'ident': 1, 'group': None}\n" ] } ], @@ -841,7 +826,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 9, "id": "8d1fbe6c-beb8-4038-93d4-c56bc7628716", "metadata": { "ExecuteTime": { @@ -854,21 +839,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "class_46456\n", "class_48724\n", - "class_51349\n", "class_53931\n", - "class_56572\n", - "class_56835\n", - "class_56992\n", - "class_57942\n", - "class_58117\n", - "class_58231\n", "class_310555\n", - "fold_46457\n", - "fold_46688\n", - "fold_47239\n", - "fold_47363\n" + "fold_48725\n", + "fold_56111\n", + "fold_56234\n", + "fold_310573\n", + "superfamily_48726\n", + "superfamily_56112\n", + "superfamily_56235\n", + "superfamily_310607\n", + "family_48942\n", + "family_56251\n", + "family_191359\n", + "family_191470\n" ] } ], @@ -904,7 +889,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 10, "id": "3ebdcae4-4344-46bd-8fc0-a82ef5d40da5", "metadata": { "ExecuteTime": { @@ -946,7 +931,7 @@ " \n", " \n", " 1\n", - " 2\n", + " 3\n", " train\n", " \n", " \n", @@ -956,12 +941,12 @@ " \n", " \n", " 3\n", - " 5\n", + " 6\n", " train\n", " \n", " \n", " 4\n", - " 6\n", + " 9\n", " train\n", " \n", " \n", @@ -971,13 +956,13 @@ "text/plain": [ " id split\n", "0 1 train\n", - "1 2 train\n", + "1 3 train\n", "2 4 train\n", - "3 5 train\n", - "4 6 train" + "3 6 train\n", + "4 9 train" ] }, - "execution_count": 19, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -1001,7 +986,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 11, "id": "6dc3fd6c-7cf6-47ef-812f-54319a0cdeb9", "metadata": {}, "outputs": [], @@ -1051,7 +1036,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 12, "id": "da47d47e-4560-46af-b246-235596f27d82", "metadata": {}, "outputs": [], @@ -1061,7 +1046,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 13, "id": "8bdbf309-29ec-4aab-a6dc-9e09bc6961a2", "metadata": {}, "outputs": [], @@ -1072,7 +1057,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 14, "id": "68e5c87c-79c3-4d5f-91e6-635399a84d3d", "metadata": {}, "outputs": [ @@ -1112,9 +1097,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python (env_chebai)", + "display_name": "venv312c", "language": "python", - "name": "env_chebai" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -1126,7 +1111,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.8" } }, "nbformat": 4, From 36e6162ce6898a8b363567a86ed6f81c6b868c15 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 5 Mar 2025 13:50:59 +0100 Subject: [PATCH 61/71] ffn: fix error for loss kwargs --- chebai/models/ffn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/models/ffn.py b/chebai/models/ffn.py index cd32086e..c9c6f912 100644 --- a/chebai/models/ffn.py +++ b/chebai/models/ffn.py @@ -37,7 +37,7 @@ def _get_prediction_and_labels(self, data, labels, model_output): loss_kwargs = data.get("loss_kwargs", dict()) if "non_null_labels" in loss_kwargs: n = loss_kwargs["non_null_labels"] - d = data[n] + d = d[n] return torch.sigmoid(d), labels.int() if labels is not None else None def _process_for_loss( From 93c7fc5080388ed0f278fb63c5f6604a36fd4318 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 9 Mar 2025 00:54:56 +0100 Subject: [PATCH 62/71] scope: fix for no True labels for some classes/columns --- chebai/preprocessing/datasets/scope/scope.py | 64 +++++++++++++++----- 1 file changed, 48 insertions(+), 16 deletions(-) diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py index 334ff20e..fe9b5e33 100644 --- a/chebai/preprocessing/datasets/scope/scope.py +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -195,22 +195,52 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: """ print("Extracting class hierarchy...") df_scope = self._get_scope_data() + pdb_chain_df = self._parse_pdb_sequence_file() + pdb_id_set = set(pdb_chain_df["pdb_id"]) # Search time complexity - O(1) g = nx.DiGraph() - egdes = [] - for _, row in df_scope.iterrows(): - g.add_node(row["sunid"], **{"sid": row["sid"], "level": row["level"]}) - if row["parent_sunid"] != -1: - egdes.append((row["parent_sunid"], row["sunid"])) + edges = [] + node_attrs = {} + px_level_nodes = set() + + # Step 1: Build the graph and store attributes + for row in df_scope.itertuples(index=False): + if row.level == "px": + if row.sid[1:5] not in pdb_id_set: + # Don't add domain level nodes that don't have pdb_id in pdb_sequences.txt file + continue + px_level_nodes.add(row.sunid) + + node_attrs[row.sunid] = {"sid": row.sid, "level": row.level} + + if row.parent_sunid != -1: + edges.append((row.parent_sunid, row.sunid)) - for children_id in row["children_sunids"]: - egdes.append((row["sunid"], children_id)) + for child_id in row.children_sunids: + edges.append((row.sunid, child_id)) - g.add_edges_from(egdes) + g.add_nodes_from((node, attrs) for node, attrs in node_attrs.items()) + g.add_edges_from(edges) + # Step 2: Compute the transitive closure first print("Computing transitive closure") - return nx.transitive_closure_dag(g) + g_tc = nx.transitive_closure_dag(g) + + print( + "Remove node without domain descendants that don't have pdb correspondence" + ) + # Step 3: Identify and remove nodes that don’t have a "px" descendant with correspondence to pdb_sequences file + nodes_to_remove = set() + for node in g_tc.nodes: + if node not in px_level_nodes and not any( + desc in px_level_nodes for desc in g_tc.successors(node) + ): + nodes_to_remove.add(node) + + g_tc.remove_nodes_from(nodes_to_remove) + + return g_tc def _get_scope_data(self) -> pd.DataFrame: """ @@ -388,7 +418,8 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: encoded_target_columns = [] for level in hierarchy_levels: - encoded_target_columns.extend(lvl_to_target_cols_mapping[level]) + if level in lvl_to_target_cols_mapping: + encoded_target_columns.extend(lvl_to_target_cols_mapping[level]) print( f"{len(encoded_target_columns)} labels has been selected for specified threshold, " @@ -471,12 +502,12 @@ def _parse_pdb_sequence_file(self) -> pd.DataFrame: for record in SeqIO.parse( os.path.join(self.scope_root_dir, self.raw_file_names_dict["PDB"]), "fasta" ): + + if not record.seq: + continue + pdb_id, chain = record.id.split("_") - sequence = ( - re.sub(f"[^{valid_amino_acids}]", "X", str(record.seq)) - if record.seq - else "" - ) + sequence = re.sub(f"[^{valid_amino_acids}]", "X", str(record.seq)) # Store as a dictionary entry (list of dicts -> DataFrame later) records.append( @@ -876,7 +907,8 @@ class SCOPeOverPartial2000(_SCOPeOverXPartial): if __name__ == "__main__": - scope = SCOPeOver2000(scope_version="2.08") + scope = SCOPeOver50(scope_version="2.08") + # g = scope._extract_class_hierarchy("dummy/path") # # Save graph # import pickle From 767b21010ca2c5ec37249b36240153a94b1b88fb Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 15 Mar 2025 12:22:26 +0100 Subject: [PATCH 63/71] scope: fix for true values less given threshold for some labels --- chebai/preprocessing/datasets/scope/scope.py | 99 ++++++++++++++------ 1 file changed, 72 insertions(+), 27 deletions(-) diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py index fe9b5e33..3887ea9e 100644 --- a/chebai/preprocessing/datasets/scope/scope.py +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -198,49 +198,91 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: pdb_chain_df = self._parse_pdb_sequence_file() pdb_id_set = set(pdb_chain_df["pdb_id"]) # Search time complexity - O(1) - g = nx.DiGraph() - - edges = [] + # Initialize sets and dictionaries for storing edges and attributes + parent_node_edges, node_child_edges = set(), set() node_attrs = {} px_level_nodes = set() + sequence_nodes = dict() + px_to_seq_edges = set() + required_graph_nodes = set() + + # Create a lookup dictionary for PDB chain sequences + lookup_dict = ( + pdb_chain_df.groupby("pdb_id")[["chain_id", "sequence"]] + .apply(lambda x: dict(zip(x["chain_id"], x["sequence"]))) + .to_dict() + ) + + def add_sequence_nodes_edges(chain_sequence, px_sun_id): + """Adds sequence nodes and edges connecting px-level nodes to sequence nodes.""" + if chain_sequence not in sequence_nodes: + sequence_nodes[chain_sequence] = f"seq_{len(sequence_nodes)}" + px_to_seq_edges.add((px_sun_id, sequence_nodes[chain_sequence])) - # Step 1: Build the graph and store attributes + # Step 1: Build the graph structure and store node attributes for row in df_scope.itertuples(index=False): if row.level == "px": - if row.sid[1:5] not in pdb_id_set: + + pdb_id, chain_id = row.sid[1:5], row.sid[5] + + if pdb_id not in pdb_id_set or chain_id == "_": # Don't add domain level nodes that don't have pdb_id in pdb_sequences.txt file + # Also chain_id with "_" which corresponds to no chain continue px_level_nodes.add(row.sunid) + # Add edges between px-level nodes and sequence nodes + if chain_id != ".": + if chain_id not in lookup_dict[pdb_id]: + continue + add_sequence_nodes_edges(lookup_dict[pdb_id][chain_id], row.sunid) + else: + # If chain_id is '.', connect all chains of this PDB ID + for chain, chain_sequence in lookup_dict[pdb_id].items(): + add_sequence_nodes_edges(chain_sequence, row.sunid) + else: + required_graph_nodes.add(row.sunid) + node_attrs[row.sunid] = {"sid": row.sid, "level": row.level} if row.parent_sunid != -1: - edges.append((row.parent_sunid, row.sunid)) + parent_node_edges.add((row.parent_sunid, row.sunid)) for child_id in row.children_sunids: - edges.append((row.sunid, child_id)) + node_child_edges.add((row.sunid, child_id)) - g.add_nodes_from((node, attrs) for node, attrs in node_attrs.items()) - g.add_edges_from(edges) + del df_scope, pdb_chain_df, pdb_id_set - # Step 2: Compute the transitive closure first - print("Computing transitive closure") - g_tc = nx.transitive_closure_dag(g) - - print( - "Remove node without domain descendants that don't have pdb correspondence" + g = nx.DiGraph() + g.add_nodes_from(node_attrs.items()) + # Note - `add_edges` internally create a node, if a node doesn't exist already + g.add_edges_from({(p, c) for p, c in parent_node_edges if p in node_attrs}) + g.add_edges_from({(p, c) for p, c in node_child_edges if c in node_attrs}) + + seq_nodes = set(sequence_nodes.values()) + g.add_nodes_from([(seq_id, {"level": "sequence"}) for seq_id in seq_nodes]) + g.add_edges_from( + { + (px_node, seq_node) + for px_node, seq_node in px_to_seq_edges + if px_node in node_attrs and seq_node in seq_nodes + } ) - # Step 3: Identify and remove nodes that don’t have a "px" descendant with correspondence to pdb_sequences file - nodes_to_remove = set() - for node in g_tc.nodes: - if node not in px_level_nodes and not any( - desc in px_level_nodes for desc in g_tc.successors(node) - ): - nodes_to_remove.add(node) - g_tc.remove_nodes_from(nodes_to_remove) + # Step 2: Count sequence successors for required graph nodes only + for node in required_graph_nodes: + num_seq_successors = sum( + g.nodes[child]["level"] == "sequence" + for child in nx.descendants(g, node) + ) + g.nodes[node]["num_seq_successors"] = num_seq_successors - return g_tc + # Step 3: Remove nodes which are not required before computing transitive closure for better efficiency + g.remove_nodes_from(px_level_nodes | seq_nodes) + + print("Computing Transitive Closure.........") + # Transitive closure is not needed in `select_classes` method but is required in _SCOPeOverXPartial + return nx.transitive_closure_dag(g) def _get_scope_data(self) -> pd.DataFrame: """ @@ -808,12 +850,15 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict[str, List[int]] """ selected_sunids_for_level = {} for node, attr_dict in g.nodes(data=True): - if g.out_degree(node) >= self.THRESHOLD: + if attr_dict["level"] in {"root", "px", "sequence"}: + # Skip nodes with level "root", "px", or "sequence" + continue + + # Check if the number of "sequence"-level successors meets or exceeds the threshold + if g.nodes[node]["num_seq_successors"] >= self.THRESHOLD: selected_sunids_for_level.setdefault(attr_dict["level"], []).append( node ) - # Remove root node, as it will True for all instances - selected_sunids_for_level.pop("root", None) return selected_sunids_for_level From 081b44d00eeb355640200914d6ad278af08cfc15 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 15 Mar 2025 12:23:33 +0100 Subject: [PATCH 64/71] go_notebook: update import statement --- tutorials/data_exploration_go.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/data_exploration_go.ipynb b/tutorials/data_exploration_go.ipynb index 1a205e37..6f67c82b 100644 --- a/tutorials/data_exploration_go.ipynb +++ b/tutorials/data_exploration_go.ipynb @@ -70,7 +70,7 @@ } }, "outputs": [], - "source": "from chebai.preprocessing.datasets.deepGO.go_uniprot import GOUniProtOver250" + "source": "from chebai.preprocessing.datasets.go_uniprot import GOUniProtOver250" }, { "cell_type": "code", From 81c1348eb29f9e88dd91cd2f134081a0306da902 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 15 Mar 2025 12:55:21 +0100 Subject: [PATCH 65/71] scope notebook: add scope description and minor changes --- tutorials/data_exploration_scope.ipynb | 279 +++++++++++++++---------- 1 file changed, 171 insertions(+), 108 deletions(-) diff --git a/tutorials/data_exploration_scope.ipynb b/tutorials/data_exploration_scope.ipynb index a67251cc..c14046ac 100644 --- a/tutorials/data_exploration_scope.ipynb +++ b/tutorials/data_exploration_scope.ipynb @@ -17,6 +17,74 @@ "---\n" ] }, + { + "cell_type": "markdown", + "id": "cca637ce-d4ea-4365-acd9-657418e0640f", + "metadata": {}, + "source": [ + "### Overview of SCOPe Data and its Usage in Protein-Related Tasks\n", + "\n", + "#### **What is SCOPe?**\n", + "\n", + "The **Structural Classification of Proteins — extended (SCOPe)** is a comprehensive database that extends the original SCOP (Structural Classification of Proteins) database. SCOPe offers a detailed classification of protein domains based on their structural and evolutionary relationships.\n", + "\n", + "The SCOPe database, like SCOP, organizes proteins into a hierarchy of domains based on structural similarities, which is crucial for understanding evolutionary patterns and functional aspects of proteins. This hierarchical structure is comparable to taxonomy in biology, where species are classified based on shared characteristics.\n", + "\n", + "#### **SCOPe Hierarchy:**\n", + "By analogy with taxonomy, SCOP was created as a hierarchy of several levels where the fundamental unit of classification is a **domain** in the experimentally determined protein structure. Starting at the bottom, the hierarchy of SCOP domains comprises the following levels:\n", + "\n", + "1. **Species**: Representing distinct protein sequences and their naturally occurring or artificially created variants.\n", + "2. **Protein**: Groups together similar sequences with essentially the same functions. These can originate from different biological species or represent isoforms within the same species.\n", + "3. **Family**: Contains proteins with similar sequences but typically distinct functions.\n", + "4. **Superfamily**: Bridges protein families with common functional and structural features, often inferred from a shared evolutionary ancestor.\n", + "5. **Fold**: Groups structurally similar superfamilies. \n", + "6. **Class**: Based on secondary structure content and organization. This level classifies proteins based on their secondary structure properties, such as alpha-helices and beta-sheets.\n", + "\n", + "\n", + "\n", + "For more details, you can refer to the [SCOPe documentation](https://scop.berkeley.edu/help/ver=2.08).\n", + "\n", + "---\n", + "\n", + "#### **Why are We Using SCOPe?**\n", + "\n", + "We are integrating the SCOPe data into our pipeline as part of an ontology pretraining task for protein-related models. SCOPe is a great fit for our goal because it is primarily **structure-based**, unlike other protein-related databases like Gene Ontology (GO), which focuses more on functional classes.\n", + "\n", + "Our primary objective is to reproduce **ontology pretraining** on a protein-related task, and SCOPe provides the structural ontology that we need for this. The steps in our pipeline are aligned as follows:\n", + "\n", + "| **Stage** | **Chemistry Task** | **Proteins Task** |\n", + "|--------------------------|-------------------------------------|------------------------------------------------|\n", + "| **Unsupervised Pretraining** | Mask pretraining (ELECTRA) | Mask pretraining (ESM2, optional) |\n", + "| **Ontology Pretraining** | ChEBI | SCOPe |\n", + "| **Finetuning Task** | Toxicity, Solubility, etc. | GO (MF, BP, CC branches) |\n", + "\n", + " \n", + "This integration will allow us to use **SCOPe** for tasks such as **protein classification** and will contribute to the success of **pretraining models** for protein structures. The data will be processed with the same approach as the GO data, with **different labels** corresponding to the SCOPe classification system.\n", + "\n", + "---\n", + "\n", + "#### **Why SCOPe is Suitable for Our Task**\n", + "\n", + "1. **Structure-Based Classification**: SCOPe is primarily concerned with the structural characteristics of proteins, making it ideal for protein structure pretraining tasks. This contrasts with other ontology databases like **GO**, which categorize proteins based on more complex functional relationships.\n", + " \n", + "2. **Manageable Size**: SCOPe contains around **140,000 entries**, making it a manageable dataset for training models. This is similar in size to **ChEBI**, which is used in the chemical domain, and ensures we can work with it effectively for pretraining." + ] + }, + { + "cell_type": "markdown", + "id": "338e452f-426c-493d-bec2-5bd51e24e4aa", + "metadata": {}, + "source": [ + "\n", + "### Protein Data Bank (PDB)\n", + "\n", + "The **Protein Data Bank (PDB)** is a global repository that stores 3D structural data of biological macromolecules like proteins and nucleic acids. It contains information obtained through experimental methods such as **X-ray crystallography**, **NMR spectroscopy**, and **cryo-EM**. The data includes atomic coordinates, secondary structure details, and experimental conditions.\n", + "\n", + "The PDB is an essential resource for **structural biology**, **bioinformatics**, and **drug discovery**, enabling scientists to understand protein functions, interactions, and mechanisms at the molecular level.\n", + "\n", + "For more details, visit the [RCSB PDB website](https://www.rcsb.org/).\n" + ] + }, { "cell_type": "markdown", "id": "f6c25706-251c-438c-9915-e8002647eb94", @@ -117,7 +185,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Changed to project root directory: c:\\Users\\sifluegel\\PycharmProjects\\python-chebai\n" + "Changed to project root directory: G:\\github-aditya0by0\\python-chebai\n" ] } ], @@ -151,26 +219,17 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 6, "id": "f3a66e07-edc9-4aa2-9cd0-d4ea58914d22", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\Users\\sifluegel\\PycharmProjects\\python-chebai\\venv312c\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], + "outputs": [], "source": [ - "from chebai.preprocessing.datasets.scope.scope import SCOPeOver2000" + "from chebai.preprocessing.datasets.scope.scope import SCOPeOver50" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 7, "id": "a71b7301-6195-4155-a439-f5eb3183d0f3", "metadata": { "ExecuteTime": { @@ -180,7 +239,7 @@ }, "outputs": [], "source": [ - "scope_class = SCOPeOver2000(scope_version=\"2.08\")" + "scope_class = SCOPeOver50(scope_version=\"2.08\")" ] }, { @@ -259,7 +318,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Checking for processed data in data\\SCOPe\\version_2.08\\SCOPe2000\\processed\n", + "Checking for processed data in data\\SCOPe\\version_2.08\\SCOPe50\\processed\n", "Missing processed data file (`data.pkl` file)\n", "Missing PDB raw data, Downloading PDB sequence data....\n", "Downloading to temporary file C:\\Users\\HP\\AppData\\Local\\Temp\\tmpsif7r129\n", @@ -295,7 +354,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Check for processed data in data\\SCOPe\\version_2.08\\SCOPe2000\\processed\\protein_token\n", + "Check for processed data in data\\SCOPe\\version_2.08\\SCOPe50\\processed\\protein_token\n", "Cross-validation enabled: False\n" ] }, @@ -433,7 +492,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "fd490270-59b8-4c1c-8b09-204defddf592", "metadata": { "ExecuteTime": { @@ -449,7 +508,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "id": "d7d16247-092c-4e8d-96c2-ab23931cf766", "metadata": { "ExecuteTime": { @@ -462,7 +521,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Size of the data (rows x columns): (14210, 31)\n" + "Size of the data (rows x columns): (60424, 1035)\n" ] }, { @@ -489,105 +548,105 @@ " id\n", " sids\n", " sequence\n", + " class_46456\n", " class_48724\n", + " class_51349\n", " class_53931\n", - " class_310555\n", - " fold_48725\n", - " fold_56111\n", - " fold_56234\n", - " fold_310573\n", + " class_56572\n", + " class_56835\n", + " class_56992\n", " ...\n", - " protein_190417\n", - " protein_190740\n", - " protein_310894\n", - " protein_310895\n", - " species_56254\n", - " species_187221\n", " species_187294\n", - " species_187920\n", - " species_311501\n", + " species_56257\n", + " species_186882\n", + " species_56690\n", + " species_161316\n", + " species_57962\n", + " species_58067\n", + " species_267696\n", " species_311502\n", + " species_311501\n", " \n", " \n", " \n", " \n", " 0\n", " 1\n", - " [d6vi2a2, d6vi2c2, d6vi2a1, d6vi2c1]\n", - " SDIQMTQSPSSLSASVGDRVTITCRASQSVSSAVAWYQQKPGKAPK...\n", + " [d4oq9a_, d4oq9b_, d4oq9c_, d4oq9d_, d4niaa_, ...\n", + " AAAAAAAAAA\n", + " False\n", " True\n", " False\n", " False\n", - " True\n", " False\n", " False\n", " False\n", " ...\n", " False\n", - " True\n", " False\n", " False\n", " False\n", - " True\n", " False\n", - " True\n", + " False\n", + " False\n", + " False\n", " False\n", " False\n", " \n", " \n", " 1\n", " 2\n", - " [d4nzul2, d4nzul1]\n", - " DIEMTQSPSSLSASTGDKVTITCQASQDIAKFLDWYQQRPGKTPKL...\n", - " True\n", + " [d7dxhc_]\n", + " AAAAAAAAAAAAAAAAAAAAAAA\n", + " False\n", " False\n", " False\n", - " True\n", " False\n", " False\n", + " True\n", " False\n", " ...\n", " False\n", - " True\n", " False\n", " False\n", " False\n", - " True\n", " False\n", - " True\n", + " False\n", + " False\n", + " False\n", " False\n", " False\n", " \n", " \n", " 2\n", " 3\n", - " [d3tv3l2, d6b3dl2, d3tv3l1, d6b3dl1]\n", - " QSALTQPPSASGSPGQSITISCTGTSNNFVSWYQQHAGKAPKLVIY...\n", - " True\n", + " [d1gkub1, d1gkub2, d1gkub3, d1gkub4]\n", + " AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASLCLFPEDFLLKEF...\n", " False\n", " False\n", " True\n", " False\n", + " True\n", " False\n", " False\n", " ...\n", " False\n", - " True\n", " False\n", " False\n", " False\n", - " True\n", " False\n", - " True\n", " False\n", " False\n", + " False\n", + " False\n", + " True\n", " \n", " \n", " 3\n", " 4\n", - " [d2nw2a2, d2nw2a1]\n", - " QNIDQPTEMTATEGAIVQINCTYQTSGFNGLFWYQQHAGEAPTFLS...\n", - " True\n", + " [d3c9wa2, d3c9wb2, d3c9wa3, d3c9wb3]\n", + " AAAAAAGPEMVRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNLNKV...\n", + " False\n", " False\n", " False\n", " True\n", @@ -596,85 +655,92 @@ " False\n", " ...\n", " False\n", - " True\n", " False\n", " False\n", " False\n", - " True\n", " False\n", - " True\n", " False\n", " False\n", + " False\n", + " False\n", + " True\n", " \n", " \n", " 4\n", " 5\n", - " [d7k3ql2, d7r6xd_]\n", - " DIVLTQTPGTLSLSPGERATLSCRASQSVSSSYLAWYQQKPGQAPR...\n", - " True\n", + " [d1xwaa1, d1xwab_, d1xwac_, d1xwad_, d1xwaa2]\n", + " AAAAAMVYQVKDKADLDGQLTKASGKLVVLDFFATWCGPCKMISPK...\n", " False\n", " False\n", " True\n", " False\n", " False\n", " False\n", + " False\n", " ...\n", " False\n", - " True\n", " False\n", " False\n", " False\n", - " True\n", " False\n", - " True\n", " False\n", " False\n", + " False\n", + " False\n", + " True\n", " \n", " \n", "\n", - "

5 rows × 31 columns

\n", + "

5 rows × 1035 columns

\n", "" ], "text/plain": [ - " id sids \\\n", - "0 1 [d6vi2a2, d6vi2c2, d6vi2a1, d6vi2c1] \n", - "1 2 [d4nzul2, d4nzul1] \n", - "2 3 [d3tv3l2, d6b3dl2, d3tv3l1, d6b3dl1] \n", - "3 4 [d2nw2a2, d2nw2a1] \n", - "4 5 [d7k3ql2, d7r6xd_] \n", + " id sids \\\n", + "0 1 [d4oq9a_, d4oq9b_, d4oq9c_, d4oq9d_, d4niaa_, ... \n", + "1 2 [d7dxhc_] \n", + "2 3 [d1gkub1, d1gkub2, d1gkub3, d1gkub4] \n", + "3 4 [d3c9wa2, d3c9wb2, d3c9wa3, d3c9wb3] \n", + "4 5 [d1xwaa1, d1xwab_, d1xwac_, d1xwad_, d1xwaa2] \n", + "\n", + " sequence class_46456 \\\n", + "0 AAAAAAAAAA False \n", + "1 AAAAAAAAAAAAAAAAAAAAAAA False \n", + "2 AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASLCLFPEDFLLKEF... False \n", + "3 AAAAAAGPEMVRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNLNKV... False \n", + "4 AAAAAMVYQVKDKADLDGQLTKASGKLVVLDFFATWCGPCKMISPK... False \n", "\n", - " sequence class_48724 class_53931 \\\n", - "0 SDIQMTQSPSSLSASVGDRVTITCRASQSVSSAVAWYQQKPGKAPK... True False \n", - "1 DIEMTQSPSSLSASTGDKVTITCQASQDIAKFLDWYQQRPGKTPKL... True False \n", - "2 QSALTQPPSASGSPGQSITISCTGTSNNFVSWYQQHAGKAPKLVIY... True False \n", - "3 QNIDQPTEMTATEGAIVQINCTYQTSGFNGLFWYQQHAGEAPTFLS... True False \n", - "4 DIVLTQTPGTLSLSPGERATLSCRASQSVSSSYLAWYQQKPGQAPR... True False \n", + " class_48724 class_51349 class_53931 class_56572 class_56835 \\\n", + "0 True False False False False \n", + "1 False False False False True \n", + "2 False True False True False \n", + "3 False False True False False \n", + "4 False True False False False \n", "\n", - " class_310555 fold_48725 fold_56111 fold_56234 fold_310573 ... \\\n", - "0 False True False False False ... \n", - "1 False True False False False ... \n", - "2 False True False False False ... \n", - "3 False True False False False ... \n", - "4 False True False False False ... \n", + " class_56992 ... species_187294 species_56257 species_186882 \\\n", + "0 False ... False False False \n", + "1 False ... False False False \n", + "2 False ... False False False \n", + "3 False ... False False False \n", + "4 False ... False False False \n", "\n", - " protein_190417 protein_190740 protein_310894 protein_310895 species_56254 \\\n", - "0 False True False False False \n", - "1 False True False False False \n", - "2 False True False False False \n", - "3 False True False False False \n", - "4 False True False False False \n", + " species_56690 species_161316 species_57962 species_58067 \\\n", + "0 False False False False \n", + "1 False False False False \n", + "2 False False False False \n", + "3 False False False False \n", + "4 False False False False \n", "\n", - " species_187221 species_187294 species_187920 species_311501 species_311502 \n", - "0 True False True False False \n", - "1 True False True False False \n", - "2 True False True False False \n", - "3 True False True False False \n", - "4 True False True False False \n", + " species_267696 species_311502 species_311501 \n", + "0 False False False \n", + "1 False False False \n", + "2 False False True \n", + "3 False False True \n", + "4 False False True \n", "\n", - "[5 rows x 31 columns]" + "[5 rows x 1035 columns]" ] }, - "execution_count": 5, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -723,7 +789,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 9, "id": "977ddd83-b469-4b58-ab1a-8574fb8769b4", "metadata": { "ExecuteTime": { @@ -738,7 +804,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 10, "id": "3266ade9-efdc-49fe-ae07-ed52b2eb52d0", "metadata": { "ExecuteTime": { @@ -767,7 +833,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 12, "id": "84cfa3e6-f60d-47c0-9f82-db3d5673d1e7", "metadata": { "ExecuteTime": { @@ -780,15 +846,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "{'features': [11, 20, 12, 18, 10, 15, 18, 11, 23, 11, 11, 17, 11, 14, 11, 28, 13, 20, 16, 28, 15, 12, 15, 24, 16, 14, 11, 18, 11, 28, 11, 11, 14, 28, 14, 26, 22, 18, 18, 21, 23, 13, 21, 14, 23, 21, 17, 17, 12, 22, 11, 14, 11, 11, 17, 22, 11, 13, 28, 23, 11, 16, 25, 11, 13, 11, 16, 11, 13, 15, 20, 25, 15, 17, 15, 12, 11, 11, 17, 18, 23, 27, 20, 25, 14, 15, 22, 22, 24, 18, 18, 11, 28, 11, 22, 10, 13, 23, 17, 15, 25, 13, 18, 13, 15, 21, 28, 27, 12, 21, 16, 15, 28, 14, 14, 23, 11, 28, 25, 12, 25, 23, 23, 11, 20, 11, 18, 17, 21, 11, 13, 15, 14, 11, 28, 28, 24, 17, 17, 19, 19, 25, 22, 23, 16, 27, 14, 21, 28, 18, 26, 21, 28, 20, 19, 14, 17, 18, 11, 13, 19, 11, 18, 27, 11, 28, 15, 27, 18, 20, 11, 21, 20, 11, 15, 22, 11, 17, 11, 11, 15, 17, 15, 17, 11, 21, 14, 20, 22, 27, 21, 29, 21, 28, 22, 14, 24, 27, 28, 15, 29, 18, 13, 17, 11, 11, 23, 28, 15, 21, 11, 25, 19, 16, 13], 'labels': array([ True, False, False, True, False, False, False, True, False,\n", - " False, False, True, False, False, True, False, False, True,\n", - " False, True, False, False, False, True, False, True, False,\n", - " False]), 'ident': 1, 'group': None}\n" + "{'features': [14, 14, 14, 14, 20, 15, 15, 28, 15, 18, 25, 17, 18, 11, 25, 21, 27, 19, 14, 27, 19, 13, 14, 17, 16, 21, 25, 22, 27, 28, 12, 10, 20, 19, 13, 13, 14, 28, 17, 20, 20, 12, 19, 11, 17, 15, 27, 28, 15, 12, 17, 14, 23, 11, 19, 27, 14, 26, 19, 11, 11, 19, 12, 19, 19, 28, 17, 16, 20, 16, 19, 21, 10, 16, 18, 12, 17, 19, 10, 29, 12, 12, 21, 20, 16, 17, 19, 28, 20, 21, 12, 16, 18, 21, 19, 14, 19, 17, 12, 14, 18, 28, 23, 15, 28, 19, 19, 19, 15, 25, 17, 22, 25, 19, 28, 16, 13, 27, 13, 11, 20, 15, 28, 12, 15, 28, 27, 13, 13, 13, 28, 19, 14, 15, 28, 12, 18, 14, 20, 28, 14, 18, 15, 19, 13, 22, 28, 29, 12, 12, 20, 29, 28, 17, 13, 28, 23, 22, 15, 15, 28, 17, 13, 21, 17, 27, 11, 20, 23, 10, 10, 11, 20, 15, 22, 21, 10, 13, 21, 25, 11, 29, 25, 19, 20, 18, 17, 19, 19, 15, 18, 16, 16, 25, 15, 22, 25, 28, 23, 16, 20, 21, 13, 26, 18, 21, 15, 27, 17, 20, 22, 23, 11, 14, 29, 21, 21, 17, 25, 10, 14, 20, 25, 11, 22, 29, 11, 21, 11, 12, 17, 27, 16, 29, 17, 14, 12, 11, 20, 21, 27, 22, 15, 10, 21, 20, 17, 28, 21, 25, 11, 18, 27, 11, 13, 11, 28, 12, 17, 23, 15, 25, 16, 20, 11, 17, 11, 12, 16, 28, 27, 27, 27, 14, 13, 16, 22, 28, 12, 12, 26, 19, 22, 21, 21, 12, 19, 28, 22, 16, 23, 20, 28, 27, 24, 15, 19, 13, 12, 12, 29, 28, 12, 20, 22, 23, 17, 17, 27, 27, 21, 20, 28, 28, 28, 14, 13, 13, 11, 14, 14, 14, 14, 14], 'labels': array([False, True, False, ..., False, False, False]), 'ident': 6, 'group': None}\n" ] } ], "source": [ - "for i in range(1):\n", + "for i in range(5, 6):\n", " print(data_pt[i])" ] }, @@ -821,7 +884,7 @@ "source": [ "## classes.txt File\n", "\n", - "**Description**: A file containing the list of selected SCOPe classes based on the specified threshold. This file is crucial for ensuring that only relevant classes are included in the dataset." + "**Description**: A file containing the list of selected SCOPe **labels** based on the specified threshold. This file is crucial for ensuring that only relevant **labels** are included in the dataset." ] }, { @@ -1097,7 +1160,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv312c", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -1111,7 +1174,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.8" + "version": "3.10.14" } }, "nbformat": 4, From 2b0ed0a7e9251e324f6263382b3de50fced76295 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 19 Mar 2025 12:30:34 +0100 Subject: [PATCH 66/71] electra config: increase max_postional_embeddings to 3000 --- configs/model/electra.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/model/electra.yml b/configs/model/electra.yml index ade89acd..23d0185b 100644 --- a/configs/model/electra.yml +++ b/configs/model/electra.yml @@ -4,7 +4,7 @@ init_args: lr: 1e-3 config: vocab_size: 8500 - max_position_embeddings: 1800 + max_position_embeddings: 3000 num_attention_heads: 8 num_hidden_layers: 6 type_vocab_size: 1 From ef4bc0b2c85eacc2b851d9f8dc163e99d51a1d45 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 29 Mar 2025 11:22:54 +0100 Subject: [PATCH 67/71] comment protein-related requirements --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 8a6d3e0c..004e0d56 100644 --- a/setup.py +++ b/setup.py @@ -50,8 +50,8 @@ "chardet", "pyyaml", "torchmetrics", - "biopython", - "fair-esm", + # "biopython", + # "fair-esm", ], extras_require={"dev": ["black", "isort", "pre-commit"]}, ) From 58bcf056481b66f9f62a13c3aac3995690736cd1 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 1 Apr 2025 19:25:48 +0200 Subject: [PATCH 68/71] Revert "comment protein-related requirements" This reverts commit ef4bc0b2c85eacc2b851d9f8dc163e99d51a1d45. --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 004e0d56..8a6d3e0c 100644 --- a/setup.py +++ b/setup.py @@ -50,8 +50,8 @@ "chardet", "pyyaml", "torchmetrics", - # "biopython", - # "fair-esm", + "biopython", + "fair-esm", ], extras_require={"dev": ["black", "isort", "pre-commit"]}, ) From 831f70d25c8d20369a81da558c76a31d4970acc2 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 1 Apr 2025 20:05:01 +0200 Subject: [PATCH 69/71] scope: filter out sequence with len gt than given len --- chebai/preprocessing/datasets/scope/scope.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py index 3887ea9e..e9127b25 100644 --- a/chebai/preprocessing/datasets/scope/scope.py +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -72,10 +72,12 @@ def __init__( self, scope_version: str, scope_version_train: Optional[str] = None, + max_sequence_len: int = 1000, **kwargs, ): self.scope_version: str = scope_version self.scope_version_train: str = scope_version_train + self.max_sequence_len: int = max_sequence_len super(_SCOPeDataExtractor, self).__init__(**kwargs) @@ -545,7 +547,7 @@ def _parse_pdb_sequence_file(self) -> pd.DataFrame: os.path.join(self.scope_root_dir, self.raw_file_names_dict["PDB"]), "fasta" ): - if not record.seq: + if not record.seq or len(record.seq) > self.max_sequence_len: continue pdb_id, chain = record.id.split("_") From 158d6f3e91608682b75d75e9f951467be8364cdd Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 1 Apr 2025 20:06:46 +0200 Subject: [PATCH 70/71] Revert "electra config: increase max_postional_embeddings to 3000" This reverts commit 2b0ed0a7e9251e324f6263382b3de50fced76295. --- configs/model/electra.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/model/electra.yml b/configs/model/electra.yml index 23d0185b..ade89acd 100644 --- a/configs/model/electra.yml +++ b/configs/model/electra.yml @@ -4,7 +4,7 @@ init_args: lr: 1e-3 config: vocab_size: 8500 - max_position_embeddings: 3000 + max_position_embeddings: 1800 num_attention_heads: 8 num_hidden_layers: 6 type_vocab_size: 1 From bb0b4db77004475cfe43e4baea6c82e5cc31f8f3 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 1 Apr 2025 20:20:09 +0200 Subject: [PATCH 71/71] electra config: reset the vocab size to previous default value for scope - the vocab size was increased for proteins in commit a12354b527f670da28ac6b8f200b659d4d67ab43 - as we are going to move protein related code to new repo, revert this to original value --- configs/model/electra.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/model/electra.yml b/configs/model/electra.yml index ade89acd..c3cf2fdf 100644 --- a/configs/model/electra.yml +++ b/configs/model/electra.yml @@ -3,7 +3,7 @@ init_args: optimizer_kwargs: lr: 1e-3 config: - vocab_size: 8500 + vocab_size: 1400 max_position_embeddings: 1800 num_attention_heads: 8 num_hidden_layers: 6