Skip to content

Commit db82c61

Browse files
authored
Update train.py
1 parent faf6326 commit db82c61

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

src/train.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,9 @@ def train(args):
5757
else:
5858
device = torch.device("cpu")
5959

60-
down_lr = 50
6160
if "mahabharata" in args.data_dir:
6261
args.batch_size = 128
63-
args.num_epochs = 50
64-
down_lr = 30
62+
args.num_epochs = 100
6563

6664
data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
6765

@@ -124,7 +122,7 @@ def save():
124122
f"Epoch: {e}. Truth: " + "".join(seq_gt[0]).replace("\n", "\\n ")
125123
)
126124

127-
if e == down_lr:
125+
if e == 50:
128126
tqdm.write("reducing learning rate.")
129127
for g in optimizer.param_groups:
130128
g["lr"] = g["lr"] / 2.0

0 commit comments

Comments
 (0)