Skip to content

Add padding-free to bamba #35861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 20, 2025

Conversation

garrett361
Copy link
Contributor

@garrett361 garrett361 commented Jan 23, 2025

What does this PR do?

Adds padding-free training to the BambaModel, enabling more efficient training with causal masking between disjoint sequences.

Performance: approximately 2x throughput improvements over naive padding for supervised finetuning on the Tulu v3 dataset with open-instruct. Tokens/sec/gpu plots for batch_size_per_gpu = 4:

8 A100s: 600 --> 1200 Tok/s/gpu

Scherm­afbeelding 2025-01-16 om 3 52 41 PM

32 A100s: 450 --> 750 Tok/s/gpu

Scherm­afbeelding 2025-01-16 om 3 52 33 PM

CC @fabianlim

CC reviewers of #34982: @ArthurZucker @molbap

Notes on Code

  • BambaAttention layers are untouched; only the BambaMixer mamba layer code is altered.
  • The padding-free path is only supported on cuda and requires the mamba kernels.
  • Supports both the position_ids and FlashAttentionKwargs padding-free code paths.

Notes on Tests

On both latest main and this PR branch the following tests/models/bamba/test_modeling_bamba.py tests are failing (with RUN_SLOW=1):

BambaModelTest::test_eager_matches_fa2_generate
BambaModelTest::test_flash_attention_2_padding_matches_padding_free_with_position_ids
BambaModelTest::test_sdpa_can_compile_dynamic
BambaModelTest::test_torchscript_output_attentions
BambaModelTest::test_torchscript_output_hidden_state
BambaModelTest::test_torchscript_simple
BambaModelIntegrationTest::test_simple_generate
  • The test_eager_matches_fa2_generate test seems flaky: sometimes it passes, other times it fails.
  • For test_flash_attention_2_padding_matches_padding_free_with_position_ids:
    • On main, this test fails because padding-free is not implemented.
    • On this PR branch this test fails because this PR only uses position_ids when model.training = True and this test explicitly calls eval() on the model. I have checked that this test passes when model.training = True. Edit: see BambaModelTest::test_attn_mask_position_ids_flash_attn_equality, also.
  • test_simple_generate appears to just need a simple edit for its expected text. It consistently fails with:
AssertionError: '<|be[35 chars]on this lovely evening? I hope you are all doing well. I am' != '<|be[35 chars]on this lovely evening? I hope you are all having a good time.'

where the generated and expected text differs at the very end.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@garrett361 garrett361 force-pushed the bamba-hf-padding-free-pr branch from cdaf1e6 to eab1ae1 Compare January 23, 2025 20:48
@garrett361 garrett361 closed this Jan 23, 2025
@garrett361 garrett361 reopened this Jan 24, 2025
@garrett361 garrett361 force-pushed the bamba-hf-padding-free-pr branch from eab1ae1 to c4874af Compare January 24, 2025 14:53
@Rocketknight1
Copy link
Member

cc @ArthurZucker for bamba, but let me know if you want me to take a look since it seems like quite an extensive PR!

@garrett361
Copy link
Contributor Author

it seems like quite an extensive PR!

I don't think it's very many changes, ultimately! Basically it just adds two helper functions so that position_ids and FlashAttentionKwargs get properly converted to the seq_idx arg that mamba expects:

  • get_cu_seq_lens_from_position_ids
  • get_seq_idx_from_cu_seq_lens

So, basically the above, making sure **kwargs get passed everywhere they should, and a little code cleanup.

@garrett361 garrett361 force-pushed the bamba-hf-padding-free-pr branch from 49d007c to d35bcc6 Compare January 24, 2025 20:04
@garrett361
Copy link
Contributor Author

Added a commit with BambaModelTest::test_attn_mask_position_ids_flash_attn_equality which tests the various code paths against each other.

@garrett361 garrett361 force-pushed the bamba-hf-padding-free-pr branch 4 times, most recently from dfaca13 to 5d39d5e Compare January 28, 2025 16:56
@garrett361
Copy link
Contributor Author

Hi @ArthurZucker @molbap, please let me know if I can answer any questions about this PR, thank you!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Good work and good PR, one thing is that for training we do recommend using the padding-free data collator that takes care of the flattening and passing approriate kwargs. This prevents us from having to do too many changes, appart from poping the cu seqlens if they are inputs

@@ -940,10 +931,39 @@ def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


def get_cu_seq_lens_from_position_ids(position_ids: torch.LongTensor) -> torch.LongTensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not super accurate to use positions ids because they can include padding

Copy link
Contributor Author

