2121
2222
2323class _ProteinPretrainingData (_DynamicDataset , ABC ):
24+ """
25+ Data module for pretraining protein sequences, specifically designed for Swiss-UniProt data. It includes methods for
26+ data preparation, loading, and dynamic splitting of protein sequences.
27+ The data is parsed and filtered to only select proteins with no associated `valid` Gene Ontology (GO) labels.
28+ A valid GO label is the one which has one of evidence codes defined in `EXPERIMENTAL_EVIDENCE_CODES`.
29+ """
30+
2431 _ID_IDX : int = 0
25- _DATA_REPRESENTATION_IDX : int = 1 # here `sequence` column
32+ _DATA_REPRESENTATION_IDX : int = 1 # Index of `sequence` column
2633
2734 def __init__ (self , ** kwargs ):
28- self ._go_extractor = GOUniProtOver250 ()
29- assert self ._go_extractor .go_branch == GOUniProtOver250 ._ALL_GO_BRANCHES
35+ """
36+ Initializes the data module with any GOUniProt extractor class object.
37+
38+ Args:
39+ **kwargs: Additional arguments for the superclass initialization.
40+ """
41+ self ._go_uniprot_extractor = GOUniProtOver250 ()
42+ assert self ._go_uniprot_extractor .go_branch == GOUniProtOver250 ._ALL_GO_BRANCHES
3043 super (_ProteinPretrainingData , self ).__init__ (** kwargs )
3144
3245 # ------------------------------ Phase: Prepare data -----------------------------------
3346 def prepare_data (self , * args : Any , ** kwargs : Any ) -> None :
47+ """
48+ Prepares the data by downloading and parsing Swiss-Prot data if not already available. Saves the processed data
49+ for further use.
50+
51+ Args:
52+ *args: Additional positional arguments.
53+ **kwargs: Additional keyword arguments.
54+ """
3455 processed_name = self .processed_dir_main_file_names_dict ["data" ]
3556 if not os .path .isfile (os .path .join (self .processed_dir_main , processed_name )):
3657 print ("Missing processed data file (`data.pkl` file)" )
@@ -40,20 +61,29 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
4061 self .save_processed (protein_df , processed_name )
4162
4263 def _extract_class_hierarchy (self , data_path : str ) -> nx .DiGraph :
64+ # method not required as no Swiss-UniProt has no ontological data
4365 pass
4466
4567 def _graph_to_raw_dataset (self , graph : nx .DiGraph ) -> pd .DataFrame :
68+ # method not required as no Swiss-UniProt has no ontological data
4669 pass
4770
4871 def select_classes (self , g : nx .DiGraph , * args , ** kwargs ) -> List :
72+ # method not required as no Swiss-UniProt has no ontological data
4973 pass
5074
5175 def _download_required_data (self ) -> str :
52- return self ._go_extractor ._download_swiss_uni_prot_data ()
76+ """
77+ Downloads the required Swiss-Prot data using the GOUniProt extractor class.
78+
79+ Returns:
80+ str: Path to the downloaded data.
81+ """
82+ return self ._go_uniprot_extractor ._download_swiss_uni_prot_data ()
5383
5484 def _parse_protein_data_for_pretraining (self ) -> pd .DataFrame :
5585 """
56- Parses the Swiss-Prot data and returns a DataFrame mapping Swiss-Prot records which does not have any valid
86+ Parses the Swiss-Prot data and returns a DataFrame containing Swiss-Prot proteins which does not have any valid
5787 Gene Ontology(GO) label. A valid GO label is the one which has one of the following evidence code
5888 (EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC).
5989
@@ -65,18 +95,17 @@ def _parse_protein_data_for_pretraining(self) -> pd.DataFrame:
6595 We ignore proteins with ambiguous amino acid codes (B, O, J, U, X, Z) in their sequence.`
6696
6797 Returns:
68- pd.DataFrame: A DataFrame where each row corresponds to a Swiss-Prot record with its associated GO data .
98+ pd.DataFrame: A DataFrame where each row corresponds to a Swiss-Prot record with not associated valid GO .
6999 """
70-
71100 print ("Parsing swiss uniprot raw data...." )
72101
73102 swiss_ids , sequences = [], []
74103
75104 swiss_data = SwissProt .parse (
76105 open (
77106 os .path .join (
78- self ._go_extractor .raw_dir ,
79- self ._go_extractor .raw_file_names_dict ["SwissUniProt" ],
107+ self ._go_uniprot_extractor .raw_dir ,
108+ self ._go_uniprot_extractor .raw_file_names_dict ["SwissUniProt" ],
80109 ),
81110 "r" ,
82111 )
@@ -88,7 +117,7 @@ def _parse_protein_data_for_pretraining(self) -> pd.DataFrame:
88117 continue
89118
90119 if not record .sequence :
91- # Consider protein with only sequence representation and seq. length not greater than max seq. length
120+ # Consider protein with only sequence representation
92121 continue
93122
94123 if any (aa in AMBIGUOUS_AMINO_ACIDS for aa in record .sequence ):
@@ -97,7 +126,7 @@ def _parse_protein_data_for_pretraining(self) -> pd.DataFrame:
97126
98127 has_valid_associated_go_label = False
99128 for cross_ref in record .cross_references :
100- if cross_ref [0 ] == self ._go_extractor ._GO_DATA_INIT :
129+ if cross_ref [0 ] == self ._go_uniprot_extractor ._GO_DATA_INIT :
101130
102131 if len (cross_ref ) <= 3 :
103132 # No evidence code
@@ -146,8 +175,6 @@ def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, No
146175 with open (input_file_path , "rb" ) as input_file :
147176 df = pd .read_pickle (input_file )
148177 for row in df .values :
149- # chebai.preprocessing.reader.DataReader only needs features, labels, ident, group
150- # "group" set to None, by default as no such entity for this data
151178 yield dict (
152179 features = row [self ._DATA_REPRESENTATION_IDX ],
153180 ident = row [self ._ID_IDX ],
@@ -184,32 +211,51 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
184211
185212 df_go_data = pd .DataFrame (data_go )
186213 train_df_go , df_test = train_test_split (
187- df_go_data , seed = self .dynamic_data_split_seed
214+ df_go_data ,
215+ train_size = self .train_split ,
216+ random_state = self .dynamic_data_split_seed ,
188217 )
189218
190219 # Get all splits
191220 df_train , df_val = train_test_split (
192221 train_df_go ,
193- seed = self .dynamic_data_split_seed ,
222+ train_size = self .train_split ,
223+ random_state = self .dynamic_data_split_seed ,
194224 )
195225
196226 return df_train , df_val , df_test
197227
198228 # ------------------------------ Phase: Raw Properties -----------------------------------
199229 @property
200230 def base_dir (self ) -> str :
201- return os .path .join (self ._go_extractor .base_dir , "Pretraining" )
231+ """
232+ str: The base directory for pretraining data storage.
233+ """
234+ return os .path .join (self ._go_uniprot_extractor .base_dir , "Pretraining" )
202235
203236 @property
204237 def raw_dir (self ) -> str :
205238 """Name of the directory where the raw data is stored."""
206- return self ._go_extractor .raw_dir
239+ return self ._go_uniprot_extractor .raw_dir
207240
208241
209242class SwissProteinPretrain (_ProteinPretrainingData ):
243+ """
244+ Data module for Swiss-Prot protein pretraining, inheriting from `_ProteinPretrainingData`.
245+ This class is specifically designed to handle data processing and loading for Swiss-Prot-based protein datasets.
246+
247+ Attributes:
248+ READER (Type): The data reader class used to load and process protein pretraining data.
249+ """
210250
211251 READER = ProteinPretrainReader
212252
213253 @property
214254 def _name (self ) -> str :
255+ """
256+ The name identifier for this data module.
257+
258+ Returns:
259+ str: A string identifier, "SwissProteinPretrain", representing the name of this data module.
260+ """
215261 return "SwissProteinPretrain"
0 commit comments