Skip to content

Commit c06b058

Browse files
authored
Merge pull request #19 from ChEB-AI/code_documentation
Code documentation
2 parents 290348d + 5c2914a commit c06b058

File tree

4 files changed

+345
-1
lines changed

4 files changed

+345
-1
lines changed

chebai/models/base.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,24 @@
1313

1414

1515
class ChebaiBaseNet(LightningModule):
16+
"""
17+
Base class for Chebai neural network models inheriting from PyTorch Lightning's LightningModule.
18+
19+
Args:
20+
criterion (torch.nn.Module, optional): The loss criterion for the model. Defaults to None.
21+
out_dim (int, optional): The output dimension of the model. Defaults to None.
22+
train_metrics (torch.nn.Module, optional): The metrics to be used during training. Defaults to None.
23+
val_metrics (torch.nn.Module, optional): The metrics to be used during validation. Defaults to None.
24+
test_metrics (torch.nn.Module, optional): The metrics to be used during testing. Defaults to None.
25+
pass_loss_kwargs (bool, optional): Whether to pass loss kwargs to the criterion. Defaults to True.
26+
optimizer_kwargs (typing.Dict, optional): Additional keyword arguments for the optimizer. Defaults to None.
27+
**kwargs: Additional keyword arguments.
28+
29+
Attributes:
30+
NAME (str): The name of the model.
31+
LOSS (torch.nn.Module): The loss function used by the model.
32+
"""
33+
1634
NAME = None
1735
LOSS = torch.nn.BCEWithLogitsLoss
1836

@@ -85,6 +103,20 @@ def predict_step(self, batch, batch_idx, **kwargs):
85103
return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False)
86104

