|
| 1 | +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +import TensorFlow |
| 16 | + |
| 17 | +// Original Paper: |
| 18 | +// "EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" |
| 19 | +// Mingxing Tan, Quoc V. Le |
| 20 | +// https://arxiv.org/abs/1905.11946 |
| 21 | +// Notes: Default baseline (B0) network, see table 1 |
| 22 | + |
| 23 | +/// some utility functions to help generate network variants |
| 24 | +/// original: https://github.com/tensorflow/tpu/blob/d6f2ef3edfeb4b1c2039b81014dc5271a7753832/models/official/efficientnet/efficientnet_model.py#L138 |
| 25 | +fileprivate func resizeDepth(blockCount: Int, depth: Float) -> Int { |
| 26 | + /// Multiply + round up the number of blocks based on depth multiplier |
| 27 | + var newFilterCount = depth * Float(blockCount) |
| 28 | + newFilterCount.round(.up) |
| 29 | + return Int(newFilterCount) |
| 30 | +} |
| 31 | + |
| 32 | +fileprivate func makeDivisible(filter: Int, width: Float, divisor: Float = 8.0) -> Int { |
| 33 | + /// Return a filter multiplied by width, rounded down and evenly divisible by the divisor |
| 34 | + let filterMult = Float(filter) * width |
| 35 | + let filterAdd = Float(filterMult) + (divisor / 2.0) |
| 36 | + var div = filterAdd / divisor |
| 37 | + div.round(.down) |
| 38 | + div = div * Float(divisor) |
| 39 | + var newFilterCount = max(1, Int(div)) |
| 40 | + if newFilterCount < Int(0.9 * Float(filter)) { |
| 41 | + newFilterCount += Int(divisor) |
| 42 | + } |
| 43 | + return Int(newFilterCount) |
| 44 | +} |
| 45 | + |
| 46 | +fileprivate func roundFilterPair(filters: (Int, Int), width: Float) -> (Int, Int) { |
| 47 | + return ( |
| 48 | + makeDivisible(filter: filters.0, width: width), |
| 49 | + makeDivisible(filter: filters.1, width: width) |
| 50 | + ) |
| 51 | +} |
| 52 | + |
| 53 | +struct InitialMBConvBlock: Layer { |
| 54 | + @noDerivative var hiddenDimension: Int |
| 55 | + var dConv: DepthwiseConv2D<Float> |
| 56 | + var batchNormDConv: BatchNorm<Float> |
| 57 | + var seAveragePool = GlobalAvgPool2D<Float>() |
| 58 | + var seReduceConv: Conv2D<Float> |
| 59 | + var seExpandConv: Conv2D<Float> |
| 60 | + var conv2: Conv2D<Float> |
| 61 | + var batchNormConv2: BatchNorm<Float> |
| 62 | + |
| 63 | + init(filters: (Int, Int), width: Float) { |
| 64 | + let filterMult = roundFilterPair(filters: filters, width: width) |
| 65 | + self.hiddenDimension = filterMult.0 |
| 66 | + dConv = DepthwiseConv2D<Float>( |
| 67 | + filterShape: (3, 3, filterMult.0, 1), |
| 68 | + strides: (1, 1), |
| 69 | + padding: .same) |
| 70 | + seReduceConv = Conv2D<Float>( |
| 71 | + filterShape: (1, 1, filterMult.0, makeDivisible(filter: 8, width: width)), |
| 72 | + strides: (1, 1), |
| 73 | + padding: .same) |
| 74 | + seExpandConv = Conv2D<Float>( |
| 75 | + filterShape: (1, 1, makeDivisible(filter: 8, width: width), filterMult.0), |
| 76 | + strides: (1, 1), |
| 77 | + padding: .same) |
| 78 | + conv2 = Conv2D<Float>( |
| 79 | + filterShape: (1, 1, filterMult.0, filterMult.1), |
| 80 | + strides: (1, 1), |
| 81 | + padding: .same) |
| 82 | + batchNormDConv = BatchNorm(featureCount: filterMult.0) |
| 83 | + batchNormConv2 = BatchNorm(featureCount: filterMult.1) |
| 84 | + } |
| 85 | + |
| 86 | + @differentiable |
| 87 | + func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> { |
| 88 | + let depthwise = swish(batchNormDConv(dConv(input))) |
| 89 | + let seAvgPoolReshaped = seAveragePool(depthwise).reshaped(to: [ |
| 90 | + input.shape[0], 1, 1, self.hiddenDimension |
| 91 | + ]) |
| 92 | + let squeezeExcite = depthwise |
| 93 | + * sigmoid(seExpandConv(swish(seReduceConv(seAvgPoolReshaped)))) |
| 94 | + return batchNormConv2(conv2(squeezeExcite)) |
| 95 | + } |
| 96 | +} |
| 97 | + |
| 98 | +struct MBConvBlock: Layer { |
| 99 | + @noDerivative var addResLayer: Bool |
| 100 | + @noDerivative var strides: (Int, Int) |
| 101 | + @noDerivative let zeroPad = ZeroPadding2D<Float>(padding: ((0, 1), (0, 1))) |
| 102 | + @noDerivative var hiddenDimension: Int |
| 103 | + |
| 104 | + var conv1: Conv2D<Float> |
| 105 | + var batchNormConv1: BatchNorm<Float> |
| 106 | + var dConv: DepthwiseConv2D<Float> |
| 107 | + var batchNormDConv: BatchNorm<Float> |
| 108 | + var seAveragePool = GlobalAvgPool2D<Float>() |
| 109 | + var seReduceConv: Conv2D<Float> |
| 110 | + var seExpandConv: Conv2D<Float> |
| 111 | + var conv2: Conv2D<Float> |
| 112 | + var batchNormConv2: BatchNorm<Float> |
| 113 | + |
| 114 | + init( |
| 115 | + filters: (Int, Int), |
| 116 | + width: Float, |
| 117 | + depthMultiplier: Int = 6, |
| 118 | + strides: (Int, Int) = (1, 1), |
| 119 | + kernel: (Int, Int) = (3, 3) |
| 120 | + ) { |
| 121 | + self.strides = strides |
| 122 | + self.addResLayer = filters.0 == filters.1 && strides == (1, 1) |
| 123 | + |
| 124 | + let filterMult = roundFilterPair(filters: filters, width: width) |
| 125 | + self.hiddenDimension = filterMult.0 * depthMultiplier |
| 126 | + let reducedDimension = max(1, Int(filterMult.0 / 4)) |
| 127 | + conv1 = Conv2D<Float>( |
| 128 | + filterShape: (1, 1, filterMult.0, hiddenDimension), |
| 129 | + strides: (1, 1), |
| 130 | + padding: .same) |
| 131 | + dConv = DepthwiseConv2D<Float>( |
| 132 | + filterShape: (kernel.0, kernel.1, hiddenDimension, 1), |
| 133 | + strides: strides, |
| 134 | + padding: strides == (1, 1) ? .same : .valid) |
| 135 | + seReduceConv = Conv2D<Float>( |
| 136 | + filterShape: (1, 1, hiddenDimension, reducedDimension), |
| 137 | + strides: (1, 1), |
| 138 | + padding: .same) |
| 139 | + seExpandConv = Conv2D<Float>( |
| 140 | + filterShape: (1, 1, reducedDimension, hiddenDimension), |
| 141 | + strides: (1, 1), |
| 142 | + padding: .same) |
| 143 | + conv2 = Conv2D<Float>( |
| 144 | + filterShape: (1, 1, hiddenDimension, filterMult.1), |
| 145 | + strides: (1, 1), |
| 146 | + padding: .same) |
| 147 | + batchNormConv1 = BatchNorm(featureCount: hiddenDimension) |
| 148 | + batchNormDConv = BatchNorm(featureCount: hiddenDimension) |
| 149 | + batchNormConv2 = BatchNorm(featureCount: filterMult.1) |
| 150 | + } |
| 151 | + |
| 152 | + @differentiable |
| 153 | + func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> { |
| 154 | + let piecewise = swish(batchNormConv1(conv1(input))) |
| 155 | + var depthwise: Tensor<Float> |
| 156 | + if self.strides == (1, 1) { |
| 157 | + depthwise = swish(batchNormDConv(dConv(piecewise))) |
| 158 | + } else { |
| 159 | + depthwise = swish(batchNormDConv(dConv(zeroPad(piecewise)))) |
| 160 | + } |
| 161 | + let seAvgPoolReshaped = seAveragePool(depthwise).reshaped(to: [ |
| 162 | + input.shape[0], 1, 1, self.hiddenDimension |
| 163 | + ]) |
| 164 | + let squeezeExcite = depthwise |
| 165 | + * sigmoid(seExpandConv(swish(seReduceConv(seAvgPoolReshaped)))) |
| 166 | + let piecewiseLinear = batchNormConv2(conv2(squeezeExcite)) |
| 167 | + |
| 168 | + if self.addResLayer { |
| 169 | + return input + piecewiseLinear |
| 170 | + } else { |
| 171 | + return piecewiseLinear |
| 172 | + } |
| 173 | + } |
| 174 | +} |
| 175 | + |
| 176 | +struct MBConvBlockStack: Layer { |
| 177 | + var blocks: [MBConvBlock] = [] |
| 178 | + |
| 179 | + init( |
| 180 | + filters: (Int, Int), |
| 181 | + width: Float, |
| 182 | + initialStrides: (Int, Int) = (2, 2), |
| 183 | + kernel: (Int, Int) = (3, 3), |
| 184 | + blockCount: Int, |
| 185 | + depth: Float |
| 186 | + ) { |
| 187 | + let blockMult = resizeDepth(blockCount: blockCount, depth: depth) |
| 188 | + self.blocks = [ |
| 189 | + MBConvBlock( |
| 190 | + filters: (filters.0, filters.1), width: width, |
| 191 | + strides: initialStrides, kernel: kernel) |
| 192 | + ] |
| 193 | + for _ in 1..<blockMult { |
| 194 | + self.blocks.append( |
| 195 | + MBConvBlock( |
| 196 | + filters: (filters.1, filters.1), |
| 197 | + width: width, kernel: kernel)) |
| 198 | + } |
| 199 | + } |
| 200 | + |
| 201 | + @differentiable |
| 202 | + func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> { |
| 203 | + return blocks.differentiableReduce(input) { $1($0) } |
| 204 | + } |
| 205 | +} |
| 206 | + |
| 207 | +public struct EfficientNet: Layer { |
| 208 | + @noDerivative let zeroPad = ZeroPadding2D<Float>(padding: ((0, 1), (0, 1))) |
| 209 | + var inputConv: Conv2D<Float> |
| 210 | + var inputConvBatchNorm: BatchNorm<Float> |
| 211 | + var initialMBConv: InitialMBConvBlock |
| 212 | + |
| 213 | + var residualBlockStack1: MBConvBlockStack |
| 214 | + var residualBlockStack2: MBConvBlockStack |
| 215 | + var residualBlockStack3: MBConvBlockStack |
| 216 | + var residualBlockStack4: MBConvBlockStack |
| 217 | + var residualBlockStack5: MBConvBlockStack |
| 218 | + var residualBlockStack6: MBConvBlockStack |
| 219 | + |
| 220 | + var outputConv: Conv2D<Float> |
| 221 | + var outputConvBatchNorm: BatchNorm<Float> |
| 222 | + var avgPool = GlobalAvgPool2D<Float>() |
| 223 | + var dropoutProb: Dropout<Float> |
| 224 | + var outputClassifier: Dense<Float> |
| 225 | + |
| 226 | + /// default settings are efficientnetB0 (baseline) network |
| 227 | + /// resolution is here to show what the network can take as input, it doesn't set anything! |
| 228 | + public init( |
| 229 | + classCount: Int = 1000, |
| 230 | + width: Float = 1.0, |
| 231 | + depth: Float = 1.0, |
| 232 | + resolution: Int = 224, |
| 233 | + dropout: Double = 0.2 |
| 234 | + ) { |
| 235 | + inputConv = Conv2D<Float>( |
| 236 | + filterShape: (3, 3, 3, makeDivisible(filter: 32, width: width)), |
| 237 | + strides: (2, 2), |
| 238 | + padding: .valid) |
| 239 | + inputConvBatchNorm = BatchNorm(featureCount: makeDivisible(filter: 32, width: width)) |
| 240 | + |
| 241 | + initialMBConv = InitialMBConvBlock(filters: (32, 16), width: width) |
| 242 | + |
| 243 | + residualBlockStack1 = MBConvBlockStack( |
| 244 | + filters: (16, 24), width: width, |
| 245 | + blockCount: 2, depth: depth) |
| 246 | + residualBlockStack2 = MBConvBlockStack( |
| 247 | + filters: (24, 40), width: width, |
| 248 | + kernel: (5, 5), blockCount: 2, depth: depth) |
| 249 | + residualBlockStack3 = MBConvBlockStack( |
| 250 | + filters: (40, 80), width: width, |
| 251 | + blockCount: 3, depth: depth) |
| 252 | + residualBlockStack4 = MBConvBlockStack( |
| 253 | + filters: (80, 112), width: width, |
| 254 | + initialStrides: (1, 1), kernel: (5, 5), blockCount: 3, depth: depth) |
| 255 | + residualBlockStack5 = MBConvBlockStack( |
| 256 | + filters: (112, 192), width: width, |
| 257 | + kernel: (5, 5), blockCount: 4, depth: depth) |
| 258 | + residualBlockStack6 = MBConvBlockStack( |
| 259 | + filters: (192, 320), width: width, |
| 260 | + initialStrides: (1, 1), blockCount: 1, depth: depth) |
| 261 | + |
| 262 | + outputConv = Conv2D<Float>( |
| 263 | + filterShape: ( |
| 264 | + 1, 1, |
| 265 | + makeDivisible(filter: 320, width: width), makeDivisible(filter: 1280, width: width) |
| 266 | + ), |
| 267 | + strides: (1, 1), |
| 268 | + padding: .same) |
| 269 | + outputConvBatchNorm = BatchNorm(featureCount: makeDivisible(filter: 1280, width: width)) |
| 270 | + |
| 271 | + dropoutProb = Dropout<Float>(probability: dropout) |
| 272 | + outputClassifier = Dense( |
| 273 | + inputSize: makeDivisible(filter: 1280, width: width), |
| 274 | + outputSize: classCount, activation: softmax) |
| 275 | + } |
| 276 | + |
| 277 | + @differentiable |
| 278 | + public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> { |
| 279 | + let convolved = swish(input.sequenced(through: zeroPad, inputConv, inputConvBatchNorm)) |
| 280 | + let initialBlock = initialMBConv(convolved) |
| 281 | + let backbone = initialBlock.sequenced( |
| 282 | + through: residualBlockStack1, residualBlockStack2, |
| 283 | + residualBlockStack3, residualBlockStack4, residualBlockStack5, residualBlockStack6) |
| 284 | + let output = swish(backbone.sequenced(through: outputConv, outputConvBatchNorm)) |
| 285 | + return output.sequenced(through: avgPool, dropoutProb, outputClassifier) |
| 286 | + } |
| 287 | +} |
| 288 | + |
| 289 | +extension EfficientNet { |
| 290 | + public enum Kind { |
| 291 | + case efficientnetB0 |
| 292 | + case efficientnetB1 |
| 293 | + case efficientnetB2 |
| 294 | + case efficientnetB3 |
| 295 | + case efficientnetB4 |
| 296 | + case efficientnetB5 |
| 297 | + case efficientnetB6 |
| 298 | + case efficientnetB7 |
| 299 | + case efficientnetB8 |
| 300 | + case efficientnetL2 |
| 301 | + } |
| 302 | + |
| 303 | + public init(kind: Kind, classCount: Int = 1000) { |
| 304 | + switch kind { |
| 305 | + case .efficientnetB0: |
| 306 | + self.init(classCount: classCount, width: 1.0, depth: 1.0, resolution: 224, dropout: 0.2) |
| 307 | + case .efficientnetB1: |
| 308 | + self.init(classCount: classCount, width: 1.0, depth: 1.1, resolution: 240, dropout: 0.2) |
| 309 | + case .efficientnetB2: |
| 310 | + self.init(classCount: classCount, width: 1.1, depth: 1.2, resolution: 260, dropout: 0.3) |
| 311 | + case .efficientnetB3: |
| 312 | + self.init(classCount: classCount, width: 1.2, depth: 1.4, resolution: 300, dropout: 0.3) |
| 313 | + case .efficientnetB4: |
| 314 | + self.init(classCount: classCount, width: 1.4, depth: 1.8, resolution: 380, dropout: 0.4) |
| 315 | + case .efficientnetB5: |
| 316 | + self.init(classCount: classCount, width: 1.6, depth: 2.2, resolution: 456, dropout: 0.4) |
| 317 | + case .efficientnetB6: |
| 318 | + self.init(classCount: classCount, width: 1.8, depth: 2.6, resolution: 528, dropout: 0.5) |
| 319 | + case .efficientnetB7: |
| 320 | + self.init(classCount: classCount, width: 2.0, depth: 3.1, resolution: 600, dropout: 0.5) |
| 321 | + case .efficientnetB8: |
| 322 | + self.init(classCount: classCount, width: 2.2, depth: 3.6, resolution: 672, dropout: 0.5) |
| 323 | + case .efficientnetL2: |
| 324 | + // https://arxiv.org/abs/1911.04252 |
| 325 | + self.init(classCount: classCount, width: 4.3, depth: 5.3, resolution: 800, dropout: 0.5) |
| 326 | + } |
| 327 | + } |
| 328 | +} |
0 commit comments