Skip to content
This repository was archived by the owner on Feb 13, 2025. It is now read-only.

Commit b599958

Browse files
xihui-wuDave Abrahams
and
Dave Abrahams
authored
Refactor LearningRateScheduler (#711)
Refactor LearningRateScheduler Co-authored-by: Dave Abrahams <[email protected]>
1 parent baf0028 commit b599958

File tree

5 files changed

+260
-444
lines changed

5 files changed

+260
-444
lines changed

Examples/BERT-CoLA/main.swift

+11-16
Original file line numberDiff line numberDiff line change
@@ -123,18 +123,8 @@ func clipGradByGlobalNorm<L: TrainingLoopProtocol>(_ loop: inout L, event: Train
123123
}
124124
}
125125

126-
/// A function that returns a LinearlyDecayedParameter but with first 10 steps linearly warmed up;
127-
/// for remaining steps it decays at slope of -(peakLearningRate / `totalStepCount`).
128-
let scheduledParameterGetter = { (_ totalStepCount: Float) -> LinearlyDecayedParameter in
129-
LinearlyDecayedParameter(
130-
baseParameter: LinearlyWarmedUpParameter(
131-
baseParameter: FixedParameter<Float>(peakLearningRate),
132-
warmUpStepCount: 10,
133-
warmUpOffset: 0),
134-
slope: -(peakLearningRate / totalStepCount), // The LR decays linearly to zero.
135-
startStep: 10
136-
)
137-
}
126+
/// A linear shape to the learning rate in both warmup and decay phases.
127+
let linear = Shape({ $0 })
138128

139129
var trainingLoop: TrainingLoop = TrainingLoop(
140130
training: cola.trainingEpochs,
@@ -144,10 +134,15 @@ var trainingLoop: TrainingLoop = TrainingLoop(
144134
metrics: [.matthewsCorrelationCoefficient],
145135
callbacks: [
146136
clipGradByGlobalNorm,
147-
LearningRateScheduler(
148-
scheduledParameterGetter: scheduledParameterGetter,
149-
biasCorrectionBeta1: beta1,
150-
biasCorrectionBeta2: beta2).schedule
137+
learningRateScheduler(
138+
schedule: makeSchedule(
139+
[
140+
ScheduleSegment(shape: linear, startRate: 0, endRate: peakLearningRate, stepCount: 10),
141+
ScheduleSegment(shape: linear, endRate: 0)
142+
]
143+
),
144+
biasCorrectionBeta: (beta1, beta2)
145+
),
151146
])
152147

153148
print("Training \(bertPretrained.name) for the CoLA task!")

TrainingLoop/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ add_library(TrainingLoop
55
Callbacks/StatisticsRecorder.swift
66
Callbacks/ProgressPrinter.swift
77
Callbacks/CSVLogger.swift
8-
Callbacks/LearningRateScheduler.swift)
8+
Callbacks/LearningRateScheduler/LearningRateScheduler.swift
9+
Callbacks/LearningRateScheduler/LearningRateSchedule.swift)
910
target_link_libraries(TrainingLoop PUBLIC
1011
ModelSupport)
1112
set_target_properties(TrainingLoop PROPERTIES

0 commit comments

Comments
 (0)