Skip to content

Commit

Permalink
use fast as softmax_mode for better accuracy and perf
Browse files Browse the repository at this point in the history
Signed-off-by: yan ma <[email protected]>
  • Loading branch information
yma11 committed Mar 4, 2025
1 parent ca9d699 commit 51ce85e
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,19 @@ def _attention_with_mask(
attention_mask = attention_mask.view(1, 1, q_len, kv_len)
if current_platform.is_hpu():
from habana_frameworks.torch.hpex.kernels import FusedSDPA
output = FusedSDPA.apply(q, k, v, attention_mask)
from vllm_hpu_extension.utils import ModuleFusedSDPA
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
# use fp32 as softmax_mode for better accuracy
output = fsdpa_op(q,
k,
v,
attention_mask,
dropout_p=0.0,
is_causal=False,
scale=None,
softmax_mode="fast",
recompute_mode=None,
valid_sequence_lengths=None)
output = output.permute(2, 0, 1, 3).reshape(
q_len, self.num_local_heads * self.head_dim)
return output
Expand Down

0 comments on commit 51ce85e

Please sign in to comment.