-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvae.py
More file actions
82 lines (66 loc) · 2.64 KB
/
vae.py
File metadata and controls
82 lines (66 loc) · 2.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import glob
import numpy as np
from keras.layers import (
Input, Conv2D, Conv2DTranspose, Lambda, Dense, Flatten, Reshape)
from keras.models import Model
import keras.backend as K
input_dim = 64, 64, 3
latent_dim = 32
EPOCHS = 1
BATCH_SIZE = 32
def sampling(args):
z_mean, z_log_var = args
epsilon = K.random_normal(shape=(K.shape(z_mean)[0], 32),
mean=0., stddev=1.0)
return z_mean + K.exp(z_log_var / 2) * epsilon
# Encoder layers
inputs = Input(shape=input_dim) # 64x64x3
h = Conv2D(32, 4, strides=2, activation='relu')(inputs) # 31x31x32
h = Conv2D(64, 4, strides=2, activation='relu')(h) # 14x14x64
h = Conv2D(128, 4, strides=2, activation='relu')(h) # 6x6x128
h = Conv2D(256, 4, strides=2, activation='relu')(h) # 2x2x256
h = Flatten()(h) # 1024
z_mean = Dense(latent_dim, name='z_mean')(h) # 32
z_log_var = Dense(latent_dim, name='z_log_var')(h) # 32
z = Lambda(sampling, name='sampling')([z_mean, z_log_var])
# Decoder layers
decoder_h1 = Dense(1024, name='decoder_h1')
decoder_h2 = Reshape((1, 1, 1024), name='decoder_reshape')
decoder_h3 = Conv2DTranspose(128, 5, strides=2, activation='relu', name='decoder_h3')
decoder_h4 = Conv2DTranspose(64, 5, strides=2, activation='relu', name='decoder_h4')
decoder_h5 = Conv2DTranspose(32, 6, strides=2, activation='relu', name='decoder_h5')
decoder_outputs = Conv2DTranspose(3, 6, strides=2, activation='sigmoid', name='decoder_out')
# VAE Decoder
h = decoder_h1(z)
h = decoder_h2(h)
h = decoder_h3(h)
h = decoder_h4(h)
h = decoder_h5(h)
outputs = decoder_outputs(h)
# Decoder
_z = Input(shape=(latent_dim,), name='decoder_input')
_h = decoder_h1(_z)
_h = decoder_h2(_h)
_h = decoder_h3(_h)
_h = decoder_h4(_h)
_h = decoder_h5(_h)
_outputs = decoder_outputs(_h)
l2_loss = K.sum(K.square(inputs - outputs), axis=[1, 2, 3]) / 2
kl_loss = -0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var),
axis=-1)
vae = Model(inputs, outputs)
vae_loss = K.mean(l2_loss + kl_loss)
vae.add_loss(vae_loss)
vae.compile(optimizer='adam', loss=None)
encoder = Model(inputs, z)
decoder = Model(_z, _outputs)
batches = glob.glob('./data/frames-*.npy')
for batch in batches:
print('Loading %s...' % batch)
frames = np.load(batch)
n_episodes, n_frames, w, h, c = frames.shape
frames = np.reshape(frames, (n_episodes * n_frames, w, h, c)) / 255.
vae.fit(frames, shuffle=True, epochs=EPOCHS, batch_size=BATCH_SIZE)
vae.save('vae.h5')
encoder.save('encoder.h5')
decoder.save('decoder.h5')