Skip to content

Commit cd75438

Browse files
committed
Enable SDPA for torch>=2.3.0 since it now supports lower-right masking
1 parent 4af022a commit cd75438

File tree

1 file changed

+42
-63
lines changed

1 file changed

+42
-63
lines changed

exllamav2/attn.py

Lines changed: 42 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
from exllamav2.architecture import RopeStyle
1515
import math
1616
# from exllamav2.util import list_live_tensors, set_snapshot, diff_snapshot, print_vram_usage_peak
17-
# import torch.nn.functional as F
17+
import torch.nn.functional as F
1818

1919
from typing import TYPE_CHECKING
2020
if TYPE_CHECKING:
2121
from exllamav2.model import ExLlamaV2
2222

23-
# Detect flash-attn
23+
# Detect available options for attention
2424

2525
has_flash_attn = False
2626
has_flash_attn_with_paged = False
@@ -54,6 +54,12 @@
5454
except ModuleNotFoundError:
5555
pass
5656

57+
has_lower_right_sdpa = False
58+
try:
59+
from torch.nn.attention.bias import causal_lower_right
60+
has_lower_right_sdpa = True
61+
except ImportError:
62+
pass
5763

5864
def assert_paged_attn():
5965
global has_flash_attn_with_paged
@@ -666,24 +672,38 @@ def forward_paged(self,
666672
return hidden_states
667673

668674

669-
def _attn_matmul(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg):
675+
def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg):
670676

671-
q_states = q_states.transpose(1, 2)
672-
k_states = k_states.transpose(1, 2)
673-
v_states = v_states.transpose(1, 2)
677+
if has_lower_right_sdpa and attn_params.is_causal():
674678

675-
k_states = self.repeat_kv(k_states, cfg.num_key_value_groups)
676-
k_states = k_states.transpose(-1, -2)
679+
q_states = q_states.transpose(1, 2)
680+
k_states = k_states.transpose(1, 2)
681+
v_states = v_states.transpose(1, 2)
677682

678-
attn_weights = torch.matmul(q_states, k_states)
683+
k_states = self.repeat_kv(k_states, cfg.num_key_value_groups)
684+
v_states = self.repeat_kv(v_states, cfg.num_key_value_groups)
679685

680-
attn_weights *= 1 / math.sqrt(cfg.head_dim)
681-
attn_mask = attn_params.get_attn_mask(attn_weights.device)
682-
if attn_mask is not None: attn_weights = attn_weights + attn_mask
683-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float16)
686+
attn_mask_lr = causal_lower_right(q_len, k_states.shape[2])
687+
attn_output = F.scaled_dot_product_attention(q_states, k_states, v_states, attn_mask_lr)
684688

685-
v_states = self.repeat_kv(v_states, cfg.num_key_value_groups)
686-
attn_output = torch.matmul(attn_weights, v_states)
689+
else:
690+
691+
q_states = q_states.transpose(1, 2)
692+
k_states = k_states.transpose(1, 2)
693+
v_states = v_states.transpose(1, 2)
694+
695+
k_states = self.repeat_kv(k_states, cfg.num_key_value_groups)
696+
k_states = k_states.transpose(-1, -2)
697+
698+
attn_weights = torch.matmul(q_states, k_states)
699+
700+
attn_weights *= 1 / math.sqrt(cfg.head_dim)
701+
attn_mask = attn_params.get_attn_mask(attn_weights.device)
702+
if attn_mask is not None: attn_weights = attn_weights + attn_mask
703+
attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16)
704+
705+
v_states = self.repeat_kv(v_states, cfg.num_key_value_groups)
706+
attn_output = torch.matmul(attn_weights, v_states)
687707

688708
attn_output = attn_output.transpose(1, 2)
689709
attn_output = attn_output.reshape((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim))
@@ -821,7 +841,7 @@ def forward(self,
821841
elif (has_xformers and not cfg.no_xformers) and attn_params.is_causal():
822842
attn_func = self._attn_xformers
823843
else:
824-
attn_func = self._attn_matmul
844+
attn_func = self._attn_torch
825845

826846
# Straight attention without cache
827847

@@ -942,59 +962,18 @@ def forward_torch(self,
942962
use_flash_attn = has_flash_attn and not cfg.no_flash_attn
943963
use_xformers = has_xformers and not cfg.no_xformers
944964

945-
# Torch matmul attention
965+
# Select attention function
946966

947967
if not (use_flash_attn or use_xformers) or not attn_params.is_causal():
948-
949-
query_states = query_states.transpose(1, 2)
950-
key_states = key_states.transpose(1, 2)
951-
value_states = value_states.transpose(1, 2)
952-
953-
key_states = self.repeat_kv(key_states, cfg.num_key_value_groups)
954-
key_states = key_states.transpose(-1, -2)
955-
956-
attn_weights = torch.matmul(query_states, key_states)
957-
# attn_weights *= self.scale_factor / math.sqrt(head_dim)
958-
# attn_mask = attn_params.get_attn_mask(hidden_states.device)
959-
# if self.scale_factor != 1: attn_weights *= self.unscale_factor
960-
attn_weights *= 1 / math.sqrt(head_dim)
961-
attn_mask = attn_params.get_attn_mask(hidden_states.device)
962-
if attn_mask is not None: attn_weights = attn_weights + attn_mask
963-
attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16)
964-
965-
value_states = self.repeat_kv(value_states, cfg.num_key_value_groups)
966-
attn_output = torch.matmul(attn_weights, value_states)
967-
968-
attn_output = attn_output.transpose(1, 2)
969-
attn_output = attn_output.reshape((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim))
970-
971-
# Flash Attention 2
972-
968+
attn_func = self._attn_torch
973969
elif use_flash_attn:
974-
975-
attn_output = flash_attn_func(
976-
query_states,
977-
key_states,
978-
value_states,
979-
# softmax_scale = None if self.scale_factor == 1 else self.scale_factor / math.sqrt(head_dim),
980-
causal = True
981-
)
982-
attn_output = attn_output.reshape((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim))
983-
984-
# Xformers attention
985-
970+
attn_func = self._attn_flash
986971
else:
972+
attn_func = self._attn_xformers
987973

988-
key_states = key_states.transpose(1, 2)
989-
value_states = value_states.transpose(1, 2)
990-
991-
key_states = self.repeat_kv(key_states, num_key_value_groups)
992-
value_states = self.repeat_kv(value_states, num_key_value_groups)
993-
key_states = key_states.transpose(1, 2)
994-
value_states = value_states.transpose(1, 2)
974+
# Attention
995975

996-
attn_output = xops.memory_efficient_attention(query_states, key_states, value_states, attn_bias = LowerTriangularFromBottomRightMask())
997-
attn_output = attn_output.reshape((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim))
976+
attn_output = attn_func(batch_size, q_len, query_states, key_states, value_states, attn_params, cfg)
998977

999978
# Update 8-bit/Q4 cache
1000979

0 commit comments

Comments
 (0)