Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit d923300

Browse files
authored
add test and fix cifar10 normalization values (#438)
1 parent 79c347b commit d923300

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

Datasets/CIFAR10/CIFAR10.swift

+23-2
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,30 @@ func loadCIFARFile(named name: String, in directory: URL, normalizing: Bool = tr
9898
// Transpose from the CIFAR-provided N(CHW) to TF's default NHWC.
9999
var imageTensor = Tensor<Float>(images.transposed(permutation: [0, 2, 3, 1]))
100100

101+
// The value of mean and std were calculated with the following Swift code:
102+
// ```
103+
// import TensorFlow
104+
// import Datasets
105+
// import Foundation
106+
// let urlString = "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/CIFAR10/cifar-10-binary.tar.gz"
107+
// let cifar = CIFAR10(batchSize: 50000,
108+
// remoteBinaryArchiveLocation: URL(string: urlString)!,
109+
// normalizing: false)
110+
// for batch in cifar.training.sequenced() {
111+
// let images = Tensor<Double>(batch.first) / 255.0
112+
// let mom = images.moments(squeezingAxes: [0,1,2])
113+
// print("mean: \(mom.mean) std: \(sqrt(mom.variance))")
114+
// }
115+
// ```
101116
if normalizing {
102-
let mean = Tensor<Float>([0.485, 0.456, 0.406])
103-
let std = Tensor<Float>([0.229, 0.224, 0.225])
117+
let mean = Tensor<Float>(
118+
[0.4913996898,
119+
0.4821584196,
120+
0.4465309242])
121+
let std = Tensor<Float>(
122+
[0.2470322324,
123+
0.2434851280,
124+
0.2615878417])
104125
imageTensor = ((imageTensor / 255.0) - mean) / std
105126
}
106127

Tests/DatasetsTests/CIFAR10/CIFAR10Tests.swift

+25-1
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,34 @@ final class CIFAR10Tests: XCTestCase {
2525
}
2626
XCTAssertEqual(totalCount, 50000)
2727
}
28+
29+
func testNormalizeCIFAR10() {
30+
let dataset = CIFAR10(
31+
batchSize: 50000,
32+
remoteBinaryArchiveLocation:
33+
URL(
34+
string:
35+
"https://storage.googleapis.com/s4tf-hosted-binaries/datasets/CIFAR10/cifar-10-binary.tar.gz"
36+
)!, normalizing: true
37+
)
38+
39+
let targetMean = Tensor<Double>([0, 0, 0])
40+
let targetStd = Tensor<Double>([1, 1, 1])
41+
for batch in dataset.training.sequenced() {
42+
let images = Tensor<Double>(batch.first)
43+
let mean = images.mean(squeezingAxes: [0, 1, 2])
44+
let std = images.standardDeviation(squeezingAxes: [0, 1, 2])
45+
XCTAssertTrue(targetMean.isAlmostEqual(to: mean,
46+
tolerance: 1e-6))
47+
XCTAssertTrue(targetStd.isAlmostEqual(to: std,
48+
tolerance: 1e-5))
49+
}
50+
}
2851
}
2952

3053
extension CIFAR10Tests {
3154
static var allTests = [
3255
("testCreateCIFAR10", testCreateCIFAR10),
56+
("testNormalizeCIFAR10", testNormalizeCIFAR10),
3357
]
34-
}
58+
}

0 commit comments

Comments
 (0)