@garrett361 garrett361 Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I thought about this for a while and it's why I am checking the non_increasing_pos_id condition.

The (just-updated; fixed an error) helper now reads:

def get_cu_seq_lens_from_position_ids(position_ids: torch.LongTensor) -> torch.LongTensor:
    batch_size = position_ids.shape[0]
    if batch_size != 1:
        raise ValueError("Only batch size 1 is supported.")
    device = position_ids.device
    idxs = torch.arange(1, position_ids.shape[1], device=device)
    non_increasing_pos_id = position_ids[0, 1:] <= position_ids[0, :-1]
    next_pos_is_is_zero = position_ids[0, 1:] == 0
    new_seq_idxs = non_increasing_pos_id | next_pos_is_is_zero
    cu_seq_lens = torch.empty(new_seq_idxs.sum() + 2, device=device, dtype=torch.int64)
    cu_seq_lens[0], cu_seq_lens[1:-1], cu_seq_lens[-1] = 0, idxs[new_seq_idxs], position_ids.shape[-1]
    return cu_seq_lens

My goal was to treat every padding token (assumed to be encoded as a negative number, like -100) as an individual sequence, while treating the non-padding sequences correctly.

So, an extreme case like like position_ids = [-100, 0, 1, -100, -100, 0, 1, 2]) would turn into:

# get_cu_seq_lens_from_position_ids
cu_seq_lens = [0, 1, 3, 4, 5, 8]

thanks to the non-increasing pos id check. Seems reasonable to me, since:

  1. The non-trivial sequences are assigned the correct lengths
  2. Every padding tok is a new, len 1 segment, so no unnecessary compute will be expended in attending across a span of padding toks.

Compare to what is done in prepare_fa2_from_positions_ids, which I was trying to improve upon:

cu_seq_lens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
)
)

Passing the same position_ids into the above gives:

# prepare_fa2_from_positions_ids
cu_seq_lens = [1, 5, 8]

which incorrectly implies there's a subseq of len 4, since it only checks for pos_id = 0.

Thoughts? Very open to improvements!

@@ -1079,6 +1117,13 @@ def _init_weights(self, module):
module.weight.data[module.padding_idx].zero_()


def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment about cat and for loops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, also changed now, thanks.

@garrett361 garrett361 force-pushed the bamba-hf-padding-free-pr branch from 5d39d5e to 6fdd9a0 Compare February 5, 2025 14:08
@garrett361
Copy link
Contributor Author

Good work and good PR

Thank you!

one thing is that for training we do recommend using the padding-free data collator that takes care of the flattening and passing approriate kwargs.

Yep, just trying to support all code paths. Hope I didn't misunderstand you here.

I think I have addressed all comments @ArthurZucker , let me know!

I also had to run modular_model_convertor.py on all models to get past CI; some other PR seemed to cause minor CI breakage.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am aligned with you on enable the path!
The key idea for me is that all these utilities, and calls you are doing are coming from not properly preparing the inputs!
And IMO preparing them outside the forward loop is a lot better! You leave it to the data collator.

What I am saying here is I'd rather we add a data collator for bamba, that properly takes care of initializing the seq_idx for example, than having to add extra logic.
BUT we need to propagate the seq_idx for sure!

Does it make sense for you?

@garrett361
Copy link
Contributor Author

The key idea for me is that all these utilities, and calls you are doing are coming from not properly preparing the inputs!
And IMO preparing them outside the forward loop is a lot better!

Ah, yes I agree. It would be much better to just have seq_idx prepared once at the outset in the dataloader (along with position_ids and/or cu_seq_len_x), rather than doing the conversions in the model. Though we still the model to be able to do these conversions if a suboptimal dataloader is used, yes?

Another optimization I could do is to move the conversion code out of the repeated BambaDecoder layers and into BambaModel, so we only convert once. I put these helper functions inside the mamba branch of the BambaDecoder layer because I was trying to be conservative and not touch the full-attention layer code path at all, but this isn't ideal for perf. Thoughts?

Last, if you haven't seen it, some similar conversation is happening over in #35941; see this comment.

@garrett361 garrett361 mentioned this pull request Feb 6, 2025
5 tasks
Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love the overall idea and implementation. Just my 2 cents on mainly design and maybe how the future would look like using flash attn kwargs.

Imo it would be nice to introduce some more uniform kwargs that is aligned with padding free, not only flash attn (likely something for the future rather than this PR).

@garrett361
Copy link
Contributor Author

Love the overall idea and implementation.

Thanks!

