@@ -27,29 +27,41 @@ import TensorFlow
27
27
public struct ConvBNV2 : Layer {
28
28
public var conv : Conv2D < Float >
29
29
public var norm : BatchNorm < Float >
30
- @noDerivative public let useRelu : Bool
30
+ @noDerivative public let isLast : Bool
31
31
32
32
public init (
33
33
inFilters: Int ,
34
34
outFilters: Int ,
35
35
kernelSize: Int = 1 ,
36
36
stride: Int = 1 ,
37
37
padding: Padding = . same,
38
- useRelu : Bool = true
38
+ isLast : Bool = false
39
39
) {
40
- //Should use no bias
41
40
self . conv = Conv2D (
42
41
filterShape: ( kernelSize, kernelSize, inFilters, outFilters) ,
43
42
strides: ( stride, stride) ,
44
- padding: padding)
45
- self . norm = BatchNorm ( featureCount: outFilters, momentum: 0.9 , epsilon: 1e-5 )
46
- self . useRelu = useRelu
43
+ padding: padding,
44
+ useBias: false )
45
+ self . isLast = isLast
46
+ if isLast {
47
+ //Initialize the last BatchNorm layer to scale zero
48
+ self . norm = BatchNorm (
49
+ axis: - 1 ,
50
+ momentum: 0.9 ,
51
+ offset: Tensor ( zeros: [ outFilters] ) ,
52
+ scale: Tensor ( zeros: [ outFilters] ) ,
53
+ epsilon: 1e-5 ,
54
+ runningMean: Tensor ( 0 ) ,
55
+ runningVariance: Tensor ( 1 ) )
56
+ } else {
57
+ self . norm = BatchNorm ( featureCount: outFilters, momentum: 0.9 , epsilon: 1e-5 )
58
+ }
47
59
}
48
60
49
61
@differentiable
50
62
public func callAsFunction( _ input: Tensor < Float > ) -> Tensor < Float > {
51
63
let convResult = input. sequenced ( through: conv, norm)
52
- return useRelu ? relu ( convResult) : convResult
64
+ return isLast ? convResult : relu ( convResult)
53
65
}
54
66
}
55
67
@@ -91,13 +103,13 @@ public struct ResidualBlockV2: Layer {
91
103
if expansion == 1 {
92
104
convs = [
93
105
ConvBNV2 ( inFilters: inFilters, outFilters: outFilters, kernelSize: 3 , stride: stride) ,
94
- ConvBNV2 ( inFilters: outFilters, outFilters: outFilters, kernelSize: 3 , useRelu : false )
106
+ ConvBNV2 ( inFilters: outFilters, outFilters: outFilters, kernelSize: 3 , isLast : true )
95
107
]
96
108
} else {
97
109
convs = [
98
110
ConvBNV2 ( inFilters: inFilters, outFilters: outFilters/ 4 ) ,
99
111
ConvBNV2 ( inFilters: outFilters/ 4 , outFilters: outFilters/ 4 , kernelSize: 3 , stride: stride) ,
100
- ConvBNV2 ( inFilters: outFilters/ 4 , outFilters: outFilters, useRelu : false )
112
+ ConvBNV2 ( inFilters: outFilters/ 4 , outFilters: outFilters, isLast : true )
101
113
]
102
114
}
103
115
shortcut = Shortcut ( inFilters: inFilters, outFilters: outFilters, stride: stride)
0 commit comments