Skip to content

Commit 5c2914a

Browse files
author
sfluegel
committed
review docstrings
1 parent a8cc1dc commit 5c2914a

File tree

4 files changed

+28
-26
lines changed

4 files changed

+28
-26
lines changed

chebai/models/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _execute(self, batch, batch_idx, metrics, prefix="", log=True, sync_dist=Fal
115115
sync_dist (bool, optional): Whether to synchronize distributed training. Defaults to False.
116116
117117
Returns:
118-
dict: A dictionary containing the processed data, labels, model output, predictions, and loss (if applicable).
118+
dict: A dictionary containing the processed data, labels, model_output, predictions, and loss (if applicable).
119119
"""
120120
assert isinstance(batch, XYData)
121121
batch = batch.to(self.device)

chebai/models/electra.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323

2424
class ElectraPre(ChebaiBaseNet):
2525
"""
26-
ElectraPre class represents a pre-trained Electra model for pre-training inherited from ChebaiBaseNet.
26+
ElectraPre class represents an Electra model for pre-training inherited from ChebaiBaseNet.
2727
2828
Args:
2929
config (dict): Configuration parameters for the Electra model.
30-
**kwargs: Additional keyword arguments.
30+
**kwargs: Additional keyword arguments (passed to parent class).
3131
3232
Attributes:
3333
NAME (str): Name of the ElectraPre model.
@@ -155,10 +155,11 @@ def _process_batch(self, batch, batch_idx):
155155
156156
Args:
157157
batch (XYData): The input batch of data.
158-
batch_idx (int): The index of the batch.
158+
batch_idx (int): The index of the batch (not used).
159159
160160
Returns:
161-
dict: A dictionary containing the processed batch.
161+
dict: A dictionary containing the processed batch, keys are `features`, `labels`, `model_kwargs`,
162+
`loss_kwargs` and `idents`.
162163
163164
"""
164165
model_kwargs = dict()
@@ -253,7 +254,7 @@ def _process_for_loss(self, model_output, labels, loss_kwargs):
253254

254255
def _get_prediction_and_labels(self, data, labels, model_output):
255256
"""
256-
Get the predictions and labels from the model output.
257+
Get the predictions and labels from the model output. Applies a sigmoid to the model output.
257258
258259
Args:
259260
data (dict): The input data.
@@ -276,12 +277,11 @@ def forward(self, data, **kwargs):
276277
Forward pass of the Electra model.
277278
278279
Args:
279-
data (dict): The input data.
280-
**kwargs: Additional keyword arguments.
280+
data (dict): The input data (expects a key `features`).
281+
**kwargs: Additional keyword arguments for `self.electra`.
281282
282283
Returns:
283-
dict: A dictionary containing the model output.
284-
284+
dict: A dictionary containing the model output (logits and attentions).
285285
"""
286286
self.batch_size = data["features"].shape[0]
287287
try:

chebai/preprocessing/datasets/base.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ class XYBaseDataModule(LightningDataModule):
2222
2323
Args:
2424
batch_size (int): The batch size for data loading. Default is 1.
25-
train_split (float): The ratio of training data to total data. Default is 0.85.
25+
train_split (float): The ratio of training data to total data and of test data to (validation + test) data. Default is 0.85.
2626
reader_kwargs (dict): Additional keyword arguments to be passed to the data reader. Default is None.
27-
prediction_kind (str): The kind of prediction to be performed. Default is "test".
28-
data_limit (int): The maximum number of data samples to load. Default is None.
27+
prediction_kind (str): The kind of prediction to be performed (only relevant for the predict_dataloader). Default is "test".
28+
data_limit (int): The maximum number of data samples to load. If set to None, the complete dataset will be used. Default is None.
2929
label_filter (int): The index of the label to filter. Default is None.
3030
balance_after_filter (float): The ratio of negative samples to positive samples after filtering. Default is None.
3131
num_workers (int): The number of worker processes for data loading. Default is 1.
32-
chebi_version (int): The version of ChEBI database to use. Default is 200.
32+
chebi_version (int): The version of ChEBI to use. Default is 200.
3333
inner_k_folds (int): The number of folds for inner cross-validation. Use -1 to disable inner cross-validation. Default is -1.
3434
fold_index (int): The index of the fold to use for training and validation. Default is None.
3535
base_dir (str): The base directory for storing processed and raw data. Default is None.
@@ -45,9 +45,9 @@ class XYBaseDataModule(LightningDataModule):
4545
label_filter (int): The index of the label to filter.
4646
balance_after_filter (float): The ratio of negative samples to positive samples after filtering.
4747
num_workers (int): The number of worker processes for data loading.
48-
chebi_version (int): The version of ChEBI database to use.
49-
inner_k_folds (int): The number of folds for inner cross-validation.
50-
fold_index (int): The index of the fold to use for training and validation.
48+
chebi_version (int): The version of ChEBI to use.
49+
inner_k_folds (int): The number of folds for inner cross-validation. If it is less than to, no cross-validation will be performed.
50+
fold_index (int): The index of the fold to use for training and validation (only relevant for cross-validation)
5151
_base_dir (str): The base directory for storing processed and raw data.
5252
raw_dir (str): The directory for storing raw data.
5353
processed_dir (str): The directory for storing processed data.

chebai/preprocessing/datasets/chebi.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,11 @@ class _ChEBIDataExtractor(XYBaseDataModule, ABC):
114114
A class for extracting and processing data from the ChEBI dataset.
115115
116116
Args:
117-
chebi_version_train (int, optional): The version of ChEBI to use for training and validation. Defaults to None.
118-
single_class (int, optional): The ID of the single class to predict. Defaults to None.
119-
**kwargs: Additional keyword arguments.
117+
chebi_version_train (int, optional): The version of ChEBI to use for training and validation. If not set,
118+
chebi_version will be used for training, validation and test. Defaults to None.
119+
single_class (int, optional): The ID of the single class to predict. If not set, all available labels will be
120+
predicted. Defaults to None.
121+
**kwargs: Additional keyword arguments (passed to XYBaseDataModule).
120122
121123
Attributes:
122124
single_class (int): The ID of the single class to predict.
@@ -135,10 +137,10 @@ def __init__(
135137

136138
def extract_class_hierarchy(self, chebi_path):
137139
"""
138-
Extract the class hierarchy from the ChEBI dataset.
140+
Extracts the class hierarchy from the ChEBI ontology.
139141
140142
Args:
141-
chebi_path (str): The path to the ChEBI dataset.
143+
chebi_path (str): The path to the ChEBI ontology.
142144
143145
Returns:
144146
nx.DiGraph: The class hierarchy.
@@ -200,7 +202,7 @@ def _load_dict(self, input_file_path):
200202
input_file_path (str): The path to the file.
201203
202204
Yields:
203-
dict: The dictionary.
205+
dict: The dictionary, keys are `features`, `labels` and `ident`.
204206
"""
205207
with open(input_file_path, "rb") as input_file:
206208
df = pickle.load(input_file)
@@ -434,7 +436,7 @@ def prepare_data(self, *args, **kwargs):
434436
Prepares the data for the Chebi dataset.
435437
436438
This method checks for the presence of raw data in the specified directory.
437-
If the raw data is missing, it fetches the data and creates the test set.
439+
If the raw data is missing, it fetches the ontology and creates a test test set.
438440
If the test set already exists, it loads it from the file.
439441
Then, it creates the train/validation split based on the test set.
440442
@@ -532,8 +534,8 @@ def select_classes(self, g, split_name, *args, **kwargs):
532534
Args:
533535
g (Graph): The graph representing the dataset.
534536
split_name (str): The name of the split.
535-
*args: Additional arguments.
536-
**kwargs: Additional keyword arguments.
537+
*args: Additional arguments (not used).
538+
**kwargs: Additional keyword arguments (not used).
537539
538540
Returns:
539541
list: The list of selected classes.

0 commit comments

Comments
 (0)