Skip to content

Commit 5d39d5e

Browse files
committed
fix FlashAttentionKwargs rope
1 parent f3cdffe commit 5d39d5e

File tree

3 files changed

+49
-6
lines changed

3 files changed

+49
-6
lines changed

src/transformers/models/bamba/modeling_bamba.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,13 @@ def _init_weights(self, module):
11171117
module.weight.data[module.padding_idx].zero_()
11181118

11191119

1120+
def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor:
1121+
pos_ids = torch.cat(
1122+
[torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1
1123+
)[None]
1124+
return pos_ids
1125+
1126+
11201127
BAMBA_INPUTS_DOCSTRING = r"""
11211128
Args:
11221129
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1281,7 +1288,10 @@ def forward(
12811288
if cache_position is None:
12821289
cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
12831290
if position_ids is None:
1284-
position_ids = cache_position.unsqueeze(0)
1291+
if "cu_seq_lens_q" in flash_attn_kwargs:
1292+
position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"])
1293+
else:
1294+
position_ids = cache_position.unsqueeze(0)
12851295

12861296
causal_mask = self._update_causal_mask(
12871297
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions

src/transformers/models/bamba/modular_bamba.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,13 @@ def get_seq_idx_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor:
233233
return seq_idx[None]
234234

235235

236+
def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor:
237+
pos_ids = torch.cat(
238+
[torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1
239+
)[None]
240+
return pos_ids
241+
242+
236243
# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer
237244
class BambaMixer(nn.Module):
238245
"""
@@ -1029,7 +1036,10 @@ def forward(
10291036
if cache_position is None:
10301037
cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
10311038
if position_ids is None:
1032-
position_ids = cache_position.unsqueeze(0)
1039+
if "cu_seq_lens_q" in flash_attn_kwargs:
1040+
position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"])
1041+
else:
1042+
position_ids = cache_position.unsqueeze(0)
10331043

10341044
causal_mask = self._update_causal_mask(
10351045
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions

tests/models/bamba/test_modeling_bamba.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222

2323
from transformers import AutoTokenizer, BambaConfig, is_torch_available
2424
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
25-
from transformers.models.bamba.modular_bamba import get_cu_seq_lens_from_position_ids, get_seq_idx_from_cu_seq_lens
25+
from transformers.models.bamba.modular_bamba import (
26+
get_cu_seq_lens_from_position_ids,
27+
get_position_ids_from_cu_seq_lens,
28+
get_seq_idx_from_cu_seq_lens,
29+
)
2630
from transformers.testing_utils import (
2731
require_flash_attn,
2832
require_torch,
@@ -565,9 +569,7 @@ def test_attn_mask_position_ids_flash_attn_equality(self):
565569
)[None]
566570

567571
torch.testing.assert_close(position_ids_logits, attn_mask_logits_reshaped)
568-
# A higher tolerance is needed for the position_ids and FlashAttentionKwargs logits to
569-
# match, for unknown reasons.
570-
torch.testing.assert_close(position_ids_logits, flash_attn_kwargs_logits, atol=1e-3, rtol=1e-1)
572+
torch.testing.assert_close(position_ids_logits, flash_attn_kwargs_logits)
571573

572574

573575
@slow
@@ -723,3 +725,24 @@ def test_seq_idx_from_cu_seq_lens() -> None:
723725
)[None]
724726
seq_idx_pred = get_seq_idx_from_cu_seq_lens(cu_seq_lens)
725727
assert torch.allclose(seq_idx_pred, seq_idx)
728+
729+
730+
def test_pos_ids_from_cu_seq_lens() -> None:
731+
n_chunks = 5
732+
max_chunk_len = 64
733+
734+
seq_lens = torch.randint(1, max_chunk_len, size=(n_chunks,))
735+
cu_seq_lens = torch.cat([torch.tensor([0]), seq_lens.cumsum(dim=-1)], dim=-1)
736+
pos_ids = torch.cat(
737+
[
738+
torch.arange(
739+
s,
740+
dtype=torch.int32,
741+
device=cu_seq_lens.device,
742+
)
743+
for s in cu_seq_lens.diff(dim=-1)
744+
],
745+
dim=-1,
746+
)[None]
747+
pos_ids_pred = get_position_ids_from_cu_seq_lens(cu_seq_lens)
748+
assert torch.allclose(pos_ids_pred, pos_ids)

0 commit comments

Comments
 (0)