Skip to content

Commit ad4fc95

Browse files
committed
pretrain: add docstrings and typehints
1 parent 2c446dc commit ad4fc95

File tree

1 file changed

+63
-17
lines changed

1 file changed

+63
-17
lines changed

chebai/preprocessing/datasets/protein_pretraining.py

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,37 @@
2121

2222

2323
class _ProteinPretrainingData(_DynamicDataset, ABC):
24+
"""
25+
Data module for pretraining protein sequences, specifically designed for Swiss-UniProt data. It includes methods for
26+
data preparation, loading, and dynamic splitting of protein sequences.
27+
The data is parsed and filtered to only select proteins with no associated `valid` Gene Ontology (GO) labels.
28+
A valid GO label is the one which has one of evidence codes defined in `EXPERIMENTAL_EVIDENCE_CODES`.
29+
"""
30+
2431
_ID_IDX: int = 0
25-
_DATA_REPRESENTATION_IDX: int = 1 # here `sequence` column
32+
_DATA_REPRESENTATION_IDX: int = 1 # Index of `sequence` column
2633

2734
def __init__(self, **kwargs):
28-
self._go_extractor = GOUniProtOver250()
29-
assert self._go_extractor.go_branch == GOUniProtOver250._ALL_GO_BRANCHES
35+
"""
36+
Initializes the data module with any GOUniProt extractor class object.
37+
38+
Args:
39+
**kwargs: Additional arguments for the superclass initialization.
40+
"""
41+
self._go_uniprot_extractor = GOUniProtOver250()
42+
assert self._go_uniprot_extractor.go_branch == GOUniProtOver250._ALL_GO_BRANCHES
3043
super(_ProteinPretrainingData, self).__init__(**kwargs)
3144

3245
# ------------------------------ Phase: Prepare data -----------------------------------
3346
def prepare_data(self, *args: Any, **kwargs: Any) -> None:
47+
"""
48+
Prepares the data by downloading and parsing Swiss-Prot data if not already available. Saves the processed data
49+
for further use.
50+
51+
Args:
52+
*args: Additional positional arguments.
53+
**kwargs: Additional keyword arguments.
54+
"""
3455
processed_name = self.processed_dir_main_file_names_dict["data"]
3556
if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)):
3657
print("Missing processed data file (`data.pkl` file)")
@@ -40,20 +61,29 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
4061
self.save_processed(protein_df, processed_name)
4162

4263
def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
64+
# method not required as no Swiss-UniProt has no ontological data
4365
pass
4466

4567
def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
68+
# method not required as no Swiss-UniProt has no ontological data
4669
pass
4770

4871
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List:
72+
# method not required as no Swiss-UniProt has no ontological data
4973
pass
5074

5175
def _download_required_data(self) -> str:
52-
return self._go_extractor._download_swiss_uni_prot_data()
76+
"""
77+
Downloads the required Swiss-Prot data using the GOUniProt extractor class.
78+
79+
Returns:
80+
str: Path to the downloaded data.
81+
"""
82+
return self._go_uniprot_extractor._download_swiss_uni_prot_data()
5383

5484
def _parse_protein_data_for_pretraining(self) -> pd.DataFrame:
5585
"""
56-
Parses the Swiss-Prot data and returns a DataFrame mapping Swiss-Prot records which does not have any valid
86+
Parses the Swiss-Prot data and returns a DataFrame containing Swiss-Prot proteins which does not have any valid
5787
Gene Ontology(GO) label. A valid GO label is the one which has one of the following evidence code
5888
(EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC).
5989
@@ -65,18 +95,17 @@ def _parse_protein_data_for_pretraining(self) -> pd.DataFrame:
6595
We ignore proteins with ambiguous amino acid codes (B, O, J, U, X, Z) in their sequence.`
6696
6797
Returns:
68-
pd.DataFrame: A DataFrame where each row corresponds to a Swiss-Prot record with its associated GO data.
98+
pd.DataFrame: A DataFrame where each row corresponds to a Swiss-Prot record with not associated valid GO.
6999
"""
70-
71100
print("Parsing swiss uniprot raw data....")
72101

