@@ -55,13 +55,14 @@ def GeneratePackedInputResetMask(segment_id, is_reverse=False):
55
55
class IdentitySeqLayer (base_layer .BaseLayer ):
56
56
"""A no-op sequence layer."""
57
57
58
- def __init__ (self , params ):
59
- super ().__init__ (params )
60
-
61
58
def zero_state (self , theta , batch_size ):
59
+ del theta
60
+ del batch_size
62
61
return py_utils .NestedMap ()
63
62
64
63
def FPropFullSequence (self , theta , inputs , paddings ):
64
+ del theta
65
+ del paddings
65
66
return inputs
66
67
67
68
@@ -868,6 +869,7 @@ def zero_state(self,
868
869
Returns:
869
870
state0 - A `.NestedMap` containing initial states of RNN and attention.
870
871
"""
872
+ del atten_state_dim
871
873
872
874
p = self .params
873
875
atten = self .atten
@@ -894,6 +896,7 @@ def zero_state(self,
894
896
return state0
895
897
896
898
def reset_atten_state (self , theta , state , inputs ):
899
+ del theta
897
900
state .atten = inputs .reset_mask * state .atten
898
901
if isinstance (state .atten_state , py_utils .NestedMap ):
899
902
if 'inner' not in state .atten_state :
@@ -986,14 +989,21 @@ def CellFn(theta, state0, inputs):
986
989
py_utils .NestedMap (
987
990
act = act , padding = inputs .padding , reset_mask = inputs .reset_mask ))
988
991
992
+ query_segment_id = (
993
+ tf .cast (tf .squeeze (inputs .segment_id , 1 ), py_utils .FPropDtype (p ))
994
+ if p .packed_input
995
+ else None
996
+ )
997
+
989
998
state1 .atten , state1 .atten_probs , state1 .atten_state = (
990
999
self .atten .ComputeContextVectorWithSource (
991
1000
theta .atten ,
992
1001
theta .packed_src ,
993
1002
self .cell .GetOutput (state1 .rnn ),
994
1003
state0_mod .atten_state ,
995
- query_segment_id = tf .cast (
996
- tf .squeeze (inputs .segment_id , 1 ), py_utils .FPropDtype (p ))))
1004
+ query_segment_id = query_segment_id ,
1005
+ )
1006
+ )
997
1007
return state1 , py_utils .NestedMap ()
998
1008
999
1009
if p .packed_input :
0 commit comments