-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtraining.py
More file actions
48 lines (39 loc) · 1.39 KB
/
training.py
File metadata and controls
48 lines (39 loc) · 1.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
def datagenerator(text, labels, batchsize):
while True:
start = 0
end = batchsize
while start < len(text):
# load your text from numpy arrays or read from directory
x = text[start:end]
y = labels[start:end]
yield x, y
start += batchsize
end += batchsize
def train_model(model, x_train1, y_train1, x_test1, y_test1, batch_size, epochs, args):
reduce_lr = ReduceLROnPlateau(
monitor="val_loss", verbose=1, factor=0.5, patience=3, min_lr=0.0001
)
checkpointer = ModelCheckpoint(
filepath=f"models/weights_{args.suffix}.hdf5",
monitor="val_loss",
verbose=1,
mode="min",
save_best_only=False,
)
earlyStopper = EarlyStopping(
monitor="val_loss", min_delta=0, patience=args.patience, verbose=1, mode="min"
)
steps = x_train1.shape[0] // batch_size
val_steps = x_test1.shape[0] // batch_size
# Train the model
hist = model.fit_generator(
datagenerator(x_train1, y_train1, batch_size),
validation_data=datagenerator(x_test1, y_test1, batch_size),
callbacks=[earlyStopper, checkpointer, reduce_lr],
verbose=1,
steps_per_epoch=steps,
validation_steps=val_steps,
epochs=epochs,
)
return hist