73102
swiss_ids, sequences = [], []
74103

75104
swiss_data = SwissProt.parse(
76105
open(
77106
os.path.join(
78-
self._go_extractor.raw_dir,
79-
self._go_extractor.raw_file_names_dict["SwissUniProt"],
107+
self._go_uniprot_extractor.raw_dir,
108+
self._go_uniprot_extractor.raw_file_names_dict["SwissUniProt"],
80109
),
81110
"r",
82111
)
@@ -88,7 +117,7 @@ def _parse_protein_data_for_pretraining(self) -> pd.DataFrame:
88117
continue
89118

90119
if not record.sequence:
91-
# Consider protein with only sequence representation and seq. length not greater than max seq. length
120+
# Consider protein with only sequence representation
92121
continue
93122

94123
if any(aa in AMBIGUOUS_AMINO_ACIDS for aa in record.sequence):
@@ -97,7 +126,7 @@ def _parse_protein_data_for_pretraining(self) -> pd.DataFrame:
97126

98127
has_valid_associated_go_label = False
99128
for cross_ref in record.cross_references:
100-
if cross_ref[0] == self._go_extractor._GO_DATA_INIT:
129+
if cross_ref[0] == self._go_uniprot_extractor._GO_DATA_INIT:
101130

102131
if len(cross_ref) <= 3:
103132
# No evidence code
@@ -146,8 +175,6 @@ def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, No
146175
with open(input_file_path, "rb") as input_file:
147176
df = pd.read_pickle(input_file)
148177
for row in df.values:
149-
# chebai.preprocessing.reader.DataReader only needs features, labels, ident, group
150-
# "group" set to None, by default as no such entity for this data
151178
yield dict(
152179
features=row[self._DATA_REPRESENTATION_IDX],
153180
ident=row[self._ID_IDX],
@@ -184,32 +211,51 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
184211

185212
df_go_data = pd.DataFrame(data_go)
186213
train_df_go, df_test = train_test_split(
187-
df_go_data, seed=self.dynamic_data_split_seed
214+
df_go_data,
215+
train_size=self.train_split,
216+
random_state=self.dynamic_data_split_seed,
188217
)
189218

190219
# Get all splits
191220
df_train, df_val = train_test_split(
192221
train_df_go,
193-
seed=self.dynamic_data_split_seed,
222+
train_size=self.train_split,
223+
random_state=self.dynamic_data_split_seed,
194224
)
195225

196226
return df_train, df_val, df_test
197227

198228
# ------------------------------ Phase: Raw Properties -----------------------------------
199229
@property
200230
def base_dir(self) -> str:
201-
return os.path.join(self._go_extractor.base_dir, "Pretraining")
231+
"""
232+
str: The base directory for pretraining data storage.
233+
"""
234+
return os.path.join(self._go_uniprot_extractor.base_dir, "Pretraining")
202235

203236
@property
204237
def raw_dir(self) -> str:
205238
"""Name of the directory where the raw data is stored."""
206-
return self._go_extractor.raw_dir
239+
return self._go_uniprot_extractor.raw_dir
207240

208241

209242
class SwissProteinPretrain(_ProteinPretrainingData):
243+
"""
244+
Data module for Swiss-Prot protein pretraining, inheriting from `_ProteinPretrainingData`.
245+
This class is specifically designed to handle data processing and loading for Swiss-Prot-based protein datasets.
246+
247+
Attributes:
248+
READER (Type): The data reader class used to load and process protein pretraining data.
249+
"""
210250

211251
READER = ProteinPretrainReader
212252

213253
@property
214254
def _name(self) -> str:
255+
"""
256+
The name identifier for this data module.
257+
258+
Returns:
259+
str: A string identifier, "SwissProteinPretrain", representing the name of this data module.
260+
"""
215261
return "SwissProteinPretrain"

0 commit comments

Comments
 (0)