forked from commaai/research
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_steering_model.py
More file actions
executable file
·77 lines (64 loc) · 2.58 KB
/
train_steering_model.py
File metadata and controls
executable file
·77 lines (64 loc) · 2.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/usr/bin/env python
"""
Steering angle prediction model
"""
import os
import argparse
import json
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, Lambda, ELU
from keras.layers.convolutional import Convolution2D
from server import client_generator
def gen(hwm, host, port):
for tup in client_generator(hwm=hwm, host=host, port=port):
X, Y, _ = tup
Y = Y[:, -1]
if X.shape[1] == 1: # no temporal context
X = X[:, -1]
yield X, Y
def get_model(time_len=1):
ch, row, col = 3, 160, 320 # camera foramt
model = Sequential()
model.add(Lambda(lambda x: x/127.5 - 1.,
input_shape=(ch, row, col),
output_shape=(ch, row, col)))
model.add(Convolution2D(16, 8, 8, subsample=(4, 4), border_mode="same"))
model.add(ELU())
model.add(Convolution2D(32, 5, 5, subsample=(2, 2), border_mode="same"))
model.add(ELU())
model.add(Convolution2D(64, 5, 5, subsample=(2, 2), border_mode="same"))
model.add(Flatten())
model.add(Dropout(.2))
model.add(ELU())
model.add(Dense(512))
model.add(Dropout(.5))
model.add(ELU())
model.add(Dense(1))
model.compile(optimizer="adam", loss="mse")
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Steering angle model trainer')
parser.add_argument('--host', type=str, default="localhost", help='Data server ip address.')
parser.add_argument('--port', type=int, default=5557, help='Port of server.')
parser.add_argument('--val_port', type=int, default=5556, help='Port of server for validation dataset.')
parser.add_argument('--batch', type=int, default=64, help='Batch size.')
parser.add_argument('--epoch', type=int, default=200, help='Number of epochs.')
parser.add_argument('--epochsize', type=int, default=10000, help='How many frames per epoch.')
parser.add_argument('--skipvalidate', dest='skipvalidate', action='store_true', help='Multiple path output.')
parser.set_defaults(skipvalidate=False)
parser.set_defaults(loadweights=False)
args = parser.parse_args()
model = get_model()
model.fit_generator(
gen(20, args.host, port=args.port),
samples_per_epoch=10000,
nb_epoch=args.epoch,
validation_data=gen(20, args.host, port=args.val_port),
nb_val_samples=1000
)
print("Saving model weights and configuration file.")
if not os.path.exists("./outputs/steering_model"):
os.makedirs("./outputs/steering_model")
model.save_weights("./outputs/steering_model/steering_angle.keras", True)
with open('./outputs/steering_model/steering_angle.json', 'w') as outfile:
json.dump(model.to_json(), outfile)