Skip to content

Commit 41c50c7

Browse files
lingvo-botcopybara-github
authored andcommitted
Don't supply query_segment_id when packed_input is disabled.
PiperOrigin-RevId: 696646145
1 parent cfd911c commit 41c50c7

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

lingvo/core/rnn_layers.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,14 @@ def GeneratePackedInputResetMask(segment_id, is_reverse=False):
5555
class IdentitySeqLayer(base_layer.BaseLayer):
5656
"""A no-op sequence layer."""
5757

58-
def __init__(self, params):
59-
super().__init__(params)
60-
6158
def zero_state(self, theta, batch_size):
59+
del theta
60+
del batch_size
6261
return py_utils.NestedMap()
6362

6463
def FPropFullSequence(self, theta, inputs, paddings):
64+
del theta
65+
del paddings
6566
return inputs
6667

6768

@@ -868,6 +869,7 @@ def zero_state(self,
868869
Returns:
869870
state0 - A `.NestedMap` containing initial states of RNN and attention.
870871
"""
872+
del atten_state_dim
871873

872874
p = self.params
873875
atten = self.atten
@@ -894,6 +896,7 @@ def zero_state(self,
894896
return state0
895897

896898
def reset_atten_state(self, theta, state, inputs):
899+
del theta
897900
state.atten = inputs.reset_mask * state.atten
898901
if isinstance(state.atten_state, py_utils.NestedMap):
899902
if 'inner' not in state.atten_state:
@@ -986,14 +989,21 @@ def CellFn(theta, state0, inputs):
986989
py_utils.NestedMap(
987990
act=act, padding=inputs.padding, reset_mask=inputs.reset_mask))
988991

992+
query_segment_id = (
993+
tf.cast(tf.squeeze(inputs.segment_id, 1), py_utils.FPropDtype(p))
994+
if p.packed_input
995+
else None
996+
)
997+
989998
state1.atten, state1.atten_probs, state1.atten_state = (
990999
self.atten.ComputeContextVectorWithSource(
9911000
theta.atten,
9921001
theta.packed_src,
9931002
self.cell.GetOutput(state1.rnn),
9941003
state0_mod.atten_state,
995-
query_segment_id=tf.cast(
996-
tf.squeeze(inputs.segment_id, 1), py_utils.FPropDtype(p))))
1004+
query_segment_id=query_segment_id,
1005+
)
1006+
)
9971007
return state1, py_utils.NestedMap()
9981008

9991009
if p.packed_input:

0 commit comments

Comments
 (0)