Skip to content

Commit 390f153

Browse files
authored
Add padding-free to bamba (#35861)
* add seq_idx and fa kwargs * update tests * docs and grad ckpt support * fmt * better names * test_raise_missing_padding_free_kwarg_errs * + seq_idx in doc strings * padding free training docs * add link to pr plots * raise err on attn_mask with padding free * rm raising missing padding free err test * BambaFlashAttentionKwargs * run modular util for modular_granitemoehybrid.py
1 parent 2a79471 commit 390f153

File tree

5 files changed

+233
-25
lines changed

5 files changed

+233
-25
lines changed

docs/source/en/model_doc/bamba.md

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Checkout all Bamba-9B model checkpoints [here](https://github.com/foundation-mod
3939
<!---
4040
## Usage Tips
4141
42-
Tips:
42+
Tips:
4343
4444
- The architecture is based on Mamba-2 models.
4545
@@ -63,7 +63,35 @@ response = model.generate(**inputs, max_new_tokens=64)
6363
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
6464
```
6565

66+
67+
## Padding-Free Training
68+
69+
Bamba supports padding-free training in which distinct training examples can be concatenated
70+
together while nevertheless processing the inputs as though they belonged to separate batches. When
71+
the examples are of varying lengths, padding-free training can provide significant speed ups and
72+
memory savings compared to batching the examples together and using padding, as the unnecessary
73+
compute and memory due to padding is avoided entirely. The performance gains depend on factors such
74+
as the model and the data distribution, but throughput gains up to [~2x are commonly
75+
seen](https://github.com/huggingface/transformers/pull/35861#issue-2807873129).
76+
77+
Using padding-free training with Bamba requires the `flash-attn`, `mamba-ssm`, and `causal-conv1d`
78+
packages, and the following arguments must be passed to the model in addition to `input_ids` and
79+
`labels`:
80+
* `position_ids: torch.LongTensor`: the position index of each token in each sequence.
81+
* `seq_idx: torch.IntTensor`: the index of each sequence in the batch.
82+
* Each of the [`FlashAttentionKwargs`]
83+
* `cu_seq_lens_q: torch.LongTensor`: The cumulative sequence lengths of all queries.
84+
* `cu_seq_lens_k: torch.LongTensor`: The cumulative sequence lengths of all keys.
85+
* `max_length_q: int`: the longest query length in the batch.
86+
* `max_length_k: int`: the longest key length in the batch.
87+
88+
The `attention_mask` inputs should not be provided. The [`DataCollatorWithFlattening`] can be used
89+
to programmatically generate the above set of additional arguments using `return_seq_idx=True` and
90+
`return_flash_attn_kwargs=True`. See [this blog post](https://huggingface.co/blog/packing-with-FA2)
91+
for additional information.
92+
93+
6694
[[autodoc]] BambaForCausalLM
6795
- forward
6896

69-
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).
97+
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).

src/transformers/models/bamba/modeling_bamba.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
# See the License for the specific language governing permissions and
2525
# limitations under the License.
2626

27-
from typing import Callable, Optional, Tuple, Union
27+
from functools import partial
28+
from typing import Callable, Optional, Tuple, TypedDict, Union
2829

2930
import torch
3031
from torch import nn
@@ -61,6 +62,31 @@
6162
logger = logging.get_logger(__name__)
6263

6364

65+
class BambaFlashAttentionKwargs(TypedDict, total=False):
66+
"""
67+
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
68+
Use cases include padding-free training and fewer `torch.compile` graph breaks.
69+
70+
Attributes:
71+
cu_seq_lens_q (`torch.LongTensor`)
72+
Gets cumulative sequence length for query state.
73+
cu_seq_lens_k (`torch.LongTensor`)
74+
Gets cumulative sequence length for key state.
75+
max_length_q (`int`):
76+
Maximum sequence length for query state.
77+
max_length_k (`int`):
78+
Maximum sequence length for key state.
79+
seq_idx (`torch.IntTensor):
80+
Index of each packed sequence.
81+
"""
82+
83+
cu_seq_lens_q: torch.LongTensor
84+
cu_seq_lens_k: torch.LongTensor
85+
max_length_q: int
86+
max_length_k: int
87+
seq_idx: torch.IntTensor
88+
89+
6490
# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
6591
class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache):
6692
"""
@@ -487,6 +513,7 @@ def cuda_kernels_forward(
487513
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
488514
cache_position: Optional[torch.LongTensor] = None,
489515
attention_mask: Optional[torch.Tensor] = None,
516+
seq_idx: Optional[torch.IntTensor] = None,
490517
):
491518
# 1. Gated MLP's linear projection
492519
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
@@ -569,7 +596,7 @@ def cuda_kernels_forward(
569596
A,
570597
D=self.D,
571598
chunk_size=self.chunk_size,
572-
seq_idx=None, # was seq_idx
599+
seq_idx=seq_idx,
573600
activation=self.activation,
574601
rmsnorm_weight=self.norm.weight,
575602
rmsnorm_eps=self.norm.variance_epsilon,
@@ -610,6 +637,7 @@ def cuda_kernels_forward(
610637
weight=self.conv1d.weight.squeeze(1),
611638
bias=self.conv1d.bias,
612639
activation=self.activation,
640+
seq_idx=seq_idx,
613641
).transpose(1, 2)
614642

615643
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
@@ -629,7 +657,7 @@ def cuda_kernels_forward(
629657
chunk_size=self.chunk_size,
630658
D=self.D,
631659
z=None,
632-
seq_idx=None,
660+
seq_idx=seq_idx,
633661
return_final_states=True,
634662
dt_bias=self.dt_bias,
635663
dt_softplus=True,
@@ -863,9 +891,15 @@ def forward(
863891
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
864892
cache_position: Optional[torch.LongTensor] = None,
865893
attention_mask: Optional[torch.Tensor] = None,
894+
seq_idx: Optional[torch.IntTensor] = None,
895+
**kwargs,
866896
):
867897
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
868-
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
898+
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
899+
if seq_idx is not None:
900+
raise NotImplementedError(
901+
"`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
902+
)
869903
dtype = hidden_states.dtype
870904
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
871905
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
@@ -939,7 +973,7 @@ def forward(
939973
use_cache: Optional[bool] = False,
940974
cache_position: Optional[torch.LongTensor] = None,
941975
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
942-
**kwargs,
976+
**kwargs: Unpack[BambaFlashAttentionKwargs],
943977
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
944978
"""
945979
Args:
@@ -959,8 +993,8 @@ def forward(
959993
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
960994
with `head_dim` being the embedding dimension of each attention head.
961995
kwargs (`dict`, *optional*):
962-
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
963-
into the model
996+
Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
997+
padding-free training and/or improve torch.compile performance.
964998
"""
965999

9661000
residual = hidden_states
@@ -974,6 +1008,7 @@ def forward(
9741008
cache_params=past_key_value,
9751009
cache_position=cache_position,
9761010
attention_mask=attention_mask,
1011+
**kwargs,
9771012
)
9781013
self_attn_weights = None
9791014
elif self.layer_type == "attention":
@@ -1076,7 +1111,7 @@ def forward(
10761111
output_attentions: Optional[bool] = None,
10771112
output_hidden_states: Optional[bool] = None,
10781113
cache_position: Optional[torch.LongTensor] = None,
1079-
**kwargs, # NOOP kwargs, for now
1114+
**kwargs: Unpack[BambaFlashAttentionKwargs],
10801115
) -> BaseModelOutputWithPast:
10811116
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
10821117
output_hidden_states = (
@@ -1128,7 +1163,7 @@ def forward(
11281163

11291164
if self.gradient_checkpointing and self.training:
11301165
layer_outputs = self._gradient_checkpointing_func(
1131-
decoder_layer.__call__,
1166+
partial(decoder_layer.__call__, **kwargs),
11321167
hidden_states,
11331168
layer_mask,
11341169
position_ids,
@@ -1148,6 +1183,7 @@ def forward(
11481183
use_cache=use_cache,
11491184
cache_position=cache_position,
11501185
position_embeddings=position_embeddings,
1186+
**kwargs,
11511187
)
11521188

11531189
hidden_states = layer_outputs[0]

src/transformers/models/bamba/modular_bamba.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
# limitations under the License.
2020
"""PyTorch Bamba model."""
2121

22-
from typing import Optional, Tuple, Union
22+
from functools import partial
23+
from typing import Optional, Tuple, TypedDict, Union
2324

2425
import torch
2526
import torch.utils.checkpoint
@@ -46,7 +47,12 @@
4647
from ...modeling_attn_mask_utils import AttentionMaskConverter
4748
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
4849
from ...modeling_utils import PreTrainedModel
49-
from ...utils import auto_docstring, can_return_tuple, logging
50+
from ...processing_utils import Unpack
51+
from ...utils import (
52+
auto_docstring,
53+
can_return_tuple,
54+
logging,
55+
)
5056
from ...utils.import_utils import is_causal_conv1d_available, is_flash_attn_2_available, is_mamba_2_ssm_available
5157
from .configuration_bamba import BambaConfig
5258

@@ -71,6 +77,31 @@
7177
logger = logging.get_logger(__name__)
7278

7379

80+
class BambaFlashAttentionKwargs(TypedDict, total=False):
81+
"""
82+
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
83+
Use cases include padding-free training and fewer `torch.compile` graph breaks.
84+
85+
Attributes:
86+
cu_seq_lens_q (`torch.LongTensor`)
87+
Gets cumulative sequence length for query state.
88+
cu_seq_lens_k (`torch.LongTensor`)
89+
Gets cumulative sequence length for key state.
90+
max_length_q (`int`):
91+
Maximum sequence length for query state.
92+
max_length_k (`int`):
93+
Maximum sequence length for key state.
94+
seq_idx (`torch.IntTensor):
95+
Index of each packed sequence.
96+
"""
97+
98+
cu_seq_lens_q: torch.LongTensor
99+
cu_seq_lens_k: torch.LongTensor
100+
max_length_q: int
101+
max_length_k: int
102+
seq_idx: torch.IntTensor
103+
104+
74105
# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
75106
class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache):
76107
"""
@@ -278,6 +309,7 @@ def cuda_kernels_forward(
278309
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
279310
cache_position: Optional[torch.LongTensor] = None,
280311
attention_mask: Optional[torch.Tensor] = None,
312+
seq_idx: Optional[torch.IntTensor] = None,
281313
):
282314
# 1. Gated MLP's linear projection
283315
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
@@ -360,7 +392,7 @@ def cuda_kernels_forward(
360392
A,
361393
D=self.D,
362394
chunk_size=self.chunk_size,
363-
seq_idx=None, # was seq_idx
395+
seq_idx=seq_idx,
364396
activation=self.activation,
365397
rmsnorm_weight=self.norm.weight,
366398
rmsnorm_eps=self.norm.variance_epsilon,
@@ -401,6 +433,7 @@ def cuda_kernels_forward(
401433
weight=self.conv1d.weight.squeeze(1),
402434
bias=self.conv1d.bias,
403435
activation=self.activation,
436+
seq_idx=seq_idx,
404437
).transpose(1, 2)
405438

406439
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
@@ -420,7 +453,7 @@ def cuda_kernels_forward(
420453
chunk_size=self.chunk_size,
421454
D=self.D,
422455
z=None,
423-
seq_idx=None,
456+
seq_idx=seq_idx,
424457
return_final_states=True,
425458
dt_bias=self.dt_bias,
426459
dt_softplus=True,
@@ -654,9 +687,15 @@ def forward(
654687
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
655688
cache_position: Optional[torch.LongTensor] = None,
656689
attention_mask: Optional[torch.Tensor] = None,
690+
seq_idx: Optional[torch.IntTensor] = None,
691+
**kwargs,
657692
):
658693
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
659-
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
694+
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
695+
if seq_idx is not None:
696+
raise NotImplementedError(
697+
"`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
698+
)
660699
dtype = hidden_states.dtype
661700
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
662701
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
@@ -701,7 +740,7 @@ def forward(
701740
use_cache: Optional[bool] = False,
702741
cache_position: Optional[torch.LongTensor] = None,
703742
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
704-
**kwargs,
743+
**kwargs: Unpack[BambaFlashAttentionKwargs],
705744
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
706745
"""
707746
Args:
@@ -721,8 +760,8 @@ def forward(
721760
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
722761
with `head_dim` being the embedding dimension of each attention head.
723762
kwargs (`dict`, *optional*):
724-
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
725-
into the model
763+
Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
764+
padding-free training and/or improve torch.compile performance.
726765
"""
727766

728767
residual = hidden_states
@@ -736,6 +775,7 @@ def forward(
736775
cache_params=past_key_value,
737776
cache_position=cache_position,
738777
attention_mask=attention_mask,
778+
**kwargs,
739779
)
740780
self_attn_weights = None
741781
elif self.layer_type == "attention":
@@ -838,7 +878,7 @@ def forward(
838878
output_attentions: Optional[bool] = None,
839879
output_hidden_states: Optional[bool] = None,
840880
cache_position: Optional[torch.LongTensor] = None,
841-
**kwargs, # NOOP kwargs, for now
881+
**kwargs: Unpack[BambaFlashAttentionKwargs],
842882
) -> BaseModelOutputWithPast:
843883
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
844884
output_hidden_states = (
@@ -890,7 +930,7 @@ def forward(
890930

891931
if self.gradient_checkpointing and self.training:
892932
layer_outputs = self._gradient_checkpointing_func(
893-
decoder_layer.__call__,
933+
partial(decoder_layer.__call__, **kwargs),
894934
hidden_states,
895935
layer_mask,
896936
position_ids,
@@ -910,6 +950,7 @@ def forward(
910950
use_cache=use_cache,
911951
cache_position=cache_position,
912952
position_embeddings=position_embeddings,
953+
**kwargs,
913954
)
914955

915956
hidden_states = layer_outputs[0]

0 commit comments

Comments
 (0)