Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 81390ad

Browse files
Convert LeNet-MNIST example to X10 (#578)
* Convert LeNet-MNIST example to X10 * Added conditional check to avoid use of X10 on macOS for now. Co-authored-by: Brad Larson <[email protected]>
1 parent 264bead commit 81390ad

File tree

1 file changed

+53
-26
lines changed

1 file changed

+53
-26
lines changed

Examples/LeNet-MNIST/main.swift

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ import Datasets
1818
let epochCount = 12
1919
let batchSize = 128
2020

21+
// Until https://github.com/tensorflow/swift-models/issues/588 is fixed, default to the eager-mode
22+
// device on macOS instead of X10.
23+
#if os(macOS)
24+
let device = Device.defaultTFEager
25+
#else
26+
let device = Device.defaultXLA
27+
#endif
28+
2129
let dataset = MNIST(batchSize: batchSize)
2230
// The LeNet-5 model, equivalent to `LeNet` in `ImageClassificationModels`.
2331
var classifier = Sequential {
@@ -30,65 +38,84 @@ var classifier = Sequential {
3038
Dense<Float>(inputSize: 120, outputSize: 84, activation: relu)
3139
Dense<Float>(inputSize: 84, outputSize: 10)
3240
}
41+
classifier.move(to: device)
3342

34-
let optimizer = SGD(for: classifier, learningRate: 0.1)
43+
var optimizer = SGD(for: classifier, learningRate: 0.1)
44+
optimizer = SGD(copying: optimizer, to: device)
3545

3646
print("Beginning training...")
3747

3848
struct Statistics {
39-
var correctGuessCount: Int = 0
40-
var totalGuessCount: Int = 0
41-
var totalLoss: Float = 0
49+
var correctGuessCount = Tensor<Int32>(0, on: Device.default)
50+
var totalGuessCount = Tensor<Int32>(0, on: Device.default)
51+
var totalLoss = Tensor<Float>(0, on: Device.default)
4252
var batches: Int = 0
53+
54+
var accuracy: Float {
55+
Float(correctGuessCount.scalarized()) / Float(totalGuessCount.scalarized()) * 100
56+
}
57+
58+
var averageLoss: Float {
59+
totalLoss.scalarized() / Float(batches)
60+
}
61+
62+
init(on device: Device = Device.default) {
63+
correctGuessCount = Tensor<Int32>(0, on: device)
64+
totalGuessCount = Tensor<Int32>(0, on: device)
65+
totalLoss = Tensor<Float>(0, on: device)
66+
}
67+
68+
mutating func update(logits: Tensor<Float>, labels: Tensor<Int32>, loss: Tensor<Float>) {
69+
let correct = logits.argmax(squeezingAxis: 1) .== labels
70+
correctGuessCount += Tensor<Int32>(correct).sum()
71+
totalGuessCount += Int32(labels.shape[0])
72+
totalLoss += loss
73+
batches += 1
74+
}
4375
}
4476

4577
// The training loop.
4678
for (epoch, epochBatches) in dataset.training.prefix(epochCount).enumerated() {
47-
var trainStats = Statistics()
48-
var testStats = Statistics()
79+
var trainStats = Statistics(on: device)
80+
var testStats = Statistics(on: device)
4981

5082
Context.local.learningPhase = .training
5183
for batch in epochBatches {
52-
let (images, labels) = (batch.data, batch.label)
84+
let (eagerImages, eagerLabels) = (batch.data, batch.label)
85+
let images = Tensor(copying: eagerImages, to: device)
86+
let labels = Tensor(copying: eagerLabels, to: device)
5387
// Compute the gradient with respect to the model.
5488
let 𝛁model = TensorFlow.gradient(at: classifier) { classifier -> Tensor<Float> in
5589
let ŷ = classifier(images)
56-
let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== labels
57-
trainStats.correctGuessCount += Int(
58-
Tensor<Int32>(correctPredictions).sum().scalarized())
59-
trainStats.totalGuessCount += images.shape[0]
6090
let loss = softmaxCrossEntropy(logits: ŷ, labels: labels)
61-
trainStats.totalLoss += loss.scalarized()
62-
trainStats.batches += 1
91+
trainStats.update(logits: ŷ, labels: labels, loss: loss)
6392
return loss
6493
}
6594
// Update the model's differentiable variables along the gradient vector.
6695
optimizer.update(&classifier, along: 𝛁model)
96+
LazyTensorBarrier()
6797
}
6898

6999
Context.local.learningPhase = .inference
70100
for batch in dataset.validation {
71-
let (images, labels) = (batch.data, batch.label)
101+
let (eagerImages, eagerLabels) = (batch.data, batch.label)
102+
let images = Tensor(copying: eagerImages, to: device)
103+
let labels = Tensor(copying: eagerLabels, to: device)
72104
// Compute loss on test set
73105
let ŷ = classifier(images)
74-
let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== labels
75-
testStats.correctGuessCount += Int(Tensor<Int32>(correctPredictions).sum().scalarized())
76-
testStats.totalGuessCount += images.shape[0]
77106
let loss = softmaxCrossEntropy(logits: ŷ, labels: labels)
78-
testStats.totalLoss += loss.scalarized()
79-
testStats.batches += 1
107+
LazyTensorBarrier()
108+
testStats.update(logits: ŷ, labels: labels, loss: loss)
80109
}
81110

82-
let trainAccuracy = Float(trainStats.correctGuessCount) / Float(trainStats.totalGuessCount)
83-
let testAccuracy = Float(testStats.correctGuessCount) / Float(testStats.totalGuessCount)
84111
print(
85112
"""
86113
[Epoch \(epoch + 1)] \
87-
Training Loss: \(trainStats.totalLoss / Float(trainStats.batches)), \
114+
Training Loss: \(String(format: "%.3f", trainStats.averageLoss)), \
88115
Training Accuracy: \(trainStats.correctGuessCount)/\(trainStats.totalGuessCount) \
89-
(\(trainAccuracy)), \
90-
Test Loss: \(testStats.totalLoss / Float(testStats.batches)), \
116+
(\(String(format: "%.1f", trainStats.accuracy))%), \
117+
Test Loss: \(String(format: "%.3f", testStats.averageLoss)), \
91118
Test Accuracy: \(testStats.correctGuessCount)/\(testStats.totalGuessCount) \
92-
(\(testAccuracy))
119+
(\(String(format: "%.3f", testStats.accuracy))%)
93120
""")
94-
}
121+
}

0 commit comments

Comments
 (0)