Skip to content

Commit db60278

Browse files
committed
Add epoch option
1 parent 0857db2 commit db60278

File tree

3 files changed

+8
-11
lines changed

3 files changed

+8
-11
lines changed

chebai/cli.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ def cli():
1717
@click.option("-g", "--group", default="default")
1818
@click.option("--version", default=None)
1919
@click.option("--load-prefix", default=None)
20+
@click.option("--epochs", default=100)
2021
@click.argument("args", nargs=-1)
21-
def train(experiment, batch_size, group, version, load_prefix, args):
22+
def train(experiment, batch_size, group, version, load_prefix, epochs, args):
2223
"""Run experiment identified by EXPERIMENT in batches of size BATCH_SIZE."""
2324
try:
2425
ex = experiments.EXPERIMENTS[experiment](batch_size, group, version=version)
@@ -27,7 +28,7 @@ def train(experiment, batch_size, group, version, load_prefix, args):
2728
"Experiment ID not found. The following are available:"
2829
+ ", ".join(experiments.EXPERIMENTS.keys())
2930
)
30-
ex.train(batch_size, *args, load_prefix=load_prefix)
31+
ex.train(batch_size, epochs, *args, load_prefix=load_prefix)
3132

3233

3334
@click.command()

chebai/experiments.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,11 @@ def model_kwargs(self, *args) -> Dict:
4141
def build_dataset(self, batch_size) -> datasets.XYBaseDataModule:
4242
raise NotImplementedError
4343

44-
def train(self, batch_size, *args, **kwargs):
44+
def train(self, batch_size, epochs, *args, **kwargs):
4545
self.MODEL.run(
4646
self.dataset,
4747
self.MODEL.NAME,
48+
epochs,
4849
loss=self.LOSS,
4950
model_kwargs=self.model_kwargs(*args),
5051
version=self.version,
@@ -93,7 +94,6 @@ def model_kwargs(self, *args) -> Dict:
9394
type_vocab_size=1,
9495
),
9596
),
96-
epochs=100,
9797
)
9898

9999
def build_dataset(self, batch_size) -> datasets.XYBaseDataModule:
@@ -117,7 +117,6 @@ def model_kwargs(self, *args) -> Dict:
117117
num_hidden_layers=6,
118118
type_vocab_size=1,
119119
),
120-
epochs=100,
121120
)
122121

123122
def build_dataset(self, batch_size) -> datasets.XYBaseDataModule:
@@ -142,7 +141,6 @@ def model_kwargs(self, *args) -> Dict:
142141
num_hidden_layers=6,
143142
type_vocab_size=1,
144143
),
145-
epochs=100,
146144
)
147145

148146
def build_dataset(self, batch_size) -> datasets.XYBaseDataModule:
@@ -174,8 +172,7 @@ def model_kwargs(self, *args) -> Dict:
174172
num_attention_heads=8,
175173
num_hidden_layers=6,
176174
type_vocab_size=1,
177-
),
178-
epochs=100,
175+
)
179176
)
180177

181178

@@ -327,7 +324,6 @@ def model_kwargs(self, *args) -> Dict:
327324
return dict(
328325
in_length=50,
329326
hidden_length=100,
330-
epochs=100,
331327
)
332328

333329
def build_dataset(self, batch_size) -> datasets.XYBaseDataModule:
@@ -345,7 +341,6 @@ def model_kwargs(self, *args) -> Dict:
345341
return dict(
346342
in_length=50,
347343
hidden_length=100,
348-
epochs=100,
349344
)
350345

351346
def build_dataset(self, batch_size) -> datasets.XYBaseDataModule:

chebai/models/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def run(
142142
cls,
143143
data,
144144
name,
145+
epochs,
145146
model_args: list = None,
146147
model_kwargs: dict = None,
147148
loss=torch.nn.BCELoss,
@@ -209,7 +210,7 @@ def run(
209210

210211
trainer = pl.Trainer(
211212
logger=tb_logger,
212-
min_epochs=model_kwargs.get("epochs", 100),
213+
min_epochs=epochs,
213214
callbacks=[best_checkpoint_callback, checkpoint_callback, es],
214215
replace_sampler_ddp=False,
215216
**trainer_kwargs

0 commit comments

Comments
 (0)