Skip to content

Commit 6066e1f

Browse files
lingvo-botcopybara-github
authored andcommitted
Rename variable names to reduce noise in followup CLs
PiperOrigin-RevId: 678824499
1 parent e18f28e commit 6066e1f

File tree

1 file changed

+69
-64
lines changed

1 file changed

+69
-64
lines changed

lingvo/core/attention.py

Lines changed: 69 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,13 @@ def _ApplyAttentionDropout(params, x):
5757

5858
if params.atten_dropout_deterministic:
5959
seeds = py_utils.GenerateStepSeedPair(params)
60-
return py_utils.DeterministicDropout(x, 1.0 - params.atten_dropout_prob,
61-
seeds)
60+
return py_utils.DeterministicDropout(
61+
x, 1.0 - params.atten_dropout_prob, seeds
62+
)
6263
else:
6364
return tf.nn.dropout(
64-
x, rate=params.atten_dropout_prob, seed=params.random_seed)
65+
x, rate=params.atten_dropout_prob, seed=params.random_seed
66+
)
6567

6668

6769
def SafeCumprod(x, *args, **kwargs):
@@ -86,7 +88,9 @@ def SafeCumprod(x, *args, **kwargs):
8688
tiny = np.finfo(x.dtype.as_numpy_dtype).tiny
8789
return tf.exp(
8890
py_utils.CumSum(
89-
tf.math.log(tf.clip_by_value(x, tiny, 1)), *args, **kwargs))
91+
tf.math.log(tf.clip_by_value(x, tiny, 1)), *args, **kwargs
92+
)
93+
)
9094

9195

