@@ -114,8 +114,8 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
114
114
115
115
// MARK: - Encode
116
116
/// Returns the hidden states of the encoder LSTM applied to the given sentence.
117
- public func encode( _ x: CharacterSequence ) -> [ Tensor < Float > ] {
118
- var embedded = encoderEmbedding ( x. tensor)
117
+ public func encode( _ x: CharacterSequence , device : Device ) -> [ Tensor < Float > ] {
118
+ var embedded = encoderEmbedding ( x. tensor ( device : device ) )
119
119
embedded = dropout ( embedded)
120
120
let encoderStates = encoderLSTM ( embedded. unstacked ( ) . differentiableMap { $0. rankLifted ( ) } )
121
121
var encoderResult = Tensor (
@@ -126,7 +126,9 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
126
126
127
127
// MARK: - Decode
128
128
/// Returns log probabilities for each of the candidates.
129
- public func decode( _ candidates: [ CharacterSequence ] , _ state: Tensor < Float > ) -> Tensor < Float > {
129
+ public func decode( _ candidates: [ CharacterSequence ] , _ state: Tensor < Float > , device: Device )
130
+ -> Tensor < Float >
131
+ {
130
132
// TODO(TF-433): Remove closure workaround when autodiff supports non-active rethrowing
131
133
// functions (`Array.map`).
132
134
let maxLen = { candidates. map { $0. count } . max ( ) ! + 1 } ( )
@@ -148,21 +150,25 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
148
150
149
151
// Shapes are [time x batch] so that we can unstack the time dimension into the array that
150
152
// the LSTM wants as input.
151
- let x : Tensor < Int32 > = Tensor ( shape: [ candidates. count, maxLen] , scalars: xBatch) . transposed ( )
152
- let y : Tensor < Int32 > = Tensor ( shape: [ candidates. count, maxLen] , scalars: yBatch) . transposed ( )
153
+ let x : Tensor < Int32 > = Tensor (
154
+ shape: [ candidates. count, maxLen] , scalars: xBatch, on: device
155
+ ) . transposed ( )
156
+ let y : Tensor < Int32 > = Tensor (
157
+ shape: [ candidates. count, maxLen] , scalars: yBatch, on: device
158
+ ) . transposed ( )
153
159
154
160
// [time x batch x ndim]
155
161
var embeddedX = decoderEmbedding ( x)
156
162
embeddedX = dropout ( embeddedX)
157
163
158
164
// [batch x ndim]
159
- let stateBatch = state. rankLifted ( ) . tiled ( multiples: Tensor ( [ Int32 ( candidates. count) , 1 ] ) )
165
+ let stateBatch = state. rankLifted ( ) . tiled ( multiples: [ candidates. count, 1 ] )
160
166
161
167
// [time] array of LSTM states whose `hidden` and `cell` fields have shape [batch x ndim]
162
168
let decoderStates = decoderLSTM (
163
169
embeddedX. unstacked ( ) ,
164
170
initialState: LSTMCell . State (
165
- cell: Tensor ( zeros: stateBatch. shape) ,
171
+ cell: Tensor ( zeros: stateBatch. shape, on : device ) ,
166
172
hidden: stateBatch) )
167
173
168
174
// [time x batch x ndim]
@@ -183,7 +189,11 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
183
189
) . reshaped ( to: y. shape)
184
190
185
191
// [time x batch]
186
- let logpExcludingPad = logp * Tensor < Float > ( y .!= parameters. chrVocab. pad)
192
+ let padScalars = [ Int32] ( repeating: parameters. chrVocab. pad, count: candidates. count * maxLen)
193
+ let noPad = Tensor < Int32 > (
194
+ y .!= Tensor ( shape: y. shape, scalars: padScalars, on: device) )
195
+ let noPadFloat = Tensor < Float > ( noPad)
196
+ let logpExcludingPad = logp * noPadFloat
187
197
188
198
// [batch]
189
199
let candidateLogP = logpExcludingPad. transposed ( ) . sum ( squeezingAxes: 1 )
@@ -200,9 +210,9 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
200
210
}
201
211
202
212
@differentiable
203
- public func buildLattice( _ sentence: CharacterSequence , maxLen: Int ) -> Lattice {
213
+ public func buildLattice( _ sentence: CharacterSequence , maxLen: Int , device : Device ) -> Lattice {
204
214
var lattice = Lattice ( count: sentence. count)
205
- let states = encode ( sentence)
215
+ let states = encode ( sentence, device : device )
206
216
let logg_batch = mlpInterpolation ( Tensor ( stacking: states) )
207
217
let logp_lex_batch = mlpMemory ( Tensor ( stacking: states) )
208
218
for pos in 0 ..< sentence. count {
@@ -225,9 +235,10 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
225
235
}
226
236
227
237
let current_state = states [ pos]
228
- let logg = logg_batch [ pos] . scalarsADHack // [2]
229
- let logp_lex = logp_lex_batch [ pos] . scalarsADHack // [strVocab.chr.count]
230
- let logp_chr = decode ( candidates, current_state) . scalarsADHack // [candidates.count]
238
+ let logg = logg_batch [ pos] . scalarsADHack ( device: device) // [2]
239
+ let logp_lex = logp_lex_batch [ pos] . scalarsADHack ( device: device) // [strVocab.chr.count]
240
+ let logp_chr = decode ( candidates, current_state, device: device)
241
+ . scalarsADHack ( device: device) // [candidates.count]
231
242
if pos != 0 {
232
243
// Cleanup: lattice[pos].recomputeSemiringScore()
233
244
var updatedNode = lattice [ pos]
@@ -315,23 +326,22 @@ extension Tensor {
315
326
// (`Differentiable.zeroTangentVectorInitializer`) instead of static zeros
316
327
// (`AdditiveArithmetic.zero`).
317
328
@differentiable ( where Scalar: TensorFlowFloatingPoint)
318
- var scalarsADHack : [ Scalar ] {
329
+ func scalarsADHack( device : Device ) -> [ Scalar ] {
319
330
scalars
320
331
}
321
332
322
333
@derivative ( of: scalarsADHack)
323
- func vjpScalarsADHack( ) -> (
334
+ func vjpScalarsADHack( device : Device ) -> (
324
335
value: [ Scalar ] , pullback: ( Array < Scalar > . TangentVector ) -> Tensor
325
336
) where Scalar: TensorFlowFloatingPoint {
326
337
// In the pullback: capture only `self.shape`, not all of `self`.
327
338
let shape = self . shape
328
339
func pullback( _ tv: Array < Scalar > . TangentVector ) -> Tensor {
329
340
if tv. count == 0 {
330
- return Tensor ( zeros: shape)
341
+ return Tensor ( zeros: shape, on : device )
331
342
}
332
- return Tensor ( shape: shape, scalars: tv. base)
343
+ return Tensor ( shape: shape, scalars: tv. base, on : device )
333
344
}
334
345
return ( scalars, pullback)
335
346
}
336
347
}
337
-
0 commit comments