@@ -114,8 +114,8 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
114114
115115 // MARK: - Encode
116116 /// Returns the hidden states of the encoder LSTM applied to the given sentence.
117- public func encode( _ x: CharacterSequence ) -> [ Tensor < Float > ] {
118- var embedded = encoderEmbedding ( x. tensor)
117+ public func encode( _ x: CharacterSequence , device : Device ) -> [ Tensor < Float > ] {
118+ var embedded = encoderEmbedding ( x. tensor ( device : device ) )
119119 embedded = dropout ( embedded)
120120 let encoderStates = encoderLSTM ( embedded. unstacked ( ) . differentiableMap { $0. rankLifted ( ) } )
121121 var encoderResult = Tensor (
@@ -126,7 +126,9 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
126126
127127 // MARK: - Decode
128128 /// Returns log probabilities for each of the candidates.
129- public func decode( _ candidates: [ CharacterSequence ] , _ state: Tensor < Float > ) -> Tensor < Float > {
129+ public func decode( _ candidates: [ CharacterSequence ] , _ state: Tensor < Float > , device: Device )
130+ -> Tensor < Float >
131+ {
130132 // TODO(TF-433): Remove closure workaround when autodiff supports non-active rethrowing
131133 // functions (`Array.map`).
132134 let maxLen = { candidates. map { $0. count } . max ( ) ! + 1 } ( )
@@ -148,21 +150,25 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
148150
149151 // Shapes are [time x batch] so that we can unstack the time dimension into the array that
150152 // the LSTM wants as input.
151- let x : Tensor < Int32 > = Tensor ( shape: [ candidates. count, maxLen] , scalars: xBatch) . transposed ( )
152- let y : Tensor < Int32 > = Tensor ( shape: [ candidates. count, maxLen] , scalars: yBatch) . transposed ( )
153+ let x : Tensor < Int32 > = Tensor (
154+ shape: [ candidates. count, maxLen] , scalars: xBatch, on: device
155+ ) . transposed ( )
156+ let y : Tensor < Int32 > = Tensor (
157+ shape: [ candidates. count, maxLen] , scalars: yBatch, on: device
158+ ) . transposed ( )
153159
154160 // [time x batch x ndim]
155161 var embeddedX = decoderEmbedding ( x)
156162 embeddedX = dropout ( embeddedX)
157163
158164 // [batch x ndim]
159- let stateBatch = state. rankLifted ( ) . tiled ( multiples: Tensor ( [ Int32 ( candidates. count) , 1 ] ) )
165+ let stateBatch = state. rankLifted ( ) . tiled ( multiples: [ candidates. count, 1 ] )
160166
161167 // [time] array of LSTM states whose `hidden` and `cell` fields have shape [batch x ndim]
162168 let decoderStates = decoderLSTM (
163169 embeddedX. unstacked ( ) ,
164170 initialState: LSTMCell . State (
165- cell: Tensor ( zeros: stateBatch. shape) ,
171+ cell: Tensor ( zeros: stateBatch. shape, on : device ) ,
166172 hidden: stateBatch) )
167173
168174 // [time x batch x ndim]
@@ -183,7 +189,11 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
183189 ) . reshaped ( to: y. shape)
184190
185191 // [time x batch]
186- let logpExcludingPad = logp * Tensor < Float > ( y .!= parameters. chrVocab. pad)
192+ let padScalars = [ Int32] ( repeating: parameters. chrVocab. pad, count: candidates. count * maxLen)
193+ let noPad = Tensor < Int32 > (
194+ y .!= Tensor ( shape: y. shape, scalars: padScalars, on: device) )
195+ let noPadFloat = Tensor < Float > ( noPad)
196+ let logpExcludingPad = logp * noPadFloat
187197
188198 // [batch]
189199 let candidateLogP = logpExcludingPad. transposed ( ) . sum ( squeezingAxes: 1 )
@@ -200,9 +210,9 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
200210 }
201211
202212 @differentiable
203- public func buildLattice( _ sentence: CharacterSequence , maxLen: Int ) -> Lattice {
213+ public func buildLattice( _ sentence: CharacterSequence , maxLen: Int , device : Device ) -> Lattice {
204214 var lattice = Lattice ( count: sentence. count)
205- let states = encode ( sentence)
215+ let states = encode ( sentence, device : device )
206216 let logg_batch = mlpInterpolation ( Tensor ( stacking: states) )
207217 let logp_lex_batch = mlpMemory ( Tensor ( stacking: states) )
208218 for pos in 0 ..< sentence. count {
@@ -225,9 +235,10 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
225235 }
226236
227237 let current_state = states [ pos]
228- let logg = logg_batch [ pos] . scalarsADHack // [2]
229- let logp_lex = logp_lex_batch [ pos] . scalarsADHack // [strVocab.chr.count]
230- let logp_chr = decode ( candidates, current_state) . scalarsADHack // [candidates.count]
238+ let logg = logg_batch [ pos] . scalarsADHack ( device: device) // [2]
239+ let logp_lex = logp_lex_batch [ pos] . scalarsADHack ( device: device) // [strVocab.chr.count]
240+ let logp_chr = decode ( candidates, current_state, device: device)
241+ . scalarsADHack ( device: device) // [candidates.count]
231242 if pos != 0 {
232243 // Cleanup: lattice[pos].recomputeSemiringScore()
233244 var updatedNode = lattice [ pos]
@@ -315,23 +326,22 @@ extension Tensor {
315326 // (`Differentiable.zeroTangentVectorInitializer`) instead of static zeros
316327 // (`AdditiveArithmetic.zero`).
317328 @differentiable ( where Scalar: TensorFlowFloatingPoint)
318- var scalarsADHack : [ Scalar ] {
329+ func scalarsADHack( device : Device ) -> [ Scalar ] {
319330 scalars
320331 }
321332
322333 @derivative ( of: scalarsADHack)
323- func vjpScalarsADHack( ) -> (
334+ func vjpScalarsADHack( device : Device ) -> (
324335 value: [ Scalar ] , pullback: ( Array < Scalar > . TangentVector ) -> Tensor
325336 ) where Scalar: TensorFlowFloatingPoint {
326337 // In the pullback: capture only `self.shape`, not all of `self`.
327338 let shape = self . shape
328339 func pullback( _ tv: Array < Scalar > . TangentVector ) -> Tensor {
329340 if tv. count == 0 {
330- return Tensor ( zeros: shape)
341+ return Tensor ( zeros: shape, on : device )
331342 }
332- return Tensor ( shape: shape, scalars: tv. base)
343+ return Tensor ( shape: shape, scalars: tv. base, on : device )
333344 }
334345 return ( scalars, pullback)
335346 }
336347}
337-
0 commit comments