|
| 1 | +require 'nn' |
| 2 | + |
| 3 | +local models = {} |
| 4 | +models.basic1 = function() |
| 5 | + local m = nn.Sequential() |
| 6 | + local prl = nn.ParallelTable() |
| 7 | + prl:add(nn.Linear(2,2)) |
| 8 | + prl:add(nn.Sequential():add(nn.Linear(2,1)):add(nn.Sigmoid()):add(nn.Linear(1,1))) |
| 9 | + m:add(prl) |
| 10 | + m:add(nn.JoinTable(2)) |
| 11 | + m:add(nn.Linear(3,2)) |
| 12 | + m:add(nn.ReLU(true)) |
| 13 | + |
| 14 | + local input = {torch.rand(2,2), torch.rand(2,2)} |
| 15 | + return m, input |
| 16 | +end |
| 17 | +models.basic2 = function() |
| 18 | + local m = nn.Sequential() |
| 19 | + m:add(nn.SpatialConvolution(1,1,3,3,1,1,1,1)) |
| 20 | + -- m:add(nn.ReLU(true)) |
| 21 | + -- m:add(nn.SpatialConvolution(1,1,3,3,1,1,1,1)) |
| 22 | + -- m:add(nn.ReLU(true)) |
| 23 | + m:add(nn.View(32*32)) |
| 24 | + m:add(nn.Linear(32*32,100)) |
| 25 | + -- m:add(nn.ReLU(true)) |
| 26 | + -- m:add(nn.Linear(100,10)) |
| 27 | + local input = torch.rand(1,1,32,32) |
| 28 | + return m, input |
| 29 | +end |
| 30 | +models.alexnet = function() |
| 31 | + -- taken from soumith's imagenet-multiGPU |
| 32 | + -- https://github.com/soumith/imagenet-multiGPU.torch/blob/master/models/alexnet.lua |
| 33 | + local features = nn.Concat(2) |
| 34 | + local fb1 = nn.Sequential() -- branch 1 |
| 35 | + fb1:add(nn.SpatialConvolution(3,48,11,11,4,4,2,2)) -- 224 -> 55 |
| 36 | + fb1:add(nn.ReLU(true)) |
| 37 | + fb1:add(nn.SpatialMaxPooling(3,3,2,2)) -- 55 -> 27 |
| 38 | + fb1:add(nn.SpatialConvolution(48,128,5,5,1,1,2,2)) -- 27 -> 27 |
| 39 | + fb1:add(nn.ReLU(true)) |
| 40 | + fb1:add(nn.SpatialMaxPooling(3,3,2,2)) -- 27 -> 13 |
| 41 | + fb1:add(nn.SpatialConvolution(128,192,3,3,1,1,1,1)) -- 13 -> 13 |
| 42 | + fb1:add(nn.ReLU(true)) |
| 43 | + fb1:add(nn.SpatialConvolution(192,192,3,3,1,1,1,1)) -- 13 -> 13 |
| 44 | + fb1:add(nn.ReLU(true)) |
| 45 | + fb1:add(nn.SpatialConvolution(192,128,3,3,1,1,1,1)) -- 13 -> 13 |
| 46 | + fb1:add(nn.ReLU(true)) |
| 47 | + fb1:add(nn.SpatialMaxPooling(3,3,2,2)) -- 13 -> 6 |
| 48 | + |
| 49 | + local fb2 = fb1:clone() -- branch 2 |
| 50 | + for k,v in ipairs(fb2:findModules('nn.SpatialConvolution')) do |
| 51 | + v:reset() -- reset branch 2's weights |
| 52 | + end |
| 53 | + |
| 54 | + features:add(fb1) |
| 55 | + features:add(fb2) |
| 56 | + |
| 57 | + -- 1.3. Create Classifier (fully connected layers) |
| 58 | + local classifier = nn.Sequential() |
| 59 | + classifier:add(nn.View(256*6*6)) |
| 60 | + classifier:add(nn.Dropout(0.5)) |
| 61 | + classifier:add(nn.Linear(256*6*6, 4096)) |
| 62 | + classifier:add(nn.Threshold(0, 1e-6)) |
| 63 | + classifier:add(nn.Dropout(0.5)) |
| 64 | + classifier:add(nn.Linear(4096, 4096)) |
| 65 | + classifier:add(nn.Threshold(0, 1e-6)) |
| 66 | + classifier:add(nn.Linear(4096, 1000)) |
| 67 | + classifier:add(nn.LogSoftMax()) |
| 68 | + |
| 69 | + -- 1.4. Combine 1.1 and 1.3 to produce final model |
| 70 | + local model = nn.Sequential():add(features):add(classifier) |
| 71 | + model.imageSize = 256 |
| 72 | + model.imageCrop = 224 |
| 73 | + |
| 74 | + local input = torch.rand(1,3,model.imageCrop,model.imageCrop) |
| 75 | + |
| 76 | + return model, input |
| 77 | +end |
| 78 | + |
| 79 | +models.resnet = function(opt) |
| 80 | + |
| 81 | + local Convolution = nn.SpatialConvolution |
| 82 | + local Avg = nn.SpatialAveragePooling |
| 83 | + local ReLU = nn.ReLU |
| 84 | + local Max = nn.SpatialMaxPooling |
| 85 | + local SBatchNorm = nn.SpatialBatchNormalization |
| 86 | + |
| 87 | + local function createModel(opt) |
| 88 | + local depth = opt.depth |
| 89 | + local shortcutType = opt.shortcutType or 'B' |
| 90 | + local iChannels |
| 91 | + |
| 92 | + -- The shortcut layer is either identity or 1x1 convolution |
| 93 | + local function shortcut(nInputPlane, nOutputPlane, stride) |
| 94 | + local useConv = shortcutType == 'C' or |
| 95 | + (shortcutType == 'B' and nInputPlane ~= nOutputPlane) |
| 96 | + if useConv then |
| 97 | + -- 1x1 convolution |
| 98 | + return nn.Sequential() |
| 99 | + :add(Convolution(nInputPlane, nOutputPlane, 1, 1, stride, stride)) |
| 100 | + :add(SBatchNorm(nOutputPlane)) |
| 101 | + elseif nInputPlane ~= nOutputPlane then |
| 102 | + -- Strided, zero-padded identity shortcut |
| 103 | + return nn.Sequential() |
| 104 | + :add(nn.SpatialAveragePooling(1, 1, stride, stride)) |
| 105 | + :add(nn.Concat(2) |
| 106 | + :add(nn.Identity()) |
| 107 | + :add(nn.MulConstant(0))) |
| 108 | + else |
| 109 | + return nn.Identity() |
| 110 | + end |
| 111 | + end |
| 112 | + |
| 113 | + -- The basic residual layer block for 18 and 34 layer network, and the |
| 114 | + -- CIFAR networks |
| 115 | + local function basicblock(n, stride) |
| 116 | + local nInputPlane = iChannels |
| 117 | + iChannels = n |
| 118 | + |
| 119 | + local s = nn.Sequential() |
| 120 | + s:add(Convolution(nInputPlane,n,3,3,stride,stride,1,1)) |
| 121 | + s:add(SBatchNorm(n)) |
| 122 | + s:add(ReLU(true)) |
| 123 | + s:add(Convolution(n,n,3,3,1,1,1,1)) |
| 124 | + s:add(SBatchNorm(n)) |
| 125 | + |
| 126 | + return nn.Sequential() |
| 127 | + :add(nn.ConcatTable() |
| 128 | + :add(s) |
| 129 | + :add(shortcut(nInputPlane, n, stride))) |
| 130 | + :add(nn.CAddTable(true)) |
| 131 | + :add(ReLU(true)) |
| 132 | + end |
| 133 | + |
| 134 | + -- The bottleneck residual layer for 50, 101, and 152 layer networks |
| 135 | + local function bottleneck(n, stride) |
| 136 | + local nInputPlane = iChannels |
| 137 | + iChannels = n * 4 |
| 138 | + |
| 139 | + local s = nn.Sequential() |
| 140 | + s:add(Convolution(nInputPlane,n,1,1,1,1,0,0)) |
| 141 | + s:add(SBatchNorm(n)) |
| 142 | + s:add(ReLU(true)) |
| 143 | + s:add(Convolution(n,n,3,3,stride,stride,1,1)) |
| 144 | + s:add(SBatchNorm(n)) |
| 145 | + s:add(ReLU(true)) |
| 146 | + s:add(Convolution(n,n*4,1,1,1,1,0,0)) |
| 147 | + s:add(SBatchNorm(n * 4)) |
| 148 | + |
| 149 | + return nn.Sequential() |
| 150 | + :add(nn.ConcatTable() |
| 151 | + :add(s) |
| 152 | + :add(shortcut(nInputPlane, n * 4, stride))) |
| 153 | + :add(nn.CAddTable(true)) |
| 154 | + :add(ReLU(true)) |
| 155 | + end |
| 156 | + |
| 157 | + -- Creates count residual blocks with specified number of features |
| 158 | + local function layer(block, features, count, stride) |
| 159 | + local s = nn.Sequential() |
| 160 | + for i=1,count do |
| 161 | + s:add(block(features, i == 1 and stride or 1)) |
| 162 | + end |
| 163 | + return s |
| 164 | + end |
| 165 | + |
| 166 | + local model = nn.Sequential() |
| 167 | + local input |
| 168 | + if opt.dataset == 'imagenet' then |
| 169 | + -- Configurations for ResNet: |
| 170 | + -- num. residual blocks, num features, residual block function |
| 171 | + local cfg = { |
| 172 | + [18] = {{2, 2, 2, 2}, 512, basicblock}, |
| 173 | + [34] = {{3, 4, 6, 3}, 512, basicblock}, |
| 174 | + [50] = {{3, 4, 6, 3}, 2048, bottleneck}, |
| 175 | + [101] = {{3, 4, 23, 3}, 2048, bottleneck}, |
| 176 | + [152] = {{3, 8, 36, 3}, 2048, bottleneck}, |
| 177 | + } |
| 178 | + |
| 179 | + assert(cfg[depth], 'Invalid depth: ' .. tostring(depth)) |
| 180 | + local def, nFeatures, block = table.unpack(cfg[depth]) |
| 181 | + iChannels = 64 |
| 182 | + --print(' | ResNet-' .. depth .. ' ImageNet') |
| 183 | + |
| 184 | + -- The ResNet ImageNet model |
| 185 | + model:add(Convolution(3,64,7,7,2,2,3,3)) |
| 186 | + model:add(SBatchNorm(64)) |
| 187 | + model:add(ReLU(true)) |
| 188 | + model:add(Max(3,3,2,2,1,1)) |
| 189 | + model:add(layer(block, 64, def[1])) |
| 190 | + model:add(layer(block, 128, def[2], 2)) |
| 191 | + model:add(layer(block, 256, def[3], 2)) |
| 192 | + model:add(layer(block, 512, def[4], 2)) |
| 193 | + model:add(Avg(7, 7, 1, 1)) |
| 194 | + model:add(nn.View(nFeatures):setNumInputDims(3)) |
| 195 | + model:add(nn.Linear(nFeatures, 1000)) |
| 196 | + |
| 197 | + input = torch.rand(1,3,224,224) |
| 198 | + elseif opt.dataset == 'cifar10' then |
| 199 | + -- Model type specifies number of layers for CIFAR-10 model |
| 200 | + assert((depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110, 1202') |
| 201 | + local n = (depth - 2) / 6 |
| 202 | + iChannels = 16 |
| 203 | + --print(' | ResNet-' .. depth .. ' CIFAR-10') |
| 204 | + |
| 205 | + -- The ResNet CIFAR-10 model |
| 206 | + model:add(Convolution(3,16,3,3,1,1,1,1)) |
| 207 | + model:add(SBatchNorm(16)) |
| 208 | + model:add(ReLU(true)) |
| 209 | + model:add(layer(basicblock, 16, n)) |
| 210 | + model:add(layer(basicblock, 32, n, 2)) |
| 211 | + model:add(layer(basicblock, 64, n, 2)) |
| 212 | + model:add(Avg(8, 8, 1, 1)) |
| 213 | + model:add(nn.View(64):setNumInputDims(3)) |
| 214 | + model:add(nn.Linear(64, 10)) |
| 215 | + input = torch.rand(1,3,32,32) |
| 216 | + else |
| 217 | + error('invalid dataset: ' .. opt.dataset) |
| 218 | + end |
| 219 | + |
| 220 | + local function ConvInit(name) |
| 221 | + for k,v in pairs(model:findModules(name)) do |
| 222 | + local n = v.kW*v.kH*v.nOutputPlane |
| 223 | + v.weight:normal(0,math.sqrt(2/n)) |
| 224 | + if false and cudnn.version >= 4000 then |
| 225 | + v.bias = nil |
| 226 | + v.gradBias = nil |
| 227 | + else |
| 228 | + v.bias:zero() |
| 229 | + end |
| 230 | + end |
| 231 | + end |
| 232 | + local function BNInit(name) |
| 233 | + for k,v in pairs(model:findModules(name)) do |
| 234 | + v.weight:fill(1) |
| 235 | + v.bias:zero() |
| 236 | + end |
| 237 | + end |
| 238 | + |
| 239 | + ConvInit('cudnn.SpatialConvolution') |
| 240 | + ConvInit('nn.SpatialConvolution') |
| 241 | + BNInit('fbnn.SpatialBatchNormalization') |
| 242 | + BNInit('cudnn.SpatialBatchNormalization') |
| 243 | + BNInit('nn.SpatialBatchNormalization') |
| 244 | + for k,v in pairs(model:findModules('nn.Linear')) do |
| 245 | + v.bias:zero() |
| 246 | + end |
| 247 | + |
| 248 | + if opt.cudnn == 'deterministic' then |
| 249 | + model:apply(function(m) |
| 250 | + if m.setMode then m:setMode(1,1,1) end |
| 251 | + end) |
| 252 | + end |
| 253 | + |
| 254 | + return model, input |
| 255 | + end |
| 256 | + |
| 257 | + return createModel(opt) |
| 258 | +end |
| 259 | + |
| 260 | +return models |
0 commit comments