@@ -123,18 +123,8 @@ func clipGradByGlobalNorm<L: TrainingLoopProtocol>(_ loop: inout L, event: Train
123
123
}
124
124
}
125
125
126
- /// A function that returns a LinearlyDecayedParameter but with first 10 steps linearly warmed up;
127
- /// for remaining steps it decays at slope of -(peakLearningRate / `totalStepCount`).
128
- let scheduledParameterGetter = { ( _ totalStepCount: Float ) -> LinearlyDecayedParameter in
129
- LinearlyDecayedParameter (
130
- baseParameter: LinearlyWarmedUpParameter (
131
- baseParameter: FixedParameter < Float > ( peakLearningRate) ,
132
- warmUpStepCount: 10 ,
133
- warmUpOffset: 0 ) ,
134
- slope: - ( peakLearningRate / totalStepCount) , // The LR decays linearly to zero.
135
- startStep: 10
136
- )
137
- }
126
+ /// A linear shape to the learning rate in both warmup and decay phases.
127
+ let linear = Shape ( { $0 } )
138
128
139
129
var trainingLoop : TrainingLoop = TrainingLoop (
140
130
training: cola. trainingEpochs,
@@ -144,10 +134,15 @@ var trainingLoop: TrainingLoop = TrainingLoop(
144
134
metrics: [ . matthewsCorrelationCoefficient] ,
145
135
callbacks: [
146
136
clipGradByGlobalNorm,
147
- LearningRateScheduler (
148
- scheduledParameterGetter: scheduledParameterGetter,
149
- biasCorrectionBeta1: beta1,
150
- biasCorrectionBeta2: beta2) . schedule
137
+ learningRateScheduler (
138
+ schedule: makeSchedule (
139
+ [
140
+ ScheduleSegment ( shape: linear, startRate: 0 , endRate: peakLearningRate, stepCount: 10 ) ,
141
+ ScheduleSegment ( shape: linear, endRate: 0 )
142
+ ]
143
+ ) ,
144
+ biasCorrectionBeta: ( beta1, beta2)
145
+ ) ,
151
146
] )
152
147
153
148
print ( " Training \( bertPretrained. name) for the CoLA task! " )
0 commit comments