-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathtrain.py
More file actions
65 lines (47 loc) · 2.18 KB
/
train.py
File metadata and controls
65 lines (47 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import tensorflow as tf
import numpy as np
import os
from model.wavenet import WaveNet
from model.module import CrossEntropyLoss
from dataset import get_train_data
import hparams
@tf.function
def train_step(model, x, mel_sp, y, loss_fn, optimizer):
with tf.GradientTape() as tape:
y_hat = model(x, mel_sp)
loss = loss_fn(y, y_hat)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
return loss
def train():
os.makedirs(hparams.result_dir + "weights/", exist_ok=True)
summary_writer = tf.summary.create_file_writer(hparams.result_dir)
wavenet = WaveNet(hparams.num_mels, hparams.upsample_scales)
loss_fn = CrossEntropyLoss(num_classes=256)
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(hparams.learning_rate,
decay_steps=hparams.exponential_decay_steps,
decay_rate=hparams.exponential_decay_rate)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule,
beta_1=hparams.beta_1)
if hparams.load_path is not None:
wavenet.load_weights(hparams.load_path)
step = np.load(hparams.result_dir + "weights/step.npy")
step = step
print(f"weights load: {hparams.load_path}")
else:
step = 0
for epoch in range(hparams.epoch):
train_data = get_train_data()
for x, mel_sp, y in train_data:
loss = train_step(wavenet, x, mel_sp, y, loss_fn, optimizer)
with summary_writer.as_default():
tf.summary.scalar('train/loss', loss, step=step)
step += 1
if epoch % hparams.save_interval == 0:
print(f'Step {step}, Loss: {loss}')
np.save(hparams.result_dir + f"weights/step.npy", np.array(step))
wavenet.save_weights(hparams.result_dir + f"weights/wavenet_{epoch:04}")
np.save(hparams.result_dir + f"weights/step.npy", np.array(step))
wavenet.save_weights(hparams.result_dir + f"weights/wavenet_{epoch:04}")
if __name__ == '__main__':
train()