@@ -118,15 +118,6 @@ padAttention(IREE::LinalgExt::AttentionOp attnOp,
118118 int64_t k2Idx = opInfo.getK2Dims ().back ();
119119 int64_t nIdx = opInfo.getNDims ().back ();
120120
121- // Padding in K2 dimension requires to fill those K2 dimensions as -Inf during
122- // softmax(Q.KT), preemptively padding it with -Inf may cause NaNs during
123- // matmul of Q.KT.
124- if (padToMultipleOf[k2Idx] > 1 ) {
125- return definiteFailureHelper (transformOp, attnOp,
126- " Padding in K2-dim is currently unsupported "
127- " until attn_mask lowering is implemented." );
128- }
129-
130121 SmallVector<OpFoldResult> padValues (domainRank, rewriter.getIndexAttr (0 ));
131122 for (auto [idx, bound] : enumerate(bounds)) {
132123 if (padToMultipleOf[idx] != 0 ) {
@@ -152,7 +143,16 @@ padAttention(IREE::LinalgExt::AttentionOp attnOp,
152143 {padValues[batchIdx], padValues[mIdx ], padValues[k1Idx]});
153144 }
154145
155- // Pad K-tensor if any non-K1 dims needs padding.
146+ // Pad K2-dim of K-tensor by a large negative S.T when used by softmax it will
147+ // generate the correct numerics.
148+ if (!isConstantIntValue (padValues[k2Idx], 0 )) {
149+ Type keyElType = attnOp.getKeyType ().getElementType ();
150+ auto largeNeg = rewriter.getFloatAttr (keyElType, 0.0 );
151+ paddedKey = getPaddedValue (rewriter, loc, paddedKey,
152+ {zero, padValues[k2Idx], zero}, largeNeg);
153+ }
154+
155+ // Pad K-tensor if any non-K2 dims needs padding.
156156 if (!isConstantIntValue (padValues[batchIdx], 0 ) ||
157157 !isConstantIntValue (padValues[k1Idx], 0 )) {
158158 paddedKey = getPaddedValue (rewriter, loc, paddedKey,
@@ -161,9 +161,11 @@ padAttention(IREE::LinalgExt::AttentionOp attnOp,
161161
162162 // Pad V-tensor if any of its' dims needs padding.
163163 if (!isConstantIntValue (padValues[batchIdx], 0 ) ||
164+ !isConstantIntValue (padValues[k2Idx], 0 ) ||
164165 !isConstantIntValue (padValues[nIdx], 0 )) {
165- paddedValue = getPaddedValue (rewriter, loc, paddedValue,
166- {padValues[batchIdx], zero, padValues[nIdx]});
166+ paddedValue = getPaddedValue (
167+ rewriter, loc, paddedValue,
168+ {padValues[batchIdx], padValues[k2Idx], padValues[nIdx]});
167169 }
168170
169171 // Pad Acc-tensor if any of its' dims needs padding.
@@ -186,11 +188,35 @@ padAttention(IREE::LinalgExt::AttentionOp attnOp,
186188 }
187189 }
188190
191+ std::optional<Value> mask = attnOp.getMask ();
192+ if (!isConstantIntValue (padValues[k2Idx], 0 )) {
193+ if (!mask.has_value ()) {
194+ SmallVector<OpFoldResult> mixedMaskShape = {
195+ bounds[batchIdx].size , bounds[mIdx ].size , bounds[k2Idx].size };
196+ SmallVector<int64_t > staticMaskShape =
197+ llvm::map_to_vector (mixedMaskShape, [](OpFoldResult dim) {
198+ std::optional<int64_t > dimVal = getConstantIntValue (dim);
199+ return dimVal.value_or (ShapedType::kDynamic );
200+ });
201+ auto maskElType = rewriter.getI1Type ();
202+ auto maskType = RankedTensorType::get (staticMaskShape, maskElType);
203+ auto oneBoolAttr = rewriter.getOneAttr (maskElType);
204+ mask = rewriter.createOrFold <arith::ConstantOp>(
205+ loc, SplatElementsAttr::get (llvm::cast<ShapedType>(maskType),
206+ oneBoolAttr));
207+ }
208+ mask = getPaddedValue (rewriter, loc, mask.value (),
209+ {zero, zero, padValues[k2Idx]});
210+ }
211+
212+ SmallVector<Value> paddedInputs = {paddedQuery, paddedKey, paddedValue,
213+ scale};
214+ if (mask.has_value ()) {
215+ paddedInputs.push_back (mask.value ());
216+ }
189217 // Generate padded attention op.
190218 auto paddedAttnOp = rewriter.create <IREE::LinalgExt::AttentionOp>(
191- loc, paddedAcc.getType (),
192- SmallVector<Value>{paddedQuery, paddedKey, paddedValue, scale},
193- paddedAcc);
219+ loc, paddedAcc.getType (), paddedInputs, paddedAcc);
194220
195221 ops.push_back (paddedAttnOp);
196222
0 commit comments