@@ -60,6 +60,7 @@ def get_config(self):
60
60
}
61
61
62
62
def build (self , inputs_shape ):
63
+
63
64
# Transformation for linearly projecting the queries, keys, and values.
64
65
self .q_transformation = self ._get_weights (
65
66
"q_project" , shape = (self .hidden_size , self .hidden_size ), init = tf .initializers .get ('glorot_uniform' )
@@ -75,20 +76,7 @@ def build(self, inputs_shape):
75
76
)
76
77
77
78
def split_heads (self , x ):
78
- """Split x into different heads, and transpose the resulting value.
79
-
80
- The tensor is transposed to insure the inner dimensions hold the correct
81
- values during the matrix multiplication.
82
79
83
- Parameters
84
- -----------
85
-
86
- x: A tensor with shape [batch_size, length, hidden_size]
87
-
88
- Returns:
89
- -----------
90
- A tensor with shape [batch_size, num_heads, length, hidden_size/num_heads]
91
- """
92
80
with tf .name_scope ("split_heads" ):
93
81
batch_size = tf .shape (x )[0 ]
94
82
length = tf .shape (x )[1 ]
@@ -103,40 +91,15 @@ def split_heads(self, x):
103
91
return tf .transpose (x , [0 , 2 , 1 , 3 ])
104
92
105
93
def combine_heads (self , x ):
106
- """Combine tensor that has been split.
107
-
108
- Args:
109
- x: A tensor [batch_size, num_heads, length, hidden_size/num_heads]
110
94
111
- Returns:
112
- -----------
113
- A tensor with shape [batch_size, length, hidden_size]
114
- """
115
95
with tf .name_scope ("combine_heads" ):
116
96
batch_size = tf .shape (x )[0 ]
117
97
length = tf .shape (x )[2 ]
118
98
x = tf .transpose (x , [0 , 2 , 1 , 3 ]) # --> [batch, length, num_heads, depth]
119
99
return tf .reshape (x , [batch_size , length , self .hidden_size ])
120
100
121
101
def forward (self , x , y , mask , cache = None ):
122
- """Apply attention mechanism to x and y.
123
-
124
- Args:
125
- x: a tensor with shape [batch_size, length_x, hidden_size]
126
- y: a tensor with shape [batch_size, length_y, hidden_size]
127
- mask: attention bias that will be added to the result of the dot product.
128
- training: boolean, whether in training mode or not.
129
- cache: (Used during prediction) dictionary with tensors containing results
130
- of previous attentions. The dictionary must have the items:
131
- {"k": tensor with shape [batch_size, i, key_channels],
132
- "v": tensor with shape [batch_size, i, value_channels]}
133
- where i is the current decoded length.
134
-
135
- Returns:
136
- -----------
137
- Attention layer output with shape [batch_size, length_x, hidden_size]
138
- Attention weights with shape [batch_size, number_of_head, length_x, length_y]
139
- """
102
+ """Apply attention mechanism to x and y."""
140
103
# Linearly project the query (q), key (k) and value (v) using different
141
104
# learned projections. This is in preparation of splitting them into
142
105
# multiple heads. Multi-head attention uses multiple queries, keys, and
0 commit comments