Skip to content

Commit 0274fa2

Browse files
lingvo-botcopybara-github
authored andcommitted
tf.matmul supports batch multiplications, drop unnecessary reshape.
The reshape prevents a number of operator folding when batch and/or source_batch dimensions are dynamic. PiperOrigin-RevId: 587114803
1 parent 36a1e31 commit 0274fa2

File tree

1 file changed

+24
-32
lines changed

1 file changed

+24
-32
lines changed

lingvo/core/attention.py

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,31 +1442,26 @@ def PackSource(self,
14421442
py_utils.GetShape(source_contexts))
14431443
time_steps, batch_size = py_utils.GetShape(source_padding, 2)
14441444
# source_projected shape [time * source_batch, hidden]
1445-
with tf.name_scope('init__0a'):
1446-
source_vec_depth = py_utils.GetShape(source_vecs)[2]
1447-
with tf.name_scope('init__0b'):
1448-
if p.enable_source_proj:
1449-
source_vecs = tf.reshape(source_vecs, [-1, source_vec_depth])
1450-
source_vecs, w_source_proj = self.ToAqtInputs(
1451-
'source_proj',
1452-
act=source_vecs,
1453-
weight=theta.source_proj,
1454-
w_feature_axis=-1)
1455-
w_source_proj = self.QWeight(w_source_proj)
1456-
source_projected = tf.matmul(source_vecs, w_source_proj)
1457-
source_projected = self.QAct('source_proj_matmul', source_projected)
1458-
source_projected = self.FromAqtMatmul('source_proj', source_projected)
1459-
if p.use_bias:
1460-
source_projected = fns.qadd(
1461-
source_projected,
1462-
self.QWeight(theta.source_proj_b),
1463-
qout_name='source_proj_add')
1464-
else:
1465-
source_projected = tf.reshape(source_vecs, [-1, source_vec_depth])
1466-
if p.activation_split_dims_mapping:
1467-
source_projected = gshard_utils.MeshSplit(
1468-
source_projected, p.device_mesh,
1469-
p.activation_split_dims_mapping[1:])
1445+
if p.enable_source_proj:
1446+
source_vecs, w_source_proj = self.ToAqtInputs(
1447+
'source_proj',
1448+
act=source_vecs,
1449+
weight=theta.source_proj,
1450+
w_feature_axis=-1)
1451+
w_source_proj = self.QWeight(w_source_proj)
1452+
source_projected = tf.matmul(source_vecs, w_source_proj)
1453+
source_projected = self.QAct('source_proj_matmul', source_projected)
1454+
source_projected = self.FromAqtMatmul('source_proj', source_projected)
1455+
if p.use_bias:
1456+
source_projected = fns.qadd(
1457+
source_projected,
1458+
self.QWeight(theta.source_proj_b),
1459+
qout_name='source_proj_add')
1460+
else:
1461+
source_projected = source_vecs
1462+
source_projected = gshard_utils.MeshSplit(
1463+
source_projected, p.device_mesh,
1464+
p.activation_split_dims_mapping)
14701465
with tf.name_scope('init__1'):
14711466
num_heads = p.num_attention_heads
14721467
# => [time, source_batch * num_heads, hidden / num_heads]
@@ -1482,8 +1477,6 @@ def PackSource(self,
14821477
source_contexts_projected = source_projected
14831478
else:
14841479
if p.enable_ctx_pre_proj:
1485-
source_contexts = tf.reshape(
1486-
source_contexts, [-1, py_utils.GetShape(source_contexts)[2]])
14871480
source_contexts, w_ctx_proj = self.ToAqtInputs(
14881481
'ctx_proj',
14891482
act=source_contexts,
@@ -1501,10 +1494,9 @@ def PackSource(self,
15011494
source_contexts_projected,
15021495
self.QWeight(theta.ctx_proj_b),
15031496
qout_name='ctx_pre_proj_add')
1504-
if p.activation_split_dims_mapping:
1505-
source_contexts_projected = gshard_utils.MeshSplit(
1506-
source_contexts_projected, p.device_mesh,
1507-
p.activation_split_dims_mapping[1:])
1497+
source_contexts_projected = gshard_utils.MeshSplit(
1498+
source_contexts_projected, p.device_mesh,
1499+
p.activation_split_dims_mapping)
15081500
else:
15091501
source_contexts_projected = source_contexts
15101502

@@ -3304,7 +3296,7 @@ def Params(cls):
33043296
p.Define('source_dim', 0, 'Default source dimension.')
33053297
p.Define(
33063298
'query_dim', 0, 'Number of query nodes. Child attention params '
3307-
'must have query_dim less or euqal than 0 or equal to this value.')
3299+
'must have query_dim less or equal than 0 or equal to this value.')
33083300
p.Define(
33093301
'primary_source_key', 'source_0', 'Key for the primary source '
33103302
'whose attention probabilities will be used as an output.')

0 commit comments

Comments
 (0)