19
19
# limitations under the License.
20
20
"""PyTorch Bamba model."""
21
21
22
- from typing import Optional , Tuple , Union
22
+ from functools import partial
23
+ from typing import Optional , Tuple , TypedDict , Union
23
24
24
25
import torch
25
26
import torch .utils .checkpoint
46
47
from ...modeling_attn_mask_utils import AttentionMaskConverter
47
48
from ...modeling_outputs import BaseModelOutputWithPast , CausalLMOutputWithPast
48
49
from ...modeling_utils import PreTrainedModel
49
- from ...utils import auto_docstring , can_return_tuple , logging
50
+ from ...processing_utils import Unpack
51
+ from ...utils import (
52
+ auto_docstring ,
53
+ can_return_tuple ,
54
+ logging ,
55
+ )
50
56
from ...utils .import_utils import is_causal_conv1d_available , is_flash_attn_2_available , is_mamba_2_ssm_available
51
57
from .configuration_bamba import BambaConfig
52
58
71
77
logger = logging .get_logger (__name__ )
72
78
73
79
80
+ class BambaFlashAttentionKwargs (TypedDict , total = False ):
81
+ """
82
+ Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
83
+ Use cases include padding-free training and fewer `torch.compile` graph breaks.
84
+
85
+ Attributes:
86
+ cu_seq_lens_q (`torch.LongTensor`)
87
+ Gets cumulative sequence length for query state.
88
+ cu_seq_lens_k (`torch.LongTensor`)
89
+ Gets cumulative sequence length for key state.
90
+ max_length_q (`int`):
91
+ Maximum sequence length for query state.
92
+ max_length_k (`int`):
93
+ Maximum sequence length for key state.
94
+ seq_idx (`torch.IntTensor):
95
+ Index of each packed sequence.
96
+ """
97
+
98
+ cu_seq_lens_q : torch .LongTensor
99
+ cu_seq_lens_k : torch .LongTensor
100
+ max_length_q : int
101
+ max_length_k : int
102
+ seq_idx : torch .IntTensor
103
+
104
+
74
105
# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
75
106
class HybridMambaAttentionDynamicCache (modeling_jamba .HybridMambaAttentionDynamicCache ):
76
107
"""
@@ -657,6 +688,7 @@ def forward(
657
688
cache_position : Optional [torch .LongTensor ] = None ,
658
689
attention_mask : Optional [torch .Tensor ] = None ,
659
690
seq_idx : Optional [torch .IntTensor ] = None ,
691
+ ** kwargs ,
660
692
):
661
693
if is_fast_path_available and "cuda" in self .in_proj .weight .device .type :
662
694
return self .cuda_kernels_forward (hidden_states , cache_params , cache_position , attention_mask , seq_idx )
@@ -708,12 +740,7 @@ def forward(
708
740
use_cache : Optional [bool ] = False ,
709
741
cache_position : Optional [torch .LongTensor ] = None ,
710
742
position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None , # necessary, but kept here for BC
711
- cu_seq_lens_q : Optional [torch .LongTensor ] = None ,
712
- cu_seq_lens_k : Optional [torch .LongTensor ] = None ,
713
- max_length_q : Optional [int ] = None ,
714
- max_length_k : Optional [int ] = None ,
715
- seq_idx : Optional [torch .IntTensor ] = None ,
716
- ** kwargs ,
743
+ ** kwargs : Unpack [BambaFlashAttentionKwargs ],
717
744
) -> Tuple [torch .FloatTensor , Optional [Tuple [torch .FloatTensor , torch .FloatTensor ]]]:
718
745
"""
719
746
Args:
@@ -732,20 +759,9 @@ def forward(
732
759
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
733
760
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
734
761
with `head_dim` being the embedding dimension of each attention head.
735
- cu_seq_lens_q (`torch.LongTensor`, *optional*):
736
- Cumulative query sequence lengths. For padding-free training with flash attention.
737
- cu_seq_lens_k (`torch.LongTensor`, *optional*):
738
- Cumulative key sequence lengths. For padding-free training with flash attention.
739
- max_length_q (`int`, *optional*):
740
- Maximum query length. For padding-free training with flash attention.
741
- max_length_k (`int`, *optional*):
742
- Maximum key length. For padding-free training with flash attention.
743
- seq_idx (`torch.IntTensor`, *optional*):
744
- Sequence index of each packed example. For padding-free training with fast mamba_ssm and causal_conv1d
745
- kernels.
746
762
kwargs (`dict`, *optional*):
747
- Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
748
- into the model
763
+ Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
764
+ padding-free training and/or improve torch.compile performance.
749
765
"""
750
766
751
767
residual = hidden_states
@@ -759,7 +775,7 @@ def forward(
759
775
cache_params = past_key_value ,
760
776
cache_position = cache_position ,
761
777
attention_mask = attention_mask ,
762
- seq_idx = seq_idx ,
778
+ ** kwargs ,
763
779
)
764
780
self_attn_weights = None
765
781
elif self .layer_type == "attention" :
@@ -772,10 +788,6 @@ def forward(
772
788
use_cache = use_cache ,
773
789
cache_position = cache_position ,
774
790
position_embeddings = position_embeddings ,
775
- cu_seq_lens_q = cu_seq_lens_q ,
776
- cu_seq_lens_k = cu_seq_lens_k ,
777
- max_length_q = max_length_q ,
778
- max_length_k = max_length_k ,
779
791
** kwargs ,
780
792
)
781
793
@@ -866,12 +878,7 @@ def forward(
866
878
output_attentions : Optional [bool ] = None ,
867
879
output_hidden_states : Optional [bool ] = None ,
868
880
cache_position : Optional [torch .LongTensor ] = None ,
869
- cu_seq_lens_q : Optional [torch .LongTensor ] = None ,
870
- cu_seq_lens_k : Optional [torch .LongTensor ] = None ,
871
- max_length_q : Optional [int ] = None ,
872
- max_length_k : Optional [int ] = None ,
873
- seq_idx : Optional [torch .IntTensor ] = None ,
874
- ** kwargs , # NOOP kwargs, for now
881
+ ** kwargs : Unpack [BambaFlashAttentionKwargs ],
875
882
) -> BaseModelOutputWithPast :
876
883
output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
877
884
output_hidden_states = (
@@ -882,22 +889,6 @@ def forward(
882
889
if (input_ids is None ) ^ (inputs_embeds is not None ):
883
890
raise ValueError ("You must specify exactly one of input_ids or inputs_embeds" )
884
891
885
- padding_free_kwargs = (cu_seq_lens_q , cu_seq_lens_k , max_length_q , max_length_k , seq_idx )
886
- num_padding_free_kwargs_used = sum (k is not None for k in padding_free_kwargs )
887
- if num_padding_free_kwargs_used :
888
- if num_padding_free_kwargs_used != len (padding_free_kwargs ):
889
- raise ValueError (
890
- "All of (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k, seq_idx) must be specified for padding-free training."
891
- )
892
- if position_ids is None :
893
- raise ValueError (
894
- "position_ids must also be specified when (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k, seq_idx) are provided for padding-free training."
895
- )
896
- if attention_mask is not None :
897
- raise ValueError (
898
- "attention_mask must be None when providing (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k, seq_idx, position_ids) for padding-free training."
899
- )
900
-
901
892
if self .gradient_checkpointing and self .training and use_cache :
902
893
logger .warning_once (
903
894
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
@@ -939,7 +930,7 @@ def forward(
939
930
940
931
if self .gradient_checkpointing and self .training :
941
932
layer_outputs = self ._gradient_checkpointing_func (
942
- decoder_layer .__call__ ,
933
+ partial ( decoder_layer .__call__ , ** kwargs ) ,
943
934
hidden_states ,
944
935
layer_mask ,
945
936
position_ids ,
@@ -948,11 +939,6 @@ def forward(
948
939
use_cache ,
949
940
cache_position ,
950
941
position_embeddings ,
951
- cu_seq_lens_q ,
952
- cu_seq_lens_k ,
953
- max_length_q ,
954
- max_length_k ,
955
- seq_idx ,
956
942
)
957
943
else :
958
944
layer_outputs = decoder_layer (
@@ -964,11 +950,7 @@ def forward(
964
950
use_cache = use_cache ,
965
951
cache_position = cache_position ,
966
952
position_embeddings = position_embeddings ,
967
- cu_seq_lens_q = cu_seq_lens_q ,
968
- cu_seq_lens_k = cu_seq_lens_k ,
969
- max_length_q = max_length_q ,
970
- max_length_k = max_length_k ,
971
- seq_idx = seq_idx ,
953
+ ** kwargs ,
972
954
)
973
955
974
956
hidden_states = layer_outputs [0 ]
0 commit comments