Skip to content

Commit

Permalink
Rename variable names to reduce noise in followup CLs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678824499
  • Loading branch information
lingvo-bot authored and copybara-github committed Sep 25, 2024
1 parent e18f28e commit 6066e1f
Showing 1 changed file with 69 additions and 64 deletions.
133 changes: 69 additions & 64 deletions lingvo/core/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@ def _ApplyAttentionDropout(params, x):

if params.atten_dropout_deterministic:
seeds = py_utils.GenerateStepSeedPair(params)
return py_utils.DeterministicDropout(x, 1.0 - params.atten_dropout_prob,
seeds)
return py_utils.DeterministicDropout(
x, 1.0 - params.atten_dropout_prob, seeds
)
else:
return tf.nn.dropout(
x, rate=params.atten_dropout_prob, seed=params.random_seed)
x, rate=params.atten_dropout_prob, seed=params.random_seed
)


def SafeCumprod(x, *args, **kwargs):
Expand All @@ -86,7 +88,9 @@ def SafeCumprod(x, *args, **kwargs):
tiny = np.finfo(x.dtype.as_numpy_dtype).tiny
return tf.exp(
py_utils.CumSum(
tf.math.log(tf.clip_by_value(x, tiny, 1)), *args, **kwargs))
tf.math.log(tf.clip_by_value(x, tiny, 1)), *args, **kwargs
)
)


