@@ -18,6 +18,14 @@ import Datasets
18
18
let epochCount = 12
19
19
let batchSize = 128
20
20
21
+ // Until https://github.com/tensorflow/swift-models/issues/588 is fixed, default to the eager-mode
22
+ // device on macOS instead of X10.
23
+ #if os(macOS)
24
+ let device = Device . defaultTFEager
25
+ #else
26
+ let device = Device . defaultXLA
27
+ #endif
28
+
21
29
let dataset = MNIST ( batchSize: batchSize)
22
30
// The LeNet-5 model, equivalent to `LeNet` in `ImageClassificationModels`.
23
31
var classifier = Sequential {
@@ -30,65 +38,84 @@ var classifier = Sequential {
30
38
Dense < Float > ( inputSize: 120 , outputSize: 84 , activation: relu)
31
39
Dense < Float > ( inputSize: 84 , outputSize: 10 )
32
40
}
41
+ classifier. move ( to: device)
33
42
34
- let optimizer = SGD ( for: classifier, learningRate: 0.1 )
43
+ var optimizer = SGD ( for: classifier, learningRate: 0.1 )
44
+ optimizer = SGD ( copying: optimizer, to: device)
35
45
36
46
print ( " Beginning training... " )
37
47
38
48
struct Statistics {
39
- var correctGuessCount : Int = 0
40
- var totalGuessCount : Int = 0
41
- var totalLoss : Float = 0
49
+ var correctGuessCount = Tensor < Int32 > ( 0 , on : Device . default )
50
+ var totalGuessCount = Tensor < Int32 > ( 0 , on : Device . default )
51
+ var totalLoss = Tensor < Float > ( 0 , on : Device . default )
42
52
var batches : Int = 0
53
+
54
+ var accuracy : Float {
55
+ Float ( correctGuessCount. scalarized ( ) ) / Float( totalGuessCount. scalarized ( ) ) * 100
56
+ }
57
+
58
+ var averageLoss : Float {
59
+ totalLoss. scalarized ( ) / Float( batches)
60
+ }
61
+
62
+ init ( on device: Device = Device . default) {
63
+ correctGuessCount = Tensor < Int32 > ( 0 , on: device)
64
+ totalGuessCount = Tensor < Int32 > ( 0 , on: device)
65
+ totalLoss = Tensor < Float > ( 0 , on: device)
66
+ }
67
+
68
+ mutating func update( logits: Tensor < Float > , labels: Tensor < Int32 > , loss: Tensor < Float > ) {
69
+ let correct = logits. argmax ( squeezingAxis: 1 ) .== labels
70
+ correctGuessCount += Tensor < Int32 > ( correct) . sum ( )
71
+ totalGuessCount += Int32 ( labels. shape [ 0 ] )
72
+ totalLoss += loss
73
+ batches += 1
74
+ }
43
75
}
44
76
45
77
// The training loop.
46
78
for (epoch, epochBatches) in dataset. training. prefix ( epochCount) . enumerated ( ) {
47
- var trainStats = Statistics ( )
48
- var testStats = Statistics ( )
79
+ var trainStats = Statistics ( on : device )
80
+ var testStats = Statistics ( on : device )
49
81
50
82
Context . local. learningPhase = . training
51
83
for batch in epochBatches {
52
- let ( images, labels) = ( batch. data, batch. label)
84
+ let ( eagerImages, eagerLabels) = ( batch. data, batch. label)
85
+ let images = Tensor ( copying: eagerImages, to: device)
86
+ let labels = Tensor ( copying: eagerLabels, to: device)
53
87
// Compute the gradient with respect to the model.
54
88
let 𝛁model = TensorFlow. gradient( at: classifier) { classifier - > Tensor< Float> in
55
89
let ŷ = classifier ( images)
56
- let correctPredictions = ŷ. argmax ( squeezingAxis: 1 ) .== labels
57
- trainStats. correctGuessCount += Int (
58
- Tensor < Int32 > ( correctPredictions) . sum ( ) . scalarized ( ) )
59
- trainStats. totalGuessCount += images. shape [ 0 ]
60
90
let loss = softmaxCrossEntropy ( logits: ŷ, labels: labels)
61
- trainStats. totalLoss += loss. scalarized ( )
62
- trainStats. batches += 1
91
+ trainStats. update ( logits: ŷ, labels: labels, loss: loss)
63
92
return loss
64
93
}
65
94
// Update the model's differentiable variables along the gradient vector.
66
95
optimizer. update ( & classifier, along: 𝛁model)
96
+ LazyTensorBarrier ( )
67
97
}
68
98
69
99
Context . local. learningPhase = . inference
70
100
for batch in dataset. validation {
71
- let ( images, labels) = ( batch. data, batch. label)
101
+ let ( eagerImages, eagerLabels) = ( batch. data, batch. label)
102
+ let images = Tensor ( copying: eagerImages, to: device)
103
+ let labels = Tensor ( copying: eagerLabels, to: device)
72
104
// Compute loss on test set
73
105
let ŷ = classifier ( images)
74
- let correctPredictions = ŷ. argmax ( squeezingAxis: 1 ) .== labels
75
- testStats. correctGuessCount += Int ( Tensor < Int32 > ( correctPredictions) . sum ( ) . scalarized ( ) )
76
- testStats. totalGuessCount += images. shape [ 0 ]
77
106
let loss = softmaxCrossEntropy ( logits: ŷ, labels: labels)
78
- testStats . totalLoss += loss . scalarized ( )
79
- testStats. batches += 1
107
+ LazyTensorBarrier ( )
108
+ testStats. update ( logits : ŷ , labels : labels , loss : loss )
80
109
}
81
110
82
- let trainAccuracy = Float ( trainStats. correctGuessCount) / Float( trainStats. totalGuessCount)
83
- let testAccuracy = Float ( testStats. correctGuessCount) / Float( testStats. totalGuessCount)
84
111
print (
85
112
"""
86
113
[Epoch \( epoch + 1 ) ] \
87
- Training Loss: \( trainStats . totalLoss / Float ( trainStats. batches ) ) , \
114
+ Training Loss: \( String ( format : " %.3f " , trainStats. averageLoss ) ) , \
88
115
Training Accuracy: \( trainStats. correctGuessCount) / \( trainStats. totalGuessCount) \
89
- ( \( trainAccuracy ) ), \
90
- Test Loss: \( testStats . totalLoss / Float ( testStats. batches ) ) , \
116
+ ( \( String ( format : " %.1f " , trainStats . accuracy ) ) % ), \
117
+ Test Loss: \( String ( format : " %.3f " , testStats. averageLoss ) ) , \
91
118
Test Accuracy: \( testStats. correctGuessCount) / \( testStats. totalGuessCount) \
92
- ( \( testAccuracy ) )
119
+ ( \( String ( format : " %.3f " , testStats . accuracy ) ) % )
93
120
""" )
94
- }
121
+ }
0 commit comments