Imo it would be nice to introduce some more uniform kwargs that is aligned with padding free, not only flash attn (likely something for the future rather than this PR).

Agreed, that would be very nice.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Super sorry to be a bit strict on this:

Ah, yes I agree. It would be much better to just have seq_idx prepared once at the outset in the dataloader (along with position_ids and/or cu_seq_len_x), rather than doing the conversions in the model. Though we still the model to be able to do these conversions if a suboptimal dataloader is used, yes?

No 😓 because if we do, we open the door to doing the same for flex attention, then flash attention3 then etc and etc!

Let's rather create a path (meaning support passing them as kwargs typed properly!) and a datacollator + add doc about this 🤗

@garrett361
Copy link
Contributor Author

Super sorry to be a bit strict on this:

I understand! I actually personally prefer the more-strict approach, I was just trying to change external APIs as minimally as possible.

Now that I understand the goals better, here is my plan:

  1. Rework this PR so that all padding-free related tensors appear as proper kwargs, raising ValueErrors if improper combinations are passed in.
  2. Open a separate PR which enables DataCollatorWithFlattening to optionally produce seq_idx and/or the FlashAttentionKwargs.
  3. Open a similar PR in trl for DataCollatorForCompletionOnlyLM

How does that sound @ArthurZucker ?

An aside: making all of these padding-free args proper individual kwargs also has the benefit of enabling (padding-free + activation checkpointing) for various models, since the latter feature needs explicit args. The FlashAttentionKwargs path for llama models does not support grad checkpointing: **flash_attn_kwargs are passed to the non-grad-ckpt layer, but no the checkpointed one, since checkpointing doesn't support kwargs.

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)

@garrett361
Copy link
Contributor Author

@ArthurZucker could you please advise on the above plan? Would love to implement whatever path you think is best, but would like confirmation before doing unnecessary work.

@garrett361
Copy link
Contributor Author

Started a draft PR for adding more return values to the DataCollatorWithFlattening here: #36456

@garrett361 garrett361 force-pushed the bamba-hf-padding-free-pr branch 2 times, most recently from 539f1bb to 6260a00 Compare March 3, 2025 17:37
@garrett361
Copy link
Contributor Author

garrett361 commented Mar 3, 2025

I reworked this PR so that it doesn't introduce any tensor conversions for padding-free training. The user must either provide:

  • seq_idx and position_ids
  • seq_idx, position_ids and all of the FlashAttentionKwargs (for better torch.compile compatibility).

The model code raises a ValueError if improper combinations are provided. The intention is that padding-free for bamba will be used along with DataCollatorWithFlattenting with additional return values which will be provided by #36456.

I made the forward signature explicit nearly everywhere(*) and added some doc strings.

(*)The one exception is BambaForCausalLM where all of the padding-free args are still processed through **kwargs. Adding seq_idx, cu_seq_lens_q, ... to BambaForCasualLM in modular_bamba.py does not cause utils/modular_model_coverter.py to propagate these args non-trivially to the BambaForCausalLM class in modeling_bamba.py, presumably because BambaForCasualLM inherits from LlamaForCausalLM which doesn't use these args. Do we want to make all of these FlashAttentionKwargs explicit in LlamaForCausalLM, too, eventually?

@ArthurZucker @vasqu please let me know what you think.

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the overall implementation, just some comments on the checks at the beginning of the forward and possibly add tests that ensure that errors are raised correctly.

Would let the others comment on modular :D would be nice to enable more (kw)args from a given class.

Comment on lines 1286 to 1123
# NOTE: @goon - padding-free kwargs intentionally omitted. Need to change LlamaForCausalLM,
# or the code gen tool, for explicitly included padding-free kwargs be properly processed.
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_length_q=max_length_q,
max_length_k=max_length_k,
seq_idx=seq_idx,
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**kwargs,

Unnecessary?

Comment on lines 1014 to 1018
if flash_attn_kwargs_all_provided and (position_ids is None or seq_idx is None):
raise ValueError(
"If (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) are provided,"
" then position_ids and seq_idx must also be provided."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if flash_attn_kwargs_all_provided and (position_ids is None or seq_idx is None):
raise ValueError(
"If (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) are provided,"
" then position_ids and seq_idx must also be provided."
)
if flash_attn_kwargs_all_provided and position_ids is None:
raise ValueError(
"If (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) are provided,"
" then position_ids must also be provided."
)

Could we not have fa kwarg and position ids? This would enable fa path but not the mamba2 path - which shouldn't be an issue imo.

Might be nice to add for which path here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I guess that's possible if the cu_seq_lens are trivial like (0, seq_len), yeah.

The concern I had here is silent incorrectness if non-trival cu_seq_lens = (0, seq_len0, seq_len0 + seq_len1, ...) are passed in without a corresponding non-trivial seq_idx, in which case the outputs would be silently incorrect.

I'll sharpen up the checks here, thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argh yea that's painful :/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, sorry this whole PR (and associated ones) and getting a bit painful and drawn out. Thank you all for your patience!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries take your time :)

