Skip to content

Commit 9e28df0

Browse files
committed
Initial refactoring
1 parent 21904f8 commit 9e28df0

File tree

3 files changed

+347
-313
lines changed

3 files changed

+347
-313
lines changed

models.lua

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
require 'nn'
2+
3+
local models = {}
4+
models.basic1 = function()
5+
local m = nn.Sequential()
6+
local prl = nn.ParallelTable()
7+
prl:add(nn.Linear(2,2))
8+
prl:add(nn.Sequential():add(nn.Linear(2,1)):add(nn.Sigmoid()):add(nn.Linear(1,1)))
9+
m:add(prl)
10+
m:add(nn.JoinTable(2))
11+
m:add(nn.Linear(3,2))
12+
m:add(nn.ReLU(true))
13+
14+
local input = {torch.rand(2,2), torch.rand(2,2)}
15+
return m, input
16+
end
17+
models.basic2 = function()
18+
local m = nn.Sequential()
19+
m:add(nn.SpatialConvolution(1,1,3,3,1,1,1,1))
20+
-- m:add(nn.ReLU(true))
21+
-- m:add(nn.SpatialConvolution(1,1,3,3,1,1,1,1))
22+
-- m:add(nn.ReLU(true))
23+
m:add(nn.View(32*32))
24+
m:add(nn.Linear(32*32,100))
25+
-- m:add(nn.ReLU(true))
26+
-- m:add(nn.Linear(100,10))
27+
local input = torch.rand(1,1,32,32)
28+
return m, input
29+
end
30+
models.alexnet = function()
31+
-- taken from soumith's imagenet-multiGPU
32+
-- https://github.com/soumith/imagenet-multiGPU.torch/blob/master/models/alexnet.lua
33+
local features = nn.Concat(2)
34+
local fb1 = nn.Sequential() -- branch 1
35+
fb1:add(nn.SpatialConvolution(3,48,11,11,4,4,2,2)) -- 224 -> 55
36+
fb1:add(nn.ReLU(true))
37+
fb1:add(nn.SpatialMaxPooling(3,3,2,2)) -- 55 -> 27
38+
fb1:add(nn.SpatialConvolution(48,128,5,5,1,1,2,2)) -- 27 -> 27
39+
fb1:add(nn.ReLU(true))
40+
fb1:add(nn.SpatialMaxPooling(3,3,2,2)) -- 27 -> 13
41+
fb1:add(nn.SpatialConvolution(128,192,3,3,1,1,1,1)) -- 13 -> 13
42+
fb1:add(nn.ReLU(true))
43+
fb1:add(nn.SpatialConvolution(192,192,3,3,1,1,1,1)) -- 13 -> 13
44+
fb1:add(nn.ReLU(true))
45+
fb1:add(nn.SpatialConvolution(192,128,3,3,1,1,1,1)) -- 13 -> 13
46+
fb1:add(nn.ReLU(true))
47+
fb1:add(nn.SpatialMaxPooling(3,3,2,2)) -- 13 -> 6
48+
49+
local fb2 = fb1:clone() -- branch 2
50+
for k,v in ipairs(fb2:findModules('nn.SpatialConvolution')) do
51+
v:reset() -- reset branch 2's weights
52+
end
53+
54+
features:add(fb1)
55+
features:add(fb2)
56+
57+
-- 1.3. Create Classifier (fully connected layers)
58+
local classifier = nn.Sequential()
59+
classifier:add(nn.View(256*6*6))
60+
classifier:add(nn.Dropout(0.5))
61+
classifier:add(nn.Linear(256*6*6, 4096))
62+
classifier:add(nn.Threshold(0, 1e-6))
63+
classifier:add(nn.Dropout(0.5))
64+
classifier:add(nn.Linear(4096, 4096))
65+
classifier:add(nn.Threshold(0, 1e-6))
66+
classifier:add(nn.Linear(4096, 1000))
67+
classifier:add(nn.LogSoftMax())
68+
69+
-- 1.4. Combine 1.1 and 1.3 to produce final model
70+
local model = nn.Sequential():add(features):add(classifier)
71+
model.imageSize = 256
72+
model.imageCrop = 224
73+
74+
local input = torch.rand(1,3,model.imageCrop,model.imageCrop)
75+
76+
return model, input
77+
end
78+
79+
models.resnet = function(opt)
80+
81+
local Convolution = nn.SpatialConvolution
82+
local Avg = nn.SpatialAveragePooling
83+
local ReLU = nn.ReLU
84+
local Max = nn.SpatialMaxPooling
85+
local SBatchNorm = nn.SpatialBatchNormalization
86+
87+
local function createModel(opt)
88+
local depth = opt.depth
89+
local shortcutType = opt.shortcutType or 'B'
90+
local iChannels
91+
92+
-- The shortcut layer is either identity or 1x1 convolution
93+
local function shortcut(nInputPlane, nOutputPlane, stride)
94+
local useConv = shortcutType == 'C' or
95+
(shortcutType == 'B' and nInputPlane ~= nOutputPlane)
96+
if useConv then
97+
-- 1x1 convolution
98+
return nn.Sequential()
99+
:add(Convolution(nInputPlane, nOutputPlane, 1, 1, stride, stride))
100+
:add(SBatchNorm(nOutputPlane))
101+
elseif nInputPlane ~= nOutputPlane then
102+
-- Strided, zero-padded identity shortcut
103+
return nn.Sequential()
104+
:add(nn.SpatialAveragePooling(1, 1, stride, stride))
105+
:add(nn.Concat(2)
106+
:add(nn.Identity())
107+
:add(nn.MulConstant(0)))
108+
else
109+
return nn.Identity()
110+
end
111+
end
112+
113+
-- The basic residual layer block for 18 and 34 layer network, and the
114+
-- CIFAR networks
115+
local function basicblock(n, stride)
116+
local nInputPlane = iChannels
117+
iChannels = n
118+
119+
local s = nn.Sequential()
120+
s:add(Convolution(nInputPlane,n,3,3,stride,stride,1,1))
121+
s:add(SBatchNorm(n))
122+
s:add(ReLU(true))
123+
s:add(Convolution(n,n,3,3,1,1,1,1))
124+
s:add(SBatchNorm(n))
125+
126+
return nn.Sequential()
127+
:add(nn.ConcatTable()
128+
:add(s)
129+
:add(shortcut(nInputPlane, n, stride)))
130+
:add(nn.CAddTable(true))
131+
:add(ReLU(true))
132+
end
133+
134+
-- The bottleneck residual layer for 50, 101, and 152 layer networks
135+
local function bottleneck(n, stride)
136+
local nInputPlane = iChannels
137+
iChannels = n * 4
138+
139+
local s = nn.Sequential()
140+
s:add(Convolution(nInputPlane,n,1,1,1,1,0,0))
141+
s:add(SBatchNorm(n))
142+
s:add(ReLU(true))
143+
s:add(Convolution(n,n,3,3,stride,stride,1,1))
144+
s:add(SBatchNorm(n))
145+
s:add(ReLU(true))
146+
s:add(Convolution(n,n*4,1,1,1,1,0,0))
147+
s:add(SBatchNorm(n * 4))
148+
149+
return nn.Sequential()
150+
:add(nn.ConcatTable()
151+
:add(s)
152+
:add(shortcut(nInputPlane, n * 4, stride)))
153+
:add(nn.CAddTable(true))
154+
:add(ReLU(true))
155+
end
156+
157+
-- Creates count residual blocks with specified number of features
158+
local function layer(block, features, count, stride)
159+
local s = nn.Sequential()
160+
for i=1,count do
161+
s:add(block(features, i == 1 and stride or 1))
162+
end
163+
return s
164+
end
165+
166+
local model = nn.Sequential()
167+
local input
168+
if opt.dataset == 'imagenet' then
169+
-- Configurations for ResNet:
170+
-- num. residual blocks, num features, residual block function
171+
local cfg = {
172+
[18] = {{2, 2, 2, 2}, 512, basicblock},
173+
[34] = {{3, 4, 6, 3}, 512, basicblock},
174+
[50] = {{3, 4, 6, 3}, 2048, bottleneck},
175+
[101] = {{3, 4, 23, 3}, 2048, bottleneck},
176+
[152] = {{3, 8, 36, 3}, 2048, bottleneck},
177+
}
178+
179+
assert(cfg[depth], 'Invalid depth: ' .. tostring(depth))
180+
local def, nFeatures, block = table.unpack(cfg[depth])
181+
iChannels = 64
182+
--print(' | ResNet-' .. depth .. ' ImageNet')
183+
184+
-- The ResNet ImageNet model
185+
model:add(Convolution(3,64,7,7,2,2,3,3))
186+
model:add(SBatchNorm(64))
187+
model:add(ReLU(true))
188+
model:add(Max(3,3,2,2,1,1))
189+
model:add(layer(block, 64, def[1]))
190+
model:add(layer(block, 128, def[2], 2))
191+
model:add(layer(block, 256, def[3], 2))
192+
model:add(layer(block, 512, def[4], 2))
193+
model:add(Avg(7, 7, 1, 1))
194+
model:add(nn.View(nFeatures):setNumInputDims(3))
195+
model:add(nn.Linear(nFeatures, 1000))
196+
197+
input = torch.rand(1,3,224,224)
198+
elseif opt.dataset == 'cifar10' then
199+
-- Model type specifies number of layers for CIFAR-10 model
200+
assert((depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110, 1202')
201+
local n = (depth - 2) / 6
202+
iChannels = 16
203+
--print(' | ResNet-' .. depth .. ' CIFAR-10')
204+
205+
-- The ResNet CIFAR-10 model
206+
model:add(Convolution(3,16,3,3,1,1,1,1))
207+
model:add(SBatchNorm(16))
208+
model:add(ReLU(true))
209+
model:add(layer(basicblock, 16, n))
210+
model:add(layer(basicblock, 32, n, 2))
211+
model:add(layer(basicblock, 64, n, 2))
212+
model:add(Avg(8, 8, 1, 1))
213+
model:add(nn.View(64):setNumInputDims(3))
214+
model:add(nn.Linear(64, 10))
215+
input = torch.rand(1,3,32,32)
216+
else
217+
error('invalid dataset: ' .. opt.dataset)
218+
end
219+
220+
local function ConvInit(name)
221+
for k,v in pairs(model:findModules(name)) do
222+
local n = v.kW*v.kH*v.nOutputPlane
223+
v.weight:normal(0,math.sqrt(2/n))
224+
if false and cudnn.version >= 4000 then
225+
v.bias = nil
226+
v.gradBias = nil
227+
else
228+
v.bias:zero()
229+
end
230+
end
231+
end
232+
local function BNInit(name)
233+
for k,v in pairs(model:findModules(name)) do
234+
v.weight:fill(1)
235+
v.bias:zero()
236+
end
237+
end
238+
239+
ConvInit('cudnn.SpatialConvolution')
240+
ConvInit('nn.SpatialConvolution')
241+
BNInit('fbnn.SpatialBatchNormalization')
242+
BNInit('cudnn.SpatialBatchNormalization')
243+
BNInit('nn.SpatialBatchNormalization')
244+
for k,v in pairs(model:findModules('nn.Linear')) do
245+
v.bias:zero()
246+
end
247+
248+
if opt.cudnn == 'deterministic' then
249+
model:apply(function(m)
250+
if m.setMode then m:setMode(1,1,1) end
251+
end)
252+
end
253+
254+
return model, input
255+
end
256+
257+
return createModel(opt)
258+
end
259+
260+
return models

0 commit comments

Comments
 (0)