diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Padding.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Padding.cpp index d034b04ac8b8..86e340e5a563 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Padding.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Padding.cpp @@ -118,15 +118,6 @@ padAttention(IREE::LinalgExt::AttentionOp attnOp, int64_t k2Idx = opInfo.getK2Dims().back(); int64_t nIdx = opInfo.getNDims().back(); - // Padding in K2 dimension requires to fill those K2 dimensions as -Inf during - // softmax(Q.KT), preemptively padding it with -Inf may cause NaNs during - // matmul of Q.KT. - if (padToMultipleOf[k2Idx] > 1) { - return definiteFailureHelper(transformOp, attnOp, - "Padding in K2-dim is currently unsupported " - "until attn_mask lowering is implemented."); - } - SmallVector padValues(domainRank, rewriter.getIndexAttr(0)); for (auto [idx, bound] : enumerate(bounds)) { if (padToMultipleOf[idx] != 0) { @@ -152,7 +143,16 @@ padAttention(IREE::LinalgExt::AttentionOp attnOp, {padValues[batchIdx], padValues[mIdx], padValues[k1Idx]}); } - // Pad K-tensor if any non-K1 dims needs padding. + // Pad K2-dim of K-tensor by a large negative S.T when used by softmax it will + // generate the correct numerics. + if (!isConstantIntValue(padValues[k2Idx], 0)) { + Type keyElType = attnOp.getKeyType().getElementType(); + auto largeNeg = rewriter.getFloatAttr(keyElType, 0.0); + paddedKey = getPaddedValue(rewriter, loc, paddedKey, + {zero, padValues[k2Idx], zero}, largeNeg); + } + + // Pad K-tensor if any non-K2 dims needs padding. if (!isConstantIntValue(padValues[batchIdx], 0) || !isConstantIntValue(padValues[k1Idx], 0)) { paddedKey = getPaddedValue(rewriter, loc, paddedKey, @@ -161,9 +161,11 @@ padAttention(IREE::LinalgExt::AttentionOp attnOp, // Pad V-tensor if any of its' dims needs padding. if (!isConstantIntValue(padValues[batchIdx], 0) || + !isConstantIntValue(padValues[k2Idx], 0) || !isConstantIntValue(padValues[nIdx], 0)) { - paddedValue = getPaddedValue(rewriter, loc, paddedValue, - {padValues[batchIdx], zero, padValues[nIdx]}); + paddedValue = getPaddedValue( + rewriter, loc, paddedValue, + {padValues[batchIdx], padValues[k2Idx], padValues[nIdx]}); } // Pad Acc-tensor if any of its' dims needs padding. @@ -186,11 +188,35 @@ padAttention(IREE::LinalgExt::AttentionOp attnOp, } } + std::optional mask = attnOp.getMask(); + if (!isConstantIntValue(padValues[k2Idx], 0)) { + if (!mask.has_value()) { + SmallVector mixedMaskShape = { + bounds[batchIdx].size, bounds[mIdx].size, bounds[k2Idx].size}; + SmallVector staticMaskShape = + llvm::map_to_vector(mixedMaskShape, [](OpFoldResult dim) { + std::optional dimVal = getConstantIntValue(dim); + return dimVal.value_or(ShapedType::kDynamic); + }); + auto maskElType = rewriter.getI1Type(); + auto maskType = RankedTensorType::get(staticMaskShape, maskElType); + auto oneBoolAttr = rewriter.getOneAttr(maskElType); + mask = rewriter.createOrFold( + loc, SplatElementsAttr::get(llvm::cast(maskType), + oneBoolAttr)); + } + mask = getPaddedValue(rewriter, loc, mask.value(), + {zero, zero, padValues[k2Idx]}); + } + + SmallVector paddedInputs = {paddedQuery, paddedKey, paddedValue, + scale}; + if (mask.has_value()) { + paddedInputs.push_back(mask.value()); + } // Generate padded attention op. auto paddedAttnOp = rewriter.create( - loc, paddedAcc.getType(), - SmallVector{paddedQuery, paddedKey, paddedValue, scale}, - paddedAcc); + loc, paddedAcc.getType(), paddedInputs, paddedAcc); ops.push_back(paddedAttnOp); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp index 1a9ff701b706..a06197e564a4 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp @@ -184,7 +184,6 @@ IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp, Location loc = attnOp.getLoc(); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(attnOp); - llvm::outs() << "high level implement tile Attention!\n"; Value query = attnOp.getQuery(); ShapedType queryType = attnOp.getQueryType();