Skip to content
This repository was archived by the owner on Feb 13, 2025. It is now read-only.

Commit eaef75c

Browse files
authored
Add zero init for last BN and don't use bias (#348)
1 parent 718c4f2 commit eaef75c

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

Models/ImageClassification/ResNet.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public struct ConvBN: Layer {
3333
strides: (Int, Int) = (1, 1),
3434
padding: Padding = .valid
3535
) {
36-
self.conv = Conv2D(filterShape: filterShape, strides: strides, padding: padding)
36+
self.conv = Conv2D(filterShape: filterShape, strides: strides, padding: padding, useBias: false)
3737
self.norm = BatchNorm(featureCount: filterShape.3, momentum: 0.9, epsilon: 1e-5)
3838
}
3939

Models/ImageClassification/ResNetV2.swift

+21-9
Original file line numberDiff line numberDiff line change
@@ -27,29 +27,41 @@ import TensorFlow
2727
public struct ConvBNV2: Layer {
2828
public var conv: Conv2D<Float>
2929
public var norm: BatchNorm<Float>
30-
@noDerivative public let useRelu: Bool
30+
@noDerivative public let isLast: Bool
3131

3232
public init(
3333
inFilters: Int,
3434
outFilters: Int,
3535
kernelSize: Int = 1,
3636
stride: Int = 1,
3737
padding: Padding = .same,
38-
useRelu: Bool = true
38+
isLast: Bool = false
3939
) {
40-
//Should use no bias
4140
self.conv = Conv2D(
4241
filterShape: (kernelSize, kernelSize, inFilters, outFilters),
4342
strides: (stride, stride),
44-
padding: padding)
45-
self.norm = BatchNorm(featureCount: outFilters, momentum: 0.9, epsilon: 1e-5)
46-
self.useRelu = useRelu
43+
padding: padding,
44+
useBias: false)
45+
self.isLast = isLast
46+
if isLast {
47+
//Initialize the last BatchNorm layer to scale zero
48+
self.norm = BatchNorm(
49+
axis: -1,
50+
momentum: 0.9,
51+
offset: Tensor(zeros: [outFilters]),
52+
scale: Tensor(zeros: [outFilters]),
53+
epsilon: 1e-5,
54+
runningMean: Tensor(0),
55+
runningVariance: Tensor(1))
56+
} else {
57+
self.norm = BatchNorm(featureCount: outFilters, momentum: 0.9, epsilon: 1e-5)
58+
}
4759
}
4860

4961
@differentiable
5062
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
5163
let convResult = input.sequenced(through: conv, norm)
52-
return useRelu ? relu(convResult) : convResult
64+
return isLast ? convResult : relu(convResult)
5365
}
5466
}
5567

@@ -91,13 +103,13 @@ public struct ResidualBlockV2: Layer {
91103
if expansion == 1 {
92104
convs = [
93105
ConvBNV2(inFilters: inFilters, outFilters: outFilters, kernelSize: 3, stride: stride),
94-
ConvBNV2(inFilters: outFilters, outFilters: outFilters, kernelSize: 3, useRelu: false)
106+
ConvBNV2(inFilters: outFilters, outFilters: outFilters, kernelSize: 3, isLast: true)
95107
]
96108
} else {
97109
convs = [
98110
ConvBNV2(inFilters: inFilters, outFilters: outFilters/4),
99111
ConvBNV2(inFilters: outFilters/4, outFilters: outFilters/4, kernelSize: 3, stride: stride),
100-
ConvBNV2(inFilters: outFilters/4, outFilters: outFilters, useRelu: false)
112+
ConvBNV2(inFilters: outFilters/4, outFilters: outFilters, isLast: true)
101113
]
102114
}
103115
shortcut = Shortcut(inFilters: inFilters, outFilters: outFilters, stride: stride)

0 commit comments

Comments
 (0)