Skip to content

Commit f38075f

Browse files
raikonenfnumonorimet
authored andcommitted
Add mask to attention op tiling and decomposition
Signed-off-by: stanley-nod <[email protected]>
1 parent 117df4f commit f38075f

File tree

6 files changed

+123
-28
lines changed

6 files changed

+123
-28
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,9 +1329,9 @@ LogicalResult AttentionOp::verify() {
13291329
int numInputs = getNumDpsInputs();
13301330
int numOutputs = getNumDpsInits();
13311331

1332-
if (numInputs != 4) {
1332+
if (numInputs < 4 || numInputs > 5) {
13331333
return op->emitOpError(
1334-
"expected 4 input operands: Query, Key, Value and Scale");
1334+
"expected 4 or 5 input operands: Query, Key, Value, Scale, and Mask");
13351335
}
13361336

13371337
if (numOutputs != 1 && numOutputs != 3) {
@@ -1340,8 +1340,9 @@ LogicalResult AttentionOp::verify() {
13401340
}
13411341

13421342
bool isTiled = numOutputs == 3;
1343-
1344-
if (!llvm::all_of(llvm::drop_end(getDpsInputs()), [](Value input) {
1343+
SmallVector<Value> dpsInputs = getDpsInputs();
1344+
ArrayRef<Value> qkvValues(dpsInputs.begin(), dpsInputs.begin() + 3);
1345+
if (!llvm::all_of(qkvValues, [](Value input) {
13451346
return isa<ShapedType>(input.getType());
13461347
})) {
13471348
return op->emitOpError(
@@ -1462,10 +1463,17 @@ SmallVector<AffineMap> AttentionOp::getIndexingMapsArray() {
14621463
AffineMap::get(/*dimCount=*/5, /*symbolCount=*/0, {batch, k2, n}, ctx);
14631464
}
14641465

1466+
SmallVector<AffineMap> results = {qMap, kMap, vMap};
1467+
1468+
if (getMask()) {
1469+
AffineMap maskMap =
1470+
AffineMap::get(/*dimCount=*/5, /*symbolCount=*/0, {batch, m, k2}, ctx);
1471+
results.push_back(maskMap);
1472+
}
1473+
14651474
AffineMap resMap =
14661475
AffineMap::get(/*dimCount=*/5, /*symbolCount=*/0, {batch, m, n}, ctx);
1467-
1468-
SmallVector<AffineMap> results = {qMap, kMap, vMap, resMap};
1476+
results.push_back(resMap);
14691477

14701478
if (getMax()) {
14711479
AffineMap maxMap =

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -522,11 +522,15 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention",
522522
let description = [{
523523
Computes the scaled dot product attention function:
524524

525-
attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V
525+
attention(Q, K, V, scale, mask) = softmax(mask(Q @ K.T * scale)) @ V
526526

527527
Here Q, K, V are given tensors and scale is a scalar value specifying
528528
the scale to use.
529529

530+
`mask` is an optional boolean tensor that specifies which relations
531+
in attn_weight that should be ignored. This is useful for
532+
causal attention, padded attention, and some special SD use cases.
533+
530534
For self-attention, all inputs and the result have the same shape BxNxd
531535
where B is the batch dimension, N is the sequence length and d is head
532536
dimension. Typically N >>> d. Usually, this operator also performs
@@ -539,10 +543,10 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention",
539543
FlashAttention 2 and optionally results in the current `max` and `sum`
540544
statistics while processing the current tile.
541545

542-
If transpose_v is speciifed, the V tensor passed as input is assumed to
546+
If transpose_v is specifed, the V tensor passed as input is assumed to
543547
be transposed:
544548

545-
attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V.T
549+
attention(Q, K, V, scale, mask) = softmax(mask(Q @ K.T * scale)) @ V.T
546550

547551
TODO: We should be moving to using a indexing map like approach so we
548552
can generalize which tensor is transposed and which is not.
@@ -579,6 +583,11 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention",
579583
Value getScale() {
580584
return getDpsInputOperand(3)->get();
581585
}
586+
std::optional<Value> getMask() {
587+
if (getNumDpsInputs() < 5)
588+
return std::nullopt;
589+
return getDpsInputOperand(4)->get();
590+
}
582591
Value getOutput() {
583592
return getDpsInitOperand(0)->get();
584593
}
@@ -659,18 +668,26 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention",
659668
AffineMap getValueMap() {
660669
return cast<AffineMap>(getIndexingMapsArray()[2]);
661670
}
662-
AffineMap getOutputMap() {
671+
std::optional<AffineMap> getMaskMap() {
672+
if (getNumDpsInputs() < 5)
673+
return std::nullopt;
663674
return cast<AffineMap>(getIndexingMapsArray()[3]);
664675
}
676+
AffineMap getOutputMap() {
677+
int64_t outputIndex = getNumDpsInputs() - 1;
678+
return cast<AffineMap>(getIndexingMapsArray()[outputIndex]);
679+
}
665680
std::optional<AffineMap> getMaxMap() {
666681
if (getNumResults() < 2)
667682
return std::nullopt;
668-
return cast<AffineMap>(getIndexingMapsArray()[4]);
683+
int64_t maxIndex = getNumDpsInputs();
684+
return cast<AffineMap>(getIndexingMapsArray()[maxIndex]);
669685
}
670686
std::optional<AffineMap> getSumMap() {
671687
if (getNumResults() < 3)
672688
return std::nullopt;
673-
return cast<AffineMap>(getIndexingMapsArray()[5]);
689+
int64_t sumIndex = getNumDpsInputs() + 1;
690+
return cast<AffineMap>(getIndexingMapsArray()[sumIndex]);
674691
}
675692

676693
int64_t getIterationDomainRank() {

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,49 @@ static Value truncateToF16(Value input, Value output,
177177
return genericOp.getResult(0);
178178
}
179179

180+
static Value applyMasking(Value qkSlice, Value mask, OpBuilder &builder) {
181+
ShapedType qkType = cast<ShapedType>(qkSlice.getType());
182+
Location loc = qkSlice.getLoc();
183+
184+
// Create a fill op for scale.
185+
SmallVector<OpFoldResult> qkDims =
186+
tensor::getMixedSizes(builder, loc, qkSlice);
187+
188+
// Attention_mask is 1.0 for positions we want to attend and 0.0 for
189+
// masked positions. this operation will create a tensor which is 0.0 for
190+
// positions we want to attend and -10000.0 for masked positions
191+
Value c0 = builder.create<arith::ConstantOp>(
192+
loc, builder.getZeroAttr(qkType.getElementType()));
193+
194+
Value cLargeNeg = builder.create<arith::ConstantOp>(
195+
loc, builder.getFloatAttr(qkType.getElementType(), -1e6));
196+
197+
Value empty =
198+
builder.create<tensor::EmptyOp>(loc, qkDims, qkType.getElementType());
199+
// Create a generic op to multiply the query by the scale.
200+
SmallVector<utils::IteratorType> iteratorTypes(2,
201+
utils::IteratorType::parallel);
202+
auto identityMap = AffineMap::getMultiDimIdentityMap(2, builder.getContext());
203+
SmallVector<AffineMap> indexingMaps(3, identityMap);
204+
auto applyMaskOp = builder.create<linalg::GenericOp>(
205+
loc, TypeRange{empty.getType()}, ValueRange{qkSlice, mask},
206+
ValueRange{empty}, indexingMaps, iteratorTypes,
207+
[&](OpBuilder &b, Location loc, ValueRange args) {
208+
// TODO: Add support to enable i1 throughout.
209+
// currently type propagation force it to i1.
210+
Value trunci =
211+
b.create<arith::TruncIOp>(loc, builder.getI1Type(), args[1]);
212+
Value masking = b.create<arith::SelectOp>(loc, trunci, c0, cLargeNeg);
213+
Value result = b.create<arith::AddFOp>(loc, args[0], masking);
214+
b.create<linalg::YieldOp>(loc, result);
215+
});
216+
return applyMaskOp.getResult(0);
217+
}
218+
180219
static std::tuple<Value, Value, Value>
181220
createAttentionBody(Value keySlice, Value valueSlice, Value querySlice,
182-
Value outputSlice, Value maxSlice, Value sumSlice,
221+
std::optional<Value> maskSlice, Value outputSlice,
222+
Value maxSlice, Value sumSlice,
183223
OpFoldResult sequenceTileLength,
184224
OpFoldResult keyValueTileLength, OpFoldResult headDimension,
185225
Type elementType, SmallVectorImpl<Operation *> &ops,
@@ -195,6 +235,11 @@ createAttentionBody(Value keySlice, Value valueSlice, Value querySlice,
195235
Value qkTranspose = computeQKTranspose(querySlice, keySlice, emptySquare,
196236
zero, loc, builder, ops);
197237

238+
// Apply masking if mask is specified.
239+
if (maskSlice.has_value()) {
240+
qkTranspose = applyMasking(qkTranspose, maskSlice.value(), builder);
241+
}
242+
198243
// Compute current statistics
199244
Value newMax = computeRowwiseReduction<arith::MaximumFOp>(
200245
qkTranspose, maxSlice, loc, builder, ops);
@@ -327,9 +372,9 @@ void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
327372
// iteration of the loop.
328373
querySlice = scaleQuery(querySlice, scale, rewriter);
329374
ops.push_back(querySlice.getDefiningOp());
330-
375+
std::optional<Value> maybeMask = tiledAttnOp.getMask();
331376
auto [result, newMax, newSum] = createAttentionBody(
332-
keySlice, valueSlice, querySlice, tiledResult, max, sum,
377+
keySlice, valueSlice, querySlice, maybeMask, tiledResult, max, sum,
333378
sequenceTileLength, keyValueTileLength, headDimension, elementType, ops,
334379
tiledAttnOp.getTransposeV(), loc, rewriter);
335380

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ static scf::LoopNest createLoopNest(SmallVectorImpl<Value> &ivs, Value lb,
7979
}
8080

8181
static Value extractSlice(Value key, ArrayRef<int64_t> keyShape,
82-
ArrayRef<Value> ivs, OpFoldResult keyValueTileLength,
82+
ArrayRef<std::optional<Value>> ivs,
83+
OpFoldResult keyValueTileLength,
8384
OpFoldResult headDimension, Type elementType,
8485
Location loc, OpBuilder &builder,
8586
bool swapLastTwoDims = false) {
@@ -90,8 +91,10 @@ static Value extractSlice(Value key, ArrayRef<int64_t> keyShape,
9091
SmallVector<OpFoldResult> offsets(keyShape.size(), zero);
9192
sizes[1] = keyValueTileLength;
9293
sizes[2] = headDimension;
93-
if (!ivs.empty()) {
94-
offsets[1] = ivs[0];
94+
for (auto [idx, iv] : llvm::enumerate(ivs)) {
95+
if (!iv.has_value())
96+
continue;
97+
offsets[idx] = iv.value();
9598
}
9699
SmallVector<int64_t> tensorShape{keyShape[1], keyShape[2]};
97100
if (swapLastTwoDims) {
@@ -181,6 +184,7 @@ IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp,
181184
Location loc = attnOp.getLoc();
182185
OpBuilder::InsertionGuard guard(rewriter);
183186
rewriter.setInsertionPoint(attnOp);
187+
llvm::outs() << "high level implement tile Attention!\n";
184188

185189
Value query = attnOp.getQuery();
186190
ShapedType queryType = attnOp.getQueryType();
@@ -253,21 +257,35 @@ IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp,
253257
rewriter.setInsertionPointToStart(loopNest.loops.back().getBody());
254258

255259
// Extract slices
256-
Value keySlice = extractSlice(key, keyShape, ivs, keyValueTileLength,
260+
SmallVector<std::optional<Value>> kvIvs(keyShape.size(), std::nullopt);
261+
kvIvs[1] = ivs[0];
262+
Value keySlice = extractSlice(key, keyShape, kvIvs, keyValueTileLength,
257263
headDimension, elementType, loc, rewriter);
258264
Value valueSlice =
259-
extractSlice(value, keyShape, ivs, keyValueTileLength, headDimension,
265+
extractSlice(value, keyShape, kvIvs, keyValueTileLength, headDimension,
260266
elementType, loc, rewriter, attnOp.getTransposeV());
261267
Value querySlice = extractSlice(query, queryShape, {}, sequenceTileLength,
262268
headDimension, elementType, loc, rewriter);
263269

264270
Value scale = attnOp.getScale();
271+
SmallVector<Value> tiledInputs = {querySlice, keySlice, valueSlice, scale};
272+
273+
if (attnOp.getMask().has_value()) {
274+
Value mask = attnOp.getMask().value();
275+
auto maskElType = llvm::cast<ShapedType>(mask.getType()).getElementType();
276+
SmallVector<std::optional<Value>> maskIvs(keyShape.size(), std::nullopt);
277+
maskIvs[2] = ivs[0];
278+
SmallVector<int64_t> maskShape{queryShape[0], queryShape[1], keyShape[1]};
279+
Value maskSlice =
280+
extractSlice(mask, maskShape, maskIvs, sequenceTileLength,
281+
keyValueTileLength, maskElType, loc, rewriter);
282+
tiledInputs.push_back(maskSlice);
283+
}
265284

266285
auto tiledAttentionOp = rewriter.create<IREE::LinalgExt::AttentionOp>(
267286
attnOp.getLoc(),
268287
SmallVector<Type>{accumulatorF32.getType(), sum.getType(), max.getType()},
269-
SmallVector<Value>{querySlice, keySlice, valueSlice, scale},
270-
SmallVector<Value>{iterArgResult, iterArgMax, iterArgSum});
288+
tiledInputs, SmallVector<Value>{iterArgResult, iterArgMax, iterArgSum});
271289

272290
if (attnOp.getTransposeV())
273291
tiledAttentionOp.setTransposeVAttr(attnOp.getTransposeVAttr());

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1716,6 +1716,14 @@ AttentionOp::getTiledImplementation(OpBuilder &builder,
17161716
tiledOperands.emplace_back(getSlice(builder, loc, getKey(), keySlice));
17171717
tiledOperands.emplace_back(getSlice(builder, loc, getValue(), valueSlice));
17181718
tiledOperands.emplace_back(scale);
1719+
1720+
std::optional<Value> mask = getMask();
1721+
if (mask) {
1722+
SmallVector<Range> maskSlice =
1723+
getPermutedSlice(*getMaskMap(), offsets, sizes);
1724+
tiledOperands.emplace_back(getSlice(builder, loc, mask.value(), maskSlice));
1725+
}
1726+
17191727
tiledOperands.emplace_back(getSlice(builder, loc, getOutput(), outputSlice));
17201728

17211729
std::optional<Value> max = getMax();
@@ -1733,13 +1741,14 @@ AttentionOp::getTiledImplementation(OpBuilder &builder,
17331741
}
17341742

17351743
SmallVector<Type> resultTypes;
1744+
int64_t resultIndex = getNumDpsInputs();
17361745
if (hasPureTensorSemantics()) {
1737-
resultTypes.push_back(tiledOperands[4].getType());
1746+
resultTypes.push_back(tiledOperands[resultIndex].getType());
17381747
if (max) {
1739-
resultTypes.push_back(tiledOperands[5].getType());
1748+
resultTypes.push_back(tiledOperands[resultIndex + 1].getType());
17401749
}
17411750
if (sum) {
1742-
resultTypes.push_back(tiledOperands[6].getType());
1751+
resultTypes.push_back(tiledOperands[resultIndex + 2].getType());
17431752
}
17441753
}
17451754

compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ void AttentionOpDetail::inferFromIndexingMaps(
3131
AffineMap qMap = indexingMaps[0];
3232
AffineMap kMap = indexingMaps[1];
3333
AffineMap vMap = indexingMaps[2];
34-
AffineMap resMap = indexingMaps[3];
3534

3635
// Q = B x M x K1
3736
// K = B x K2 x K1
@@ -40,7 +39,6 @@ void AttentionOpDetail::inferFromIndexingMaps(
4039
llvm::SmallDenseSet<int64_t> qSet = findPermutationsIndexingOperand(qMap);
4140
llvm::SmallDenseSet<int64_t> vSet = findPermutationsIndexingOperand(vMap);
4241
llvm::SmallDenseSet<int64_t> kSet = findPermutationsIndexingOperand(kMap);
43-
llvm::SmallDenseSet<int64_t> resSet = findPermutationsIndexingOperand(resMap);
4442

4543
// B = Q & K & V
4644
llvm::SmallDenseSet<int64_t> bSet = qSet;
@@ -76,7 +74,7 @@ void AttentionOpDetail::inferFromIndexingMaps(
7674

7775
FailureOr<AttentionOpDetail>
7876
AttentionOpDetail::get(ArrayRef<AffineMap> indexingMaps) {
79-
if (indexingMaps.size() != 4 && indexingMaps.size() != 6) {
77+
if (indexingMaps.size() < 4 || indexingMaps.size() > 7) {
8078
return failure();
8179
}
8280

0 commit comments

Comments
 (0)