19
19
# limitations under the License.
20
20
"""PyTorch Bamba model."""
21
21
22
- from typing import Optional , Tuple , Union
22
+ from functools import partial
23
+ from typing import Optional , Tuple , TypedDict , Union
23
24
24
25
import torch
25
26
import torch .utils .checkpoint
46
47
from ...modeling_attn_mask_utils import AttentionMaskConverter
47
48
from ...modeling_outputs import BaseModelOutputWithPast , CausalLMOutputWithPast
48
49
from ...modeling_utils import PreTrainedModel
49
- from ...utils import auto_docstring , can_return_tuple , logging
50
+ from ...processing_utils import Unpack
51
+ from ...utils import (
52
+ auto_docstring ,
53
+ can_return_tuple ,
54
+ logging ,
55
+ )
50
56
from ...utils .import_utils import is_causal_conv1d_available , is_flash_attn_2_available , is_mamba_2_ssm_available
51
57
from .configuration_bamba import BambaConfig
52
58
71
77
logger = logging .get_logger (__name__ )
72
78
73
79
80
+ class BambaFlashAttentionKwargs (TypedDict , total = False ):
81
+ """
82
+ Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
83
+ Use cases include padding-free training and fewer `torch.compile` graph breaks.
84
+
85
+ Attributes:
86
+ cu_seq_lens_q (`torch.LongTensor`)
87
+ Gets cumulative sequence length for query state.
88
+ cu_seq_lens_k (`torch.LongTensor`)
89
+ Gets cumulative sequence length for key state.
90
+ max_length_q (`int`):
91
+ Maximum sequence length for query state.
92
+ max_length_k (`int`):
93
+ Maximum sequence length for key state.
94
+ seq_idx (`torch.IntTensor):
95
+ Index of each packed sequence.
96
+ """
97
+
98
+ cu_seq_lens_q : torch .LongTensor
99
+ cu_seq_lens_k : torch .LongTensor
100
+ max_length_q : int
101
+ max_length_k : int
102
+ seq_idx : torch .IntTensor
103
+
104
+
74
105
# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
75
106
class HybridMambaAttentionDynamicCache (modeling_jamba .HybridMambaAttentionDynamicCache ):
76
107
"""
@@ -278,6 +309,7 @@ def cuda_kernels_forward(
278
309
cache_params : Optional [HybridMambaAttentionDynamicCache ] = None ,
279
310
cache_position : Optional [torch .LongTensor ] = None ,
280
311
attention_mask : Optional [torch .Tensor ] = None ,
312
+ seq_idx : Optional [torch .IntTensor ] = None ,
281
313
):
282
314
# 1. Gated MLP's linear projection
283
315
hidden_states = apply_mask_to_padding_states (hidden_states , attention_mask )
@@ -360,7 +392,7 @@ def cuda_kernels_forward(
360
392
A ,
361
393
D = self .D ,
362
394
chunk_size = self .chunk_size ,
363
- seq_idx = None , # was seq_idx
395
+ seq_idx = seq_idx ,
364
396
activation = self .activation ,
365
397
rmsnorm_weight = self .norm .weight ,
366
398
rmsnorm_eps = self .norm .variance_epsilon ,
@@ -401,6 +433,7 @@ def cuda_kernels_forward(
401
433
weight = self .conv1d .weight .squeeze (1 ),
402
434
bias = self .conv1d .bias ,
403
435
activation = self .activation ,
436
+ seq_idx = seq_idx ,
404
437
).transpose (1 , 2 )
405
438
406
439
hidden_states_B_C = apply_mask_to_padding_states (hidden_states_B_C , attention_mask )
@@ -420,7 +453,7 @@ def cuda_kernels_forward(
420
453
chunk_size = self .chunk_size ,
421
454
D = self .D ,
422
455
z = None ,
423
- seq_idx = None ,
456
+ seq_idx = seq_idx ,
424
457
return_final_states = True ,
425
458
dt_bias = self .dt_bias ,
426
459
dt_softplus = True ,
@@ -654,9 +687,15 @@ def forward(
654
687
cache_params : Optional [HybridMambaAttentionDynamicCache ] = None ,
655
688
cache_position : Optional [torch .LongTensor ] = None ,
656
689
attention_mask : Optional [torch .Tensor ] = None ,
690
+ seq_idx : Optional [torch .IntTensor ] = None ,
691
+ ** kwargs ,
657
692
):
658
693
if is_fast_path_available and "cuda" in self .in_proj .weight .device .type :
659
- return self .cuda_kernels_forward (hidden_states , cache_params , cache_position , attention_mask )
694
+ return self .cuda_kernels_forward (hidden_states , cache_params , cache_position , attention_mask , seq_idx )
695
+ if seq_idx is not None :
696
+ raise NotImplementedError (
697
+ "`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
698
+ )
660
699
dtype = hidden_states .dtype
661
700
if attention_mask is not None and attention_mask .shape [1 ] > 1 and attention_mask .shape [0 ] > 1 :
662
701
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
@@ -701,7 +740,7 @@ def forward(
701
740
use_cache : Optional [bool ] = False ,
702
741
cache_position : Optional [torch .LongTensor ] = None ,
703
742
position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None , # necessary, but kept here for BC
704
- ** kwargs ,
743
+ ** kwargs : Unpack [ BambaFlashAttentionKwargs ] ,
705
744
) -> Tuple [torch .FloatTensor , Optional [Tuple [torch .FloatTensor , torch .FloatTensor ]]]:
706
745
"""
707
746
Args:
@@ -721,8 +760,8 @@ def forward(
721
760
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
722
761
with `head_dim` being the embedding dimension of each attention head.
723
762
kwargs (`dict`, *optional*):
724
- Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
725
- into the model
763
+ Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
764
+ padding-free training and/or improve torch.compile performance.
726
765
"""
727
766
728
767
residual = hidden_states
@@ -736,6 +775,7 @@ def forward(
736
775
cache_params = past_key_value ,
737
776
cache_position = cache_position ,
738
777
attention_mask = attention_mask ,
778
+ ** kwargs ,
739
779
)
740
780
self_attn_weights = None
741
781
elif self .layer_type == "attention" :
@@ -838,7 +878,7 @@ def forward(
838
878
output_attentions : Optional [bool ] = None ,
839
879
output_hidden_states : Optional [bool ] = None ,
840
880
cache_position : Optional [torch .LongTensor ] = None ,
841
- ** kwargs , # NOOP kwargs, for now
881
+ ** kwargs : Unpack [ BambaFlashAttentionKwargs ],
842
882
) -> BaseModelOutputWithPast :
843
883
output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
844
884
output_hidden_states = (
@@ -890,7 +930,7 @@ def forward(
890
930
891
931
if self .gradient_checkpointing and self .training :
892
932
layer_outputs = self ._gradient_checkpointing_func (
893
- decoder_layer .__call__ ,
933
+ partial ( decoder_layer .__call__ , ** kwargs ) ,
894
934
hidden_states ,
895
935
layer_mask ,
896
936
position_ids ,
@@ -910,6 +950,7 @@ def forward(
910
950
use_cache = use_cache ,
911
951
cache_position = cache_position ,
912
952
position_embeddings = position_embeddings ,
953
+ ** kwargs ,
913
954
)
914
955
915
956
hidden_states = layer_outputs [0 ]
0 commit comments