Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit a3bcac8

Browse files
authored
add efficientnet (#394)
1 parent 6fecd19 commit a3bcac8

File tree

3 files changed

+383
-0
lines changed

3 files changed

+383
-0
lines changed

Models/ImageClassification/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_library(ImageClassificationModels
22
DenseNet121.swift
3+
EfficientNet.swift
34
LeNet-5.swift
45
MobileNetV1.swift
56
MobileNetV2.swift
Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import TensorFlow
16+
17+
// Original Paper:
18+
// "EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks"
19+
// Mingxing Tan, Quoc V. Le
20+
// https://arxiv.org/abs/1905.11946
21+
// Notes: Default baseline (B0) network, see table 1
22+
23+
/// some utility functions to help generate network variants
24+
/// original: https://github.com/tensorflow/tpu/blob/d6f2ef3edfeb4b1c2039b81014dc5271a7753832/models/official/efficientnet/efficientnet_model.py#L138
25+
fileprivate func resizeDepth(blockCount: Int, depth: Float) -> Int {
26+
/// Multiply + round up the number of blocks based on depth multiplier
27+
var newFilterCount = depth * Float(blockCount)
28+
newFilterCount.round(.up)
29+
return Int(newFilterCount)
30+
}
31+
32+
fileprivate func makeDivisible(filter: Int, width: Float, divisor: Float = 8.0) -> Int {
33+
/// Return a filter multiplied by width, rounded down and evenly divisible by the divisor
34+
let filterMult = Float(filter) * width
35+
let filterAdd = Float(filterMult) + (divisor / 2.0)
36+
var div = filterAdd / divisor
37+
div.round(.down)
38+
div = div * Float(divisor)
39+
var newFilterCount = max(1, Int(div))
40+
if newFilterCount < Int(0.9 * Float(filter)) {
41+
newFilterCount += Int(divisor)
42+
}
43+
return Int(newFilterCount)
44+
}
45+
46+
fileprivate func roundFilterPair(filters: (Int, Int), width: Float) -> (Int, Int) {
47+
return (
48+
makeDivisible(filter: filters.0, width: width),
49+
makeDivisible(filter: filters.1, width: width)
50+
)
51+
}
52+
53+
struct InitialMBConvBlock: Layer {
54+
@noDerivative var hiddenDimension: Int
55+
var dConv: DepthwiseConv2D<Float>
56+
var batchNormDConv: BatchNorm<Float>
57+
var seAveragePool = GlobalAvgPool2D<Float>()
58+
var seReduceConv: Conv2D<Float>
59+
var seExpandConv: Conv2D<Float>
60+
var conv2: Conv2D<Float>
61+
var batchNormConv2: BatchNorm<Float>
62+
63+
init(filters: (Int, Int), width: Float) {
64+
let filterMult = roundFilterPair(filters: filters, width: width)
65+
self.hiddenDimension = filterMult.0
66+
dConv = DepthwiseConv2D<Float>(
67+
filterShape: (3, 3, filterMult.0, 1),
68+
strides: (1, 1),
69+
padding: .same)
70+
seReduceConv = Conv2D<Float>(
71+
filterShape: (1, 1, filterMult.0, makeDivisible(filter: 8, width: width)),
72+
strides: (1, 1),
73+
padding: .same)
74+
seExpandConv = Conv2D<Float>(
75+
filterShape: (1, 1, makeDivisible(filter: 8, width: width), filterMult.0),
76+
strides: (1, 1),
77+
padding: .same)
78+
conv2 = Conv2D<Float>(
79+
filterShape: (1, 1, filterMult.0, filterMult.1),
80+
strides: (1, 1),
81+
padding: .same)
82+
batchNormDConv = BatchNorm(featureCount: filterMult.0)
83+
batchNormConv2 = BatchNorm(featureCount: filterMult.1)
84+
}
85+
86+
@differentiable
87+
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
88+
let depthwise = swish(batchNormDConv(dConv(input)))
89+
let seAvgPoolReshaped = seAveragePool(depthwise).reshaped(to: [
90+
input.shape[0], 1, 1, self.hiddenDimension
91+
])
92+
let squeezeExcite = depthwise
93+
* sigmoid(seExpandConv(swish(seReduceConv(seAvgPoolReshaped))))
94+
return batchNormConv2(conv2(squeezeExcite))
95+
}
96+
}
97+
98+
struct MBConvBlock: Layer {
99+
@noDerivative var addResLayer: Bool
100+
@noDerivative var strides: (Int, Int)
101+
@noDerivative let zeroPad = ZeroPadding2D<Float>(padding: ((0, 1), (0, 1)))
102+
@noDerivative var hiddenDimension: Int
103+
104+
var conv1: Conv2D<Float>
105+
var batchNormConv1: BatchNorm<Float>
106+
var dConv: DepthwiseConv2D<Float>
107+
var batchNormDConv: BatchNorm<Float>
108+
var seAveragePool = GlobalAvgPool2D<Float>()
109+
var seReduceConv: Conv2D<Float>
110+
var seExpandConv: Conv2D<Float>
111+
var conv2: Conv2D<Float>
112+
var batchNormConv2: BatchNorm<Float>
113+
114+
init(
115+
filters: (Int, Int),
116+
width: Float,
117+
depthMultiplier: Int = 6,
118+
strides: (Int, Int) = (1, 1),
119+
kernel: (Int, Int) = (3, 3)
120+
) {
121+
self.strides = strides
122+
self.addResLayer = filters.0 == filters.1 && strides == (1, 1)
123+
124+
let filterMult = roundFilterPair(filters: filters, width: width)
125+
self.hiddenDimension = filterMult.0 * depthMultiplier
126+
let reducedDimension = max(1, Int(filterMult.0 / 4))
127+
conv1 = Conv2D<Float>(
128+
filterShape: (1, 1, filterMult.0, hiddenDimension),
129+
strides: (1, 1),
130+
padding: .same)
131+
dConv = DepthwiseConv2D<Float>(
132+
filterShape: (kernel.0, kernel.1, hiddenDimension, 1),
133+
strides: strides,
134+
padding: strides == (1, 1) ? .same : .valid)
135+
seReduceConv = Conv2D<Float>(
136+
filterShape: (1, 1, hiddenDimension, reducedDimension),
137+
strides: (1, 1),
138+
padding: .same)
139+
seExpandConv = Conv2D<Float>(
140+
filterShape: (1, 1, reducedDimension, hiddenDimension),
141+
strides: (1, 1),
142+
padding: .same)
143+
conv2 = Conv2D<Float>(
144+
filterShape: (1, 1, hiddenDimension, filterMult.1),
145+
strides: (1, 1),
146+
padding: .same)
147+
batchNormConv1 = BatchNorm(featureCount: hiddenDimension)
148+
batchNormDConv = BatchNorm(featureCount: hiddenDimension)
149+
batchNormConv2 = BatchNorm(featureCount: filterMult.1)
150+
}
151+
152+
@differentiable
153+
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
154+
let piecewise = swish(batchNormConv1(conv1(input)))
155+
var depthwise: Tensor<Float>
156+
if self.strides == (1, 1) {
157+
depthwise = swish(batchNormDConv(dConv(piecewise)))
158+
} else {
159+
depthwise = swish(batchNormDConv(dConv(zeroPad(piecewise))))
160+
}
161+
let seAvgPoolReshaped = seAveragePool(depthwise).reshaped(to: [
162+
input.shape[0], 1, 1, self.hiddenDimension
163+
])
164+
let squeezeExcite = depthwise
165+
* sigmoid(seExpandConv(swish(seReduceConv(seAvgPoolReshaped))))
166+
let piecewiseLinear = batchNormConv2(conv2(squeezeExcite))
167+
168+
if self.addResLayer {
169+
return input + piecewiseLinear
170+
} else {
171+
return piecewiseLinear
172+
}
173+
}
174+
}
175+
176+
struct MBConvBlockStack: Layer {
177+
var blocks: [MBConvBlock] = []
178+
179+
init(
180+
filters: (Int, Int),
181+
width: Float,
182+
initialStrides: (Int, Int) = (2, 2),
183+
kernel: (Int, Int) = (3, 3),
184+
blockCount: Int,
185+
depth: Float
186+
) {
187+
let blockMult = resizeDepth(blockCount: blockCount, depth: depth)
188+
self.blocks = [
189+
MBConvBlock(
190+
filters: (filters.0, filters.1), width: width,
191+
strides: initialStrides, kernel: kernel)
192+
]
193+
for _ in 1..<blockMult {
194+
self.blocks.append(
195+
MBConvBlock(
196+
filters: (filters.1, filters.1),
197+
width: width, kernel: kernel))
198+
}
199+
}
200+
201+
@differentiable
202+
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
203+
return blocks.differentiableReduce(input) { $1($0) }
204+
}
205+
}
206+
207+
public struct EfficientNet: Layer {
208+
@noDerivative let zeroPad = ZeroPadding2D<Float>(padding: ((0, 1), (0, 1)))
209+
var inputConv: Conv2D<Float>
210+
var inputConvBatchNorm: BatchNorm<Float>
211+
var initialMBConv: InitialMBConvBlock
212+
213+
var residualBlockStack1: MBConvBlockStack
214+
var residualBlockStack2: MBConvBlockStack
215+
var residualBlockStack3: MBConvBlockStack
216+
var residualBlockStack4: MBConvBlockStack
217+
var residualBlockStack5: MBConvBlockStack
218+
var residualBlockStack6: MBConvBlockStack
219+
220+
var outputConv: Conv2D<Float>
221+
var outputConvBatchNorm: BatchNorm<Float>
222+
var avgPool = GlobalAvgPool2D<Float>()
223+
var dropoutProb: Dropout<Float>
224+
var outputClassifier: Dense<Float>
225+
226+
/// default settings are efficientnetB0 (baseline) network
227+
/// resolution is here to show what the network can take as input, it doesn't set anything!
228+
public init(
229+
classCount: Int = 1000,
230+
width: Float = 1.0,
231+
depth: Float = 1.0,
232+
resolution: Int = 224,
233+
dropout: Double = 0.2
234+
) {
235+
inputConv = Conv2D<Float>(
236+
filterShape: (3, 3, 3, makeDivisible(filter: 32, width: width)),
237+
strides: (2, 2),
238+
padding: .valid)
239+
inputConvBatchNorm = BatchNorm(featureCount: makeDivisible(filter: 32, width: width))
240+
241+
initialMBConv = InitialMBConvBlock(filters: (32, 16), width: width)
242+
243+
residualBlockStack1 = MBConvBlockStack(
244+
filters: (16, 24), width: width,
245+
blockCount: 2, depth: depth)
246+
residualBlockStack2 = MBConvBlockStack(
247+
filters: (24, 40), width: width,
248+
kernel: (5, 5), blockCount: 2, depth: depth)
249+
residualBlockStack3 = MBConvBlockStack(
250+
filters: (40, 80), width: width,
251+
blockCount: 3, depth: depth)
252+
residualBlockStack4 = MBConvBlockStack(
253+
filters: (80, 112), width: width,
254+
initialStrides: (1, 1), kernel: (5, 5), blockCount: 3, depth: depth)
255+
residualBlockStack5 = MBConvBlockStack(
256+
filters: (112, 192), width: width,
257+
kernel: (5, 5), blockCount: 4, depth: depth)
258+
residualBlockStack6 = MBConvBlockStack(
259+
filters: (192, 320), width: width,
260+
initialStrides: (1, 1), blockCount: 1, depth: depth)
261+
262+
outputConv = Conv2D<Float>(
263+
filterShape: (
264+
1, 1,
265+
makeDivisible(filter: 320, width: width), makeDivisible(filter: 1280, width: width)
266+
),
267+
strides: (1, 1),
268+
padding: .same)
269+
outputConvBatchNorm = BatchNorm(featureCount: makeDivisible(filter: 1280, width: width))
270+
271+
dropoutProb = Dropout<Float>(probability: dropout)
272+
outputClassifier = Dense(
273+
inputSize: makeDivisible(filter: 1280, width: width),
274+
outputSize: classCount, activation: softmax)
275+
}
276+
277+
@differentiable
278+
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
279+
let convolved = swish(input.sequenced(through: zeroPad, inputConv, inputConvBatchNorm))
280+
let initialBlock = initialMBConv(convolved)
281+
let backbone = initialBlock.sequenced(
282+
through: residualBlockStack1, residualBlockStack2,
283+
residualBlockStack3, residualBlockStack4, residualBlockStack5, residualBlockStack6)
284+
let output = swish(backbone.sequenced(through: outputConv, outputConvBatchNorm))
285+
return output.sequenced(through: avgPool, dropoutProb, outputClassifier)
286+
}
287+
}
288+
289+
extension EfficientNet {
290+
public enum Kind {
291+
case efficientnetB0
292+
case efficientnetB1
293+
case efficientnetB2
294+
case efficientnetB3
295+
case efficientnetB4
296+
case efficientnetB5
297+
case efficientnetB6
298+
case efficientnetB7
299+
case efficientnetB8
300+
case efficientnetL2
301+
}
302+
303+
public init(kind: Kind, classCount: Int = 1000) {
304+
switch kind {
305+
case .efficientnetB0:
306+
self.init(classCount: classCount, width: 1.0, depth: 1.0, resolution: 224, dropout: 0.2)
307+
case .efficientnetB1:
308+
self.init(classCount: classCount, width: 1.0, depth: 1.1, resolution: 240, dropout: 0.2)
309+
case .efficientnetB2:
310+
self.init(classCount: classCount, width: 1.1, depth: 1.2, resolution: 260, dropout: 0.3)
311+
case .efficientnetB3:
312+
self.init(classCount: classCount, width: 1.2, depth: 1.4, resolution: 300, dropout: 0.3)
313+
case .efficientnetB4:
314+
self.init(classCount: classCount, width: 1.4, depth: 1.8, resolution: 380, dropout: 0.4)
315+
case .efficientnetB5:
316+
self.init(classCount: classCount, width: 1.6, depth: 2.2, resolution: 456, dropout: 0.4)
317+
case .efficientnetB6:
318+
self.init(classCount: classCount, width: 1.8, depth: 2.6, resolution: 528, dropout: 0.5)
319+
case .efficientnetB7:
320+
self.init(classCount: classCount, width: 2.0, depth: 3.1, resolution: 600, dropout: 0.5)
321+
case .efficientnetB8:
322+
self.init(classCount: classCount, width: 2.2, depth: 3.6, resolution: 672, dropout: 0.5)
323+
case .efficientnetL2:
324+
// https://arxiv.org/abs/1911.04252
325+
self.init(classCount: classCount, width: 4.3, depth: 5.3, resolution: 800, dropout: 0.5)
326+
}
327+
}
328+
}

0 commit comments

Comments
 (0)