|
14 | 14 | from exllamav2.architecture import RopeStyle
|
15 | 15 | import math
|
16 | 16 | # from exllamav2.util import list_live_tensors, set_snapshot, diff_snapshot, print_vram_usage_peak
|
17 |
| -# import torch.nn.functional as F |
| 17 | +import torch.nn.functional as F |
18 | 18 |
|
19 | 19 | from typing import TYPE_CHECKING
|
20 | 20 | if TYPE_CHECKING:
|
21 | 21 | from exllamav2.model import ExLlamaV2
|
22 | 22 |
|
23 |
| -# Detect flash-attn |
| 23 | +# Detect available options for attention |
24 | 24 |
|
25 | 25 | has_flash_attn = False
|
26 | 26 | has_flash_attn_with_paged = False
|
|
54 | 54 | except ModuleNotFoundError:
|
55 | 55 | pass
|
56 | 56 |
|
| 57 | +has_lower_right_sdpa = False |
| 58 | +try: |
| 59 | + from torch.nn.attention.bias import causal_lower_right |
| 60 | + has_lower_right_sdpa = True |
| 61 | +except ImportError: |
| 62 | + pass |
57 | 63 |
|
58 | 64 | def assert_paged_attn():
|
59 | 65 | global has_flash_attn_with_paged
|
@@ -666,24 +672,38 @@ def forward_paged(self,
|
666 | 672 | return hidden_states
|
667 | 673 |
|
668 | 674 |
|
669 |
| - def _attn_matmul(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg): |
| 675 | + def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg): |
670 | 676 |
|
671 |
| - q_states = q_states.transpose(1, 2) |
672 |
| - k_states = k_states.transpose(1, 2) |
673 |
| - v_states = v_states.transpose(1, 2) |
| 677 | + if has_lower_right_sdpa and attn_params.is_causal(): |
674 | 678 |
|
675 |
| - k_states = self.repeat_kv(k_states, cfg.num_key_value_groups) |
676 |
| - k_states = k_states.transpose(-1, -2) |
| 679 | + q_states = q_states.transpose(1, 2) |
| 680 | + k_states = k_states.transpose(1, 2) |
| 681 | + v_states = v_states.transpose(1, 2) |
677 | 682 |
|
678 |
| - attn_weights = torch.matmul(q_states, k_states) |
| 683 | + k_states = self.repeat_kv(k_states, cfg.num_key_value_groups) |
| 684 | + v_states = self.repeat_kv(v_states, cfg.num_key_value_groups) |
679 | 685 |
|
680 |
| - attn_weights *= 1 / math.sqrt(cfg.head_dim) |
681 |
| - attn_mask = attn_params.get_attn_mask(attn_weights.device) |
682 |
| - if attn_mask is not None: attn_weights = attn_weights + attn_mask |
683 |
| - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float16) |
| 686 | + attn_mask_lr = causal_lower_right(q_len, k_states.shape[2]) |
| 687 | + attn_output = F.scaled_dot_product_attention(q_states, k_states, v_states, attn_mask_lr) |
684 | 688 |
|
685 |
| - v_states = self.repeat_kv(v_states, cfg.num_key_value_groups) |
686 |
| - attn_output = torch.matmul(attn_weights, v_states) |
| 689 | + else: |
| 690 | + |
| 691 | + q_states = q_states.transpose(1, 2) |
| 692 | + k_states = k_states.transpose(1, 2) |
| 693 | + v_states = v_states.transpose(1, 2) |
| 694 | + |
| 695 | + k_states = self.repeat_kv(k_states, cfg.num_key_value_groups) |
| 696 | + k_states = k_states.transpose(-1, -2) |
| 697 | + |
| 698 | + attn_weights = torch.matmul(q_states, k_states) |
| 699 | + |
| 700 | + attn_weights *= 1 / math.sqrt(cfg.head_dim) |
| 701 | + attn_mask = attn_params.get_attn_mask(attn_weights.device) |
| 702 | + if attn_mask is not None: attn_weights = attn_weights + attn_mask |
| 703 | + attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16) |
| 704 | + |
| 705 | + v_states = self.repeat_kv(v_states, cfg.num_key_value_groups) |
| 706 | + attn_output = torch.matmul(attn_weights, v_states) |
687 | 707 |
|
688 | 708 | attn_output = attn_output.transpose(1, 2)
|
689 | 709 | attn_output = attn_output.reshape((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim))
|
@@ -821,7 +841,7 @@ def forward(self,
|
821 | 841 | elif (has_xformers and not cfg.no_xformers) and attn_params.is_causal():
|
822 | 842 | attn_func = self._attn_xformers
|
823 | 843 | else:
|
824 |
| - attn_func = self._attn_matmul |
| 844 | + attn_func = self._attn_torch |
825 | 845 |
|
826 | 846 | # Straight attention without cache
|
827 | 847 |
|
@@ -942,59 +962,18 @@ def forward_torch(self,
|
942 | 962 | use_flash_attn = has_flash_attn and not cfg.no_flash_attn
|
943 | 963 | use_xformers = has_xformers and not cfg.no_xformers
|
944 | 964 |
|
945 |
| - # Torch matmul attention |
| 965 | + # Select attention function |
946 | 966 |
|
947 | 967 | if not (use_flash_attn or use_xformers) or not attn_params.is_causal():
|
948 |
| - |
949 |
| - query_states = query_states.transpose(1, 2) |
950 |
| - key_states = key_states.transpose(1, 2) |
951 |
| - value_states = value_states.transpose(1, 2) |
952 |
| - |
953 |
| - key_states = self.repeat_kv(key_states, cfg.num_key_value_groups) |
954 |
| - key_states = key_states.transpose(-1, -2) |
955 |
| - |
956 |
| - attn_weights = torch.matmul(query_states, key_states) |
957 |
| - # attn_weights *= self.scale_factor / math.sqrt(head_dim) |
958 |
| - # attn_mask = attn_params.get_attn_mask(hidden_states.device) |
959 |
| - # if self.scale_factor != 1: attn_weights *= self.unscale_factor |
960 |
| - attn_weights *= 1 / math.sqrt(head_dim) |
961 |
| - attn_mask = attn_params.get_attn_mask(hidden_states.device) |
962 |
| - if attn_mask is not None: attn_weights = attn_weights + attn_mask |
963 |
| - attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16) |
964 |
| - |
965 |
| - value_states = self.repeat_kv(value_states, cfg.num_key_value_groups) |
966 |
| - attn_output = torch.matmul(attn_weights, value_states) |
967 |
| - |
968 |
| - attn_output = attn_output.transpose(1, 2) |
969 |
| - attn_output = attn_output.reshape((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim)) |
970 |
| - |
971 |
| - # Flash Attention 2 |
972 |
| - |
| 968 | + attn_func = self._attn_torch |
973 | 969 | elif use_flash_attn:
|
974 |
| - |
975 |
| - attn_output = flash_attn_func( |
976 |
| - query_states, |
977 |
| - key_states, |
978 |
| - value_states, |
979 |
| - # softmax_scale = None if self.scale_factor == 1 else self.scale_factor / math.sqrt(head_dim), |
980 |
| - causal = True |
981 |
| - ) |
982 |
| - attn_output = attn_output.reshape((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim)) |
983 |
| - |
984 |
| - # Xformers attention |
985 |
| - |
| 970 | + attn_func = self._attn_flash |
986 | 971 | else:
|
| 972 | + attn_func = self._attn_xformers |
987 | 973 |
|
988 |
| - key_states = key_states.transpose(1, 2) |
989 |
| - value_states = value_states.transpose(1, 2) |
990 |
| - |
991 |
| - key_states = self.repeat_kv(key_states, num_key_value_groups) |
992 |
| - value_states = self.repeat_kv(value_states, num_key_value_groups) |
993 |
| - key_states = key_states.transpose(1, 2) |
994 |
| - value_states = value_states.transpose(1, 2) |
| 974 | + # Attention |
995 | 975 |
|
996 |
| - attn_output = xops.memory_efficient_attention(query_states, key_states, value_states, attn_bias = LowerTriangularFromBottomRightMask()) |
997 |
| - attn_output = attn_output.reshape((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim)) |
| 976 | + attn_output = attn_func(batch_size, q_len, query_states, key_states, value_states, attn_params, cfg) |
998 | 977 |
|
999 | 978 | # Update 8-bit/Q4 cache
|
1000 | 979 |
|
|
0 commit comments