Skip to content

Commit b59e5c9

Browse files
committed
BambaFlashAttentionKwargs
1 parent a47588a commit b59e5c9

File tree

2 files changed

+76
-117
lines changed

2 files changed

+76
-117
lines changed

src/transformers/models/bamba/modeling_bamba.py

Lines changed: 35 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
# See the License for the specific language governing permissions and
2525
# limitations under the License.
2626

27-
from typing import Callable, Optional, Tuple, Union
27+
from functools import partial
28+
from typing import Callable, Optional, Tuple, TypedDict, Union
2829

2930
import torch
3031
from torch import nn
@@ -61,6 +62,31 @@
6162
logger = logging.get_logger(__name__)
6263

6364

65+
class BambaFlashAttentionKwargs(TypedDict, total=False):
66+
"""
67+
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
68+
Use cases include padding-free training and fewer `torch.compile` graph breaks.
69+
70+
Attributes:
71+
cu_seq_lens_q (`torch.LongTensor`)
72+
Gets cumulative sequence length for query state.
73+
cu_seq_lens_k (`torch.LongTensor`)
74+
Gets cumulative sequence length for key state.
75+
max_length_q (`int`):
76+
Maximum sequence length for query state.
77+
max_length_k (`int`):
78+
Maximum sequence length for key state.
79+
seq_idx (`torch.IntTensor):
80+
Index of each packed sequence.
81+
"""
82+
83+
cu_seq_lens_q: torch.LongTensor
84+
cu_seq_lens_k: torch.LongTensor
85+
max_length_q: int
86+
max_length_k: int
87+
seq_idx: torch.IntTensor
88+
89+
6490
# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
6591
class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache):
6692
"""
@@ -866,6 +892,7 @@ def forward(
866892
cache_position: Optional[torch.LongTensor] = None,
867893
attention_mask: Optional[torch.Tensor] = None,
868894
seq_idx: Optional[torch.IntTensor] = None,
895+
**kwargs,
869896
):
870897
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
871898
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
@@ -946,12 +973,7 @@ def forward(
946973
use_cache: Optional[bool] = False,
947974
cache_position: Optional[torch.LongTensor] = None,
948975
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
949-
cu_seq_lens_q: Optional[torch.LongTensor] = None,
950-
cu_seq_lens_k: Optional[torch.LongTensor] = None,
951-
max_length_q: Optional[int] = None,
952-
max_length_k: Optional[int] = None,
953-
seq_idx: Optional[torch.IntTensor] = None,
954-
**kwargs,
976+
**kwargs: Unpack[BambaFlashAttentionKwargs],
955977
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
956978
"""
957979
Args:
@@ -970,20 +992,9 @@ def forward(
970992
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
971993
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
972994
with `head_dim` being the embedding dimension of each attention head.
973-
cu_seq_lens_q (`torch.LongTensor`, *optional*):
974-
Cumulative query sequence lengths. For padding-free training with flash attention.
975-
cu_seq_lens_k (`torch.LongTensor`, *optional*):
976-
Cumulative key sequence lengths. For padding-free training with flash attention.
977-
max_length_q (`int`, *optional*):
978-
Maximum query length. For padding-free training with flash attention.
979-
max_length_k (`int`, *optional*):
980-
Maximum key length. For padding-free training with flash attention.
981-
seq_idx (`torch.IntTensor`, *optional*):
982-
Sequence index of each packed example. For padding-free training with fast mamba_ssm and causal_conv1d
983-
kernels.
984995
kwargs (`dict`, *optional*):
985-
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
986-
into the model
996+
Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
997+
padding-free training and/or improve torch.compile performance.
987998
"""
988999

9891000
residual = hidden_states
@@ -997,7 +1008,7 @@ def forward(
9971008
cache_params=past_key_value,
9981009
cache_position=cache_position,
9991010
attention_mask=attention_mask,
1000-
seq_idx=seq_idx,
1011+
**kwargs,
10011012
)
10021013
self_attn_weights = None
10031014
elif self.layer_type == "attention":
@@ -1010,10 +1021,6 @@ def forward(
10101021
use_cache=use_cache,
10111022
cache_position=cache_position,
10121023
position_embeddings=position_embeddings,
1013-
cu_seq_lens_q=cu_seq_lens_q,
1014-
cu_seq_lens_k=cu_seq_lens_k,
1015-
max_length_q=max_length_q,
1016-
max_length_k=max_length_k,
10171024
**kwargs,
10181025
)
10191026

@@ -1104,12 +1111,7 @@ def forward(
11041111
output_attentions: Optional[bool] = None,
11051112
output_hidden_states: Optional[bool] = None,
11061113
cache_position: Optional[torch.LongTensor] = None,
1107-
cu_seq_lens_q: Optional[torch.LongTensor] = None,
1108-
cu_seq_lens_k: Optional[torch.LongTensor] = None,
1109-
max_length_q: Optional[int] = None,
1110-
max_length_k: Optional[int] = None,
1111-
seq_idx: Optional[torch.IntTensor] = None,
1112-
**kwargs, # NOOP kwargs, for now
1114+
**kwargs: Unpack[BambaFlashAttentionKwargs],
11131115
) -> BaseModelOutputWithPast:
11141116
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
11151117
output_hidden_states = (
@@ -1120,22 +1122,6 @@ def forward(
11201122
if (input_ids is None) ^ (inputs_embeds is not None):
11211123
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
11221124

1123-
padding_free_kwargs = (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k, seq_idx)
1124-
num_padding_free_kwargs_used = sum(k is not None for k in padding_free_kwargs)
1125-
if num_padding_free_kwargs_used:
1126-
if num_padding_free_kwargs_used != len(padding_free_kwargs):
1127-
raise ValueError(
1128-
"All of (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k, seq_idx) must be specified for padding-free training."
1129-
)
1130-
if position_ids is None:
1131-
raise ValueError(
1132-
"position_ids must also be specified when (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k, seq_idx) are provided for padding-free training."
1133-
)
1134-
if attention_mask is not None:
1135-
raise ValueError(
1136-
"attention_mask must be None when providing (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k, seq_idx, position_ids) for padding-free training."
1137-
)
1138-
11391125
if self.gradient_checkpointing and self.training and use_cache:
11401126
logger.warning_once(
11411127
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
@@ -1177,7 +1163,7 @@ def forward(
11771163

11781164
if self.gradient_checkpointing and self.training:
11791165
layer_outputs = self._gradient_checkpointing_func(
1180-
decoder_layer.__call__,
1166+
partial(decoder_layer.__call__, **kwargs),
11811167
hidden_states,
11821168
layer_mask,
11831169
position_ids,
@@ -1186,11 +1172,6 @@ def forward(
11861172
use_cache,
11871173
cache_position,
11881174
position_embeddings,
1189-
cu_seq_lens_q,
1190-
cu_seq_lens_k,
1191-
max_length_q,
1192-
max_length_k,
1193-
seq_idx,
11941175
)
11951176
else:
11961177
layer_outputs = decoder_layer(
@@ -1202,11 +1183,7 @@ def forward(
12021183
use_cache=use_cache,
12031184
cache_position=cache_position,
12041185
position_embeddings=position_embeddings,
1205-
cu_seq_lens_q=cu_seq_lens_q,
1206-
cu_seq_lens_k=cu_seq_lens_k,
1207-
max_length_q=max_length_q,
1208-
max_length_k=max_length_k,
1209-
seq_idx=seq_idx,
1186+
**kwargs,
12101187
)
12111188

12121189
hidden_states = layer_outputs[0]

src/transformers/models/bamba/modular_bamba.py

Lines changed: 41 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
# limitations under the License.
2020
"""PyTorch Bamba model."""
2121

22-
from typing import Optional, Tuple, Union
22+
from functools import partial
23+
from typing import Optional, Tuple, TypedDict, Union
2324

2425
import torch
2526
import torch.utils.checkpoint
@@ -46,7 +47,12 @@
4647
from ...modeling_attn_mask_utils import AttentionMaskConverter
4748
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
4849
from ...modeling_utils import PreTrainedModel
49-
from ...utils import auto_docstring, can_return_tuple, logging
50+
from ...processing_utils import Unpack
51+
from ...utils import (
52+
auto_docstring,
53+
can_return_tuple,
54+
logging,
55+
)
5056
from ...utils.import_utils import is_causal_conv1d_available, is_flash_attn_2_available, is_mamba_2_ssm_available
5157
from .configuration_bamba import BambaConfig
5258

@@ -71,6 +77,31 @@
7177
logger = logging.get_logger(__name__)
7278

7379

80+
class BambaFlashAttentionKwargs(TypedDict, total=False):
81+
"""
82+
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
83+
Use cases include padding-free training and fewer `torch.compile` graph breaks.
84+
85+
Attributes:
86+
cu_seq_lens_q (`torch.LongTensor`)
87+
Gets cumulative sequence length for query state.
88+
cu_seq_lens_k (`torch.LongTensor`)
89+
Gets cumulative sequence length for key state.
90+
max_length_q (`int`):
91+
Maximum sequence length for query state.
92+
max_length_k (`int`):
93+
Maximum sequence length for key state.
94+
seq_idx (`torch.IntTensor):
95+
Index of each packed sequence.
96+
"""
97+
98+
cu_seq_lens_q: torch.LongTensor
99+
cu_seq_lens_k: torch.LongTensor
100+
max_length_q: int
101+
max_length_k: int
102+
seq_idx: torch.IntTensor
103+
104+
74105
# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
75106
class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache):
76107
"""
@@ -657,6 +688,7 @@ def forward(
657688
cache_position: Optional[torch.LongTensor] = None,
658689
attention_mask: Optional[torch.Tensor] = None,
659690
seq_idx: Optional[torch.IntTensor] = None,
691+
**kwargs,
660692
):
661693
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
662694
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
@@ -708,12 +740,7 @@ def forward(
708740
use_cache: Optional[bool] = False,
709741
cache_position: Optional[torch.LongTensor] = None,
710742
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
711-
cu_seq_lens_q: Optional[torch.LongTensor] = None,
712-
cu_seq_lens_k: Optional[torch.LongTensor] = None,
713-
max_length_q: Optional[int] = None,
714-
max_length_k: Optional[int] = None,
715-
seq_idx: Optional[torch.IntTensor] = None,
716-
**kwargs,
743+
**kwargs: Unpack[BambaFlashAttentionKwargs],
717744
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
718745
"""
719746
Args:
@@ -732,20 +759,9 @@ def forward(
732759
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
733760
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
734761
with `head_dim` being the embedding dimension of each attention head.
735-
cu_seq_lens_q (`torch.LongTensor`, *optional*):
736-
Cumulative query sequence lengths. For padding-free training with flash attention.
737-
cu_seq_lens_k (`torch.LongTensor`, *optional*):
738-
Cumulative key sequence lengths. For padding-free training with flash attention.
739-
max_length_q (`int`, *optional*):
740-
Maximum query length. For padding-free training with flash attention.
741-
max_length_k (`int`, *optional*):
742-
Maximum key length. For padding-free training with flash attention.
743-
seq_idx (`torch.IntTensor`, *optional*):
744-
Sequence index of each packed example. For padding-free training with fast mamba_ssm and causal_conv1d
745-
kernels.
746762
kwargs (`dict`, *optional*):
747-
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
748-
into the model
763+
Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
764+
padding-free training and/or improve torch.compile performance.
749765
"""
750766

751767
residual = hidden_states
@@ -759,7 +775,7 @@ def forward(
759775
cache_params=past_key_value,
760776
cache_position=cache_position,
761777
attention_mask=attention_mask,
762-
seq_idx=seq_idx,
778+
**kwargs,
763779
)
764780
self_attn_weights = None
765781
elif self.layer_type == "attention":
@@ -772,10 +788,6 @@ def forward(
772788
use_cache=use_cache,
773789
cache_position=cache_position,
774790
position_embeddings=position_embeddings,
775-
cu_seq_lens_q=cu_seq_lens_q,
776-
cu_seq_lens_k=cu_seq_lens_k,
777-
max_length_q=max_length_q,
778-
max_length_k=max_length_k,
779791
**kwargs,
780792
)
781793

@@ -866,12 +878,7 @@ def forward(
866878
output_attentions: Optional[bool] = None,
867879
output_hidden_states: Optional[bool] = None,
868880
cache_position: Optional[torch.LongTensor] = None,
869-
cu_seq_lens_q: Optional[torch.LongTensor] = None,
870-
cu_seq_lens_k: Optional[torch.LongTensor] = None,
871-
max_length_q: Optional[int] = None,
872-
max_length_k: Optional[int] = None,
873-
seq_idx: Optional[torch.IntTensor] = None,
874-
**kwargs, # NOOP kwargs, for now
881+
**kwargs: Unpack[BambaFlashAttentionKwargs],
875882
) -> BaseModelOutputWithPast:
876883
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
877884
output_hidden_states = (
@@ -882,22 +889,6 @@ def forward(
882889
if (input_ids is None) ^ (inputs_embeds is not None):
883890
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
884891

885-
padding_free_kwargs = (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k, seq_idx)
886-
num_padding_free_kwargs_used = sum(k is not None for k in padding_free_kwargs)
887-
if num_padding_free_kwargs_used:
888-
if num_padding_free_kwargs_used != len(padding_free_kwargs):
889-
raise ValueError(
890-
"All of (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k, seq_idx) must be specified for padding-free training."
891-
)
892-
if position_ids is None:
893-
raise ValueError(
894-
"position_ids must also be specified when (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k, seq_idx) are provided for padding-free training."
895-
)
896-
if attention_mask is not None:
897-
raise ValueError(
898-
"attention_mask must be None when providing (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k, seq_idx, position_ids) for padding-free training."
899-
)
900-
901892
if self.gradient_checkpointing and self.training and use_cache:
902893
logger.warning_once(
903894
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
@@ -939,7 +930,7 @@ def forward(
939930

940931
if self.gradient_checkpointing and self.training:
941932
layer_outputs = self._gradient_checkpointing_func(
942-
decoder_layer.__call__,
933+
partial(decoder_layer.__call__, **kwargs),
943934
hidden_states,
944935
layer_mask,
945936
position_ids,
@@ -948,11 +939,6 @@ def forward(
948939
use_cache,
949940
cache_position,
950941
position_embeddings,
951-
cu_seq_lens_q,
952-
cu_seq_lens_k,
953-
max_length_q,
954-
max_length_k,
955-
seq_idx,
956942
)
957943
else:
958944
layer_outputs = decoder_layer(
@@ -964,11 +950,7 @@ def forward(
964950
use_cache=use_cache,
965951
cache_position=cache_position,
966952
position_embeddings=position_embeddings,
967-
cu_seq_lens_q=cu_seq_lens_q,
968-
cu_seq_lens_k=cu_seq_lens_k,
969-
max_length_q=max_length_q,
970-
max_length_k=max_length_k,
971-
seq_idx=seq_idx,
953+
**kwargs,
972954
)
973955

974956
hidden_states = layer_outputs[0]

0 commit comments

Comments
 (0)