Skip to content

Commit

Permalink
Add masking and re-enable K2 padding
Browse files Browse the repository at this point in the history
Signed-off-by: stanley-nod <[email protected]>
  • Loading branch information
raikonenfnu authored and monorimet committed Jun 19, 2024
1 parent f38075f commit d39479a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 16 deletions.
56 changes: 41 additions & 15 deletions compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Padding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,6 @@ padAttention(IREE::LinalgExt::AttentionOp attnOp,
int64_t k2Idx = opInfo.getK2Dims().back();
int64_t nIdx = opInfo.getNDims().back();

// Padding in K2 dimension requires to fill those K2 dimensions as -Inf during
// softmax(Q.KT), preemptively padding it with -Inf may cause NaNs during
// matmul of Q.KT.
if (padToMultipleOf[k2Idx] > 1) {
return definiteFailureHelper(transformOp, attnOp,
"Padding in K2-dim is currently unsupported "
"until attn_mask lowering is implemented.");
}

SmallVector<OpFoldResult> padValues(domainRank, rewriter.getIndexAttr(0));
for (auto [idx, bound] : enumerate(bounds)) {
if (padToMultipleOf[idx] != 0) {
Expand All @@ -152,7 +143,16 @@ padAttention(IREE::LinalgExt::AttentionOp attnOp,
{padValues[batchIdx], padValues[mIdx], padValues[k1Idx]});
}

// Pad K-tensor if any non-K1 dims needs padding.
// Pad K2-dim of K-tensor by a large negative S.T when used by softmax it will
// generate the correct numerics.
if (!isConstantIntValue(padValues[k2Idx], 0)) {
Type keyElType = attnOp.getKeyType().getElementType();
auto largeNeg = rewriter.getFloatAttr(keyElType, 0.0);
paddedKey = getPaddedValue(rewriter, loc, paddedKey,
{zero, padValues[k2Idx], zero}, largeNeg);
}

// Pad K-tensor if any non-K2 dims needs padding.
if (!isConstantIntValue(padValues[batchIdx], 0) ||
!isConstantIntValue(padValues[k1Idx], 0)) {
paddedKey = getPaddedValue(rewriter, loc, paddedKey,
Expand All @@ -161,9 +161,11 @@ padAttention(IREE::LinalgExt::AttentionOp attnOp,

// Pad V-tensor if any of its' dims needs padding.
if (!isConstantIntValue(padValues[batchIdx], 0) ||
!isConstantIntValue(padValues[k2Idx], 0) ||
!isConstantIntValue(padValues[nIdx], 0)) {
paddedValue = getPaddedValue(rewriter, loc, paddedValue,
{padValues[batchIdx], zero, padValues[nIdx]});
paddedValue = getPaddedValue(
rewriter, loc, paddedValue,
{padValues[batchIdx], padValues[k2Idx], padValues[nIdx]});
}

// Pad Acc-tensor if any of its' dims needs padding.
Expand All @@ -186,11 +188,35 @@ padAttention(IREE::LinalgExt::AttentionOp attnOp,
}
}

std::optional<Value> mask = attnOp.getMask();
if (!isConstantIntValue(padValues[k2Idx], 0)) {
if (!mask.has_value()) {
SmallVector<OpFoldResult> mixedMaskShape = {
bounds[batchIdx].size, bounds[mIdx].size, bounds[k2Idx].size};
SmallVector<int64_t> staticMaskShape =
llvm::map_to_vector(mixedMaskShape, [](OpFoldResult dim) {
std::optional<int64_t> dimVal = getConstantIntValue(dim);
return dimVal.value_or(ShapedType::kDynamic);
});
auto maskElType = rewriter.getI1Type();
auto maskType = RankedTensorType::get(staticMaskShape, maskElType);
auto oneBoolAttr = rewriter.getOneAttr(maskElType);
mask = rewriter.createOrFold<arith::ConstantOp>(
loc, SplatElementsAttr::get(llvm::cast<ShapedType>(maskType),
oneBoolAttr));
}
mask = getPaddedValue(rewriter, loc, mask.value(),
{zero, zero, padValues[k2Idx]});
}

SmallVector<Value> paddedInputs = {paddedQuery, paddedKey, paddedValue,
scale};
if (mask.has_value()) {
paddedInputs.push_back(mask.value());
}
// Generate padded attention op.
auto paddedAttnOp = rewriter.create<IREE::LinalgExt::AttentionOp>(
loc, paddedAcc.getType(),
SmallVector<Value>{paddedQuery, paddedKey, paddedValue, scale},
paddedAcc);
loc, paddedAcc.getType(), paddedInputs, paddedAcc);

ops.push_back(paddedAttnOp);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp,
Location loc = attnOp.getLoc();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(attnOp);
llvm::outs() << "high level implement tile Attention!\n";

Value query = attnOp.getQuery();
ShapedType queryType = attnOp.getQueryType();
Expand Down

0 comments on commit d39479a

Please sign in to comment.