@@ -57,11 +57,13 @@ def _ApplyAttentionDropout(params, x):
5757
5858 if params .atten_dropout_deterministic :
5959 seeds = py_utils .GenerateStepSeedPair (params )
60- return py_utils .DeterministicDropout (x , 1.0 - params .atten_dropout_prob ,
61- seeds )
60+ return py_utils .DeterministicDropout (
61+ x , 1.0 - params .atten_dropout_prob , seeds
62+ )
6263 else :
6364 return tf .nn .dropout (
64- x , rate = params .atten_dropout_prob , seed = params .random_seed )
65+ x , rate = params .atten_dropout_prob , seed = params .random_seed
66+ )
6567
6668
6769def SafeCumprod (x , * args , ** kwargs ):
@@ -86,7 +88,9 @@ def SafeCumprod(x, *args, **kwargs):
8688 tiny = np .finfo (x .dtype .as_numpy_dtype ).tiny
8789 return tf .exp (
8890 py_utils .CumSum (
89- tf .math .log (tf .clip_by_value (x , tiny , 1 )), * args , ** kwargs ))
91+ tf .math .log (tf .clip_by_value (x , tiny , 1 )), * args , ** kwargs
92+ )
93+ )
9094
9195
9296# pyformat: disable
@@ -1420,15 +1424,15 @@ def PackSource(self,
14201424 Args:
14211425 theta: A `.NestedMap` object containing weights' values of this layer and
14221426 its children layers.
1423- source_vecs: A tensor of shape [time, source_batch , source_dim].
1424- source_contexts: A tensor of shape [time, source_batch , context_dim].
1425- source_padding: A tensor of shape [time, source_batch ].
1426- source_segment_id: A tensor of shape [time, source_batch ].
1427+ source_vecs: A tensor of shape [time_steps, batch_size , source_dim].
1428+ source_contexts: A tensor of shape [time_steps, batch_size , context_dim].
1429+ source_padding: A tensor of shape [time_steps, batch_size ].
1430+ source_segment_id: A tensor of shape [time_steps, batch_size ].
14271431
14281432 Returns:
14291433 A NestedMap representing packed src. It will have the same structure
1430- as the one returned by the inner atten, except that source_batch will be
1431- source_batch * num_heads.
1434+ as the one returned by the inner atten, except that batch_size will be
1435+ batch_size * num_heads.
14321436 """
14331437
14341438 p = self .params
@@ -1439,52 +1443,52 @@ def PackSource(self,
14391443 assert p .query_dim == p .hidden_dim
14401444 # Check input tensor shapes
14411445 source_vecs = py_utils .HasRank (source_vecs , 3 )
1442- [time_steps , source_batch ] = py_utils .GetShape (source_vecs , 2 )
1446+ [time_steps , batch_size ] = py_utils .GetShape (source_vecs , 2 )
14431447 if p .use_source_vec_as_attention_value :
14441448 assert source_contexts is not None
14451449 source_contexts = py_utils .HasShape (source_contexts ,
1446- [time_steps , source_batch , - 1 ])
1450+ [time_steps , batch_size , - 1 ])
14471451 source_padding = py_utils .HasShape (source_padding ,
1448- [time_steps , source_batch ])
1452+ [time_steps , batch_size ])
14491453 if source_segment_id is not None :
14501454 source_segment_id = py_utils .HasShape (source_segment_id ,
1451- [time_steps , source_batch ])
1455+ [time_steps , batch_size ])
14521456
14531457 with tf .name_scope ('init__0' ):
1454- # source_projected shape [time * source_batch, hidden]
1458+ # source_vecs shape after (optional) projection is
1459+ # [time_steps, batch_size, hidden]
14551460 if p .enable_source_proj :
14561461 source_vecs , w_source_proj = self .ToAqtInputs (
14571462 'source_proj' ,
14581463 act = source_vecs ,
14591464 weight = theta .source_proj ,
14601465 w_feature_axis = - 1 )
14611466 w_source_proj = self .QWeight (w_source_proj )
1462- source_projected = self .QMatmul (source_vecs , w_source_proj )
1463- source_projected = self .QAct ('source_proj_matmul' , source_projected )
1464- source_projected = self .FromAqtMatmul ('source_proj' , source_projected )
1467+ source_vecs = self .QMatmul (source_vecs , w_source_proj )
1468+ source_vecs = self .QAct ('source_proj_matmul' , source_vecs )
1469+ source_vecs = self .FromAqtMatmul ('source_proj' , source_vecs )
14651470 if p .use_bias :
1466- source_projected = fns .qadd (
1467- source_projected ,
1471+ source_vecs = fns .qadd (
1472+ source_vecs ,
14681473 self .QWeight (theta .source_proj_b ),
14691474 qout_name = 'source_proj_add' )
1470- else :
1471- source_projected = source_vecs
1472- source_projected = gshard_utils .MeshSplit (
1473- source_projected , p .device_mesh ,
1475+ source_vecs = gshard_utils .MeshSplit (
1476+ source_vecs ,
1477+ p .device_mesh ,
14741478 p .activation_split_dims_mapping )
14751479 with tf .name_scope ('init__1' ):
14761480 num_heads = p .num_attention_heads
1477- # => [time, source_batch * num_heads, hidden / num_heads]
1481+ # => [time_steps, batch_size * num_heads, hidden / num_heads]
14781482 [time_steps_vecs ] = py_utils .GetShape (source_vecs , 1 )
1479- source_projected = tf .reshape (source_projected , [
1480- time_steps_vecs , - 1 , symbolic . ToStatic ( p . hidden_dim // num_heads )
1481- ])
1482- source_projected = gshard_utils .MeshSplit (source_projected , p . device_mesh ,
1483- p . activation_split_dims_mapping )
1484- source_projected = self . ProcessProjectionVec ( theta , source_projected ,
1485- 'source' )
1483+ source_vecs = tf .reshape (
1484+ source_vecs ,
1485+ [ time_steps_vecs , - 1 , symbolic . ToStatic ( p . hidden_dim // num_heads ) ])
1486+ source_vecs = gshard_utils .MeshSplit (source_vecs ,
1487+ p . device_mesh ,
1488+ p . activation_split_dims_mapping )
1489+ source_vecs = self . ProcessProjectionVec ( theta , source_vecs , 'source' )
14861490 if p .use_source_vec_as_attention_value :
1487- source_contexts_projected = source_projected
1491+ source_contexts = source_vecs
14881492 else :
14891493 if p .enable_ctx_pre_proj :
14901494 source_contexts , w_ctx_proj = self .ToAqtInputs (
@@ -1493,50 +1497,51 @@ def PackSource(self,
14931497 weight = theta .ctx_proj ,
14941498 w_feature_axis = - 1 )
14951499 w_ctx_proj = self .QWeight (w_ctx_proj )
1496- source_contexts_projected = self .QMatmul (source_contexts , w_ctx_proj )
1497- source_contexts_projected = self .QAct ('ctx_pre_proj_matmul' ,
1498- source_contexts_projected )
1499- source_contexts_projected = self .FromAqtMatmul (
1500- 'ctx_proj' , source_contexts_projected )
1500+ source_contexts = self .QMatmul (source_contexts , w_ctx_proj )
1501+ source_contexts = self .QAct ('ctx_pre_proj_matmul' , source_contexts )
1502+ source_contexts = self .FromAqtMatmul ('ctx_proj' , source_contexts )
15011503 if p .use_bias :
1502- source_contexts_projected = fns .qadd (
1503- source_contexts_projected ,
1504+ source_contexts = fns .qadd (
1505+ source_contexts ,
15041506 self .QWeight (theta .ctx_proj_b ),
15051507 qout_name = 'ctx_pre_proj_add' )
1506- source_contexts_projected = gshard_utils .MeshSplit (
1507- source_contexts_projected , p .device_mesh ,
1508+ source_contexts = gshard_utils .MeshSplit (
1509+ source_contexts ,
1510+ p .device_mesh ,
15081511 p .activation_split_dims_mapping )
1509- else :
1510- source_contexts_projected = source_contexts
15111512
1512- time_steps_contexts = py_utils .GetShape (source_contexts_projected )[0 ]
1513- source_context_depth = py_utils .GetShape (source_contexts_projected )[- 1 ]
1514- source_contexts_projected = tf .reshape (
1515- source_contexts_projected ,
1513+ time_steps_contexts = py_utils .GetShape (source_contexts )[0 ]
1514+ source_context_depth = py_utils .GetShape (source_contexts )[- 1 ]
1515+ source_contexts = tf .reshape (
1516+ source_contexts ,
15161517 [time_steps_contexts , - 1 , source_context_depth // num_heads ])
1517- source_contexts_projected = gshard_utils .MeshSplit (
1518- source_contexts_projected , p .device_mesh ,
1518+ source_contexts = gshard_utils .MeshSplit (
1519+ source_contexts ,
1520+ p .device_mesh ,
15191521 p .activation_split_dims_mapping )
1520- source_contexts_projected = self .ProcessProjectionVec (
1521- theta , source_contexts_projected , 'ctx' )
1522+ source_contexts = self .ProcessProjectionVec (theta ,
1523+ source_contexts ,
1524+ 'ctx' )
15221525
15231526 with tf .name_scope ('init__2' ):
15241527 [time_steps_paddings ] = py_utils .GetShape (source_padding , 1 )
1525- source_padding_replicated = tf .reshape (
1526- tf .tile (tf . expand_dims ( source_padding , 2 ), [1 , 1 , num_heads ]),
1527- [time_steps_paddings , - 1 ])
1528+ source_padding = tf .expand_dims ( source_padding , 2 )
1529+ source_padding = tf .tile (source_padding , [1 , 1 , num_heads ])
1530+ source_padding = tf . reshape ( source_padding , [time_steps_paddings , - 1 ])
15281531 if source_segment_id is None :
1529- source_segment_id_repl = tf .zeros_like (source_padding_replicated )
1532+ source_segment_id = tf .zeros_like (source_padding )
15301533 else :
15311534 [time_steps_segment_id ] = py_utils .GetShape (source_segment_id , 1 )
1532- source_segment_id_repl = tf .reshape (
1533- tf .tile (tf .expand_dims (source_segment_id , 2 ), [1 , 1 , num_heads ]),
1534- [time_steps_segment_id , - 1 ])
1535-
1536- return self .atten .PackSource (theta .atten , source_projected ,
1537- source_contexts_projected ,
1538- source_padding_replicated ,
1539- source_segment_id_repl )
1535+ source_segment_id = tf .expand_dims (source_segment_id , 2 )
1536+ source_segment_id = tf .tile (source_segment_id , [1 , 1 , num_heads ])
1537+ source_segment_id = tf .reshape (source_segment_id ,
1538+ [time_steps_segment_id , - 1 ])
1539+
1540+ return self .atten .PackSource (theta .atten ,
1541+ source_vecs ,
1542+ source_contexts ,
1543+ source_padding ,
1544+ source_segment_id )
15401545
15411546 @py_utils .NameScopeDecorator ('MultiHeadedAttention/ExtendSourcePacked' )
15421547 def ExtendSourcePacked (self ,
0 commit comments