Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 8310e14

Browse files
authored
Adding a ResNet56-CIFAR10 benchmark (#431)
* Added ResNet56 model and benchmark. * inputFilters now switches its default value based on whether or not CIFAR-10-sized images are used for ResNet. * On second thought, let's just remove that option and only use the defaults.
1 parent 6a687b4 commit 8310e14

File tree

5 files changed

+65
-5
lines changed

5 files changed

+65
-5
lines changed

Benchmarks/Models.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
let benchmarkModels = [
15+
let benchmarkModels: [String: BenchmarkModel] = [
1616
"LeNetMNIST": LeNetMNIST(),
17+
"ResNetCIFAR10": ResNetCIFAR10(),
1718
]

Benchmarks/Models/ImageClassificationInference.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ protocol ImageClassificationModel: Layer where Input == Tensor<Float>, Output ==
2121
}
2222

2323
extension LeNet: ImageClassificationModel {}
24+
extension ResNet56: ImageClassificationModel {}
2425

2526
class ImageClassificationInference<Model, ClassificationDataset>: Benchmark
2627
where Model: ImageClassificationModel, ClassificationDataset: ImageClassificationDataset {

Benchmarks/Models/ResNetCIFAR10.swift

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Datasets
16+
import ImageClassificationModels
17+
import TensorFlow
18+
19+
struct ResNetCIFAR10: BenchmarkModel {
20+
21+
var defaultInferenceSettings: BenchmarkSettings {
22+
return BenchmarkSettings(batches: 1000, batchSize: 1, iterations: 10, epochs: -1)
23+
}
24+
25+
func makeInferenceBenchmark(settings: BenchmarkSettings) -> Benchmark {
26+
return ImageClassificationInference<ResNet56, CIFAR10>(settings: settings)
27+
}
28+
29+
var defaultTrainingSettings: BenchmarkSettings {
30+
return BenchmarkSettings(batches: -1, batchSize: 128, iterations: 10, epochs: 1)
31+
}
32+
33+
func makeTrainingBenchmark(settings: BenchmarkSettings) -> Benchmark {
34+
return ImageClassificationTraining<ResNet56, CIFAR10>(settings: settings)
35+
}
36+
}
37+
38+
struct ResNet56: Layer {
39+
var model: ResNet
40+
41+
init() {
42+
model = ResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false)
43+
}
44+
45+
@differentiable
46+
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
47+
return model(input)
48+
}
49+
}

Examples/ResNet-CIFAR10/main.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ let batchSize = 10
2121
let dataset = CIFAR10(batchSize: batchSize)
2222

2323
// Use the network sized for CIFAR-10
24-
var model = ResNet(classCount: 10, depth: .resNet50, downsamplingInFirstStage: true)
24+
var model = ResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false)
2525

2626
// the classic ImageNet optimizer setting diverges on CIFAR-10
2727
// let optimizer = SGD(for: model, learningRate: 0.1, momentum: 0.9)

Models/ImageClassification/ResNet.swift

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,17 @@ public struct ResNet: Layer {
127127
/// 3x3 convolution, corresponding to the v1.5 variant of the architecture.
128128
public init(
129129
classCount: Int, depth: Depth, downsamplingInFirstStage: Bool = true,
130-
inputFilters: Int = 64, useLaterStride: Bool = true
130+
useLaterStride: Bool = true
131131
) {
132+
let inputFilters: Int
133+
132134
if downsamplingInFirstStage {
135+
inputFilters = 64
133136
initialLayer = ConvBN(
134137
filterShape: (7, 7, 3, inputFilters), strides: (2, 2), padding: .same)
135138
maxPool = MaxPool2D(poolSize: (3, 3), strides: (2, 2), padding: .same)
136139
} else {
140+
inputFilters = 16
137141
initialLayer = ConvBN(filterShape: (3, 3, 3, inputFilters), padding: .same)
138142
maxPool = MaxPool2D(poolSize: (1, 1), strides: (1, 1)) // no-op
139143
}
@@ -151,7 +155,10 @@ public struct ResNet: Layer {
151155
}
152156
}
153157

154-
classifier = Dense(inputSize: depth.usesBasicBlocks ? 512 : 2048, outputSize: classCount)
158+
let finalFilters = inputFilters * Int(pow(2.0, Double(depth.layerBlockSizes.count - 1)))
159+
classifier = Dense(
160+
inputSize: depth.usesBasicBlocks ? finalFilters : finalFilters * 4,
161+
outputSize: classCount)
155162
}
156163

157164
@differentiable
@@ -169,12 +176,13 @@ extension ResNet {
169176
case resNet18
170177
case resNet34
171178
case resNet50
179+
case resNet56
172180
case resNet101
173181
case resNet152
174182

175183
var usesBasicBlocks: Bool {
176184
switch self {
177-
case .resNet18, .resNet34: return true
185+
case .resNet18, .resNet34, .resNet56: return true
178186
default: return false
179187
}
180188
}
@@ -184,6 +192,7 @@ extension ResNet {
184192
case .resNet18: return [2, 2, 2, 2]
185193
case .resNet34: return [3, 4, 6, 3]
186194
case .resNet50: return [3, 4, 6, 3]
195+
case .resNet56: return [9, 9, 9]
187196
case .resNet101: return [3, 4, 23, 3]
188197
case .resNet152: return [3, 8, 36, 3]
189198
}

0 commit comments

Comments
 (0)