# pyformat: disable
Expand Down Expand Up @@ -1420,15 +1424,15 @@ def PackSource(self,
Args:
theta: A `.NestedMap` object containing weights' values of this layer and
its children layers.
source_vecs: A tensor of shape [time, source_batch, source_dim].
source_contexts: A tensor of shape [time, source_batch, context_dim].
source_padding: A tensor of shape [time, source_batch].
source_segment_id: A tensor of shape [time, source_batch].
source_vecs: A tensor of shape [time_steps, batch_size, source_dim].
source_contexts: A tensor of shape [time_steps, batch_size, context_dim].
source_padding: A tensor of shape [time_steps, batch_size].
source_segment_id: A tensor of shape [time_steps, batch_size].
Returns:
A NestedMap representing packed src. It will have the same structure
as the one returned by the inner atten, except that source_batch will be
source_batch * num_heads.
as the one returned by the inner atten, except that batch_size will be
batch_size * num_heads.
"""

p = self.params
Expand All @@ -1439,52 +1443,52 @@ def PackSource(self,
assert p.query_dim == p.hidden_dim
# Check input tensor shapes
source_vecs = py_utils.HasRank(source_vecs, 3)
[time_steps, source_batch] = py_utils.GetShape(source_vecs, 2)
[time_steps, batch_size] = py_utils.GetShape(source_vecs, 2)
if p.use_source_vec_as_attention_value:
assert source_contexts is not None
source_contexts = py_utils.HasShape(source_contexts,
[time_steps, source_batch, -1])
[time_steps, batch_size, -1])
source_padding = py_utils.HasShape(source_padding,
[time_steps, source_batch])
[time_steps, batch_size])
if source_segment_id is not None:
source_segment_id = py_utils.HasShape(source_segment_id,
[time_steps, source_batch])
[time_steps, batch_size])

with tf.name_scope('init__0'):
# source_projected shape [time * source_batch, hidden]
# source_vecs shape after (optional) projection is
# [time_steps, batch_size, hidden]
if p.enable_source_proj:
source_vecs, w_source_proj = self.ToAqtInputs(
'source_proj',
act=source_vecs,
weight=theta.source_proj,
w_feature_axis=-1)
w_source_proj = self.QWeight(w_source_proj)
source_projected = self.QMatmul(source_vecs, w_source_proj)
source_projected = self.QAct('source_proj_matmul', source_projected)
source_projected = self.FromAqtMatmul('source_proj', source_projected)
source_vecs = self.QMatmul(source_vecs, w_source_proj)
source_vecs = self.QAct('source_proj_matmul', source_vecs)
source_vecs = self.FromAqtMatmul('source_proj', source_vecs)
if p.use_bias:
source_projected = fns.qadd(
source_projected,
source_vecs = fns.qadd(
source_vecs,
self.QWeight(theta.source_proj_b),
qout_name='source_proj_add')
else:
source_projected = source_vecs
source_projected = gshard_utils.MeshSplit(
source_projected, p.device_mesh,
source_vecs = gshard_utils.MeshSplit(
source_vecs,
p.device_mesh,
p.activation_split_dims_mapping)
with tf.name_scope('init__1'):
num_heads = p.num_attention_heads
# => [time, source_batch * num_heads, hidden / num_heads]
# => [time_steps, batch_size * num_heads, hidden / num_heads]
[time_steps_vecs] = py_utils.GetShape(source_vecs, 1)
source_projected = tf.reshape(source_projected, [
time_steps_vecs, -1, symbolic.ToStatic(p.hidden_dim // num_heads)
])
source_projected = gshard_utils.MeshSplit(source_projected, p.device_mesh,
p.activation_split_dims_mapping)
source_projected = self.ProcessProjectionVec(theta, source_projected,
'source')
source_vecs = tf.reshape(
source_vecs,
[time_steps_vecs, -1, symbolic.ToStatic(p.hidden_dim // num_heads)])
source_vecs = gshard_utils.MeshSplit(source_vecs,
p.device_mesh,
p.activation_split_dims_mapping)
source_vecs = self.ProcessProjectionVec(theta, source_vecs, 'source')
if p.use_source_vec_as_attention_value:
source_contexts_projected = source_projected
source_contexts = source_vecs
else:
if p.enable_ctx_pre_proj:
source_contexts, w_ctx_proj = self.ToAqtInputs(
Expand All @@ -1493,50 +1497,51 @@ def PackSource(self,
weight=theta.ctx_proj,
w_feature_axis=-1)
w_ctx_proj = self.QWeight(w_ctx_proj)
source_contexts_projected = self.QMatmul(source_contexts, w_ctx_proj)
source_contexts_projected = self.QAct('ctx_pre_proj_matmul',
source_contexts_projected)
source_contexts_projected = self.FromAqtMatmul(
'ctx_proj', source_contexts_projected)
source_contexts = self.QMatmul(source_contexts, w_ctx_proj)
source_contexts = self.QAct('ctx_pre_proj_matmul', source_contexts)
source_contexts = self.FromAqtMatmul('ctx_proj', source_contexts)
if p.use_bias:
source_contexts_projected = fns.qadd(
source_contexts_projected,
source_contexts = fns.qadd(
source_contexts,
self.QWeight(theta.ctx_proj_b),
qout_name='ctx_pre_proj_add')
source_contexts_projected = gshard_utils.MeshSplit(
source_contexts_projected, p.device_mesh,
source_contexts = gshard_utils.MeshSplit(
source_contexts,
p.device_mesh,
p.activation_split_dims_mapping)
else:
source_contexts_projected = source_contexts

time_steps_contexts = py_utils.GetShape(source_contexts_projected)[0]
source_context_depth = py_utils.GetShape(source_contexts_projected)[-1]
source_contexts_projected = tf.reshape(
source_contexts_projected,
time_steps_contexts = py_utils.GetShape(source_contexts)[0]
source_context_depth = py_utils.GetShape(source_contexts)[-1]
source_contexts = tf.reshape(
source_contexts,
[time_steps_contexts, -1, source_context_depth // num_heads])
source_contexts_projected = gshard_utils.MeshSplit(
source_contexts_projected, p.device_mesh,
source_contexts = gshard_utils.MeshSplit(
source_contexts,
p.device_mesh,
p.activation_split_dims_mapping)
source_contexts_projected = self.ProcessProjectionVec(
theta, source_contexts_projected, 'ctx')
source_contexts = self.ProcessProjectionVec(theta,
source_contexts,
'ctx')

with tf.name_scope('init__2'):
[time_steps_paddings] = py_utils.GetShape(source_padding, 1)
source_padding_replicated = tf.reshape(
tf.tile(tf.expand_dims(source_padding, 2), [1, 1, num_heads]),
[time_steps_paddings, -1])
source_padding = tf.expand_dims(source_padding, 2)
source_padding = tf.tile(source_padding, [1, 1, num_heads])
source_padding = tf.reshape(source_padding, [time_steps_paddings, -1])
if source_segment_id is None:
source_segment_id_repl = tf.zeros_like(source_padding_replicated)
source_segment_id = tf.zeros_like(source_padding)
else:
[time_steps_segment_id] = py_utils.GetShape(source_segment_id, 1)
source_segment_id_repl = tf.reshape(
tf.tile(tf.expand_dims(source_segment_id, 2), [1, 1, num_heads]),
[time_steps_segment_id, -1])

return self.atten.PackSource(theta.atten, source_projected,
source_contexts_projected,
source_padding_replicated,
source_segment_id_repl)
source_segment_id = tf.expand_dims(source_segment_id, 2)
source_segment_id = tf.tile(source_segment_id, [1, 1, num_heads])
source_segment_id = tf.reshape(source_segment_id,
[time_steps_segment_id, -1])

return self.atten.PackSource(theta.atten,
source_vecs,
source_contexts,
source_padding,
source_segment_id)

@py_utils.NameScopeDecorator('MultiHeadedAttention/ExtendSourcePacked')
def ExtendSourcePacked(self,
Expand Down

0 comments on commit 6066e1f

Please sign in to comment.