@@ -64,19 +64,13 @@ def multi_head_attention_forward(
64
64
) -> Tuple [Tensor , Optional [Tensor ]]:
65
65
66
66
is_batched = _mha_shape_check (query , key , value , key_padding_mask , attn_mask , num_heads )
67
-
68
- # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
69
- # is batched, run the computation and before returning squeeze the
70
- # batch dimension so that the output doesn't carry this temporary batch dimension.
71
67
if not is_batched :
72
- # unsqueeze if the input is unbatched
73
68
query = query .unsqueeze (1 )
74
69
key = key .unsqueeze (1 )
75
70
value = value .unsqueeze (1 )
76
71
if key_padding_mask is not None :
77
72
key_padding_mask = key_padding_mask .unsqueeze (0 )
78
73
79
- # set up shape vars
80
74
tgt_len , bsz , embed_dim = query .shape
81
75
src_len , _ , _ = key .shape
82
76
@@ -96,9 +90,6 @@ def multi_head_attention_forward(
96
90
)
97
91
98
92
if is_causal and key_padding_mask is None and not need_weights :
99
- # when we have a kpm or need weights, we need attn_mask
100
- # Otherwise, we use the is_causal hint go as is_causal
101
- # indicator to SDPA.
102
93
attn_mask = None
103
94
else :
104
95
attn_mask = _canonical_mask (
@@ -111,31 +102,23 @@ def multi_head_attention_forward(
111
102
)
112
103
113
104
if key_padding_mask is not None :
114
- # We have the attn_mask, and use that to merge kpm into it.
115
- # Turn off use of is_causal hint, as the merged mask is no
116
- # longer causal.
117
105
is_causal = False
118
106
119
107
assert (
120
108
embed_dim == embed_dim_to_check
121
109
), f"was expecting embedding dimension of { embed_dim_to_check } , but got { embed_dim } "
122
110
if isinstance (embed_dim , torch .Tensor ):
123
- # embed_dim can be a tensor when JIT tracing
124
111
head_dim = embed_dim .div (num_heads , rounding_mode = "trunc" )
125
112
else :
126
113
head_dim = embed_dim // num_heads
127
114
assert head_dim * num_heads == embed_dim , f"embed_dim { embed_dim } not divisible by num_heads { num_heads } "
128
115
if use_separate_proj_weight :
129
- # allow MHA to have different embedding dimensions when separate projection weights are used
130
116
assert (
131
117
key .shape [:2 ] == value .shape [:2 ]
132
118
), f"key's sequence and batch dims { key .shape [:2 ]} do not match value's { value .shape [:2 ]} "
133
119
else :
134
120
assert key .shape == value .shape , f"key shape { key .shape } does not match value shape { value .shape } "
135
121
136
- #
137
- # compute in-projection
138
- #
139
122
if not use_separate_proj_weight :
140
123
assert in_proj_weight is not None , "use_separate_proj_weight is False but in_proj_weight is None"
141
124
q , k , v = _in_projection_packed (query , key , value , in_proj_weight , in_proj_bias )
@@ -159,10 +142,7 @@ def multi_head_attention_forward(
159
142
b_v ,
160
143
)
161
144
162
- # prep attention mask
163
-
164
145
if attn_mask is not None :
165
- # ensure attn_mask's dim is 3
166
146
if attn_mask .dim () == 2 :
167
147
correct_2d_size = (tgt_len , src_len )
168
148
if attn_mask .shape != correct_2d_size :
@@ -179,7 +159,6 @@ def multi_head_attention_forward(
179
159
else :
180
160
raise RuntimeError (f"attn_mask's dimension { attn_mask .dim ()} is not supported" )
181
161
182
- # add bias along batch dimension (currently second)
183
162
if bias_k is not None and bias_v is not None :
184
163
assert static_k is None , "bias cannot be added to static key."
185
164
assert static_v is None , "bias cannot be added to static value."
@@ -193,14 +172,10 @@ def multi_head_attention_forward(
193
172
assert bias_k is None
194
173
assert bias_v is None
195
174
196
- #
197
- # reshape q, k, v for multihead attention and make them batch first
198
- #
199
175
q = q .view (tgt_len , bsz * num_heads , head_dim ).transpose (0 , 1 )
200
176
if static_k is None :
201
177
k = k .view (k .shape [0 ], bsz * num_heads , head_dim ).transpose (0 , 1 )
202
178
else :
203
- # TODO finish disentangling control flow so we don't do in-projections when statics are passed
204
179
assert (
205
180
static_k .size (0 ) == bsz * num_heads
206
181
), f"expecting static_k.size(0) of { bsz * num_heads } , but got { static_k .size (0 )} "
@@ -209,14 +184,12 @@ def multi_head_attention_forward(
209
184
if static_v is None :
210
185
v = v .view (v .shape [0 ], bsz * num_heads , head_dim ).transpose (0 , 1 )
211
186
else :
212
- # TODO finish disentangling control flow so we don't do in-projections when statics are passed
213
187
assert (
214
188
static_v .size (0 ) == bsz * num_heads
215
189
), f"expecting static_v.size(0) of { bsz * num_heads } , but got { static_v .size (0 )} "
216
190
assert static_v .size (2 ) == head_dim , f"expecting static_v.size(2) of { head_dim } , but got { static_v .size (2 )} "
217
191
v = static_v
218
192
219
- # add zero attention along batch dimension (now first)
220
193
if add_zero_attn :
221
194
zero_attn_shape = (bsz * num_heads , 1 , head_dim )
222
195
k = torch .cat ([k , torch .zeros (zero_attn_shape , dtype = k .dtype , device = k .device )], dim = 1 )
@@ -226,10 +199,8 @@ def multi_head_attention_forward(
226
199
if key_padding_mask is not None :
227
200
key_padding_mask = pad (key_padding_mask , (0 , 1 ))
228
201
229
- # update source sequence length after adjustments
230
202
src_len = k .size (1 )
231
203
232
- # merge key padding and attention masks
233
204
if key_padding_mask is not None :
234
205
if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
235
206
_check_key_padding_mask (key_padding_mask , src_len , bsz )
@@ -242,16 +213,11 @@ def multi_head_attention_forward(
242
213
else :
243
214
attn_mask = attn_mask + key_padding_mask
244
215
245
- # adjust dropout probability
246
216
if not training :
247
217
dropout_p = 0.0
248
218
249
- #
250
- # (deep breath) calculate attention and out projection
251
- #
252
-
253
219
if need_weights :
254
- _B , _Nt , E = q .shape
220
+ _B , _Nt , E = q .shape # noqa: F841
255
221
q_scaled = q * math .sqrt (1.0 / float (E ))
256
222
257
223
assert not (is_causal and attn_mask is None ), "FIXME: is_causal not implemented for need_weights"
@@ -270,20 +236,15 @@ def multi_head_attention_forward(
270
236
attn_output = linear (attn_output , out_proj_weight , out_proj_bias )
271
237
attn_output = attn_output .view (tgt_len , bsz , attn_output .size (1 ))
272
238
273
- # optionally average attention weights over heads
274
239
attn_output_weights = attn_output_weights .view (bsz , num_heads , tgt_len , src_len )
275
240
if average_attn_weights :
276
241
attn_output_weights = attn_output_weights .mean (dim = 1 )
277
242
278
243
if not is_batched :
279
- # squeeze the output if input was unbatched
280
244
attn_output = attn_output .squeeze (1 )
281
245
attn_output_weights = attn_output_weights .squeeze (0 )
282
246
return attn_output , attn_output_weights
283
247
else :
284
- # attn_mask can be either (L,S) or (N*num_heads, L, S)
285
- # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
286
- # in order to match the input for SDPA of (N, num_heads, L, S)
287
248
if attn_mask is not None :
288
249
if attn_mask .size (0 ) == 1 and attn_mask .dim () == 3 :
289
250
attn_mask = attn_mask .unsqueeze (0 )
@@ -300,7 +261,6 @@ def multi_head_attention_forward(
300
261
attn_output = linear (attn_output , out_proj_weight , out_proj_bias )
301
262
attn_output = attn_output .view (tgt_len , bsz , attn_output .size (1 ))
302
263
if not is_batched :
303
- # squeeze the output if input was unbatched
304
264
attn_output = attn_output .squeeze (1 )
305
265
return attn_output , None
306
266
0 commit comments