@@ -41,10 +41,11 @@ def model_kwargs(self, *args) -> Dict:
4141 def build_dataset (self , batch_size ) -> datasets .XYBaseDataModule :
4242 raise NotImplementedError
4343
44- def train (self , batch_size , * args , ** kwargs ):
44+ def train (self , batch_size , epochs , * args , ** kwargs ):
4545 self .MODEL .run (
4646 self .dataset ,
4747 self .MODEL .NAME ,
48+ epochs ,
4849 loss = self .LOSS ,
4950 model_kwargs = self .model_kwargs (* args ),
5051 version = self .version ,
@@ -93,7 +94,6 @@ def model_kwargs(self, *args) -> Dict:
9394 type_vocab_size = 1 ,
9495 ),
9596 ),
96- epochs = 100 ,
9797 )
9898
9999 def build_dataset (self , batch_size ) -> datasets .XYBaseDataModule :
@@ -117,7 +117,6 @@ def model_kwargs(self, *args) -> Dict:
117117 num_hidden_layers = 6 ,
118118 type_vocab_size = 1 ,
119119 ),
120- epochs = 100 ,
121120 )
122121
123122 def build_dataset (self , batch_size ) -> datasets .XYBaseDataModule :
@@ -142,7 +141,6 @@ def model_kwargs(self, *args) -> Dict:
142141 num_hidden_layers = 6 ,
143142 type_vocab_size = 1 ,
144143 ),
145- epochs = 100 ,
146144 )
147145
148146 def build_dataset (self , batch_size ) -> datasets .XYBaseDataModule :
@@ -174,8 +172,7 @@ def model_kwargs(self, *args) -> Dict:
174172 num_attention_heads = 8 ,
175173 num_hidden_layers = 6 ,
176174 type_vocab_size = 1 ,
177- ),
178- epochs = 100 ,
175+ )
179176 )
180177
181178
@@ -327,7 +324,6 @@ def model_kwargs(self, *args) -> Dict:
327324 return dict (
328325 in_length = 50 ,
329326 hidden_length = 100 ,
330- epochs = 100 ,
331327 )
332328
333329 def build_dataset (self , batch_size ) -> datasets .XYBaseDataModule :
@@ -345,7 +341,6 @@ def model_kwargs(self, *args) -> Dict:
345341 return dict (
346342 in_length = 50 ,
347343 hidden_length = 100 ,
348- epochs = 100 ,
349344 )
350345
351346 def build_dataset (self , batch_size ) -> datasets .XYBaseDataModule :
0 commit comments