@@ -63,9 +63,7 @@ def multi_head_attention_forward(
6363 is_causal : bool = False ,
6464) -> Tuple [Tensor , Optional [Tensor ]]:
6565
66- is_batched = _mha_shape_check (
67- query , key , value , key_padding_mask , attn_mask , num_heads
68- )
66+ is_batched = _mha_shape_check (query , key , value , key_padding_mask , attn_mask , num_heads )
6967
7068 # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
7169 # is batched, run the computation and before returning squeeze the
@@ -126,37 +124,25 @@ def multi_head_attention_forward(
126124 head_dim = embed_dim .div (num_heads , rounding_mode = "trunc" )
127125 else :
128126 head_dim = embed_dim // num_heads
129- assert (
130- head_dim * num_heads == embed_dim
131- ), f"embed_dim { embed_dim } not divisible by num_heads { num_heads } "
127+ assert head_dim * num_heads == embed_dim , f"embed_dim { embed_dim } not divisible by num_heads { num_heads } "
132128 if use_separate_proj_weight :
133129 # allow MHA to have different embedding dimensions when separate projection weights are used
134130 assert (
135131 key .shape [:2 ] == value .shape [:2 ]
136132 ), f"key's sequence and batch dims { key .shape [:2 ]} do not match value's { value .shape [:2 ]} "
137133 else :
138- assert (
139- key .shape == value .shape
140- ), f"key shape { key .shape } does not match value shape { value .shape } "
134+ assert key .shape == value .shape , f"key shape { key .shape } does not match value shape { value .shape } "
141135
142136 #
143137 # compute in-projection
144138 #
145139 if not use_separate_proj_weight :
146- assert (
147- in_proj_weight is not None
148- ), "use_separate_proj_weight is False but in_proj_weight is None"
140+ assert in_proj_weight is not None , "use_separate_proj_weight is False but in_proj_weight is None"
149141 q , k , v = _in_projection_packed (query , key , value , in_proj_weight , in_proj_bias )
150142 else :
151- assert (
152- q_proj_weight is not None
153- ), "use_separate_proj_weight is True but q_proj_weight is None"
154- assert (
155- k_proj_weight is not None
156- ), "use_separate_proj_weight is True but k_proj_weight is None"
157- assert (
158- v_proj_weight is not None
159- ), "use_separate_proj_weight is True but v_proj_weight is None"
143+ assert q_proj_weight is not None , "use_separate_proj_weight is True but q_proj_weight is None"
144+ assert k_proj_weight is not None , "use_separate_proj_weight is True but k_proj_weight is None"
145+ assert v_proj_weight is not None , "use_separate_proj_weight is True but v_proj_weight is None"
160146 if in_proj_bias is None :
161147 b_q = b_k = b_v = None
162148 else :
@@ -191,9 +177,7 @@ def multi_head_attention_forward(
191177 f"The shape of the 3D attn_mask is { attn_mask .shape } , but should be { correct_3d_size } ."
192178 )
193179 else :
194- raise RuntimeError (
195- f"attn_mask's dimension { attn_mask .dim ()} is not supported"
196- )
180+ raise RuntimeError (f"attn_mask's dimension { attn_mask .dim ()} is not supported" )
197181
198182 # add bias along batch dimension (currently second)
199183 if bias_k is not None and bias_v is not None :
@@ -220,9 +204,7 @@ def multi_head_attention_forward(
220204 assert (
221205 static_k .size (0 ) == bsz * num_heads
222206 ), f"expecting static_k.size(0) of { bsz * num_heads } , but got { static_k .size (0 )} "
223- assert (
224- static_k .size (2 ) == head_dim
225- ), f"expecting static_k.size(2) of { head_dim } , but got { static_k .size (2 )} "
207+ assert static_k .size (2 ) == head_dim , f"expecting static_k.size(2) of { head_dim } , but got { static_k .size (2 )} "
226208 k = static_k
227209 if static_v is None :
228210 v = v .view (v .shape [0 ], bsz * num_heads , head_dim ).transpose (0 , 1 )
@@ -231,20 +213,14 @@ def multi_head_attention_forward(
231213 assert (
232214 static_v .size (0 ) == bsz * num_heads
233215 ), f"expecting static_v.size(0) of { bsz * num_heads } , but got { static_v .size (0 )} "
234- assert (
235- static_v .size (2 ) == head_dim
236- ), f"expecting static_v.size(2) of { head_dim } , but got { static_v .size (2 )} "
216+ assert static_v .size (2 ) == head_dim , f"expecting static_v.size(2) of { head_dim } , but got { static_v .size (2 )} "
237217 v = static_v
238218
239219 # add zero attention along batch dimension (now first)
240220 if add_zero_attn :
241221 zero_attn_shape = (bsz * num_heads , 1 , head_dim )
242- k = torch .cat (
243- [k , torch .zeros (zero_attn_shape , dtype = k .dtype , device = k .device )], dim = 1
244- )
245- v = torch .cat (
246- [v , torch .zeros (zero_attn_shape , dtype = v .dtype , device = v .device )], dim = 1
247- )
222+ k = torch .cat ([k , torch .zeros (zero_attn_shape , dtype = k .dtype , device = k .device )], dim = 1 )
223+ v = torch .cat ([v , torch .zeros (zero_attn_shape , dtype = v .dtype , device = v .device )], dim = 1 )
248224 if attn_mask is not None :
249225 attn_mask = pad (attn_mask , (0 , 1 ))
250226 if key_padding_mask is not None :
@@ -259,9 +235,7 @@ def multi_head_attention_forward(
259235 _check_key_padding_mask (key_padding_mask , src_len , bsz )
260236
261237 key_padding_mask = (
262- key_padding_mask .view (bsz , 1 , 1 , src_len )
263- .expand (- 1 , num_heads , - 1 , - 1 )
264- .reshape (bsz * num_heads , 1 , src_len )
238+ key_padding_mask .view (bsz , 1 , 1 , src_len ).expand (- 1 , num_heads , - 1 , - 1 ).reshape (bsz * num_heads , 1 , src_len )
265239 )
266240 if attn_mask is None :
267241 attn_mask = key_padding_mask
@@ -280,14 +254,10 @@ def multi_head_attention_forward(
280254 _B , _Nt , E = q .shape
281255 q_scaled = q * math .sqrt (1.0 / float (E ))
282256
283- assert not (
284- is_causal and attn_mask is None
285- ), "FIXME: is_causal not implemented for need_weights"
257+ assert not (is_causal and attn_mask is None ), "FIXME: is_causal not implemented for need_weights"
286258
287259 if attn_mask is not None :
288- attn_output_weights = torch .baddbmm (
289- attn_mask , q_scaled , k .transpose (- 2 , - 1 )
290- )
260+ attn_output_weights = torch .baddbmm (attn_mask , q_scaled , k .transpose (- 2 , - 1 ))
291261 else :
292262 attn_output_weights = torch .bmm (q_scaled , k .transpose (- 2 , - 1 ))
293263 attn_output_weights = softmax (attn_output_weights , dim = - 1 )
@@ -296,9 +266,7 @@ def multi_head_attention_forward(
296266
297267 attn_output = torch .bmm (attn_output_weights , v )
298268
299- attn_output = (
300- attn_output .transpose (0 , 1 ).contiguous ().view (tgt_len * bsz , embed_dim )
301- )
269+ attn_output = attn_output .transpose (0 , 1 ).contiguous ().view (tgt_len * bsz , embed_dim )
302270 attn_output = linear (attn_output , out_proj_weight , out_proj_bias )
303271 attn_output = attn_output .view (tgt_len , bsz , attn_output .size (1 ))
304272
@@ -326,12 +294,8 @@ def multi_head_attention_forward(
326294 k = k .view (bsz , num_heads , src_len , head_dim )
327295 v = v .view (bsz , num_heads , src_len , head_dim )
328296
329- attn_output = scaled_dot_product_attention (
330- q , k , v , attn_mask , dropout_p , is_causal
331- )
332- attn_output = (
333- attn_output .permute (2 , 0 , 1 , 3 ).contiguous ().view (bsz * tgt_len , embed_dim )
334- )
297+ attn_output = scaled_dot_product_attention (q , k , v , attn_mask , dropout_p , is_causal )
298+ attn_output = attn_output .permute (2 , 0 , 1 , 3 ).contiguous ().view (bsz * tgt_len , embed_dim )
335299
336300 attn_output = linear (attn_output , out_proj_weight , out_proj_bias )
337301 attn_output = attn_output .view (tgt_len , bsz , attn_output .size (1 ))
0 commit comments