Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit def2eb4

Browse files
authored
Make WordSeg benchmarks support x10 as an option (#593)
Also convert WordSeg example to x10.
1 parent 709c89e commit def2eb4

File tree

5 files changed

+60
-43
lines changed

5 files changed

+60
-43
lines changed

Benchmarks/Models/WordSeg.swift

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ let maximumSequenceLength = 18
7878
struct WordSegBenchmark: Benchmark {
7979
let batchSize: Int
8080
let duration: BenchmarkDuration
81-
let operation: (SNLM, CharacterSequence) -> ()
81+
let operation: (SNLM, CharacterSequence, Device) -> ()
8282

83-
init(settings: BenchmarkSettings, operation: @escaping (SNLM, CharacterSequence) -> ()) {
83+
init(settings: BenchmarkSettings, operation: @escaping (SNLM, CharacterSequence, Device) -> ()) {
8484
self.duration = settings.duration
8585
self.batchSize = settings.batchSize
8686
self.operation = operation
@@ -129,7 +129,7 @@ struct WordSegBenchmark: Benchmark {
129129
}
130130

131131
for _ in 0..<iterations {
132-
operation(model, sentence)
132+
operation(model, sentence, device)
133133
LazyTensorBarrier()
134134

135135
batchTimings.append(durationInMilliseconds(since: beforeBatch))
@@ -152,26 +152,26 @@ extension WordSegBenchmark {
152152
return try CharacterSequence(alphabet: alphabet, appendingEoSTo: truncatedSentence)
153153
}
154154

155-
static func score(model: SNLM, sentence: CharacterSequence) {
156-
let lattice = model.buildLattice(sentence, maxLen: maximumSequenceLength)
155+
static func score(model: SNLM, sentence: CharacterSequence, device: Device) {
156+
let lattice = model.buildLattice(sentence, maxLen: maximumSequenceLength, device: device)
157157
let score = lattice[sentence.count].semiringScore
158158
let _ = score.logr + score.logp
159159
}
160160

161-
static func scoreAndGradient(model: SNLM, sentence: CharacterSequence) {
161+
static func scoreAndGradient(model: SNLM, sentence: CharacterSequence, device: Device) {
162162
let lambd: Float = 0.00075
163163

164-
let _ = valueWithGradient(at: model) { model -> Float in
165-
let lattice = model.buildLattice(sentence, maxLen: maximumSequenceLength)
164+
let _ = valueWithGradient(at: model) { model -> Tensor<Float> in
165+
let lattice = model.buildLattice(sentence, maxLen: maximumSequenceLength, device: device)
166166
let score = lattice[sentence.count].semiringScore
167167
let expectedLength = exp(score.logr - score.logp)
168168
let loss = -1 * score.logp + lambd * expectedLength
169-
return loss
169+
return Tensor(loss, on: device)
170170
}
171171
}
172172

173-
static func viterbi(model: SNLM, sentence: CharacterSequence) {
174-
var lattice = model.buildLattice(sentence, maxLen: maximumSequenceLength)
173+
static func viterbi(model: SNLM, sentence: CharacterSequence, device: Device) {
174+
var lattice = model.buildLattice(sentence, maxLen: maximumSequenceLength, device: device)
175175
let _ = lattice.viterbi(sentence: sentence)
176176
}
177177
}

Examples/WordSeg/main.swift

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,13 @@ let modelParameters = SNLM.Parameters(
6666
order: order
6767
)
6868

69+
let device = Device.defaultXLA
70+
6971
var model = SNLM(parameters: modelParameters)
72+
model.move(to: device)
7073

71-
let optimizer = Adam(for: model, learningRate: learningRate)
74+
var optimizer = Adam(for: model, learningRate: learningRate)
75+
optimizer = Adam(copying: optimizer, to: device)
7276

7377
print("Starting training...")
7478

@@ -78,18 +82,18 @@ for epoch in 1...maxEpochs {
7882
var trainingBatchCount = 0
7983
for record in dataset.training {
8084
let sentence = record.numericalizedText
81-
let (loss, gradients) = valueWithGradient(at: model) { model -> Float in
82-
let lattice = model.buildLattice(sentence, maxLen: maxLength)
85+
let (loss, gradients) = valueWithGradient(at: model) { model -> Tensor<Float> in
86+
let lattice = model.buildLattice(sentence, maxLen: maxLength, device: device)
8387
let score = lattice[sentence.count].semiringScore
8488
let expectedLength = exp(score.logr - score.logp)
8589
let loss = -1 * score.logp + lambd * expectedLength
86-
return loss
90+
return Tensor(loss, on: device)
8791
}
8892

89-
trainingLossSum += loss
93+
trainingLossSum += loss.scalarized()
9094
trainingBatchCount += 1
9195
optimizer.update(&model, along: gradients)
92-
96+
LazyTensorBarrier()
9397
if hasNaN(gradients) {
9498
print("Warning: grad has NaN")
9599
}
@@ -129,7 +133,7 @@ for epoch in 1...maxEpochs {
129133
var validationPlainText: String = ""
130134
for record in validationDataset {
131135
let sentence = record.numericalizedText
132-
var lattice = model.buildLattice(sentence, maxLen: maxLength)
136+
var lattice = model.buildLattice(sentence, maxLen: maxLength, device: device)
133137
let score = lattice[sentence.count].semiringScore
134138

135139
validationLossSum -= score.logp

Models/Text/WordSeg/Model.swift

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
114114

115115
// MARK: - Encode
116116
/// Returns the hidden states of the encoder LSTM applied to the given sentence.
117-
public func encode(_ x: CharacterSequence) -> [Tensor<Float>] {
118-
var embedded = encoderEmbedding(x.tensor)
117+
public func encode(_ x: CharacterSequence, device: Device) -> [Tensor<Float>] {
118+
var embedded = encoderEmbedding(x.tensor(device: device))
119119
embedded = dropout(embedded)
120120
let encoderStates = encoderLSTM(embedded.unstacked().differentiableMap { $0.rankLifted() })
121121
var encoderResult = Tensor(
@@ -126,7 +126,9 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
126126

127127
// MARK: - Decode
128128
/// Returns log probabilities for each of the candidates.
129-
public func decode(_ candidates: [CharacterSequence], _ state: Tensor<Float>) -> Tensor<Float> {
129+
public func decode(_ candidates: [CharacterSequence], _ state: Tensor<Float>, device: Device)
130+
-> Tensor<Float>
131+
{
130132
// TODO(TF-433): Remove closure workaround when autodiff supports non-active rethrowing
131133
// functions (`Array.map`).
132134
let maxLen = { candidates.map { $0.count }.max()! + 1 }()
@@ -148,21 +150,25 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
148150

149151
// Shapes are [time x batch] so that we can unstack the time dimension into the array that
150152
// the LSTM wants as input.
151-
let x: Tensor<Int32> = Tensor(shape: [candidates.count, maxLen], scalars: xBatch).transposed()
152-
let y: Tensor<Int32> = Tensor(shape: [candidates.count, maxLen], scalars: yBatch).transposed()
153+
let x: Tensor<Int32> = Tensor(
154+
shape: [candidates.count, maxLen], scalars: xBatch, on: device
155+
).transposed()
156+
let y: Tensor<Int32> = Tensor(
157+
shape: [candidates.count, maxLen], scalars: yBatch, on: device
158+
).transposed()
153159

154160
// [time x batch x ndim]
155161
var embeddedX = decoderEmbedding(x)
156162
embeddedX = dropout(embeddedX)
157163

158164
// [batch x ndim]
159-
let stateBatch = state.rankLifted().tiled(multiples: Tensor([Int32(candidates.count), 1]))
165+
let stateBatch = state.rankLifted().tiled(multiples: [candidates.count, 1])
160166

161167
// [time] array of LSTM states whose `hidden` and `cell` fields have shape [batch x ndim]
162168
let decoderStates = decoderLSTM(
163169
embeddedX.unstacked(),
164170
initialState: LSTMCell.State(
165-
cell: Tensor(zeros: stateBatch.shape),
171+
cell: Tensor(zeros: stateBatch.shape, on: device),
166172
hidden: stateBatch))
167173

168174
// [time x batch x ndim]
@@ -183,7 +189,11 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
183189
).reshaped(to: y.shape)
184190

185191
// [time x batch]
186-
let logpExcludingPad = logp * Tensor<Float>(y .!= parameters.chrVocab.pad)
192+
let padScalars = [Int32](repeating: parameters.chrVocab.pad, count: candidates.count * maxLen)
193+
let noPad = Tensor<Int32>(
194+
y .!= Tensor(shape: y.shape, scalars: padScalars, on: device))
195+
let noPadFloat = Tensor<Float>(noPad)
196+
let logpExcludingPad = logp * noPadFloat
187197

188198
// [batch]
189199
let candidateLogP = logpExcludingPad.transposed().sum(squeezingAxes: 1)
@@ -200,9 +210,9 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
200210
}
201211

202212
@differentiable
203-
public func buildLattice(_ sentence: CharacterSequence, maxLen: Int) -> Lattice {
213+
public func buildLattice(_ sentence: CharacterSequence, maxLen: Int, device: Device) -> Lattice {
204214
var lattice = Lattice(count: sentence.count)
205-
let states = encode(sentence)
215+
let states = encode(sentence, device: device)
206216
let logg_batch = mlpInterpolation(Tensor(stacking: states))
207217
let logp_lex_batch = mlpMemory(Tensor(stacking: states))
208218
for pos in 0..<sentence.count {
@@ -225,9 +235,10 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
225235
}
226236

227237
let current_state = states[pos]
228-
let logg = logg_batch[pos].scalarsADHack // [2]
229-
let logp_lex = logp_lex_batch[pos].scalarsADHack // [strVocab.chr.count]
230-
let logp_chr = decode(candidates, current_state).scalarsADHack // [candidates.count]
238+
let logg = logg_batch[pos].scalarsADHack(device: device) // [2]
239+
let logp_lex = logp_lex_batch[pos].scalarsADHack(device: device) // [strVocab.chr.count]
240+
let logp_chr = decode(candidates, current_state, device: device)
241+
.scalarsADHack(device: device) // [candidates.count]
231242
if pos != 0 {
232243
// Cleanup: lattice[pos].recomputeSemiringScore()
233244
var updatedNode = lattice[pos]
@@ -315,23 +326,22 @@ extension Tensor {
315326
// (`Differentiable.zeroTangentVectorInitializer`) instead of static zeros
316327
// (`AdditiveArithmetic.zero`).
317328
@differentiable(where Scalar: TensorFlowFloatingPoint)
318-
var scalarsADHack: [Scalar] {
329+
func scalarsADHack(device: Device) -> [Scalar] {
319330
scalars
320331
}
321332

322333
@derivative(of: scalarsADHack)
323-
func vjpScalarsADHack() -> (
334+
func vjpScalarsADHack(device: Device) -> (
324335
value: [Scalar], pullback: (Array<Scalar>.TangentVector) -> Tensor
325336
) where Scalar: TensorFlowFloatingPoint {
326337
// In the pullback: capture only `self.shape`, not all of `self`.
327338
let shape = self.shape
328339
func pullback(_ tv: Array<Scalar>.TangentVector) -> Tensor {
329340
if tv.count == 0 {
330-
return Tensor(zeros: shape)
341+
return Tensor(zeros: shape, on: device)
331342
}
332-
return Tensor(shape: shape, scalars: tv.base)
343+
return Tensor(shape: shape, scalars: tv.base, on: device)
333344
}
334345
return (scalars, pullback)
335346
}
336347
}
337-

Support/Text/WordSeg/CharacterSequence.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,12 @@ public struct CharacterSequence: Hashable {
5555
return characters[range]
5656
}
5757

58+
public func tensor(device: Device) -> Tensor<Int32> {
59+
Tensor<Int32>([self.eos] + characters[0..<characters.count - 1], on: device)
60+
}
61+
5862
public var count: Int { return characters.count }
5963
public var last: Int32? { return characters.last }
60-
public var tensor: Tensor<Int32> {
61-
Tensor<Int32>([self.eos] + characters[0..<characters.count - 1])
62-
}
6364
}
6465

6566
extension CharacterSequence: CustomStringConvertible {

Tests/TextTests/WordSegmentationTests/ProbeLayers.swift

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,11 @@ class WordSegProbeLayerTests: XCTestCase {
155155
order: 5))
156156

157157
model.setParameters(Example1.parameters)
158+
let device = Device.default
158159

159160
print("Encoding")
160161
let encoderStates = model.encode(
161-
CharacterSequence(alphabet: chrVocab, characters: [0, 1, 0, 1])) // "abab"
162+
CharacterSequence(alphabet: chrVocab, characters: [0, 1, 0, 1]), device: device) // "abab"
162163
let encoderStatesTensor = Tensor(stacking: encoderStates)
163164
print("Expected: \(Example1.expectedEncoding)")
164165
print("Actual: \(encoderStatesTensor)")
@@ -187,7 +188,8 @@ class WordSegProbeLayerTests: XCTestCase {
187188
CharacterSequence(alphabet: chrVocab, characters: [0, 0, 0]), // "aaa"
188189
CharacterSequence(alphabet: chrVocab, characters: [0, 1]), // "ab"
189190
],
190-
encoderStates[0]
191+
encoderStates[0],
192+
device: device
191193
)
192194
print("Expected: \(Example1.expectedDecoded)")
193195
print("Actual: \(decoded)")
@@ -196,12 +198,12 @@ class WordSegProbeLayerTests: XCTestCase {
196198

197199
print("Build Lattice")
198200
let abab = CharacterSequence(alphabet: chrVocab, characters: [0, 1, 0, 1])
199-
let lattice = model.buildLattice(abab, maxLen: 5)
201+
let lattice = model.buildLattice(abab, maxLen: 5, device: device)
200202
XCTAssert(lattice.isAlmostEqual(to: Example1.lattice, tolerance: 1e-5))
201203

202204
print("Gradient")
203205
func f(_ x: SNLM) -> Float {
204-
x.buildLattice(abab, maxLen: 5)[4].semiringScore.logr
206+
x.buildLattice(abab, maxLen: 5, device: device)[4].semiringScore.logr
205207
}
206208
let (_, grad) = valueWithGradient(at: model, in: f)
207209
let expectedGrad = tangentVector(from: Example1.gradWrtLogR, model: model)

0 commit comments

Comments
 (0)