Skip to content

Commit d39479a

Browse files
raikonenfnumonorimet
authored andcommitted
Add masking and re-enable K2 padding
Signed-off-by: stanley-nod <[email protected]>
1 parent f38075f commit d39479a

File tree

2 files changed

+41
-16
lines changed

2 files changed

+41
-16
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Padding.cpp

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,6 @@ padAttention(IREE::LinalgExt::AttentionOp attnOp,
118118
int64_t k2Idx = opInfo.getK2Dims().back();
119119
int64_t nIdx = opInfo.getNDims().back();
120120

121-
// Padding in K2 dimension requires to fill those K2 dimensions as -Inf during
122-
// softmax(Q.KT), preemptively padding it with -Inf may cause NaNs during
123-
// matmul of Q.KT.
124-
if (padToMultipleOf[k2Idx] > 1) {
125-
return definiteFailureHelper(transformOp, attnOp,
126-
"Padding in K2-dim is currently unsupported "
127-
"until attn_mask lowering is implemented.");
128-
}
129-
130121
SmallVector<OpFoldResult> padValues(domainRank, rewriter.getIndexAttr(0));
131122
for (auto [idx, bound] : enumerate(bounds)) {
132123
if (padToMultipleOf[idx] != 0) {
@@ -152,7 +143,16 @@ padAttention(IREE::LinalgExt::AttentionOp attnOp,
152143
{padValues[batchIdx], padValues[mIdx], padValues[k1Idx]});
153144
}
154145

155-
// Pad K-tensor if any non-K1 dims needs padding.
146+
// Pad K2-dim of K-tensor by a large negative S.T when used by softmax it will
147+
// generate the correct numerics.
148+
if (!isConstantIntValue(padValues[k2Idx], 0)) {
149+
Type keyElType = attnOp.getKeyType().getElementType();
150+
auto largeNeg = rewriter.getFloatAttr(keyElType, 0.0);
151+
paddedKey = getPaddedValue(rewriter, loc, paddedKey,
152+
{zero, padValues[k2Idx], zero}, largeNeg);
153+
}
154+
155+
// Pad K-tensor if any non-K2 dims needs padding.
156156
if (!isConstantIntValue(padValues[batchIdx], 0) ||
157157
!isConstantIntValue(padValues[k1Idx], 0)) {
158158
paddedKey = getPaddedValue(rewriter, loc, paddedKey,
@@ -161,9 +161,11 @@ padAttention(IREE::LinalgExt::AttentionOp attnOp,
161161

162162
// Pad V-tensor if any of its' dims needs padding.
163163
if (!isConstantIntValue(padValues[batchIdx], 0) ||
164+
!isConstantIntValue(padValues[k2Idx], 0) ||
164165
!isConstantIntValue(padValues[nIdx], 0)) {
165-
paddedValue = getPaddedValue(rewriter, loc, paddedValue,
166-
{padValues[batchIdx], zero, padValues[nIdx]});
166+
paddedValue = getPaddedValue(
167+
rewriter, loc, paddedValue,
168+
{padValues[batchIdx], padValues[k2Idx], padValues[nIdx]});
167169
}
168170

169171
// Pad Acc-tensor if any of its' dims needs padding.
@@ -186,11 +188,35 @@ padAttention(IREE::LinalgExt::AttentionOp attnOp,
186188
}
187189
}
188190

191+
std::optional<Value> mask = attnOp.getMask();
192+
if (!isConstantIntValue(padValues[k2Idx], 0)) {
193+
if (!mask.has_value()) {
194+
SmallVector<OpFoldResult> mixedMaskShape = {
195+
bounds[batchIdx].size, bounds[mIdx].size, bounds[k2Idx].size};
196+
SmallVector<int64_t> staticMaskShape =
197+
llvm::map_to_vector(mixedMaskShape, [](OpFoldResult dim) {
198+
std::optional<int64_t> dimVal = getConstantIntValue(dim);
199+
return dimVal.value_or(ShapedType::kDynamic);
200+
});
201+
auto maskElType = rewriter.getI1Type();
202+
auto maskType = RankedTensorType::get(staticMaskShape, maskElType);
203+
auto oneBoolAttr = rewriter.getOneAttr(maskElType);
204+
mask = rewriter.createOrFold<arith::ConstantOp>(
205+
loc, SplatElementsAttr::get(llvm::cast<ShapedType>(maskType),
206+
oneBoolAttr));
207+
}
208+
mask = getPaddedValue(rewriter, loc, mask.value(),
209+
{zero, zero, padValues[k2Idx]});
210+
}
211+
212+
SmallVector<Value> paddedInputs = {paddedQuery, paddedKey, paddedValue,
213+
scale};
214+
if (mask.has_value()) {
215+
paddedInputs.push_back(mask.value());
216+
}
189217
// Generate padded attention op.
190218
auto paddedAttnOp = rewriter.create<IREE::LinalgExt::AttentionOp>(
191-
loc, paddedAcc.getType(),
192-
SmallVector<Value>{paddedQuery, paddedKey, paddedValue, scale},
193-
paddedAcc);
219+
loc, paddedAcc.getType(), paddedInputs, paddedAcc);
194220

195221
ops.push_back(paddedAttnOp);
196222

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp,
184184
Location loc = attnOp.getLoc();
185185
OpBuilder::InsertionGuard guard(rewriter);
186186
rewriter.setInsertionPoint(attnOp);
187-
llvm::outs() << "high level implement tile Attention!\n";
188187

189188
Value query = attnOp.getQuery();
190189
ShapedType queryType = attnOp.getQueryType();

0 commit comments

Comments
 (0)