-
Notifications
You must be signed in to change notification settings - Fork 73
Expand file tree
/
Copy pathmodels.lua
More file actions
71 lines (61 loc) · 2.11 KB
/
models.lua
File metadata and controls
71 lines (61 loc) · 2.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
require 'nn';
require 'cunn';
require 'MultiCrossEntropyCriterion'
local models = {}
function models.cnnModel(k,c)
local k = k or 5
local c = c or 36
-- Will use "ceil" MaxPooling because we want to save as much
-- space as we can
local vgg = nn.Sequential()
vgg:add(nn.Reshape(1,50,200))
local backend_name = 'nn'
local backend
if backend_name == 'cudnn' then
require 'cudnn'
backend = cudnn
else
backend = nn
end
local MaxPooling = backend.SpatialMaxPooling
-- building block
local function ConvBNReLU(nInputPlane, nOutputPlane)
vgg:add(backend.SpatialConvolution(nInputPlane, nOutputPlane, 3,3, 1,1, 1,1))
vgg:add(nn.SpatialBatchNormalization(nOutputPlane,1e-3))
vgg:add(backend.ReLU(true))
return vgg
end
ConvBNReLU(1,64)--:add(nn.Dropout(0.3,nil,true))
ConvBNReLU(64,64)
vgg:add(MaxPooling(2,2,2,2):ceil())
ConvBNReLU(64,128)--:add(nn.Dropout(0.4,nil,true))
ConvBNReLU(128,128)
vgg:add(MaxPooling(2,2,2,2):ceil())
ConvBNReLU(128,256)--:add(nn.Dropout(0.4,nil,true))
ConvBNReLU(256,256)--:add(nn.Dropout(0.4,nil,true))
ConvBNReLU(256,256)
vgg:add(MaxPooling(2,2,2,2):ceil())
ConvBNReLU(256,512)--:add(nn.Dropout(0.4,nil,true))
ConvBNReLU(512,512)--:add(nn.Dropout(0.4,nil,true))
ConvBNReLU(512,512)
vgg:add(MaxPooling(2,2,2,2):ceil())
-- In the last block of convolutions the inputs are smaller than
-- the kernels and cudnn doesn't handle that, have to use cunn
backend = nn
ConvBNReLU(512,512)--:add(nn.Dropout(0.4,nil,true))
ConvBNReLU(512,512)--:add(nn.Dropout(0.4,nil,true))
ConvBNReLU(512,512)
vgg:add(MaxPooling(2,2,2,2):ceil())
vgg:add(nn.View(512*2*7))
classifier = nn.Sequential()
--classifier:add(nn.Dropout(0.5,nil,true))
classifier:add(nn.Linear(512*2*7,512))
classifier:add(nn.BatchNormalization(512))
classifier:add(nn.ReLU(true))
--classifier:add(nn.Dropout(0.5,nil,true))
classifier:add(nn.Linear(512,k*c))
vgg:add(classifier)
vgg:add(nn.Reshape(k,c))
return vgg,nn.MultiCrossEntropyCriterion()
end
return models