|
| 1 | +function net = DnCNN_Init() |
| 2 | + |
| 3 | +% by Kai Zhang (1/2018) |
| 4 | + |
| 5 | +% https://github.com/cszn |
| 6 | + |
| 7 | +% Create DAGNN object |
| 8 | +net = dagnn.DagNN(); |
| 9 | + |
| 10 | +% conv + relu |
| 11 | +blockNum = 1; |
| 12 | +inVar = 'input'; |
| 13 | +channel= 1; % grayscale image |
| 14 | +dims = [3,3,channel,64]; |
| 15 | +pad = [1,1]; |
| 16 | +stride = [1,1]; |
| 17 | +lr = [1,1]; |
| 18 | +[net, inVar, blockNum] = addConv(net, blockNum, inVar, dims, pad, stride, lr); |
| 19 | +[net, inVar, blockNum] = addReLU(net, blockNum, inVar); |
| 20 | + |
| 21 | +for i = 1:15 |
| 22 | + % conv + bn + relu |
| 23 | + dims = [3,3,64,64]; |
| 24 | + pad = [1,1]; |
| 25 | + stride = [1,1]; |
| 26 | + lr = [1,0]; |
| 27 | + [net, inVar, blockNum] = addConv(net, blockNum, inVar, dims, pad, stride, lr); |
| 28 | + n_ch = dims(4); |
| 29 | + [net, inVar, blockNum] = addBnorm(net, blockNum, inVar, n_ch); |
| 30 | + [net, inVar, blockNum] = addReLU(net, blockNum, inVar); |
| 31 | +end |
| 32 | + |
| 33 | +% conv |
| 34 | +dims = [3,3,64,channel]; |
| 35 | +pad = [1,1]; |
| 36 | +stride = [1,1]; |
| 37 | +lr = [1,0]; % or [1,1], it does not influence the results |
| 38 | +[net, inVar, blockNum] = addConv(net, blockNum, inVar, dims, pad, stride, lr); |
| 39 | + |
| 40 | +% sum |
| 41 | +inVar = {inVar,'input'}; |
| 42 | +[net, inVar, blockNum] = addSum(net, blockNum, inVar); |
| 43 | + |
| 44 | +outputName = 'prediction'; |
| 45 | +net.renameVar(inVar,outputName) |
| 46 | + |
| 47 | +% loss |
| 48 | +net.addLayer('loss', dagnn.Loss('loss','L2'), {'prediction','label'}, {'objective'},{}); |
| 49 | +net.vars(net.getVarIndex('prediction')).precious = 1; |
| 50 | + |
| 51 | + |
| 52 | +end |
| 53 | + |
| 54 | + |
| 55 | + |
| 56 | + |
| 57 | +% Add a Concat layer |
| 58 | +function [net, inVar, blockNum] = addConcat(net, blockNum, inVar) |
| 59 | + |
| 60 | +outVar = sprintf('concat%d', blockNum); |
| 61 | +layerCur = sprintf('concat%d', blockNum); |
| 62 | + |
| 63 | +block = dagnn.Concat('dim',3); |
| 64 | +net.addLayer(layerCur, block, inVar, {outVar},{}); |
| 65 | + |
| 66 | +inVar = outVar; |
| 67 | +blockNum = blockNum + 1; |
| 68 | +end |
| 69 | + |
| 70 | + |
| 71 | +% Add a loss layer |
| 72 | +function [net, inVar, blockNum] = addLoss(net, blockNum, inVar) |
| 73 | + |
| 74 | +outVar = 'objective'; |
| 75 | +layerCur = sprintf('loss%d', blockNum); |
| 76 | + |
| 77 | +block = dagnn.Loss('loss','L2'); |
| 78 | +net.addLayer(layerCur, block, inVar, {outVar},{}); |
| 79 | + |
| 80 | +inVar = outVar; |
| 81 | +blockNum = blockNum + 1; |
| 82 | +end |
| 83 | + |
| 84 | + |
| 85 | +% Add a sum layer |
| 86 | +function [net, inVar, blockNum] = addSum(net, blockNum, inVar) |
| 87 | + |
| 88 | +outVar = sprintf('sum%d', blockNum); |
| 89 | +layerCur = sprintf('sum%d', blockNum); |
| 90 | + |
| 91 | +block = dagnn.Sum(); |
| 92 | +net.addLayer(layerCur, block, inVar, {outVar},{}); |
| 93 | + |
| 94 | +inVar = outVar; |
| 95 | +blockNum = blockNum + 1; |
| 96 | +end |
| 97 | + |
| 98 | + |
| 99 | +% Add a relu layer |
| 100 | +function [net, inVar, blockNum] = addReLU(net, blockNum, inVar) |
| 101 | + |
| 102 | +outVar = sprintf('relu%d', blockNum); |
| 103 | +layerCur = sprintf('relu%d', blockNum); |
| 104 | + |
| 105 | +block = dagnn.ReLU('leak',0); |
| 106 | +net.addLayer(layerCur, block, {inVar}, {outVar},{}); |
| 107 | + |
| 108 | +inVar = outVar; |
| 109 | +blockNum = blockNum + 1; |
| 110 | +end |
| 111 | + |
| 112 | + |
| 113 | +% Add a bnorm layer |
| 114 | +function [net, inVar, blockNum] = addBnorm(net, blockNum, inVar, n_ch) |
| 115 | + |
| 116 | +trainMethod = 'adam'; |
| 117 | +outVar = sprintf('bnorm%d', blockNum); |
| 118 | +layerCur = sprintf('bnorm%d', blockNum); |
| 119 | + |
| 120 | +params={[layerCur '_g'], [layerCur '_b'], [layerCur '_m']}; |
| 121 | +net.addLayer(layerCur, dagnn.BatchNorm('numChannels', n_ch), {inVar}, {outVar},params) ; |
| 122 | + |
| 123 | +pidx = net.getParamIndex({[layerCur '_g'], [layerCur '_b'], [layerCur '_m']}); |
| 124 | +b_min = 0.025; |
| 125 | +net.params(pidx(1)).value = clipping(sqrt(2/(9*n_ch))*randn(n_ch,1,'single'),b_min); |
| 126 | +net.params(pidx(1)).learningRate= 1; |
| 127 | +net.params(pidx(1)).weightDecay = 0; |
| 128 | +net.params(pidx(1)).trainMethod = trainMethod; |
| 129 | + |
| 130 | +net.params(pidx(2)).value = zeros(n_ch, 1, 'single'); |
| 131 | +net.params(pidx(2)).learningRate= 1; |
| 132 | +net.params(pidx(2)).weightDecay = 0; |
| 133 | +net.params(pidx(2)).trainMethod = trainMethod; |
| 134 | + |
| 135 | +net.params(pidx(3)).value = [zeros(n_ch,1,'single'), 0.01*ones(n_ch,1,'single')]; |
| 136 | +net.params(pidx(3)).learningRate= 1; |
| 137 | +net.params(pidx(3)).weightDecay = 0; |
| 138 | +net.params(pidx(3)).trainMethod = 'average'; |
| 139 | + |
| 140 | +inVar = outVar; |
| 141 | +blockNum = blockNum + 1; |
| 142 | +end |
| 143 | + |
| 144 | + |
| 145 | +% add a ConvTranspose layer |
| 146 | +function [net, inVar, blockNum] = addConvt(net, blockNum, inVar, dims, crop, upsample, lr) |
| 147 | +opts.cudnnWorkspaceLimit = 1024*1024*1024*2; % 2GB |
| 148 | +convOpts = {'CudnnWorkspaceLimit', opts.cudnnWorkspaceLimit} ; |
| 149 | +trainMethod = 'adam'; |
| 150 | + |
| 151 | +outVar = sprintf('convt%d', blockNum); |
| 152 | + |
| 153 | +layerCur = sprintf('convt%d', blockNum); |
| 154 | + |
| 155 | +convBlock = dagnn.ConvTranspose('size', dims, 'crop', crop,'upsample', upsample, ... |
| 156 | + 'hasBias', true, 'opts', convOpts); |
| 157 | + |
| 158 | +net.addLayer(layerCur, convBlock, {inVar}, {outVar},{[layerCur '_f'], [layerCur '_b']}); |
| 159 | + |
| 160 | +f = net.getParamIndex([layerCur '_f']) ; |
| 161 | +sc = sqrt(2/(dims(1)*dims(2)*dims(4))) ; %improved Xavier |
| 162 | +net.params(f).value = sc*randn(dims, 'single'); |
| 163 | +net.params(f).learningRate = lr(1); |
| 164 | +net.params(f).weightDecay = 1; |
| 165 | +net.params(f).trainMethod = trainMethod; |
| 166 | + |
| 167 | +f = net.getParamIndex([layerCur '_b']) ; |
| 168 | +net.params(f).value = zeros(dims(3), 1, 'single'); |
| 169 | +net.params(f).learningRate = lr(2); |
| 170 | +net.params(f).weightDecay = 1; |
| 171 | +net.params(f).trainMethod = trainMethod; |
| 172 | + |
| 173 | +inVar = outVar; |
| 174 | +blockNum = blockNum + 1; |
| 175 | +end |
| 176 | + |
| 177 | + |
| 178 | +% add a Conv layer |
| 179 | +function [net, inVar, blockNum] = addConv(net, blockNum, inVar, dims, pad, stride, lr) |
| 180 | +opts.cudnnWorkspaceLimit = 1024*1024*1024*2; % 2GB |
| 181 | +convOpts = {'CudnnWorkspaceLimit', opts.cudnnWorkspaceLimit} ; |
| 182 | +trainMethod = 'adam'; |
| 183 | + |
| 184 | +outVar = sprintf('conv%d', blockNum); |
| 185 | +layerCur = sprintf('conv%d', blockNum); |
| 186 | + |
| 187 | +convBlock = dagnn.Conv('size', dims, 'pad', pad,'stride', stride, ... |
| 188 | + 'hasBias', true, 'opts', convOpts); |
| 189 | + |
| 190 | +net.addLayer(layerCur, convBlock, {inVar}, {outVar},{[layerCur '_f'], [layerCur '_b']}); |
| 191 | + |
| 192 | +f = net.getParamIndex([layerCur '_f']) ; |
| 193 | +sc = sqrt(2/(dims(1)*dims(2)*max(dims(3), dims(4)))) ; %improved Xavier |
| 194 | +net.params(f).value = sc*randn(dims, 'single') ; |
| 195 | +net.params(f).learningRate = lr(1); |
| 196 | +net.params(f).weightDecay = 1; |
| 197 | +net.params(f).trainMethod = trainMethod; |
| 198 | + |
| 199 | +f = net.getParamIndex([layerCur '_b']) ; |
| 200 | +net.params(f).value = zeros(dims(4), 1, 'single'); |
| 201 | +net.params(f).learningRate = lr(2); |
| 202 | +net.params(f).weightDecay = 1; |
| 203 | +net.params(f).trainMethod = trainMethod; |
| 204 | + |
| 205 | +inVar = outVar; |
| 206 | +blockNum = blockNum + 1; |
| 207 | +end |
| 208 | + |
| 209 | + |
| 210 | +function A = clipping(A,b) |
| 211 | +A(A>=0&A<b) = b; |
| 212 | +A(A<0&A>-b) = -b; |
| 213 | +end |
0 commit comments