Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit ff5d809

Browse files
authored
Convert to TrainingLoop for BERTCoLA (#679)
1 parent b6e5d01 commit ff5d809

File tree

10 files changed

+580
-536
lines changed

10 files changed

+580
-536
lines changed

Datasets/CoLA/CoLA.swift

+3-3
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public struct CoLA<Entropy: RandomNumberGenerator> {
4242
public let directoryURL: URL
4343

4444
/// A `TextBatch` with the corresponding labels.
45-
public typealias LabeledTextBatch = (data: TextBatch, label: Tensor<Int32>)
45+
public typealias LabeledTextBatch = LabeledData<TextBatch, Tensor<Int32>>
4646
/// The type of the labeled samples.
4747
public typealias Samples = LazyMapSequence<[CoLAExample], LabeledTextBatch>
4848
/// The training texts.
@@ -158,7 +158,7 @@ extension CoLA {
158158
samples: trainingExamples, batchSize: batchSize / maxSequenceLength, entropy: entropy
159159
).lazy.map { (batches: Batches) -> LazyMapSequence<Batches, LabeledTextBatch> in
160160
batches.lazy.map{
161-
(
161+
LabeledData(
162162
data: $0.map(\.data).paddedAndCollated(to: maxSequenceLength, on: device),
163163
label: Tensor(copying: Tensor($0.map(\.label)), to: device)
164164
)
@@ -167,7 +167,7 @@ extension CoLA {
167167

168168
// Create the validation collection of batches.
169169
validationBatches = validationExamples.inBatches(of: batchSize / maxSequenceLength).lazy.map{
170-
(
170+
LabeledData(
171171
data: $0.map(\.data).paddedAndCollated(to: maxSequenceLength, on: device),
172172
label: Tensor(copying: Tensor($0.map(\.label)), to: device)
173173
)

Examples/BERT-CoLA/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ add_executable(BERT-CoLA
22
main.swift)
33
target_link_libraries(BERT-CoLA PRIVATE
44
TextModels
5-
Datasets)
5+
Datasets
6+
TrainingLoop)
67

78

89
install(TARGETS BERT-CoLA

Examples/BERT-CoLA/main.swift

+78-96
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,24 @@ import Foundation
1717
import ModelSupport
1818
import TensorFlow
1919
import TextModels
20+
import TrainingLoop
2021
import x10_optimizers_optimizer
2122

2223
let device = Device.defaultXLA
2324

2425
var bertPretrained: BERT.PreTrainedModel
2526
if CommandLine.arguments.count >= 2 {
26-
if CommandLine.arguments[1].lowercased() == "albert" {
27-
bertPretrained = BERT.PreTrainedModel.albertBase
28-
} else if CommandLine.arguments[1].lowercased() == "roberta" {
29-
bertPretrained = BERT.PreTrainedModel.robertaBase
30-
} else if CommandLine.arguments[1].lowercased() == "electra" {
31-
bertPretrained = BERT.PreTrainedModel.electraBase
32-
} else {
33-
bertPretrained = BERT.PreTrainedModel.bertBase(cased: false, multilingual: false)
34-
}
35-
} else {
27+
if CommandLine.arguments[1].lowercased() == "albert" {
28+
bertPretrained = BERT.PreTrainedModel.albertBase
29+
} else if CommandLine.arguments[1].lowercased() == "roberta" {
30+
bertPretrained = BERT.PreTrainedModel.robertaBase
31+
} else if CommandLine.arguments[1].lowercased() == "electra" {
32+
bertPretrained = BERT.PreTrainedModel.electraBase
33+
} else {
3634
bertPretrained = BERT.PreTrainedModel.bertBase(cased: false, multilingual: false)
35+
}
36+
} else {
37+
bertPretrained = BERT.PreTrainedModel.bertBase(cased: false, multilingual: false)
3738
}
3839

3940
let bert = try bertPretrained.load()
@@ -54,11 +55,12 @@ bertClassifier.move(to: device)
5455
let maxSequenceLength = 128
5556
let batchSize = 1024
5657
let epochCount = 3
57-
let stepsPerEpoch = 1068 // function of training set size and batching configuration
58+
let stepsPerEpoch = 1068 // function of training set size and batching configuration
5859
let peakLearningRate: Float = 2e-5
5960

60-
let workspaceURL = URL(fileURLWithPath: "bert_models", isDirectory: true,
61-
relativeTo: URL(fileURLWithPath: NSTemporaryDirectory(),isDirectory: true))
61+
let workspaceURL = URL(
62+
fileURLWithPath: "bert_models", isDirectory: true,
63+
relativeTo: URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true))
6264

6365
var cola = try CoLA(
6466
taskDirectoryURL: workspaceURL,
@@ -69,10 +71,11 @@ var cola = try CoLA(
6971
) { example in
7072
// In this closure, both the input and output text batches must be eager
7173
// since the text is not padded and x10 requires stable shapes.
72-
let textBatch = bertClassifier.bert.preprocess(
74+
let classifier = bertClassifier
75+
let textBatch = classifier.bert.preprocess(
7376
sequences: [example.sentence],
7477
maxSequenceLength: maxSequenceLength)
75-
return (data: textBatch, label: Tensor<Int32>(example.isAcceptable! ? 1 : 0))
78+
return LabeledData(data: textBatch, label: Tensor<Int32>(example.isAcceptable! ? 1 : 0))
7679
}
7780

7881
print("Dataset acquired.")
@@ -82,91 +85,70 @@ let beta2: Float = 0.999
8285
let useBiasCorrection = true
8386

8487
var optimizer = x10_optimizers_optimizer.GeneralOptimizer(
85-
for: bertClassifier,
86-
TensorVisitorPlan(bertClassifier.differentiableVectorView),
87-
defaultOptimizer: makeWeightDecayedAdam(
88-
learningRate: peakLearningRate,
89-
beta1: beta1,
90-
beta2: beta2
91-
)
88+
for: bertClassifier,
89+
TensorVisitorPlan(bertClassifier.differentiableVectorView),
90+
defaultOptimizer: makeWeightDecayedAdam(
91+
learningRate: peakLearningRate,
92+
beta1: beta1,
93+
beta2: beta2
94+
)
9295
)
9396

94-
var scheduledLearningRate = LinearlyDecayedParameter(
95-
baseParameter: LinearlyWarmedUpParameter(
97+
/// Computes sigmoidCrossEntropy loss from `logits` and `labels`.
98+
///
99+
/// This defines the loss function used in TrainingLoop; it's a wrapper of the
100+
/// standard sigmoidCrossEntropy; it reshapes logits to required shape before
101+
/// calling the standard sigmoidCrossEntropy.
102+
@differentiable
103+
public func sigmoidCrossEntropyReshaped<Scalar>(logits: Tensor<Scalar>, labels: Tensor<Int32>)
104+
-> Tensor<
105+
Scalar
106+
> where Scalar: TensorFlowFloatingPoint
107+
{
108+
return sigmoidCrossEntropy(
109+
logits: logits.squeezingShape(at: -1),
110+
labels: Tensor<Scalar>(labels),
111+
reduction: _mean)
112+
}
113+
114+
/// Clips the gradients by global norm.
115+
///
116+
/// This's defined as a callback registered into TrainingLoop.
117+
func clipGradByGlobalNorm<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws
118+
{
119+
if event == .updateStart {
120+
var gradients = loop.lastStepGradient!
121+
gradients.clipByGlobalNorm(clipNorm: 1)
122+
loop.lastStepGradient = gradients
123+
}
124+
}
125+
126+
/// A function that returns a LinearlyDecayedParameter but with first 10 steps linearly warmed up;
127+
/// for remaining steps it decays at slope of -(peakLearningRate / `totalStepCount`).
128+
let scheduledParameterGetter = { (_ totalStepCount: Float) -> LinearlyDecayedParameter in
129+
LinearlyDecayedParameter(
130+
baseParameter: LinearlyWarmedUpParameter(
96131
baseParameter: FixedParameter<Float>(peakLearningRate),
97132
warmUpStepCount: 10,
98133
warmUpOffset: 0),
99-
slope: -(peakLearningRate / Float(stepsPerEpoch * epochCount)), // The LR decays linearly to zero.
100-
startStep: 10
101-
)
102-
103-
print("Training \(bertPretrained.name) for the CoLA task!")
104-
for (epoch, epochBatches) in cola.trainingEpochs.prefix(epochCount).enumerated() {
105-
print("[Epoch \(epoch + 1)]")
106-
Context.local.learningPhase = .training
107-
var trainingLossSum: Float = 0
108-
var trainingBatchCount = 0
109-
110-
for batch in epochBatches {
111-
let (documents, labels) = (batch.data, Tensor<Float>(batch.label))
112-
var (loss, gradients) = valueWithGradient(at: bertClassifier) { model -> Tensor<Float> in
113-
let logits = model(documents)
114-
return sigmoidCrossEntropy(
115-
logits: logits.squeezingShape(at: -1),
116-
labels: labels,
117-
reduction: { $0.mean() })
118-
}
119-
120-
trainingLossSum += loss.scalarized()
121-
trainingBatchCount += 1
122-
gradients.clipByGlobalNorm(clipNorm: 1)
123-
124-
let step = optimizer.step + 1 // for scheduled rates and bias correction, steps start at 1
125-
optimizer.learningRate = scheduledLearningRate(forStep: UInt64(step))
126-
if useBiasCorrection {
127-
let step = Float(step)
128-
optimizer.learningRate *= sqrtf(1 - powf(beta2, step)) / (1 - powf(beta1, step))
129-
}
130-
131-
optimizer.update(&bertClassifier, along: gradients)
132-
LazyTensorBarrier()
133-
134-
print(
135-
"""
136-
Training loss: \(trainingLossSum / Float(trainingBatchCount))
137-
"""
138-
)
139-
}
140-
141-
Context.local.learningPhase = .inference
142-
var devLossSum: Float = 0
143-
var devBatchCount = 0
144-
var devPredictedLabels = [Bool]()
145-
var devGroundTruth = [Bool]()
146-
for batch in cola.validationBatches {
147-
let (documents, labels) = (batch.data, Tensor<Float>(batch.label))
148-
let logits = bertClassifier(documents)
149-
let loss = sigmoidCrossEntropy(
150-
logits: logits.squeezingShape(at: -1),
151-
labels: labels,
152-
reduction: { $0.mean() }
153-
)
154-
devLossSum += loss.scalarized()
155-
devBatchCount += 1
156-
157-
let predictedLabels = sigmoid(logits.squeezingShape(at: -1)) .>= 0.5
158-
devPredictedLabels.append(contentsOf: predictedLabels.scalars)
159-
devGroundTruth.append(contentsOf: labels.scalars.map { $0 == 1 })
160-
}
134+
slope: -(peakLearningRate / totalStepCount), // The LR decays linearly to zero.
135+
startStep: 10
136+
)
137+
}
161138

162-
let mcc = matthewsCorrelationCoefficient(
163-
predictions: devPredictedLabels,
164-
groundTruth: devGroundTruth)
139+
var trainingLoop: TrainingLoop = TrainingLoop(
140+
training: cola.trainingEpochs,
141+
validation: cola.validationBatches,
142+
optimizer: optimizer,
143+
lossFunction: sigmoidCrossEntropyReshaped,
144+
metrics: [.matthewsCorrelationCoefficient],
145+
callbacks: [
146+
clipGradByGlobalNorm,
147+
LearningRateScheduler(
148+
scheduledParameterGetter: scheduledParameterGetter,
149+
biasCorrectionBeta1: beta1,
150+
biasCorrectionBeta2: beta2).schedule
151+
])
165152

166-
print(
167-
"""
168-
MCC: \(mcc)
169-
Eval loss: \(devLossSum / Float(devBatchCount))
170-
"""
171-
)
172-
}
153+
print("Training \(bertPretrained.name) for the CoLA task!")
154+
try! trainingLoop.fit(&bertClassifier, epochs: epochCount, on: device)

Models/Text/CMakeLists.txt

-2
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@ add_library(TextModels
33
BERT.swift
44
BERTClassifier.swift
55
BERT/BERTCheckpointReader.swift
6-
Evaluation.swift
76
GPT2/CheckpointWriter.swift
87
GPT2/GPT2.swift
98
GPT2/TransformerLM.swift
109
GPT2/Operators.swift
1110
GPT2/PythonCheckpointReader.swift
12-
ScheduledParameters.swift
1311
TransformerBERT.swift
1412
Utilities.swift
1513
WeightDecayedAdam.swift

Models/Text/Evaluation.swift

-42
This file was deleted.

0 commit comments

Comments
 (0)