Skip to content

Commit e18f28e

Browse files
lingvo-botcopybara-github
authored andcommitted
Use tf.strided_slice (through Python indexing syntax) instead of tf.slice to avoid a small subgraph calculating shapes, but also use tf.reshape to avoid dynamically-shaped state tensors in the output.
PiperOrigin-RevId: 678804868
1 parent f671a1a commit e18f28e

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

lingvo/core/conv_layers_with_time_padding.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,6 @@ def StreamStep(self, theta, inputs, paddings, state0):
584584

585585
with tf.name_scope(p.name):
586586
inputs = py_utils.HasShape(inputs, [-1, -1, 1, p.filter_shape[2]])
587-
q = py_utils.GetShape(inputs)[1]
588587

589588
if paddings is not None:
590589
paddings = py_utils.HasShape(paddings, py_utils.GetShape(inputs)[:2])
@@ -600,8 +599,9 @@ def StreamStep(self, theta, inputs, paddings, state0):
600599
padding='VALID')
601600
if p.bias:
602601
outputs = tf.nn.bias_add(outputs, theta.b)
603-
new_context = tf.slice(concat_inputs, [0, q, 0, 0],
604-
tf.shape(state0.context))
602+
state0_context_shape = py_utils.GetShape(state0.context)
603+
new_context = concat_inputs[:, -state0_context_shape[1] :]
604+
new_context = tf.reshape(new_context, state0_context_shape)
605605
return outputs, paddings, py_utils.NestedMap(context=new_context)
606606

607607

@@ -801,7 +801,6 @@ def StreamStep(
801801

802802
with tf.name_scope(p.name):
803803
inputs = py_utils.HasShape(inputs, [-1, -1, 1, p.filter_shape[2]])
804-
q = py_utils.GetShape(inputs)[1]
805804

806805
if paddings is not None:
807806
paddings = py_utils.HasShape(paddings, py_utils.GetShape(inputs)[:2])
@@ -825,7 +824,8 @@ def StreamStep(
825824
if p.bias:
826825
outputs = tf.nn.bias_add(outputs, theta.b)
827826
state0_context_shape = py_utils.GetShape(state0_context)
828-
new_context = tf.slice(concat_inputs, [0, q, 0, 0], state0_context_shape)
827+
new_context = concat_inputs[:, -state0_context_shape[1] :]
828+
new_context = tf.reshape(new_context, state0_context_shape)
829829
if p.time_alignment is not None:
830830
time_size = self._get_time_size()
831831
aligned_time_size = self._get_aligned_time_size()

0 commit comments

Comments
 (0)