2222
2323
2424class ElectraPre (ChebaiBaseNet ):
25+ """
26+ ElectraPre class represents an Electra model for pre-training inherited from ChebaiBaseNet.
27+
28+ Args:
29+ config (dict): Configuration parameters for the Electra model.
30+ **kwargs: Additional keyword arguments (passed to parent class).
31+
32+ Attributes:
33+ NAME (str): Name of the ElectraPre model.
34+ generator_config (ElectraConfig): Configuration for the generator model.
35+ generator (ElectraForMaskedLM): Generator model for masked language modeling.
36+ discriminator_config (ElectraConfig): Configuration for the discriminator model.
37+ discriminator (ElectraForPreTraining): Discriminator model for pre-training.
38+ replace_p (float): Probability of replacing tokens during training.
39+
40+ """
41+
2542 NAME = "ElectraPre"
2643
2744 def __init__ (self , config = None , ** kwargs ):
@@ -34,12 +51,32 @@ def __init__(self, config=None, **kwargs):
3451
3552 @property
3653 def as_pretrained (self ):
54+ """
55+ Returns the discriminator model as a pre-trained model.
56+
57+ Returns:
58+ ElectraForPreTraining: The discriminator model.
59+
60+ """
3761 return self .discriminator
3862
3963 def _process_labels_in_batch (self , batch ):
4064 return None
4165
4266 def forward (self , data , ** kwargs ):
67+ """
68+ Forward pass of the ElectraPre model.
69+
70+ Args:
71+ data (dict): Input data.
72+ **kwargs: Additional keyword arguments.
73+
74+ Returns:
75+ tuple: A tuple containing the raw generator output and discriminator output.
76+ The generator output is a tensor of shape (batch_size, max_seq_len, vocab_size).
77+ The discriminator output is a tensor of shape (batch_size, max_seq_len).
78+
79+ """
4380 features = data ["features" ]
4481 features = features .long ()
4582 self .batch_size = batch_size = features .shape [0 ]
@@ -96,9 +133,35 @@ def filter_dict(d, filter_key):
96133
97134
98135class Electra (ChebaiBaseNet ):
136+ """
137+ Electra model implementation inherited from ChebaiBaseNet.
138+
139+ Args:
140+ config (dict, optional): Configuration parameters for the Electra model. Defaults to None.
141+ pretrained_checkpoint (str, optional): Path to the pretrained checkpoint file. Defaults to None.
142+ load_prefix (str, optional): Prefix to filter the state_dict keys from the pretrained checkpoint. Defaults to None.
143+ **kwargs: Additional keyword arguments.
144+
145+ Attributes:
146+ NAME (str): Name of the Electra model.
147+
148+ """
149+
99150 NAME = "Electra"
100151
101152 def _process_batch (self , batch , batch_idx ):
153+ """
154+ Process a batch of data.
155+
156+ Args:
157+ batch (XYData): The input batch of data.
158+ batch_idx (int): The index of the batch (not used).
159+
160+ Returns:
161+ dict: A dictionary containing the processed batch, keys are `features`, `labels`, `model_kwargs`,
162+ `loss_kwargs` and `idents`.
163+
164+ """
102165 model_kwargs = dict ()
103166 loss_kwargs = batch .additional_fields ["loss_kwargs" ]
104167 if "lens" in batch .additional_fields ["model_kwargs" ]:
@@ -125,6 +188,13 @@ def _process_batch(self, batch, batch_idx):
125188
126189 @property
127190 def as_pretrained (self ):
191+ """
192+ Get the pretrained Electra model.
193+
194+ Returns:
195+ ElectraModel: The pretrained Electra model.
196+
197+ """
128198 return self .electra .electra
129199
130200 def __init__ (
@@ -149,6 +219,8 @@ def __init__(
149219 nn .Dropout (self .config .hidden_dropout_prob ),
150220 nn .Linear (in_d , self .config .num_labels ),
151221 )
222+
223+ # Load pretrained checkpoint if provided
152224 if pretrained_checkpoint :
153225 with open (pretrained_checkpoint , "rb" ) as fin :
154226 model_dict = torch .load (fin , map_location = self .device )
@@ -163,12 +235,36 @@ def __init__(
163235 self .electra = ElectraModel (config = self .config )
164236
165237 def _process_for_loss (self , model_output , labels , loss_kwargs ):
238+ """
239+ Process the model output for calculating the loss.
240+
241+ Args:
242+ model_output (dict): The output of the model.
243+ labels (Tensor): The target labels.
244+ loss_kwargs (dict): Additional loss arguments.
245+
246+ Returns:
247+ tuple: A tuple containing the processed model output, labels, and loss arguments.
248+
249+ """
166250 kwargs_copy = dict (loss_kwargs )
167251 if labels is not None :
168252 labels = labels .float ()
169253 return model_output ["logits" ], labels , kwargs_copy
170254
171255 def _get_prediction_and_labels (self , data , labels , model_output ):
256+ """
257+ Get the predictions and labels from the model output. Applies a sigmoid to the model output.
258+
259+ Args:
260+ data (dict): The input data.
261+ labels (Tensor): The target labels.
262+ model_output (dict): The output of the model.
263+
264+ Returns:
265+ tuple: A tuple containing the predictions and labels.
266+
267+ """
172268 d = model_output ["logits" ]
173269 loss_kwargs = data .get ("loss_kwargs" , dict ())
174270 if "non_null_labels" in loss_kwargs :
@@ -177,6 +273,16 @@ def _get_prediction_and_labels(self, data, labels, model_output):
177273 return torch .sigmoid (d ), labels .int () if labels is not None else None
178274
179275 def forward (self , data , ** kwargs ):
276+ """
277+ Forward pass of the Electra model.
278+
279+ Args:
280+ data (dict): The input data (expects a key `features`).
281+ **kwargs: Additional keyword arguments for `self.electra`.
282+
283+ Returns:
284+ dict: A dictionary containing the model output (logits and attentions).
285+ """
180286 self .batch_size = data ["features" ].shape [0 ]
181287 try :
182288 inp = self .electra .embeddings .forward (data ["features" ].int ())
0 commit comments