Skip to content

Commit

Permalink
tf.matmul supports batch multiplications, drop unnecessary reshape.
Browse files Browse the repository at this point in the history
The reshape prevents a number of operator folding when batch and/or source_batch dimensions are dynamic.

PiperOrigin-RevId: 587114803
  • Loading branch information
lingvo-bot authored and copybara-github committed Dec 1, 2023
1 parent 36a1e31 commit 0274fa2
Showing 1 changed file with 24 additions and 32 deletions.
56 changes: 24 additions & 32 deletions lingvo/core/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1442,31 +1442,26 @@ def PackSource(self,
py_utils.GetShape(source_contexts))
time_steps, batch_size = py_utils.GetShape(source_padding, 2)
# source_projected shape [time * source_batch, hidden]
with tf.name_scope('init__0a'):
source_vec_depth = py_utils.GetShape(source_vecs)[2]
with tf.name_scope('init__0b'):
if p.enable_source_proj:
source_vecs = tf.reshape(source_vecs, [-1, source_vec_depth])
source_vecs, w_source_proj = self.ToAqtInputs(
'source_proj',
act=source_vecs,
weight=theta.source_proj,
w_feature_axis=-1)
w_source_proj = self.QWeight(w_source_proj)
source_projected = tf.matmul(source_vecs, w_source_proj)
source_projected = self.QAct('source_proj_matmul', source_projected)
source_projected = self.FromAqtMatmul('source_proj', source_projected)
if p.use_bias:
source_projected = fns.qadd(
source_projected,
self.QWeight(theta.source_proj_b),
qout_name='source_proj_add')
else:
source_projected = tf.reshape(source_vecs, [-1, source_vec_depth])
if p.activation_split_dims_mapping:
source_projected = gshard_utils.MeshSplit(
source_projected, p.device_mesh,
p.activation_split_dims_mapping[1:])
if p.enable_source_proj:
source_vecs, w_source_proj = self.ToAqtInputs(
'source_proj',
act=source_vecs,
weight=theta.source_proj,
w_feature_axis=-1)
w_source_proj = self.QWeight(w_source_proj)
source_projected = tf.matmul(source_vecs, w_source_proj)
source_projected = self.QAct('source_proj_matmul', source_projected)
source_projected = self.FromAqtMatmul('source_proj', source_projected)
if p.use_bias:
source_projected = fns.qadd(
source_projected,
self.QWeight(theta.source_proj_b),
qout_name='source_proj_add')
else:
source_projected = source_vecs
source_projected = gshard_utils.MeshSplit(
source_projected, p.device_mesh,
p.activation_split_dims_mapping)
with tf.name_scope('init__1'):
num_heads = p.num_attention_heads
# => [time, source_batch * num_heads, hidden / num_heads]
Expand All @@ -1482,8 +1477,6 @@ def PackSource(self,
source_contexts_projected = source_projected
else:
if p.enable_ctx_pre_proj:
source_contexts = tf.reshape(
source_contexts, [-1, py_utils.GetShape(source_contexts)[2]])
source_contexts, w_ctx_proj = self.ToAqtInputs(
'ctx_proj',
act=source_contexts,
Expand All @@ -1501,10 +1494,9 @@ def PackSource(self,
source_contexts_projected,
self.QWeight(theta.ctx_proj_b),
qout_name='ctx_pre_proj_add')
if p.activation_split_dims_mapping:
source_contexts_projected = gshard_utils.MeshSplit(
source_contexts_projected, p.device_mesh,
p.activation_split_dims_mapping[1:])
source_contexts_projected = gshard_utils.MeshSplit(
source_contexts_projected, p.device_mesh,
p.activation_split_dims_mapping)
else:
source_contexts_projected = source_contexts

Expand Down Expand Up @@ -3304,7 +3296,7 @@ def Params(cls):
p.Define('source_dim', 0, 'Default source dimension.')
p.Define(
'query_dim', 0, 'Number of query nodes. Child attention params '
'must have query_dim less or euqal than 0 or equal to this value.')
'must have query_dim less or equal than 0 or equal to this value.')
p.Define(
'primary_source_key', 'source_0', 'Key for the primary source '
'whose attention probabilities will be used as an output.')
Expand Down

0 comments on commit 0274fa2

Please sign in to comment.