Skip to content

Commit 1b03423

Browse files
davidberard98pytorchmergebot
authored andcommitted
[meta registration] fix _efficient_attention_forward for jagged inputs (pytorch#118657)
Fixes the meta registration for the logsumexp output, whose shape should be defined by the size of the offsets tensor when it exists. https://github.com/pytorch/pytorch/blob/644f64f2d112b7c0b758b044821cf3972c0c17e9/aten/src/ATen/native/transformers/cuda/attention.cu#L1045 Differential Revision: [D53234217](https://our.internmc.facebook.com/intern/diff/D53234217) Pull Request resolved: pytorch#118657 Approved by: https://github.com/YuqingJ
1 parent 6fa162e commit 1b03423

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

torch/_meta_registrations.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5447,9 +5447,10 @@ def meta__efficient_attention_forward(
54475447

54485448
res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
54495449

5450+
logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B
54505451
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
54515452
logsum_exp = torch.empty(
5452-
(B, num_heads, logsumexp_dim),
5453+
(logsumexp_batch_dim, num_heads, logsumexp_dim),
54535454
dtype=torch.float,
54545455
device=query.device,
54555456
)

torch/testing/_internal/common_methods_invocations.py

+20
Original file line numberDiff line numberDiff line change
@@ -8531,6 +8531,26 @@ def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_g
85318531
)
85328532
)
85338533

8534+
# jagged (with query/keys offsets)
8535+
samples.append(
8536+
SampleInput(
8537+
make((4, 2, 64)).view(-1, 8, 8).unsqueeze(0),
8538+
make((6, 64)).view(-1, 8, 8).unsqueeze(0),
8539+
make((6, 64)).view(-1, 8, 8).unsqueeze(0),
8540+
bias=None,
8541+
cu_seqlens_q=torch.tensor((0, 2, 4, 6, 8), dtype=torch.int32, device=device),
8542+
cu_seqlens_k=torch.tensor((0, 1, 3, 5, 6), dtype=torch.int32, device=device),
8543+
max_seqlen_q=2,
8544+
max_seqlen_k=2,
8545+
dropout_p=0.0,
8546+
custom_mask_type=0, # No Mask
8547+
compute_log_sumexp=requires_grad,
8548+
scale=None,
8549+
causal_diagonal=None,
8550+
seqlen_k=None,
8551+
)
8552+
)
8553+
85348554
yield from samples
85358555

85368556
def sample_inputs_flash_attention_forward(op_info, device, dtype, requires_grad, **kwargs):

0 commit comments

Comments
 (0)