9296
# pyformat: disable
@@ -1420,15 +1424,15 @@ def PackSource(self,
14201424
Args:
14211425
theta: A `.NestedMap` object containing weights' values of this layer and
14221426
its children layers.
1423-
source_vecs: A tensor of shape [time, source_batch, source_dim].
1424-
source_contexts: A tensor of shape [time, source_batch, context_dim].
1425-
source_padding: A tensor of shape [time, source_batch].
1426-
source_segment_id: A tensor of shape [time, source_batch].
1427+
source_vecs: A tensor of shape [time_steps, batch_size, source_dim].
1428+
source_contexts: A tensor of shape [time_steps, batch_size, context_dim].
1429+
source_padding: A tensor of shape [time_steps, batch_size].
1430+
source_segment_id: A tensor of shape [time_steps, batch_size].
14271431
14281432
Returns:
14291433
A NestedMap representing packed src. It will have the same structure
1430-
as the one returned by the inner atten, except that source_batch will be
1431-
source_batch * num_heads.
1434+
as the one returned by the inner atten, except that batch_size will be
1435+
batch_size * num_heads.
14321436
"""
14331437

14341438
p = self.params
@@ -1439,52 +1443,52 @@ def PackSource(self,
14391443
assert p.query_dim == p.hidden_dim
14401444
# Check input tensor shapes
14411445
source_vecs = py_utils.HasRank(source_vecs, 3)
1442-
[time_steps, source_batch] = py_utils.GetShape(source_vecs, 2)
1446+
[time_steps, batch_size] = py_utils.GetShape(source_vecs, 2)
14431447
if p.use_source_vec_as_attention_value:
14441448
assert source_contexts is not None
14451449
source_contexts = py_utils.HasShape(source_contexts,
1446-
[time_steps, source_batch, -1])
1450+
[time_steps, batch_size, -1])
14471451
source_padding = py_utils.HasShape(source_padding,
1448-
[time_steps, source_batch])
1452+
[time_steps, batch_size])
14491453
if source_segment_id is not None:
14501454
source_segment_id = py_utils.HasShape(source_segment_id,
1451-
[time_steps, source_batch])
1455+
[time_steps, batch_size])
14521456

14531457
with tf.name_scope('init__0'):
1454-
# source_projected shape [time * source_batch, hidden]
1458+
# source_vecs shape after (optional) projection is
1459+
# [time_steps, batch_size, hidden]
14551460
if p.enable_source_proj:
14561461
source_vecs, w_source_proj = self.ToAqtInputs(
14571462
'source_proj',
14581463
act=source_vecs,
14591464
weight=theta.source_proj,
14601465
w_feature_axis=-1)
14611466
w_source_proj = self.QWeight(w_source_proj)
1462-
source_projected = self.QMatmul(source_vecs, w_source_proj)
1463-
source_projected = self.QAct('source_proj_matmul', source_projected)
1464-
source_projected = self.FromAqtMatmul('source_proj', source_projected)
1467+
source_vecs = self.QMatmul(source_vecs, w_source_proj)
1468+
source_vecs = self.QAct('source_proj_matmul', source_vecs)
1469+
source_vecs = self.FromAqtMatmul('source_proj', source_vecs)
14651470
if p.use_bias:
1466-
source_projected = fns.qadd(
1467-
source_projected,
1471+
source_vecs = fns.qadd(
1472+
source_vecs,
14681473
self.QWeight(theta.source_proj_b),
14691474
qout_name='source_proj_add')
1470-
else:
1471-
source_projected = source_vecs
1472-
source_projected = gshard_utils.MeshSplit(
1473-
source_projected, p.device_mesh,
1475+
source_vecs = gshard_utils.MeshSplit(
1476+
source_vecs,
1477+
p.device_mesh,
14741478
p.activation_split_dims_mapping)
14751479
with tf.name_scope('init__1'):
14761480
num_heads = p.num_attention_heads
1477-
# => [time, source_batch * num_heads, hidden / num_heads]
1481+
# => [time_steps, batch_size * num_heads, hidden / num_heads]
14781482
[time_steps_vecs] = py_utils.GetShape(source_vecs, 1)
1479-
source_projected = tf.reshape(source_projected, [
1480-
time_steps_vecs, -1, symbolic.ToStatic(p.hidden_dim // num_heads)
1481-
])
1482-
source_projected = gshard_utils.MeshSplit(source_projected, p.device_mesh,
1483-
p.activation_split_dims_mapping)
1484-
source_projected = self.ProcessProjectionVec(theta, source_projected,
1485-
'source')
1483+
source_vecs = tf.reshape(
1484+
source_vecs,
1485+
[time_steps_vecs, -1, symbolic.ToStatic(p.hidden_dim // num_heads)])
1486+
source_vecs = gshard_utils.MeshSplit(source_vecs,
1487+
p.device_mesh,
1488+
p.activation_split_dims_mapping)
1489+
source_vecs = self.ProcessProjectionVec(theta, source_vecs, 'source')
14861490
if p.use_source_vec_as_attention_value:
1487-
source_contexts_projected = source_projected
1491+
source_contexts = source_vecs
14881492
else:
14891493
if p.enable_ctx_pre_proj:
14901494
source_contexts, w_ctx_proj = self.ToAqtInputs(
@@ -1493,50 +1497,51 @@ def PackSource(self,
14931497
weight=theta.ctx_proj,
14941498
w_feature_axis=-1)
14951499
w_ctx_proj = self.QWeight(w_ctx_proj)
1496-
source_contexts_projected = self.QMatmul(source_contexts, w_ctx_proj)
1497-
source_contexts_projected = self.QAct('ctx_pre_proj_matmul',
1498-
source_contexts_projected)
1499-
source_contexts_projected = self.FromAqtMatmul(
1500-
'ctx_proj', source_contexts_projected)
1500+
source_contexts = self.QMatmul(source_contexts, w_ctx_proj)
1501+
source_contexts = self.QAct('ctx_pre_proj_matmul', source_contexts)
1502+
source_contexts = self.FromAqtMatmul('ctx_proj', source_contexts)
15011503
if p.use_bias:
1502-
source_contexts_projected = fns.qadd(
1503-
source_contexts_projected,
1504+
source_contexts = fns.qadd(
1505+
source_contexts,
15041506
self.QWeight(theta.ctx_proj_b),
15051507
qout_name='ctx_pre_proj_add')
1506-
source_contexts_projected = gshard_utils.MeshSplit(
1507-
source_contexts_projected, p.device_mesh,
1508+
source_contexts = gshard_utils.MeshSplit(
1509+
source_contexts,
1510+
p.device_mesh,
15081511
p.activation_split_dims_mapping)
1509-
else:
1510-
source_contexts_projected = source_contexts
15111512

1512-
time_steps_contexts = py_utils.GetShape(source_contexts_projected)[0]
1513-
source_context_depth = py_utils.GetShape(source_contexts_projected)[-1]
1514-
source_contexts_projected = tf.reshape(
1515-
source_contexts_projected,
1513+
time_steps_contexts = py_utils.GetShape(source_contexts)[0]
1514+
source_context_depth = py_utils.GetShape(source_contexts)[-1]
1515+
source_contexts = tf.reshape(
1516+
source_contexts,
15161517
[time_steps_contexts, -1, source_context_depth // num_heads])
1517-
source_contexts_projected = gshard_utils.MeshSplit(
1518-
source_contexts_projected, p.device_mesh,
1518+
source_contexts = gshard_utils.MeshSplit(
1519+
source_contexts,
1520+
p.device_mesh,
15191521
p.activation_split_dims_mapping)
1520-
source_contexts_projected = self.ProcessProjectionVec(
1521-
theta, source_contexts_projected, 'ctx')
1522+
source_contexts = self.ProcessProjectionVec(theta,
1523+
source_contexts,
1524+
'ctx')
15221525

15231526
with tf.name_scope('init__2'):
15241527
[time_steps_paddings] = py_utils.GetShape(source_padding, 1)
1525-
source_padding_replicated = tf.reshape(
1526-
tf.tile(tf.expand_dims(source_padding, 2), [1, 1, num_heads]),
1527-
[time_steps_paddings, -1])
1528+
source_padding = tf.expand_dims(source_padding, 2)
1529+
source_padding = tf.tile(source_padding, [1, 1, num_heads])
1530+
source_padding = tf.reshape(source_padding, [time_steps_paddings, -1])
15281531
if source_segment_id is None:
1529-
source_segment_id_repl = tf.zeros_like(source_padding_replicated)
1532+
source_segment_id = tf.zeros_like(source_padding)
15301533
else:
15311534
[time_steps_segment_id] = py_utils.GetShape(source_segment_id, 1)
1532-
source_segment_id_repl = tf.reshape(
1533-
tf.tile(tf.expand_dims(source_segment_id, 2), [1, 1, num_heads]),
1534-
[time_steps_segment_id, -1])
1535-
1536-
return self.atten.PackSource(theta.atten, source_projected,
1537-
source_contexts_projected,
1538-
source_padding_replicated,
1539-
source_segment_id_repl)
1535+
source_segment_id = tf.expand_dims(source_segment_id, 2)
1536+
source_segment_id = tf.tile(source_segment_id, [1, 1, num_heads])
1537+
source_segment_id = tf.reshape(source_segment_id,
1538+
[time_steps_segment_id, -1])
1539+
1540+
return self.atten.PackSource(theta.atten,
1541+
source_vecs,
1542+
source_contexts,
1543+
source_padding,
1544+
source_segment_id)
15401545

15411546
@py_utils.NameScopeDecorator('MultiHeadedAttention/ExtendSourcePacked')
15421547
def ExtendSourcePacked(self,

0 commit comments

Comments
 (0)