forked from Moodstocks/gtsrb
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.lua
More file actions
128 lines (102 loc) · 4.18 KB
/
train.lua
File metadata and controls
128 lines (102 loc) · 4.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
----------------------------------------------------------------------
-- This script contains the main training function
--
-- Prior to using this script, we need to generate the datasets with
-- createDataSet.lua, and pre-process it using preProcess.lua.
--
-- It uses the ConvNet model described in models/MSmodel.lua
--
-- Required :
-- + model described in modesl/MSmodel.lua
-- + training set loaded with dataset.lua
--
-- Hugo Duthil
----------------------------------------------------------------------
require 'torch'
require 'optim'
require 'paths'
local time = sys.clock()
-- Parameters
batch_size = params.batch_size -- batch size
learning_rate = params.lr -- learning rate
save_model_it = params.save_model_iterations -- save model every 200 batch iterations
model_file = paths.dirname(paths.thisfile()).."/"..params.model_name -- save model under this path
record_f_it = params.save_f_iterations -- save f every 200 batch iterations
f_file = paths.dirname(paths.thisfile()).."/"..params.f_name -- save objective function graph under this path
use_3_channels = params.use_3_channels -- boolean, use 1 or 3 channels for computation
-- visualize ojective function with itorch
saved_f = {}
-- Main taining loop accross the entire dataset
-- The optimization method is a classic batch sgd
function train()
-- Classes
classes = {}
for i = 1, 43 do classes[i] = (i-1).."" end
-- this matrix records the current confusion across classes
local confusion = optim.ConfusionMatrix(classes)
confusion:zero()
-- get the learnable parameters of the model and the gradient of the cost function
-- with respect to these parameters
if model then
parameters, gradParameters= model:getParameters()
else
print("No model found, please load a model with models/MSmodel.lua")
end
print("Training network")
local m = 1
-- shuffle the training set
shuffle = torch.randperm(train_set:size())
for t =1, train_set:size(), batch_size do
-- progress bar
xlua.progress(t, train_set:size())
-- table containing batch examples
local batch_examples = {}
for i=t, math.min(t+batch_size-1, train_set:size()) do
batch_examples[#batch_examples+1] = train_set[shuffle[i]]
end
-- reset gradients
model:zeroGradParameters()
-- objective function
local f = 0
-- compute gradient for the batch
for i=1, #batch_examples do
local input
if use_3_channels then
input = batch_examples[i][1]
else
-- extract Y channel
input = batch_examples[i][1][{{1}, {}, {}}]
end
-- extract corresponding label
local label = batch_examples[i][2]
-- forward propagation of the input through the model
local output = model:forward(input)
-- accumulate f
f = f+criterion:forward(output, label) -- the criterion takes classes from 1 to 43
-- estimate df/dw
local df_d0 = criterion:backward(output, label)
model:backward(input, df_d0)
-- update confusion matrix
confusion:add(output, label)
end
-- normalize gradients and f(X)
gradParameters:div(#batch_examples)
f = f/#batch_examples
-- save model every save_model_it iterations
if m%save_model_it == 0 and save_model_it ~= 0 then
torch.save(model_file,model)
end
-- record f every record_f_it iterations
if m%record_f_it == 0 and record_f_i ~= 0 then
table.insert(saved_f, f)
torch.save(f_file,saved_f)
end
-- update model parameters
model:updateParameters(learning_rate)
m =m+1
end
-- print and save confusion matrix
print(confusion)
torch.save("saves/confusion.t7", confusion)
confusion:zero()
end