Skip to content

Commit

Permalink
Add mask to attention op tiling and decomposition
Browse files Browse the repository at this point in the history
Signed-off-by: stanley-nod <[email protected]>
  • Loading branch information
raikonenfnu authored and monorimet committed Jun 19, 2024
1 parent 117df4f commit f38075f
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 28 deletions.
20 changes: 14 additions & 6 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1329,9 +1329,9 @@ LogicalResult AttentionOp::verify() {
int numInputs = getNumDpsInputs();
int numOutputs = getNumDpsInits();

if (numInputs != 4) {
if (numInputs < 4 || numInputs > 5) {
return op->emitOpError(
"expected 4 input operands: Query, Key, Value and Scale");
"expected 4 or 5 input operands: Query, Key, Value, Scale, and Mask");
}

if (numOutputs != 1 && numOutputs != 3) {
Expand All @@ -1340,8 +1340,9 @@ LogicalResult AttentionOp::verify() {
}

bool isTiled = numOutputs == 3;

if (!llvm::all_of(llvm::drop_end(getDpsInputs()), [](Value input) {
SmallVector<Value> dpsInputs = getDpsInputs();
ArrayRef<Value> qkvValues(dpsInputs.begin(), dpsInputs.begin() + 3);
if (!llvm::all_of(qkvValues, [](Value input) {
return isa<ShapedType>(input.getType());
})) {
return op->emitOpError(
Expand Down Expand Up @@ -1462,10 +1463,17 @@ SmallVector<AffineMap> AttentionOp::getIndexingMapsArray() {
AffineMap::get(/*dimCount=*/5, /*symbolCount=*/0, {batch, k2, n}, ctx);
}

SmallVector<AffineMap> results = {qMap, kMap, vMap};

if (getMask()) {
AffineMap maskMap =
AffineMap::get(/*dimCount=*/5, /*symbolCount=*/0, {batch, m, k2}, ctx);
results.push_back(maskMap);
}

AffineMap resMap =
AffineMap::get(/*dimCount=*/5, /*symbolCount=*/0, {batch, m, n}, ctx);

SmallVector<AffineMap> results = {qMap, kMap, vMap, resMap};
results.push_back(resMap);

if (getMax()) {
AffineMap maxMap =
Expand Down
29 changes: 23 additions & 6 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -522,11 +522,15 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention",
let description = [{
Computes the scaled dot product attention function:

attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V
attention(Q, K, V, scale, mask) = softmax(mask(Q @ K.T * scale)) @ V

Here Q, K, V are given tensors and scale is a scalar value specifying
the scale to use.

`mask` is an optional boolean tensor that specifies which relations
in attn_weight that should be ignored. This is useful for
causal attention, padded attention, and some special SD use cases.

For self-attention, all inputs and the result have the same shape BxNxd
where B is the batch dimension, N is the sequence length and d is head
dimension. Typically N >>> d. Usually, this operator also performs
Expand All @@ -539,10 +543,10 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention",
FlashAttention 2 and optionally results in the current `max` and `sum`
statistics while processing the current tile.

If transpose_v is speciifed, the V tensor passed as input is assumed to
If transpose_v is specifed, the V tensor passed as input is assumed to
be transposed:

attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V.T
attention(Q, K, V, scale, mask) = softmax(mask(Q @ K.T * scale)) @ V.T

TODO: We should be moving to using a indexing map like approach so we
can generalize which tensor is transposed and which is not.
Expand Down Expand Up @@ -579,6 +583,11 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention",
Value getScale() {
return getDpsInputOperand(3)->get();
}
std::optional<Value> getMask() {
if (getNumDpsInputs() < 5)
return std::nullopt;
return getDpsInputOperand(4)->get();
}
Value getOutput() {
return getDpsInitOperand(0)->get();
}
Expand Down Expand Up @@ -659,18 +668,26 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention",
AffineMap getValueMap() {
return cast<AffineMap>(getIndexingMapsArray()[2]);
}
AffineMap getOutputMap() {
std::optional<AffineMap> getMaskMap() {
if (getNumDpsInputs() < 5)
return std::nullopt;
return cast<AffineMap>(getIndexingMapsArray()[3]);
}
AffineMap getOutputMap() {
int64_t outputIndex = getNumDpsInputs() - 1;
return cast<AffineMap>(getIndexingMapsArray()[outputIndex]);
}
std::optional<AffineMap> getMaxMap() {
if (getNumResults() < 2)
return std::nullopt;
return cast<AffineMap>(getIndexingMapsArray()[4]);
int64_t maxIndex = getNumDpsInputs();
return cast<AffineMap>(getIndexingMapsArray()[maxIndex]);
}
std::optional<AffineMap> getSumMap() {
if (getNumResults() < 3)
return std::nullopt;
return cast<AffineMap>(getIndexingMapsArray()[5]);
int64_t sumIndex = getNumDpsInputs() + 1;
return cast<AffineMap>(getIndexingMapsArray()[sumIndex]);
}

int64_t getIterationDomainRank() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,49 @@ static Value truncateToF16(Value input, Value output,
return genericOp.getResult(0);
}

static Value applyMasking(Value qkSlice, Value mask, OpBuilder &builder) {
ShapedType qkType = cast<ShapedType>(qkSlice.getType());
Location loc = qkSlice.getLoc();

// Create a fill op for scale.
SmallVector<OpFoldResult> qkDims =
tensor::getMixedSizes(builder, loc, qkSlice);

// Attention_mask is 1.0 for positions we want to attend and 0.0 for
// masked positions. this operation will create a tensor which is 0.0 for
// positions we want to attend and -10000.0 for masked positions
Value c0 = builder.create<arith::ConstantOp>(
loc, builder.getZeroAttr(qkType.getElementType()));

Value cLargeNeg = builder.create<arith::ConstantOp>(
loc, builder.getFloatAttr(qkType.getElementType(), -1e6));

Value empty =
builder.create<tensor::EmptyOp>(loc, qkDims, qkType.getElementType());
// Create a generic op to multiply the query by the scale.
SmallVector<utils::IteratorType> iteratorTypes(2,
utils::IteratorType::parallel);
auto identityMap = AffineMap::getMultiDimIdentityMap(2, builder.getContext());
SmallVector<AffineMap> indexingMaps(3, identityMap);
auto applyMaskOp = builder.create<linalg::GenericOp>(
loc, TypeRange{empty.getType()}, ValueRange{qkSlice, mask},
ValueRange{empty}, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
// TODO: Add support to enable i1 throughout.
// currently type propagation force it to i1.
Value trunci =
b.create<arith::TruncIOp>(loc, builder.getI1Type(), args[1]);
Value masking = b.create<arith::SelectOp>(loc, trunci, c0, cLargeNeg);
Value result = b.create<arith::AddFOp>(loc, args[0], masking);
b.create<linalg::YieldOp>(loc, result);
});
return applyMaskOp.getResult(0);
}

static std::tuple<Value, Value, Value>
createAttentionBody(Value keySlice, Value valueSlice, Value querySlice,
Value outputSlice, Value maxSlice, Value sumSlice,
std::optional<Value> maskSlice, Value outputSlice,
Value maxSlice, Value sumSlice,
OpFoldResult sequenceTileLength,
OpFoldResult keyValueTileLength, OpFoldResult headDimension,
Type elementType, SmallVectorImpl<Operation *> &ops,
Expand All @@ -195,6 +235,11 @@ createAttentionBody(Value keySlice, Value valueSlice, Value querySlice,
Value qkTranspose = computeQKTranspose(querySlice, keySlice, emptySquare,
zero, loc, builder, ops);

// Apply masking if mask is specified.
if (maskSlice.has_value()) {
qkTranspose = applyMasking(qkTranspose, maskSlice.value(), builder);
}

// Compute current statistics
Value newMax = computeRowwiseReduction<arith::MaximumFOp>(
qkTranspose, maxSlice, loc, builder, ops);
Expand Down Expand Up @@ -327,9 +372,9 @@ void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
// iteration of the loop.
querySlice = scaleQuery(querySlice, scale, rewriter);
ops.push_back(querySlice.getDefiningOp());

std::optional<Value> maybeMask = tiledAttnOp.getMask();
auto [result, newMax, newSum] = createAttentionBody(
keySlice, valueSlice, querySlice, tiledResult, max, sum,
keySlice, valueSlice, querySlice, maybeMask, tiledResult, max, sum,
sequenceTileLength, keyValueTileLength, headDimension, elementType, ops,
tiledAttnOp.getTransposeV(), loc, rewriter);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ static scf::LoopNest createLoopNest(SmallVectorImpl<Value> &ivs, Value lb,
}

static Value extractSlice(Value key, ArrayRef<int64_t> keyShape,
ArrayRef<Value> ivs, OpFoldResult keyValueTileLength,
ArrayRef<std::optional<Value>> ivs,
OpFoldResult keyValueTileLength,
OpFoldResult headDimension, Type elementType,
Location loc, OpBuilder &builder,
bool swapLastTwoDims = false) {
Expand All @@ -90,8 +91,10 @@ static Value extractSlice(Value key, ArrayRef<int64_t> keyShape,
SmallVector<OpFoldResult> offsets(keyShape.size(), zero);
sizes[1] = keyValueTileLength;
sizes[2] = headDimension;
if (!ivs.empty()) {
offsets[1] = ivs[0];
for (auto [idx, iv] : llvm::enumerate(ivs)) {
if (!iv.has_value())
continue;
offsets[idx] = iv.value();
}
SmallVector<int64_t> tensorShape{keyShape[1], keyShape[2]};
if (swapLastTwoDims) {
Expand Down Expand Up @@ -181,6 +184,7 @@ IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp,
Location loc = attnOp.getLoc();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(attnOp);
llvm::outs() << "high level implement tile Attention!\n";

Value query = attnOp.getQuery();
ShapedType queryType = attnOp.getQueryType();
Expand Down Expand Up @@ -253,21 +257,35 @@ IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp,
rewriter.setInsertionPointToStart(loopNest.loops.back().getBody());

// Extract slices
Value keySlice = extractSlice(key, keyShape, ivs, keyValueTileLength,
SmallVector<std::optional<Value>> kvIvs(keyShape.size(), std::nullopt);
kvIvs[1] = ivs[0];
Value keySlice = extractSlice(key, keyShape, kvIvs, keyValueTileLength,
headDimension, elementType, loc, rewriter);
Value valueSlice =
extractSlice(value, keyShape, ivs, keyValueTileLength, headDimension,
extractSlice(value, keyShape, kvIvs, keyValueTileLength, headDimension,
elementType, loc, rewriter, attnOp.getTransposeV());
Value querySlice = extractSlice(query, queryShape, {}, sequenceTileLength,
headDimension, elementType, loc, rewriter);

Value scale = attnOp.getScale();
SmallVector<Value> tiledInputs = {querySlice, keySlice, valueSlice, scale};

if (attnOp.getMask().has_value()) {
Value mask = attnOp.getMask().value();
auto maskElType = llvm::cast<ShapedType>(mask.getType()).getElementType();
SmallVector<std::optional<Value>> maskIvs(keyShape.size(), std::nullopt);
maskIvs[2] = ivs[0];
SmallVector<int64_t> maskShape{queryShape[0], queryShape[1], keyShape[1]};
Value maskSlice =
extractSlice(mask, maskShape, maskIvs, sequenceTileLength,
keyValueTileLength, maskElType, loc, rewriter);
tiledInputs.push_back(maskSlice);
}

auto tiledAttentionOp = rewriter.create<IREE::LinalgExt::AttentionOp>(
attnOp.getLoc(),
SmallVector<Type>{accumulatorF32.getType(), sum.getType(), max.getType()},
SmallVector<Value>{querySlice, keySlice, valueSlice, scale},
SmallVector<Value>{iterArgResult, iterArgMax, iterArgSum});
tiledInputs, SmallVector<Value>{iterArgResult, iterArgMax, iterArgSum});

if (attnOp.getTransposeV())
tiledAttentionOp.setTransposeVAttr(attnOp.getTransposeVAttr());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1716,6 +1716,14 @@ AttentionOp::getTiledImplementation(OpBuilder &builder,
tiledOperands.emplace_back(getSlice(builder, loc, getKey(), keySlice));
tiledOperands.emplace_back(getSlice(builder, loc, getValue(), valueSlice));
tiledOperands.emplace_back(scale);

std::optional<Value> mask = getMask();
if (mask) {
SmallVector<Range> maskSlice =
getPermutedSlice(*getMaskMap(), offsets, sizes);
tiledOperands.emplace_back(getSlice(builder, loc, mask.value(), maskSlice));
}

tiledOperands.emplace_back(getSlice(builder, loc, getOutput(), outputSlice));

std::optional<Value> max = getMax();
Expand All @@ -1733,13 +1741,14 @@ AttentionOp::getTiledImplementation(OpBuilder &builder,
}

SmallVector<Type> resultTypes;
int64_t resultIndex = getNumDpsInputs();
if (hasPureTensorSemantics()) {
resultTypes.push_back(tiledOperands[4].getType());
resultTypes.push_back(tiledOperands[resultIndex].getType());
if (max) {
resultTypes.push_back(tiledOperands[5].getType());
resultTypes.push_back(tiledOperands[resultIndex + 1].getType());
}
if (sum) {
resultTypes.push_back(tiledOperands[6].getType());
resultTypes.push_back(tiledOperands[resultIndex + 2].getType());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ void AttentionOpDetail::inferFromIndexingMaps(
AffineMap qMap = indexingMaps[0];
AffineMap kMap = indexingMaps[1];
AffineMap vMap = indexingMaps[2];
AffineMap resMap = indexingMaps[3];

// Q = B x M x K1
// K = B x K2 x K1
Expand All @@ -40,7 +39,6 @@ void AttentionOpDetail::inferFromIndexingMaps(
llvm::SmallDenseSet<int64_t> qSet = findPermutationsIndexingOperand(qMap);
llvm::SmallDenseSet<int64_t> vSet = findPermutationsIndexingOperand(vMap);
llvm::SmallDenseSet<int64_t> kSet = findPermutationsIndexingOperand(kMap);
llvm::SmallDenseSet<int64_t> resSet = findPermutationsIndexingOperand(resMap);

// B = Q & K & V
llvm::SmallDenseSet<int64_t> bSet = qSet;
Expand Down Expand Up @@ -76,7 +74,7 @@ void AttentionOpDetail::inferFromIndexingMaps(

FailureOr<AttentionOpDetail>
AttentionOpDetail::get(ArrayRef<AffineMap> indexingMaps) {
if (indexingMaps.size() != 4 && indexingMaps.size() != 6) {
if (indexingMaps.size() < 4 || indexingMaps.size() > 7) {
return failure();
}

Expand Down

0 comments on commit f38075f

Please sign in to comment.