Skip to content

Commit

Permalink
Use tf.strided_slice (through Python indexing syntax) instead of `t…
Browse files Browse the repository at this point in the history
…f.slice` to avoid a small subgraph calculating shapes, but also use `tf.reshape` to avoid dynamically-shaped state tensors in the output.

PiperOrigin-RevId: 678804868
  • Loading branch information
lingvo-bot authored and copybara-github committed Sep 25, 2024
1 parent f671a1a commit e18f28e
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions lingvo/core/conv_layers_with_time_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,6 @@ def StreamStep(self, theta, inputs, paddings, state0):

with tf.name_scope(p.name):
inputs = py_utils.HasShape(inputs, [-1, -1, 1, p.filter_shape[2]])
q = py_utils.GetShape(inputs)[1]

if paddings is not None:
paddings = py_utils.HasShape(paddings, py_utils.GetShape(inputs)[:2])
Expand All @@ -600,8 +599,9 @@ def StreamStep(self, theta, inputs, paddings, state0):
padding='VALID')
if p.bias:
outputs = tf.nn.bias_add(outputs, theta.b)
new_context = tf.slice(concat_inputs, [0, q, 0, 0],
tf.shape(state0.context))
state0_context_shape = py_utils.GetShape(state0.context)
new_context = concat_inputs[:, -state0_context_shape[1] :]
new_context = tf.reshape(new_context, state0_context_shape)
return outputs, paddings, py_utils.NestedMap(context=new_context)


Expand Down Expand Up @@ -801,7 +801,6 @@ def StreamStep(

with tf.name_scope(p.name):
inputs = py_utils.HasShape(inputs, [-1, -1, 1, p.filter_shape[2]])
q = py_utils.GetShape(inputs)[1]

if paddings is not None:
paddings = py_utils.HasShape(paddings, py_utils.GetShape(inputs)[:2])
Expand All @@ -825,7 +824,8 @@ def StreamStep(
if p.bias:
outputs = tf.nn.bias_add(outputs, theta.b)
state0_context_shape = py_utils.GetShape(state0_context)
new_context = tf.slice(concat_inputs, [0, q, 0, 0], state0_context_shape)
new_context = concat_inputs[:, -state0_context_shape[1] :]
new_context = tf.reshape(new_context, state0_context_shape)
if p.time_alignment is not None:
time_size = self._get_time_size()
aligned_time_size = self._get_aligned_time_size()
Expand Down

0 comments on commit e18f28e

Please sign in to comment.