@garrett361
Copy link
Contributor Author

garrett361 commented May 20, 2025

Super happy to have padding free, but it should follow what we have for other model: typed kwargs

@ArthurZucker I guess I misunderstood the ask here. By "typed kwargs" I thought you meant explicitly adding the padding-free kwargs to the forward method with individual type annotations, like

        cu_seq_lens_q: Optional[torch.LongTensor] = None,
        cu_seq_lens_k: Optional[torch.LongTensor] = None,
        max_length_q: Optional[int] = None,
        max_length_k: Optional[int] = None,
        seq_idx: Optional[torch.IntTensor] = None,

What did you mean instead? Something like **kwargs: Unpack[FlashAttentionKwargs] but with a new class there?

@garrett361
Copy link
Contributor Author

garrett361 commented May 20, 2025

This should still work with ckping (partial) and should not need much new code!

So this would work?

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        ...,
        **kwargs,  
    ) -> BaseModelOutputWithPast:
        ...
        if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    ...,
                    **kwargs
                )

For some reason, I thought **kwargs weren't supported in grad ckpting.

EDIT: ah, I found this example, so we can just use partial(decoder_layer.__call__, **kwargs) for grad ckpt + kwargs

@garrett361
Copy link
Contributor Author

CC @vasqu for the above

@garrett361 garrett361 force-pushed the bamba-hf-padding-free-pr branch from bfc123d to 56f508f Compare May 20, 2025 14:11
@garrett361
Copy link
Contributor Author

@ArthurZucker I reworked this to use **kwargs: Unpack[BambaFlashAttentionKwargs], with the new class defined to add seq_idx to the list of kwargs usually in FlashAttentionKwargs. I also removed the arg checking code. Let me know if I'm still misunderstanding the ask, here, thanks!

@garrett361 garrett361 force-pushed the bamba-hf-padding-free-pr branch from 56f508f to b59e5c9 Compare May 20, 2025 14:17
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay now LGTM! can you just fix the red ci and we can merge!

@@ -863,9 +891,15 @@ def forward(
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.IntTensor] = None,
**kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prob unpack here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unpack the **kwargs? These **kwargs are just a catch-all for the non-seq_idx kwargs in BambaFlashAttentionKwargs because the BambaMixer layer only uses seq_idx, while BambaAttention uses the rest.

Do you want me to do **kwargs: Unpack[FlashAttentionKwargs] even though the kwargs are unused?

@garrett361
Copy link
Contributor Author

Thanks so much @ArthurZucker for your patience!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@garrett361
Copy link
Contributor Author

@ArthurZucker CI is green now, CI just building docs

@ArthurZucker ArthurZucker merged commit 390f153 into huggingface:main May 20, 2025
12 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks for the contribution and your resilience! 🤗 sorry if I am not super specific sometimes hehehe

@garrett361
Copy link
Contributor Author

I probably should have understood the ask better from context, no worries!

@vasqu
Copy link
Contributor

vasqu commented May 20, 2025

Very nice contribution, thanks @garrett361:)

@garrett361
Copy link
Contributor Author

Thanks for your help and patience also @vasqu !

faaany pushed a commit to faaany/transformers that referenced this pull request May 21, 2025
* add seq_idx and fa kwargs

* update tests

* docs and grad ckpt support

* fmt

* better names

* test_raise_missing_padding_free_kwarg_errs

* + seq_idx in doc strings

* padding free training docs

* add link to pr plots

* raise err on attn_mask with padding free

* rm raising missing padding free err test

* BambaFlashAttentionKwargs

* run modular util for modular_granitemoehybrid.py
xvyv99 pushed a commit to xvyv99/transformers that referenced this pull request May 21, 2025
* add seq_idx and fa kwargs

* update tests

* docs and grad ckpt support

* fmt

* better names

* test_raise_missing_padding_free_kwarg_errs

* + seq_idx in doc strings

* padding free training docs

* add link to pr plots

* raise err on attn_mask with padding free

* rm raising missing padding free err test

* BambaFlashAttentionKwargs

* run modular util for modular_granitemoehybrid.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants