Skip to content

Commit d91e1ac

Browse files
authored
[linalg] Use query shapes for attention broadcast (#4060)
When broadcasting the mask we need to use the query shapes and not the keys. This is due to the key for GQA having different batch dimensions than the expanded output.
1 parent 6706183 commit d91e1ac

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1841,7 +1841,7 @@ class ConvertAtenScaledDotProductAttentionOp
18411841
int64_t rank = maskTy.getRank();
18421842
bool needsBroadcast = false;
18431843
for (int i = 0, s = rank - 2; i < s; ++i) {
1844-
needsBroadcast |= maskTy.getDimSize(i) != keyTy.getDimSize(i);
1844+
needsBroadcast |= maskTy.getDimSize(i) != queryTy.getDimSize(i);
18451845
}
18461846

18471847
if (needsBroadcast) {
@@ -1850,16 +1850,17 @@ class ConvertAtenScaledDotProductAttentionOp
18501850

18511851
SmallVector<AffineExpr> maskExprs;
18521852
for (int i = 0, s = rank - 2; i < s; ++i) {
1853-
maskShape.push_back(keyTy.getDimSize(i));
1853+
maskShape.push_back(queryTy.getDimSize(i));
18541854

1855-
if (maskTy.getDimSize(i) != keyTy.getDimSize(i)) {
1855+
if (maskTy.getDimSize(i) != queryTy.getDimSize(i)) {
18561856
maskExprs.push_back(rewriter.getAffineConstantExpr(0));
18571857
} else {
18581858
maskExprs.push_back(rewriter.getAffineDimExpr(i));
18591859
}
18601860

1861-
if (keyTy.isDynamicDim(i)) {
1862-
maskDynDims.push_back(rewriter.create<tensor::DimOp>(loc, key, i));
1861+
if (queryTy.isDynamicDim(i)) {
1862+
maskDynDims.push_back(
1863+
rewriter.create<tensor::DimOp>(loc, query, i));
18631864
}
18641865
}
18651866

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5787,7 +5787,7 @@ def __init__(self):
57875787
)
57885788
def forward(self, query, key, value):
57895789
return torch.ops.aten.scaled_dot_product_attention(
5790-
query, key, value, enable_gqa=True
5790+
query, key, value, enable_gqa=True, is_causal=True
57915791
)
57925792

57935793

0 commit comments

Comments
 (0)