We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent faf6326 commit db82c61Copy full SHA for db82c61
src/train.py
@@ -57,11 +57,9 @@ def train(args):
57
else:
58
device = torch.device("cpu")
59
60
- down_lr = 50
61
if "mahabharata" in args.data_dir:
62
args.batch_size = 128
63
- args.num_epochs = 50
64
- down_lr = 30
+ args.num_epochs = 100
65
66
data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
67
@@ -124,7 +122,7 @@ def save():
124
122
f"Epoch: {e}. Truth: " + "".join(seq_gt[0]).replace("\n", "\\n ")
125
123
)
126
127
- if e == down_lr:
+ if e == 50:
128
tqdm.write("reducing learning rate.")
129
for g in optimizer.param_groups:
130
g["lr"] = g["lr"] / 2.0
0 commit comments