@@ -63,9 +63,7 @@ def multi_head_attention_forward(
63
63
is_causal : bool = False ,
64
64
) -> Tuple [Tensor , Optional [Tensor ]]:
65
65
66
- is_batched = _mha_shape_check (
67
- query , key , value , key_padding_mask , attn_mask , num_heads
68
- )
66
+ is_batched = _mha_shape_check (query , key , value , key_padding_mask , attn_mask , num_heads )
69
67
70
68
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
71
69
# is batched, run the computation and before returning squeeze the
@@ -126,37 +124,25 @@ def multi_head_attention_forward(
126
124
head_dim = embed_dim .div (num_heads , rounding_mode = "trunc" )
127
125
else :
128
126
head_dim = embed_dim // num_heads
129
- assert (
130
- head_dim * num_heads == embed_dim
131
- ), f"embed_dim { embed_dim } not divisible by num_heads { num_heads } "
127
+ assert head_dim * num_heads == embed_dim , f"embed_dim { embed_dim } not divisible by num_heads { num_heads } "
132
128
if use_separate_proj_weight :
133
129
# allow MHA to have different embedding dimensions when separate projection weights are used
134
130
assert (
135
131
key .shape [:2 ] == value .shape [:2 ]
136
132
), f"key's sequence and batch dims { key .shape [:2 ]} do not match value's { value .shape [:2 ]} "
137
133
else :
138
- assert (
139
- key .shape == value .shape
140
- ), f"key shape { key .shape } does not match value shape { value .shape } "
134
+ assert key .shape == value .shape , f"key shape { key .shape } does not match value shape { value .shape } "
141
135
142
136
#
143
137
# compute in-projection
144
138
#
145
139
if not use_separate_proj_weight :
146
- assert (
147
- in_proj_weight is not None
148
- ), "use_separate_proj_weight is False but in_proj_weight is None"
140
+ assert in_proj_weight is not None , "use_separate_proj_weight is False but in_proj_weight is None"
149
141
q , k , v = _in_projection_packed (query , key , value , in_proj_weight , in_proj_bias )
150
142
else :
151
- assert (
152
- q_proj_weight is not None
153
- ), "use_separate_proj_weight is True but q_proj_weight is None"
154
- assert (
155
- k_proj_weight is not None
156
- ), "use_separate_proj_weight is True but k_proj_weight is None"
157
- assert (
158
- v_proj_weight is not None
159
- ), "use_separate_proj_weight is True but v_proj_weight is None"
143
+ assert q_proj_weight is not None , "use_separate_proj_weight is True but q_proj_weight is None"
144
+ assert k_proj_weight is not None , "use_separate_proj_weight is True but k_proj_weight is None"
145
+ assert v_proj_weight is not None , "use_separate_proj_weight is True but v_proj_weight is None"
160
146
if in_proj_bias is None :
161
147
b_q = b_k = b_v = None
162
148
else :
@@ -191,9 +177,7 @@ def multi_head_attention_forward(
191
177
f"The shape of the 3D attn_mask is { attn_mask .shape } , but should be { correct_3d_size } ."
192
178
)
193
179
else :
194
- raise RuntimeError (
195
- f"attn_mask's dimension { attn_mask .dim ()} is not supported"
196
- )
180
+ raise RuntimeError (f"attn_mask's dimension { attn_mask .dim ()} is not supported" )
197
181
198
182
# add bias along batch dimension (currently second)
199
183
if bias_k is not None and bias_v is not None :
@@ -220,9 +204,7 @@ def multi_head_attention_forward(
220
204
assert (
221
205
static_k .size (0 ) == bsz * num_heads
222
206
), f"expecting static_k.size(0) of { bsz * num_heads } , but got { static_k .size (0 )} "
223
- assert (
224
- static_k .size (2 ) == head_dim
225
- ), f"expecting static_k.size(2) of { head_dim } , but got { static_k .size (2 )} "
207
+ assert static_k .size (2 ) == head_dim , f"expecting static_k.size(2) of { head_dim } , but got { static_k .size (2 )} "
226
208
k = static_k
227
209
if static_v is None :
228
210
v = v .view (v .shape [0 ], bsz * num_heads , head_dim ).transpose (0 , 1 )
@@ -231,20 +213,14 @@ def multi_head_attention_forward(
231
213
assert (
232
214
static_v .size (0 ) == bsz * num_heads
233
215
), f"expecting static_v.size(0) of { bsz * num_heads } , but got { static_v .size (0 )} "
234
- assert (
235
- static_v .size (2 ) == head_dim
236
- ), f"expecting static_v.size(2) of { head_dim } , but got { static_v .size (2 )} "
216
+ assert static_v .size (2 ) == head_dim , f"expecting static_v.size(2) of { head_dim } , but got { static_v .size (2 )} "
237
217
v = static_v
238
218
239
219
# add zero attention along batch dimension (now first)
240
220
if add_zero_attn :
241
221
zero_attn_shape = (bsz * num_heads , 1 , head_dim )
242
- k = torch .cat (
243
- [k , torch .zeros (zero_attn_shape , dtype = k .dtype , device = k .device )], dim = 1
244
- )
245
- v = torch .cat (
246
- [v , torch .zeros (zero_attn_shape , dtype = v .dtype , device = v .device )], dim = 1
247
- )
222
+ k = torch .cat ([k , torch .zeros (zero_attn_shape , dtype = k .dtype , device = k .device )], dim = 1 )
223
+ v = torch .cat ([v , torch .zeros (zero_attn_shape , dtype = v .dtype , device = v .device )], dim = 1 )
248
224
if attn_mask is not None :
249
225
attn_mask = pad (attn_mask , (0 , 1 ))
250
226
if key_padding_mask is not None :
@@ -259,9 +235,7 @@ def multi_head_attention_forward(
259
235
_check_key_padding_mask (key_padding_mask , src_len , bsz )
260
236
261
237
key_padding_mask = (
262
- key_padding_mask .view (bsz , 1 , 1 , src_len )
263
- .expand (- 1 , num_heads , - 1 , - 1 )
264
- .reshape (bsz * num_heads , 1 , src_len )
238
+ key_padding_mask .view (bsz , 1 , 1 , src_len ).expand (- 1 , num_heads , - 1 , - 1 ).reshape (bsz * num_heads , 1 , src_len )
265
239
)
266
240
if attn_mask is None :
267
241
attn_mask = key_padding_mask
@@ -280,14 +254,10 @@ def multi_head_attention_forward(
280
254
_B , _Nt , E = q .shape
281
255
q_scaled = q * math .sqrt (1.0 / float (E ))
282
256
283
- assert not (
284
- is_causal and attn_mask is None
285
- ), "FIXME: is_causal not implemented for need_weights"
257
+ assert not (is_causal and attn_mask is None ), "FIXME: is_causal not implemented for need_weights"
286
258
287
259
if attn_mask is not None :
288
- attn_output_weights = torch .baddbmm (
289
- attn_mask , q_scaled , k .transpose (- 2 , - 1 )
290
- )
260
+ attn_output_weights = torch .baddbmm (attn_mask , q_scaled , k .transpose (- 2 , - 1 ))
291
261
else :
292
262
attn_output_weights = torch .bmm (q_scaled , k .transpose (- 2 , - 1 ))
293
263
attn_output_weights = softmax (attn_output_weights , dim = - 1 )
@@ -296,9 +266,7 @@ def multi_head_attention_forward(
296
266
297
267
attn_output = torch .bmm (attn_output_weights , v )
298
268
299
- attn_output = (
300
- attn_output .transpose (0 , 1 ).contiguous ().view (tgt_len * bsz , embed_dim )
301
- )
269
+ attn_output = attn_output .transpose (0 , 1 ).contiguous ().view (tgt_len * bsz , embed_dim )
302
270
attn_output = linear (attn_output , out_proj_weight , out_proj_bias )
303
271
attn_output = attn_output .view (tgt_len , bsz , attn_output .size (1 ))
304
272
@@ -326,12 +294,8 @@ def multi_head_attention_forward(
326
294
k = k .view (bsz , num_heads , src_len , head_dim )
327
295
v = v .view (bsz , num_heads , src_len , head_dim )
328
296
329
- attn_output = scaled_dot_product_attention (
330
- q , k , v , attn_mask , dropout_p , is_causal
331
- )
332
- attn_output = (
333
- attn_output .permute (2 , 0 , 1 , 3 ).contiguous ().view (bsz * tgt_len , embed_dim )
334
- )
297
+ attn_output = scaled_dot_product_attention (q , k , v , attn_mask , dropout_p , is_causal )
298
+ attn_output = attn_output .permute (2 , 0 , 1 , 3 ).contiguous ().view (bsz * tgt_len , embed_dim )
335
299
336
300
attn_output = linear (attn_output , out_proj_weight , out_proj_bias )
337
301
attn_output = attn_output .view (tgt_len , bsz , attn_output .size (1 ))
0 commit comments