@@ -1442,31 +1442,26 @@ def PackSource(self,
1442
1442
py_utils .GetShape (source_contexts ))
1443
1443
time_steps , batch_size = py_utils .GetShape (source_padding , 2 )
1444
1444
# source_projected shape [time * source_batch, hidden]
1445
- with tf .name_scope ('init__0a' ):
1446
- source_vec_depth = py_utils .GetShape (source_vecs )[2 ]
1447
- with tf .name_scope ('init__0b' ):
1448
- if p .enable_source_proj :
1449
- source_vecs = tf .reshape (source_vecs , [- 1 , source_vec_depth ])
1450
- source_vecs , w_source_proj = self .ToAqtInputs (
1451
- 'source_proj' ,
1452
- act = source_vecs ,
1453
- weight = theta .source_proj ,
1454
- w_feature_axis = - 1 )
1455
- w_source_proj = self .QWeight (w_source_proj )
1456
- source_projected = tf .matmul (source_vecs , w_source_proj )
1457
- source_projected = self .QAct ('source_proj_matmul' , source_projected )
1458
- source_projected = self .FromAqtMatmul ('source_proj' , source_projected )
1459
- if p .use_bias :
1460
- source_projected = fns .qadd (
1461
- source_projected ,
1462
- self .QWeight (theta .source_proj_b ),
1463
- qout_name = 'source_proj_add' )
1464
- else :
1465
- source_projected = tf .reshape (source_vecs , [- 1 , source_vec_depth ])
1466
- if p .activation_split_dims_mapping :
1467
- source_projected = gshard_utils .MeshSplit (
1468
- source_projected , p .device_mesh ,
1469
- p .activation_split_dims_mapping [1 :])
1445
+ if p .enable_source_proj :
1446
+ source_vecs , w_source_proj = self .ToAqtInputs (
1447
+ 'source_proj' ,
1448
+ act = source_vecs ,
1449
+ weight = theta .source_proj ,
1450
+ w_feature_axis = - 1 )
1451
+ w_source_proj = self .QWeight (w_source_proj )
1452
+ source_projected = tf .matmul (source_vecs , w_source_proj )
1453
+ source_projected = self .QAct ('source_proj_matmul' , source_projected )
1454
+ source_projected = self .FromAqtMatmul ('source_proj' , source_projected )
1455
+ if p .use_bias :
1456
+ source_projected = fns .qadd (
1457
+ source_projected ,
1458
+ self .QWeight (theta .source_proj_b ),
1459
+ qout_name = 'source_proj_add' )
1460
+ else :
1461
+ source_projected = source_vecs
1462
+ source_projected = gshard_utils .MeshSplit (
1463
+ source_projected , p .device_mesh ,
1464
+ p .activation_split_dims_mapping )
1470
1465
with tf .name_scope ('init__1' ):
1471
1466
num_heads = p .num_attention_heads
1472
1467
# => [time, source_batch * num_heads, hidden / num_heads]
@@ -1482,8 +1477,6 @@ def PackSource(self,
1482
1477
source_contexts_projected = source_projected
1483
1478
else :
1484
1479
if p .enable_ctx_pre_proj :
1485
- source_contexts = tf .reshape (
1486
- source_contexts , [- 1 , py_utils .GetShape (source_contexts )[2 ]])
1487
1480
source_contexts , w_ctx_proj = self .ToAqtInputs (
1488
1481
'ctx_proj' ,
1489
1482
act = source_contexts ,
@@ -1501,10 +1494,9 @@ def PackSource(self,
1501
1494
source_contexts_projected ,
1502
1495
self .QWeight (theta .ctx_proj_b ),
1503
1496
qout_name = 'ctx_pre_proj_add' )
1504
- if p .activation_split_dims_mapping :
1505
- source_contexts_projected = gshard_utils .MeshSplit (
1506
- source_contexts_projected , p .device_mesh ,
1507
- p .activation_split_dims_mapping [1 :])
1497
+ source_contexts_projected = gshard_utils .MeshSplit (
1498
+ source_contexts_projected , p .device_mesh ,
1499
+ p .activation_split_dims_mapping )
1508
1500
else :
1509
1501
source_contexts_projected = source_contexts
1510
1502
@@ -3304,7 +3296,7 @@ def Params(cls):
3304
3296
p .Define ('source_dim' , 0 , 'Default source dimension.' )
3305
3297
p .Define (
3306
3298
'query_dim' , 0 , 'Number of query nodes. Child attention params '
3307
- 'must have query_dim less or euqal than 0 or equal to this value.' )
3299
+ 'must have query_dim less or equal than 0 or equal to this value.' )
3308
3300
p .Define (
3309
3301
'primary_source_key' , 'source_0' , 'Key for the primary source '
3310
3302
'whose attention probabilities will be used as an output.' )
0 commit comments