Commit 1b03423
[meta registration] fix _efficient_attention_forward for jagged inputs (pytorch#118657)
Fixes the meta registration for the logsumexp output, whose shape should
be defined by the size of the offsets tensor when it exists.
https://github.com/pytorch/pytorch/blob/644f64f2d112b7c0b758b044821cf3972c0c17e9/aten/src/ATen/native/transformers/cuda/attention.cu#L1045
Differential Revision: [D53234217](https://our.internmc.facebook.com/intern/diff/D53234217)
Pull Request resolved: pytorch#118657
Approved by: https://github.com/YuqingJ1 parent 6fa162e commit 1b03423
File tree
2 files changed
+22
-1
lines changed- torch
- testing/_internal
2 files changed
+22
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
5447 | 5447 | | |
5448 | 5448 | | |
5449 | 5449 | | |
| 5450 | + | |
5450 | 5451 | | |
5451 | 5452 | | |
5452 | | - | |
| 5453 | + | |
5453 | 5454 | | |
5454 | 5455 | | |
5455 | 5456 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8531 | 8531 | | |
8532 | 8532 | | |
8533 | 8533 | | |
| 8534 | + | |
| 8535 | + | |
| 8536 | + | |
| 8537 | + | |
| 8538 | + | |
| 8539 | + | |
| 8540 | + | |
| 8541 | + | |
| 8542 | + | |
| 8543 | + | |
| 8544 | + | |
| 8545 | + | |
| 8546 | + | |
| 8547 | + | |
| 8548 | + | |
| 8549 | + | |
| 8550 | + | |
| 8551 | + | |
| 8552 | + | |
| 8553 | + | |
8534 | 8554 | | |
8535 | 8555 | | |
8536 | 8556 | | |
| |||
0 commit comments