@@ -57,11 +57,13 @@ def _ApplyAttentionDropout(params, x):
57
57
58
58
if params .atten_dropout_deterministic :
59
59
seeds = py_utils .GenerateStepSeedPair (params )
60
- return py_utils .DeterministicDropout (x , 1.0 - params .atten_dropout_prob ,
61
- seeds )
60
+ return py_utils .DeterministicDropout (
61
+ x , 1.0 - params .atten_dropout_prob , seeds
62
+ )
62
63
else :
63
64
return tf .nn .dropout (
64
- x , rate = params .atten_dropout_prob , seed = params .random_seed )
65
+ x , rate = params .atten_dropout_prob , seed = params .random_seed
66
+ )
65
67
66
68
67
69
def SafeCumprod (x , * args , ** kwargs ):
@@ -86,7 +88,9 @@ def SafeCumprod(x, *args, **kwargs):
86
88
tiny = np .finfo (x .dtype .as_numpy_dtype ).tiny
87
89
return tf .exp (
88
90
py_utils .CumSum (
89
- tf .math .log (tf .clip_by_value (x , tiny , 1 )), * args , ** kwargs ))
91
+ tf .math .log (tf .clip_by_value (x , tiny , 1 )), * args , ** kwargs
92
+ )
93
+ )
90
94
91
95
92
96
# pyformat: disable
@@ -1420,15 +1424,15 @@ def PackSource(self,
1420
1424
Args:
1421
1425
theta: A `.NestedMap` object containing weights' values of this layer and
1422
1426
its children layers.
1423
- source_vecs: A tensor of shape [time, source_batch , source_dim].
1424
- source_contexts: A tensor of shape [time, source_batch , context_dim].
1425
- source_padding: A tensor of shape [time, source_batch ].
1426
- source_segment_id: A tensor of shape [time, source_batch ].
1427
+ source_vecs: A tensor of shape [time_steps, batch_size , source_dim].
1428
+ source_contexts: A tensor of shape [time_steps, batch_size , context_dim].
1429
+ source_padding: A tensor of shape [time_steps, batch_size ].
1430
+ source_segment_id: A tensor of shape [time_steps, batch_size ].
1427
1431
1428
1432
Returns:
1429
1433
A NestedMap representing packed src. It will have the same structure
1430
- as the one returned by the inner atten, except that source_batch will be
1431
- source_batch * num_heads.
1434
+ as the one returned by the inner atten, except that batch_size will be
1435
+ batch_size * num_heads.
1432
1436
"""
1433
1437
1434
1438
p = self .params
@@ -1439,52 +1443,52 @@ def PackSource(self,
1439
1443
assert p .query_dim == p .hidden_dim
1440
1444
# Check input tensor shapes
1441
1445
source_vecs = py_utils .HasRank (source_vecs , 3 )
1442
- [time_steps , source_batch ] = py_utils .GetShape (source_vecs , 2 )
1446
+ [time_steps , batch_size ] = py_utils .GetShape (source_vecs , 2 )
1443
1447
if p .use_source_vec_as_attention_value :
1444
1448
assert source_contexts is not None
1445
1449
source_contexts = py_utils .HasShape (source_contexts ,
1446
- [time_steps , source_batch , - 1 ])
1450
+ [time_steps , batch_size , - 1 ])
1447
1451
source_padding = py_utils .HasShape (source_padding ,
1448
- [time_steps , source_batch ])
1452
+ [time_steps , batch_size ])
1449
1453
if source_segment_id is not None :
1450
1454
source_segment_id = py_utils .HasShape (source_segment_id ,
1451
- [time_steps , source_batch ])
1455
+ [time_steps , batch_size ])
1452
1456
1453
1457
with tf .name_scope ('init__0' ):
1454
- # source_projected shape [time * source_batch, hidden]
1458
+ # source_vecs shape after (optional) projection is
1459
+ # [time_steps, batch_size, hidden]
1455
1460
if p .enable_source_proj :
1456
1461
source_vecs , w_source_proj = self .ToAqtInputs (
1457
1462
'source_proj' ,
1458
1463
act = source_vecs ,
1459
1464
weight = theta .source_proj ,
1460
1465
w_feature_axis = - 1 )
1461
1466
w_source_proj = self .QWeight (w_source_proj )
1462
- source_projected = self .QMatmul (source_vecs , w_source_proj )
1463
- source_projected = self .QAct ('source_proj_matmul' , source_projected )
1464
- source_projected = self .FromAqtMatmul ('source_proj' , source_projected )
1467
+ source_vecs = self .QMatmul (source_vecs , w_source_proj )
1468
+ source_vecs = self .QAct ('source_proj_matmul' , source_vecs )
1469
+ source_vecs = self .FromAqtMatmul ('source_proj' , source_vecs )
1465
1470
if p .use_bias :
1466
- source_projected = fns .qadd (
1467
- source_projected ,
1471
+ source_vecs = fns .qadd (
1472
+ source_vecs ,
1468
1473
self .QWeight (theta .source_proj_b ),
1469
1474
qout_name = 'source_proj_add' )
1470
- else :
1471
- source_projected = source_vecs
1472
- source_projected = gshard_utils .MeshSplit (
1473
- source_projected , p .device_mesh ,
1475
+ source_vecs = gshard_utils .MeshSplit (
1476
+ source_vecs ,
1477
+ p .device_mesh ,
1474
1478
p .activation_split_dims_mapping )
1475
1479
with tf .name_scope ('init__1' ):
1476
1480
num_heads = p .num_attention_heads
1477
- # => [time, source_batch * num_heads, hidden / num_heads]
1481
+ # => [time_steps, batch_size * num_heads, hidden / num_heads]
1478
1482
[time_steps_vecs ] = py_utils .GetShape (source_vecs , 1 )
1479
- source_projected = tf .reshape (source_projected , [
1480
- time_steps_vecs , - 1 , symbolic . ToStatic ( p . hidden_dim // num_heads )
1481
- ])
1482
- source_projected = gshard_utils .MeshSplit (source_projected , p . device_mesh ,
1483
- p . activation_split_dims_mapping )
1484
- source_projected = self . ProcessProjectionVec ( theta , source_projected ,
1485
- 'source' )
1483
+ source_vecs = tf .reshape (
1484
+ source_vecs ,
1485
+ [ time_steps_vecs , - 1 , symbolic . ToStatic ( p . hidden_dim // num_heads ) ])
1486
+ source_vecs = gshard_utils .MeshSplit (source_vecs ,
1487
+ p . device_mesh ,
1488
+ p . activation_split_dims_mapping )
1489
+ source_vecs = self . ProcessProjectionVec ( theta , source_vecs , 'source' )
1486
1490
if p .use_source_vec_as_attention_value :
1487
- source_contexts_projected = source_projected
1491
+ source_contexts = source_vecs
1488
1492
else :
1489
1493
if p .enable_ctx_pre_proj :
1490
1494
source_contexts , w_ctx_proj = self .ToAqtInputs (
@@ -1493,50 +1497,51 @@ def PackSource(self,
1493
1497
weight = theta .ctx_proj ,
1494
1498
w_feature_axis = - 1 )
1495
1499
w_ctx_proj = self .QWeight (w_ctx_proj )
1496
- source_contexts_projected = self .QMatmul (source_contexts , w_ctx_proj )
1497
- source_contexts_projected = self .QAct ('ctx_pre_proj_matmul' ,
1498
- source_contexts_projected )
1499
- source_contexts_projected = self .FromAqtMatmul (
1500
- 'ctx_proj' , source_contexts_projected )
1500
+ source_contexts = self .QMatmul (source_contexts , w_ctx_proj )
1501
+ source_contexts = self .QAct ('ctx_pre_proj_matmul' , source_contexts )
1502
+ source_contexts = self .FromAqtMatmul ('ctx_proj' , source_contexts )
1501
1503
if p .use_bias :
1502
- source_contexts_projected = fns .qadd (
1503
- source_contexts_projected ,
1504
+ source_contexts = fns .qadd (
1505
+ source_contexts ,
1504
1506
self .QWeight (theta .ctx_proj_b ),
1505
1507
qout_name = 'ctx_pre_proj_add' )
1506
- source_contexts_projected = gshard_utils .MeshSplit (
1507
- source_contexts_projected , p .device_mesh ,
1508
+ source_contexts = gshard_utils .MeshSplit (
1509
+ source_contexts ,
1510
+ p .device_mesh ,
1508
1511
p .activation_split_dims_mapping )
1509
- else :
1510
- source_contexts_projected = source_contexts
1511
1512
1512
- time_steps_contexts = py_utils .GetShape (source_contexts_projected )[0 ]
1513
- source_context_depth = py_utils .GetShape (source_contexts_projected )[- 1 ]
1514
- source_contexts_projected = tf .reshape (
1515
- source_contexts_projected ,
1513
+ time_steps_contexts = py_utils .GetShape (source_contexts )[0 ]
1514
+ source_context_depth = py_utils .GetShape (source_contexts )[- 1 ]
1515
+ source_contexts = tf .reshape (
1516
+ source_contexts ,
1516
1517
[time_steps_contexts , - 1 , source_context_depth // num_heads ])
1517
- source_contexts_projected = gshard_utils .MeshSplit (
1518
- source_contexts_projected , p .device_mesh ,
1518
+ source_contexts = gshard_utils .MeshSplit (
1519
+ source_contexts ,
1520
+ p .device_mesh ,
1519
1521
p .activation_split_dims_mapping )
1520
- source_contexts_projected = self .ProcessProjectionVec (
1521
- theta , source_contexts_projected , 'ctx' )
1522
+ source_contexts = self .ProcessProjectionVec (theta ,
1523
+ source_contexts ,
1524
+ 'ctx' )
1522
1525
1523
1526
with tf .name_scope ('init__2' ):
1524
1527
[time_steps_paddings ] = py_utils .GetShape (source_padding , 1 )
1525
- source_padding_replicated = tf .reshape (
1526
- tf .tile (tf . expand_dims ( source_padding , 2 ), [1 , 1 , num_heads ]),
1527
- [time_steps_paddings , - 1 ])
1528
+ source_padding = tf .expand_dims ( source_padding , 2 )
1529
+ source_padding = tf .tile (source_padding , [1 , 1 , num_heads ])
1530
+ source_padding = tf . reshape ( source_padding , [time_steps_paddings , - 1 ])
1528
1531
if source_segment_id is None :
1529
- source_segment_id_repl = tf .zeros_like (source_padding_replicated )
1532
+ source_segment_id = tf .zeros_like (source_padding )
1530
1533
else :
1531
1534
[time_steps_segment_id ] = py_utils .GetShape (source_segment_id , 1 )
1532
- source_segment_id_repl = tf .reshape (
1533
- tf .tile (tf .expand_dims (source_segment_id , 2 ), [1 , 1 , num_heads ]),
1534
- [time_steps_segment_id , - 1 ])
1535
-
1536
- return self .atten .PackSource (theta .atten , source_projected ,
1537
- source_contexts_projected ,
1538
- source_padding_replicated ,
1539
- source_segment_id_repl )
1535
+ source_segment_id = tf .expand_dims (source_segment_id , 2 )
1536
+ source_segment_id = tf .tile (source_segment_id , [1 , 1 , num_heads ])
1537
+ source_segment_id = tf .reshape (source_segment_id ,
1538
+ [time_steps_segment_id , - 1 ])
1539
+
1540
+ return self .atten .PackSource (theta .atten ,
1541
+ source_vecs ,
1542
+ source_contexts ,
1543
+ source_padding ,
1544
+ source_segment_id )
1540
1545
1541
1546
@py_utils .NameScopeDecorator ('MultiHeadedAttention/ExtendSourcePacked' )
1542
1547
def ExtendSourcePacked (self ,
0 commit comments