87105
def _execute(self, batch, batch_idx, metrics, prefix="", log=True, sync_dist=False):
106+
"""
107+
Executes the model on a batch of data and returns the model output and predictions.
108+
109+
Args:
110+
batch (XYData): The input batch of data.
111+
batch_idx (int): The index of the current batch.
112+
metrics (dict): A dictionary of metrics to track.
113+
prefix (str, optional): A prefix to add to the metric names. Defaults to "".
114+
log (bool, optional): Whether to log the metrics. Defaults to True.
115+
sync_dist (bool, optional): Whether to synchronize distributed training. Defaults to False.
116+
117+
Returns:
118+
dict: A dictionary containing the processed data, labels, model_output, predictions, and loss (if applicable).
119+
"""
88120
assert isinstance(batch, XYData)
89121
batch = batch.to(self.device)
90122
data = self._process_batch(batch, batch_idx)
@@ -134,6 +166,17 @@ def _execute(self, batch, batch_idx, metrics, prefix="", log=True, sync_dist=Fal
134166
return d
135167

136168
def _log_metrics(self, prefix, metrics, batch_size):
169+
"""
170+
Logs the metrics for the given prefix.
171+
172+
Args:
173+
prefix (str): The prefix to be added to the metric names.
174+
metrics (dict): A dictionary containing the metrics to be logged.
175+
batch_size (int): The batch size used for logging.
176+
177+
Returns:
178+
None
179+
"""
137180
# don't use sync_dist=True if the metric is a torchmetrics-metric
138181
# (see https://github.com/Lightning-AI/pytorch-lightning/discussions/6501#discussioncomment-569757)
139182
for metric_name, metric in metrics.items():

chebai/models/electra.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,23 @@
2222

2323

2424
class ElectraPre(ChebaiBaseNet):
25+
"""
26+
ElectraPre class represents an Electra model for pre-training inherited from ChebaiBaseNet.
27+
28+
Args:
29+
config (dict): Configuration parameters for the Electra model.
30+
**kwargs: Additional keyword arguments (passed to parent class).
31+
32+
Attributes:
33+
NAME (str): Name of the ElectraPre model.
34+
generator_config (ElectraConfig): Configuration for the generator model.
35+
generator (ElectraForMaskedLM): Generator model for masked language modeling.
36+
discriminator_config (ElectraConfig): Configuration for the discriminator model.
37+
discriminator (ElectraForPreTraining): Discriminator model for pre-training.
38+
replace_p (float): Probability of replacing tokens during training.
39+
40+
"""
41+
2542
NAME = "ElectraPre"
2643

2744
def __init__(self, config=None, **kwargs):
@@ -34,12 +51,32 @@ def __init__(self, config=None, **kwargs):
3451

3552
@property
3653
def as_pretrained(self):
54+
"""
55+
Returns the discriminator model as a pre-trained model.
56+
57+
Returns:
58+
ElectraForPreTraining: The discriminator model.
59+
60+
"""
3761
return self.discriminator
3862

3963
def _process_labels_in_batch(self, batch):
4064
return None
4165

4266
def forward(self, data, **kwargs):
67+
"""
68+
Forward pass of the ElectraPre model.
69+
70+
Args:
71+
data (dict): Input data.
72+
**kwargs: Additional keyword arguments.
73+
74+
Returns:
75+
tuple: A tuple containing the raw generator output and discriminator output.
76+
The generator output is a tensor of shape (batch_size, max_seq_len, vocab_size).
77+
The discriminator output is a tensor of shape (batch_size, max_seq_len).
78+
79+
"""
4380
features = data["features"]
4481
features = features.long()
4582
self.batch_size = batch_size = features.shape[0]
@@ -96,9 +133,35 @@ def filter_dict(d, filter_key):
96133

97134

98135
class Electra(ChebaiBaseNet):
136+
"""
137+
Electra model implementation inherited from ChebaiBaseNet.
138+
139+
Args:
140+
config (dict, optional): Configuration parameters for the Electra model. Defaults to None.
141+
pretrained_checkpoint (str, optional): Path to the pretrained checkpoint file. Defaults to None.
142+
load_prefix (str, optional): Prefix to filter the state_dict keys from the pretrained checkpoint. Defaults to None.
143+
**kwargs: Additional keyword arguments.
144+
145+
Attributes:
146+
NAME (str): Name of the Electra model.
147+
148+
"""
149+
99150
NAME = "Electra"
100151

101152
def _process_batch(self, batch, batch_idx):
153+
"""
154+
Process a batch of data.
155+
156+
Args:
157+
batch (XYData): The input batch of data.
158+
batch_idx (int): The index of the batch (not used).
159+
160+
Returns:
161+
dict: A dictionary containing the processed batch, keys are `features`, `labels`, `model_kwargs`,
162+
`loss_kwargs` and `idents`.
163+
164+
"""
102165
model_kwargs = dict()
103166
loss_kwargs = batch.additional_fields["loss_kwargs"]
104167
if "lens" in batch.additional_fields["model_kwargs"]:
@@ -125,6 +188,13 @@ def _process_batch(self, batch, batch_idx):
125188

126189
@property
127190
def as_pretrained(self):
191+
"""
192+
Get the pretrained Electra model.
193+
194+
Returns:
195+
ElectraModel: The pretrained Electra model.
196+
197+
"""
128198
return self.electra.electra
129199

130200
def __init__(
@@ -149,6 +219,8 @@ def __init__(
149219
nn.Dropout(self.config.hidden_dropout_prob),
150220
nn.Linear(in_d, self.config.num_labels),
151221
)
222+
223+
# Load pretrained checkpoint if provided
152224
if pretrained_checkpoint:
153225
with open(pretrained_checkpoint, "rb") as fin:
154226
model_dict = torch.load(fin, map_location=self.device)
@@ -163,12 +235,36 @@ def __init__(
163235
self.electra = ElectraModel(config=self.config)
164236

165237
def _process_for_loss(self, model_output, labels, loss_kwargs):
238+
"""
239+
Process the model output for calculating the loss.
240+
241+
Args:
242+
model_output (dict): The output of the model.
243+
labels (Tensor): The target labels.
244+
loss_kwargs (dict): Additional loss arguments.
245+
246+
Returns:
247+
tuple: A tuple containing the processed model output, labels, and loss arguments.
248+
249+
"""
166250
kwargs_copy = dict(loss_kwargs)
167251
if labels is not None:
168252
labels = labels.float()
169253
return model_output["logits"], labels, kwargs_copy
170254

171255
def _get_prediction_and_labels(self, data, labels, model_output):
256+
"""
257+
Get the predictions and labels from the model output. Applies a sigmoid to the model output.
258+
259+
Args:
260+
data (dict): The input data.
261+
labels (Tensor): The target labels.
262+
model_output (dict): The output of the model.
263+
264+
Returns:
265+
tuple: A tuple containing the predictions and labels.
266+
267+
"""
172268
d = model_output["logits"]
173269
loss_kwargs = data.get("loss_kwargs", dict())
174270
if "non_null_labels" in loss_kwargs:
@@ -177,6 +273,16 @@ def _get_prediction_and_labels(self, data, labels, model_output):
177273
return torch.sigmoid(d), labels.int() if labels is not None else None
178274

179275
def forward(self, data, **kwargs):
276+
"""
277+
Forward pass of the Electra model.
278+
279+
Args:
280+
data (dict): The input data (expects a key `features`).
281+
**kwargs: Additional keyword arguments for `self.electra`.
282+
283+
Returns:
284+
dict: A dictionary containing the model output (logits and attentions).
285+
"""
180286
self.batch_size = data["features"].shape[0]
181287
try:
182288
inp = self.electra.embeddings.forward(data["features"].int())

chebai/preprocessing/datasets/base.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,48 @@
1414

1515

1616
class XYBaseDataModule(LightningDataModule):
17+
"""
18+
Base class for data modules.
19+
20+
This class provides a base implementation for loading and preprocessing datasets.
21+
It inherits from `LightningDataModule` and defines common properties and methods for data loading and processing.
22+
23+
Args:
24+
batch_size (int): The batch size for data loading. Default is 1.
25+
train_split (float): The ratio of training data to total data and of test data to (validation + test) data. Default is 0.85.
26+
reader_kwargs (dict): Additional keyword arguments to be passed to the data reader. Default is None.
27+
prediction_kind (str): The kind of prediction to be performed (only relevant for the predict_dataloader). Default is "test".
28+
data_limit (int): The maximum number of data samples to load. If set to None, the complete dataset will be used. Default is None.
29+
label_filter (int): The index of the label to filter. Default is None.
30+
balance_after_filter (float): The ratio of negative samples to positive samples after filtering. Default is None.
31+
num_workers (int): The number of worker processes for data loading. Default is 1.
32+
chebi_version (int): The version of ChEBI to use. Default is 200.
33+
inner_k_folds (int): The number of folds for inner cross-validation. Use -1 to disable inner cross-validation. Default is -1.
34+
fold_index (int): The index of the fold to use for training and validation. Default is None.
35+
base_dir (str): The base directory for storing processed and raw data. Default is None.
36+
**kwargs: Additional keyword arguments.
37+
38+
Attributes:
39+
READER (DataReader): The data reader class to use.
40+
reader (DataReader): An instance of the data reader class.
41+
train_split (float): The ratio of training data to total data.
42+
batch_size (int): The batch size for data loading.
43+
prediction_kind (str): The kind of prediction to be performed.
44+
data_limit (int): The maximum number of data samples to load.
45+
label_filter (int): The index of the label to filter.
46+
balance_after_filter (float): The ratio of negative samples to positive samples after filtering.
47+
num_workers (int): The number of worker processes for data loading.
48+
chebi_version (int): The version of ChEBI to use.
49+
inner_k_folds (int): The number of folds for inner cross-validation. If it is less than to, no cross-validation will be performed.
50+
fold_index (int): The index of the fold to use for training and validation (only relevant for cross-validation)
51+
_base_dir (str): The base directory for storing processed and raw data.
52+
raw_dir (str): The directory for storing raw data.
53+
processed_dir (str): The directory for storing processed data.
54+
fold_dir (str): The name of the directory where the folds from inner cross-validation are stored.
55+
_name (str): The name of the data module.
56+
57+
"""
58+
1759
READER = dr.DataReader
1860

1961
def __init__(
@@ -69,10 +111,12 @@ def __init__(
69111

70112
@property
71113
def identifier(self):
114+
"""Identifier for the dataset."""
72115
return (self.reader.name(),)
73116

74117
@property
75118
def full_identifier(self):
119+
"""Full identifier for the dataset."""
76120
return (self._name, *self.identifier)
77121

78122
@property
@@ -84,10 +128,12 @@ def base_dir(self):
84128

85129
@property
86130
def processed_dir(self):
131+
"""name of dir where the processed data is stored"""
87132
return os.path.join(self.base_dir, "processed", *self.identifier)
88133

89134
@property
90135
def raw_dir(self):
136+
"""name of dir where the raw data is stored"""
91137
return os.path.join(self.base_dir, "raw")
92138

93139
@property
@@ -100,10 +146,24 @@ def _name(self):
100146
raise NotImplementedError
101147

102148
def _filter_labels(self, row):
149+
"""filter labels based on label_filter"""
103150
row["labels"] = [row["labels"][self.label_filter]]
104151
return row
105152

106153
def load_processed_data(self, kind: str = None, filename: str = None) -> List:
154+
"""
155+
Load processed data from a file.
156+
157+
Args:
158+
kind (str, optional): The kind of dataset to load such as "train", "val" or "test". Defaults to None.
159+
filename (str, optional): The name of the file to load the dataset from. Defaults to None.
160+
161+
Returns:
162+
List: The loaded processed data.
163+
164+
Raises:
165+
ValueError: If both kind and filename are None.
166+
"""
107167
if kind is None and filename is None:
108168
raise ValueError(
109169
"Either kind or filename is required to load the correct dataset, both are None"
@@ -123,6 +183,17 @@ def load_processed_data(self, kind: str = None, filename: str = None) -> List:
123183
return torch.load(os.path.join(self.processed_dir, filename))
124184

125185
def dataloader(self, kind, **kwargs) -> DataLoader:
186+
"""
187+
Returns a DataLoader object for the specified kind (train, val or test) of data.
188+
189+
Args:
190+
kind (str): The kind indicates whether it is a train, val or test data to load.
191+
**kwargs: Additional keyword arguments.
192+
193+
Returns:
194+
DataLoader: A DataLoader object.
195+
196+
"""
126197
dataset = self.load_processed_data(kind)
127198
if "ids" in kwargs:
128199
ids = kwargs.pop("ids")
@@ -155,6 +226,15 @@ def dataloader(self, kind, **kwargs) -> DataLoader:
155226

156227
@staticmethod
157228
def _load_dict(input_file_path):
229+
"""
230+
Load data from a file and return a dictionary.
231+
232+
Args:
233+
input_file_path (str): The path to the input file.
234+
235+
Yields:
236+
dict: A dictionary containing the features and labels.
237+
"""
158238
with open(input_file_path, "r") as input_file:
159239
for row in input_file:
160240
smiles, labels = row.split("\t")
@@ -166,6 +246,15 @@ def _get_data_size(input_file_path):
166246
return sum(1 for _ in f)
167247

168248
def _load_data_from_file(self, path):
249+
"""
250+
Load data from a file and return a list of dictionaries.
251+
252+
Args:
253+
path (str): The path to the input file.
254+
255+
Returns:
256+
List: A list of dictionaries containing the features and labels.
257+
"""
169258
lines = self._get_data_size(path)
170259
print(f"Processing {lines} lines...")
171260
data = [

0 commit comments

Comments
 (0)