-
Notifications
You must be signed in to change notification settings - Fork 29.2k
Add padding-free to bamba #35861
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add padding-free to bamba #35861
Conversation
cdaf1e6
to
eab1ae1
Compare
eab1ae1
to
c4874af
Compare
cc @ArthurZucker for bamba, but let me know if you want me to take a look since it seems like quite an extensive PR! |
I don't think it's very many changes, ultimately! Basically it just adds two helper functions so that
So, basically the above, making sure |
49d007c
to
d35bcc6
Compare
Added a commit with |
dfaca13
to
5d39d5e
Compare
Hi @ArthurZucker @molbap, please let me know if I can answer any questions about this PR, thank you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Good work and good PR, one thing is that for training we do recommend using the padding-free data collator that takes care of the flattening and passing approriate kwargs. This prevents us from having to do too many changes, appart from poping the cu seqlens if they are inputs
@@ -940,10 +931,39 @@ def extra_repr(self): | |||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" | |||
|
|||
|
|||
def get_cu_seq_lens_from_position_ids(position_ids: torch.LongTensor) -> torch.LongTensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not super accurate to use positions ids because they can include padding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I thought about this for a while and it's why I am checking the non_increasing_pos_id
condition.
The (just-updated; fixed an error) helper now reads:
def get_cu_seq_lens_from_position_ids(position_ids: torch.LongTensor) -> torch.LongTensor:
batch_size = position_ids.shape[0]
if batch_size != 1:
raise ValueError("Only batch size 1 is supported.")
device = position_ids.device
idxs = torch.arange(1, position_ids.shape[1], device=device)
non_increasing_pos_id = position_ids[0, 1:] <= position_ids[0, :-1]
next_pos_is_is_zero = position_ids[0, 1:] == 0
new_seq_idxs = non_increasing_pos_id | next_pos_is_is_zero
cu_seq_lens = torch.empty(new_seq_idxs.sum() + 2, device=device, dtype=torch.int64)
cu_seq_lens[0], cu_seq_lens[1:-1], cu_seq_lens[-1] = 0, idxs[new_seq_idxs], position_ids.shape[-1]
return cu_seq_lens
My goal was to treat every padding token (assumed to be encoded as a negative number, like -100) as an individual sequence, while treating the non-padding sequences correctly.
So, an extreme case like like position_ids = [-100, 0, 1, -100, -100, 0, 1, 2]
) would turn into:
# get_cu_seq_lens_from_position_ids
cu_seq_lens = [0, 1, 3, 4, 5, 8]
thanks to the non-increasing pos id check. Seems reasonable to me, since:
- The non-trivial sequences are assigned the correct lengths
- Every padding tok is a new, len 1 segment, so no unnecessary compute will be expended in attending across a span of padding toks.
Compare to what is done in prepare_fa2_from_positions_ids
, which I was trying to improve upon:
transformers/src/transformers/modeling_flash_attention_utils.py
Lines 174 to 179 in c772bff
cu_seq_lens = torch.cat( | |
( | |
indices_q[position_ids == 0], | |
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), | |
) | |
) |
Passing the same position_ids
into the above gives:
# prepare_fa2_from_positions_ids
cu_seq_lens = [1, 5, 8]
which incorrectly implies there's a subseq of len 4, since it only checks for pos_id = 0
.
Thoughts? Very open to improvements!
@@ -1079,6 +1117,13 @@ def _init_weights(self, module): | |||
module.weight.data[module.padding_idx].zero_() | |||
|
|||
|
|||
def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment about cat and for loops
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, also changed now, thanks.
5d39d5e
to
6fdd9a0
Compare
Thank you!
Yep, just trying to support all code paths. Hope I didn't misunderstand you here. I think I have addressed all comments @ArthurZucker , let me know! I also had to run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am aligned with you on enable the path!
The key idea for me is that all these utilities, and calls you are doing are coming from not properly preparing the inputs!
And IMO preparing them outside the forward loop is a lot better! You leave it to the data collator.
What I am saying here is I'd rather we add a data collator for bamba, that properly takes care of initializing the seq_idx for example, than having to add extra logic.
BUT we need to propagate the seq_idx for sure!
Does it make sense for you?
Ah, yes I agree. It would be much better to just have Another optimization I could do is to move the conversion code out of the repeated Last, if you haven't seen it, some similar conversation is happening over in #35941; see this comment. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love the overall idea and implementation. Just my 2 cents on mainly design and maybe how the future would look like using flash attn kwargs.
Imo it would be nice to introduce some more uniform kwargs that is aligned with padding free, not only flash attn (likely something for the future rather than this PR).
Thanks!
Agreed, that would be very nice. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Super sorry to be a bit strict on this:
Ah, yes I agree. It would be much better to just have seq_idx prepared once at the outset in the dataloader (along with position_ids and/or cu_seq_len_x), rather than doing the conversions in the model. Though we still the model to be able to do these conversions if a suboptimal dataloader is used, yes?
No 😓 because if we do, we open the door to doing the same for flex attention, then flash attention3 then etc and etc!
Let's rather create a path (meaning support passing them as kwargs typed properly!) and a datacollator + add doc about this 🤗
I understand! I actually personally prefer the more-strict approach, I was just trying to change external APIs as minimally as possible. Now that I understand the goals better, here is my plan:
How does that sound @ArthurZucker ? An aside: making all of these padding-free args proper individual kwargs also has the benefit of enabling (padding-free + activation checkpointing) for various models, since the latter feature needs explicit args. The transformers/src/transformers/models/llama/modeling_llama.py Lines 581 to 604 in 60226c6
|
@ArthurZucker could you please advise on the above plan? Would love to implement whatever path you think is best, but would like confirmation before doing unnecessary work. |
Started a draft PR for adding more return values to the |
539f1bb
to
6260a00
Compare
I reworked this PR so that it doesn't introduce any tensor conversions for padding-free training. The user must either provide:
The model code raises a I made the (*)The one exception is @ArthurZucker @vasqu please let me know what you think. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the overall implementation, just some comments on the checks at the beginning of the forward and possibly add tests that ensure that errors are raised correctly.
Would let the others comment on modular :D would be nice to enable more (kw)args from a given class.
# NOTE: @goon - padding-free kwargs intentionally omitted. Need to change LlamaForCausalLM, | ||
# or the code gen tool, for explicitly included padding-free kwargs be properly processed. | ||
**kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modular cc @Cyrilvallez @ArthurZucker
max_length_q=max_length_q, | ||
max_length_k=max_length_k, | ||
seq_idx=seq_idx, | ||
**kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
**kwargs, |
Unnecessary?
if flash_attn_kwargs_all_provided and (position_ids is None or seq_idx is None): | ||
raise ValueError( | ||
"If (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) are provided," | ||
" then position_ids and seq_idx must also be provided." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if flash_attn_kwargs_all_provided and (position_ids is None or seq_idx is None): | |
raise ValueError( | |
"If (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) are provided," | |
" then position_ids and seq_idx must also be provided." | |
) | |
if flash_attn_kwargs_all_provided and position_ids is None: | |
raise ValueError( | |
"If (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) are provided," | |
" then position_ids must also be provided." | |
) |
Could we not have fa kwarg and position ids? This would enable fa path but not the mamba2 path - which shouldn't be an issue imo.
Might be nice to add for which path here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I guess that's possible if the cu_seq_lens
are trivial like (0, seq_len)
, yeah.
The concern I had here is silent incorrectness if non-trival cu_seq_lens = (0, seq_len0, seq_len0 + seq_len1, ...)
are passed in without a corresponding non-trivial seq_idx
, in which case the outputs would be silently incorrect.
I'll sharpen up the checks here, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Argh yea that's painful :/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, sorry this whole PR (and associated ones) and getting a bit painful and drawn out. Thank you all for your patience!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries take your time :)
@ArthurZucker I guess I misunderstood the ask here. By "typed kwargs" I thought you meant explicitly adding the padding-free kwargs to the cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: Optional[int] = None,
max_length_k: Optional[int] = None,
seq_idx: Optional[torch.IntTensor] = None, What did you mean instead? Something like |
So this would work? def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
...,
**kwargs,
) -> BaseModelOutputWithPast:
...
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
...,
**kwargs
) For some reason, I thought EDIT: ah, I found this example, so we can just use |
CC @vasqu for the above |
bfc123d
to
56f508f
Compare
@ArthurZucker I reworked this to use |
56f508f
to
b59e5c9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay now LGTM! can you just fix the red ci and we can merge!
@@ -863,9 +891,15 @@ def forward( | |||
cache_params: Optional[HybridMambaAttentionDynamicCache] = None, | |||
cache_position: Optional[torch.LongTensor] = None, | |||
attention_mask: Optional[torch.Tensor] = None, | |||
seq_idx: Optional[torch.IntTensor] = None, | |||
**kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prob unpack here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unpack
the **kwargs
? These **kwargs
are just a catch-all for the non-seq_idx
kwargs in BambaFlashAttentionKwargs
because the BambaMixer
layer only uses seq_idx
, while BambaAttention
uses the rest.
Do you want me to do **kwargs: Unpack[FlashAttentionKwargs]
even though the kwargs are unused?
Thanks so much @ArthurZucker for your patience! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@ArthurZucker CI is green now, CI just building docs |
Thanks for the contribution and your resilience! 🤗 sorry if I am not super specific sometimes hehehe |
I probably should have understood the ask better from context, no worries! |
Very nice contribution, thanks @garrett361:) |
Thanks for your help and patience also @vasqu ! |
* add seq_idx and fa kwargs * update tests * docs and grad ckpt support * fmt * better names * test_raise_missing_padding_free_kwarg_errs * + seq_idx in doc strings * padding free training docs * add link to pr plots * raise err on attn_mask with padding free * rm raising missing padding free err test * BambaFlashAttentionKwargs * run modular util for modular_granitemoehybrid.py
* add seq_idx and fa kwargs * update tests * docs and grad ckpt support * fmt * better names * test_raise_missing_padding_free_kwarg_errs * + seq_idx in doc strings * padding free training docs * add link to pr plots * raise err on attn_mask with padding free * rm raising missing padding free err test * BambaFlashAttentionKwargs * run modular util for modular_granitemoehybrid.py
What does this PR do?
Adds padding-free training to the
BambaModel
, enabling more efficient training with causal masking between disjoint sequences.Performance: approximately 2x throughput improvements over naive padding for supervised finetuning on the Tulu v3 dataset with open-instruct. Tokens/sec/gpu plots for
batch_size_per_gpu = 4
:8 A100s: 600 --> 1200 Tok/s/gpu
32 A100s: 450 --> 750 Tok/s/gpu
CC @fabianlim
CC reviewers of #34982: @ArthurZucker @molbap
Notes on Code
BambaAttention
layers are untouched; only theBambaMixer
mamba layer code is altered.cuda
and requires the mamba kernels.position_ids
andFlashAttentionKwargs
padding-free code paths.Notes on Tests
On both latest
main
and this PR branch the followingtests/models/bamba/test_modeling_bamba.py
tests are failing (withRUN_SLOW=1
):test_eager_matches_fa2_generate
test seems flaky: sometimes it passes, other times it fails.test_flash_attention_2_padding_matches_padding_free_with_position_ids
:main
, this test fails because padding-free is not implemented.position_ids
whenmodel.training = True
and this test explicitly callseval()
on the model. I have checked that this test passes whenmodel.training = True
. Edit: seeBambaModelTest::test_attn_mask_position_ids_flash_attn_equality
, also.test_simple_generate
appears to just need a simple edit for its expected text. It consistently fails with:where the generated and expected text differs at the very end.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.