@@ -17,23 +17,24 @@ import Foundation
17
17
import ModelSupport
18
18
import TensorFlow
19
19
import TextModels
20
+ import TrainingLoop
20
21
import x10_optimizers_optimizer
21
22
22
23
let device = Device . defaultXLA
23
24
24
25
var bertPretrained : BERT . PreTrainedModel
25
26
if CommandLine . arguments. count >= 2 {
26
- if CommandLine . arguments [ 1 ] . lowercased ( ) == " albert " {
27
- bertPretrained = BERT . PreTrainedModel. albertBase
28
- } else if CommandLine . arguments [ 1 ] . lowercased ( ) == " roberta " {
29
- bertPretrained = BERT . PreTrainedModel. robertaBase
30
- } else if CommandLine . arguments [ 1 ] . lowercased ( ) == " electra " {
31
- bertPretrained = BERT . PreTrainedModel. electraBase
32
- } else {
33
- bertPretrained = BERT . PreTrainedModel. bertBase ( cased: false , multilingual: false )
34
- }
35
- } else {
27
+ if CommandLine . arguments [ 1 ] . lowercased ( ) == " albert " {
28
+ bertPretrained = BERT . PreTrainedModel. albertBase
29
+ } else if CommandLine . arguments [ 1 ] . lowercased ( ) == " roberta " {
30
+ bertPretrained = BERT . PreTrainedModel. robertaBase
31
+ } else if CommandLine . arguments [ 1 ] . lowercased ( ) == " electra " {
32
+ bertPretrained = BERT . PreTrainedModel. electraBase
33
+ } else {
36
34
bertPretrained = BERT . PreTrainedModel. bertBase ( cased: false , multilingual: false )
35
+ }
36
+ } else {
37
+ bertPretrained = BERT . PreTrainedModel. bertBase ( cased: false , multilingual: false )
37
38
}
38
39
39
40
let bert = try bertPretrained. load ( )
@@ -54,11 +55,12 @@ bertClassifier.move(to: device)
54
55
let maxSequenceLength = 128
55
56
let batchSize = 1024
56
57
let epochCount = 3
57
- let stepsPerEpoch = 1068 // function of training set size and batching configuration
58
+ let stepsPerEpoch = 1068 // function of training set size and batching configuration
58
59
let peakLearningRate : Float = 2e-5
59
60
60
- let workspaceURL = URL ( fileURLWithPath: " bert_models " , isDirectory: true ,
61
- relativeTo: URL ( fileURLWithPath: NSTemporaryDirectory ( ) , isDirectory: true ) )
61
+ let workspaceURL = URL (
62
+ fileURLWithPath: " bert_models " , isDirectory: true ,
63
+ relativeTo: URL ( fileURLWithPath: NSTemporaryDirectory ( ) , isDirectory: true ) )
62
64
63
65
var cola = try CoLA (
64
66
taskDirectoryURL: workspaceURL,
@@ -69,10 +71,11 @@ var cola = try CoLA(
69
71
) { example in
70
72
// In this closure, both the input and output text batches must be eager
71
73
// since the text is not padded and x10 requires stable shapes.
72
- let textBatch = bertClassifier. bert. preprocess (
74
+ let classifier = bertClassifier
75
+ let textBatch = classifier. bert. preprocess (
73
76
sequences: [ example. sentence] ,
74
77
maxSequenceLength: maxSequenceLength)
75
- return ( data: textBatch, label: Tensor < Int32 > ( example. isAcceptable! ? 1 : 0 ) )
78
+ return LabeledData ( data: textBatch, label: Tensor < Int32 > ( example. isAcceptable! ? 1 : 0 ) )
76
79
}
77
80
78
81
print ( " Dataset acquired. " )
@@ -82,91 +85,70 @@ let beta2: Float = 0.999
82
85
let useBiasCorrection = true
83
86
84
87
var optimizer = x10_optimizers_optimizer. GeneralOptimizer (
85
- for: bertClassifier,
86
- TensorVisitorPlan ( bertClassifier. differentiableVectorView) ,
87
- defaultOptimizer: makeWeightDecayedAdam (
88
- learningRate: peakLearningRate,
89
- beta1: beta1,
90
- beta2: beta2
91
- )
88
+ for: bertClassifier,
89
+ TensorVisitorPlan ( bertClassifier. differentiableVectorView) ,
90
+ defaultOptimizer: makeWeightDecayedAdam (
91
+ learningRate: peakLearningRate,
92
+ beta1: beta1,
93
+ beta2: beta2
94
+ )
92
95
)
93
96
94
- var scheduledLearningRate = LinearlyDecayedParameter (
95
- baseParameter: LinearlyWarmedUpParameter (
97
+ /// Computes sigmoidCrossEntropy loss from `logits` and `labels`.
98
+ ///
99
+ /// This defines the loss function used in TrainingLoop; it's a wrapper of the
100
+ /// standard sigmoidCrossEntropy; it reshapes logits to required shape before
101
+ /// calling the standard sigmoidCrossEntropy.
102
+ @differentiable
103
+ public func sigmoidCrossEntropyReshaped< Scalar> ( logits: Tensor < Scalar > , labels: Tensor < Int32 > )
104
+ -> Tensor <
105
+ Scalar
106
+ > where Scalar: TensorFlowFloatingPoint
107
+ {
108
+ return sigmoidCrossEntropy (
109
+ logits: logits. squeezingShape ( at: - 1 ) ,
110
+ labels: Tensor < Scalar > ( labels) ,
111
+ reduction: _mean)
112
+ }
113
+
114
+ /// Clips the gradients by global norm.
115
+ ///
116
+ /// This's defined as a callback registered into TrainingLoop.
117
+ func clipGradByGlobalNorm< L: TrainingLoopProtocol > ( _ loop: inout L , event: TrainingLoopEvent ) throws
118
+ {
119
+ if event == . updateStart {
120
+ var gradients = loop. lastStepGradient!
121
+ gradients. clipByGlobalNorm ( clipNorm: 1 )
122
+ loop. lastStepGradient = gradients
123
+ }
124
+ }
125
+
126
+ /// A function that returns a LinearlyDecayedParameter but with first 10 steps linearly warmed up;
127
+ /// for remaining steps it decays at slope of -(peakLearningRate / `totalStepCount`).
128
+ let scheduledParameterGetter = { ( _ totalStepCount: Float ) -> LinearlyDecayedParameter in
129
+ LinearlyDecayedParameter (
130
+ baseParameter: LinearlyWarmedUpParameter (
96
131
baseParameter: FixedParameter < Float > ( peakLearningRate) ,
97
132
warmUpStepCount: 10 ,
98
133
warmUpOffset: 0 ) ,
99
- slope: - ( peakLearningRate / Float( stepsPerEpoch * epochCount) ) , // The LR decays linearly to zero.
100
- startStep: 10
101
- )
102
-
103
- print ( " Training \( bertPretrained. name) for the CoLA task! " )
104
- for (epoch, epochBatches) in cola. trainingEpochs. prefix ( epochCount) . enumerated ( ) {
105
- print ( " [Epoch \( epoch + 1 ) ] " )
106
- Context . local. learningPhase = . training
107
- var trainingLossSum : Float = 0
108
- var trainingBatchCount = 0
109
-
110
- for batch in epochBatches {
111
- let ( documents, labels) = ( batch. data, Tensor < Float > ( batch. label) )
112
- var ( loss, gradients) = valueWithGradient ( at: bertClassifier) { model -> Tensor < Float > in
113
- let logits = model ( documents)
114
- return sigmoidCrossEntropy (
115
- logits: logits. squeezingShape ( at: - 1 ) ,
116
- labels: labels,
117
- reduction: { $0. mean ( ) } )
118
- }
119
-
120
- trainingLossSum += loss. scalarized ( )
121
- trainingBatchCount += 1
122
- gradients. clipByGlobalNorm ( clipNorm: 1 )
123
-
124
- let step = optimizer. step + 1 // for scheduled rates and bias correction, steps start at 1
125
- optimizer. learningRate = scheduledLearningRate ( forStep: UInt64 ( step) )
126
- if useBiasCorrection {
127
- let step = Float ( step)
128
- optimizer. learningRate *= sqrtf ( 1 - powf( beta2, step) ) / ( 1 - powf( beta1, step) )
129
- }
130
-
131
- optimizer. update ( & bertClassifier, along: gradients)
132
- LazyTensorBarrier ( )
133
-
134
- print (
135
- """
136
- Training loss: \( trainingLossSum / Float( trainingBatchCount) )
137
- """
138
- )
139
- }
140
-
141
- Context . local. learningPhase = . inference
142
- var devLossSum : Float = 0
143
- var devBatchCount = 0
144
- var devPredictedLabels = [ Bool] ( )
145
- var devGroundTruth = [ Bool] ( )
146
- for batch in cola. validationBatches {
147
- let ( documents, labels) = ( batch. data, Tensor < Float > ( batch. label) )
148
- let logits = bertClassifier ( documents)
149
- let loss = sigmoidCrossEntropy (
150
- logits: logits. squeezingShape ( at: - 1 ) ,
151
- labels: labels,
152
- reduction: { $0. mean ( ) }
153
- )
154
- devLossSum += loss. scalarized ( )
155
- devBatchCount += 1
156
-
157
- let predictedLabels = sigmoid ( logits. squeezingShape ( at: - 1 ) ) .>= 0.5
158
- devPredictedLabels. append ( contentsOf: predictedLabels. scalars)
159
- devGroundTruth. append ( contentsOf: labels. scalars. map { $0 == 1 } )
160
- }
134
+ slope: - ( peakLearningRate / totalStepCount) , // The LR decays linearly to zero.
135
+ startStep: 10
136
+ )
137
+ }
161
138
162
- let mcc = matthewsCorrelationCoefficient (
163
- predictions: devPredictedLabels,
164
- groundTruth: devGroundTruth)
139
+ var trainingLoop : TrainingLoop = TrainingLoop (
140
+ training: cola. trainingEpochs,
141
+ validation: cola. validationBatches,
142
+ optimizer: optimizer,
143
+ lossFunction: sigmoidCrossEntropyReshaped,
144
+ metrics: [ . matthewsCorrelationCoefficient] ,
145
+ callbacks: [
146
+ clipGradByGlobalNorm,
147
+ LearningRateScheduler (
148
+ scheduledParameterGetter: scheduledParameterGetter,
149
+ biasCorrectionBeta1: beta1,
150
+ biasCorrectionBeta2: beta2) . schedule
151
+ ] )
165
152
166
- print (
167
- """
168
- MCC: \( mcc)
169
- Eval loss: \( devLossSum / Float( devBatchCount) )
170
- """
171
- )
172
- }
153
+ print ( " Training \( bertPretrained. name) for the CoLA task! " )
154
+ try ! trainingLoop. fit ( & bertClassifier, epochs: epochCount, on: device)
0 commit comments