Skip to content

Commit

Permalink
revert to use F.scaled_dot_product_attention in cross attention
Browse files Browse the repository at this point in the history
Signed-off-by: yan ma <[email protected]>
  • Loading branch information
yma11 committed Feb 25, 2025
1 parent e11d833 commit 5731b5a
Showing 1 changed file with 8 additions and 15 deletions.
23 changes: 8 additions & 15 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,21 +949,14 @@ def _attention_with_mask(
kv_len,
self.head_dim).contiguous()
attention_mask = attention_mask.view(1, 1, q_len, kv_len)
if current_platform.is_hpu():
from habana_frameworks.torch.hpex.kernels import FusedSDPA
output = FusedSDPA.apply(q, k, v, attention_mask)
output = output.permute(2, 0, 1, 3).reshape(
q_len, self.num_local_heads * self.head_dim)
return output
else:
output = F.scaled_dot_product_attention(q,
k,
v,
attn_mask=attention_mask,
is_causal=False)
output = output.permute(2, 0, 1, 3).reshape(
q_len, self.num_local_heads * self.head_dim)
return output
output = F.scaled_dot_product_attention(q,
k,
v,
attn_mask=attention_mask,
is_causal=False)
output = output.permute(2, 0, 1, 3).reshape(
q_len, self.num_local_heads * self.head_dim)
return output


class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
Expand Down

0 comments on commit 5731b5a

Please sign in to comment.