@@ -118,15 +118,6 @@ padAttention(IREE::LinalgExt::AttentionOp attnOp,
118
118
int64_t k2Idx = opInfo.getK2Dims ().back ();
119
119
int64_t nIdx = opInfo.getNDims ().back ();
120
120
121
- // Padding in K2 dimension requires to fill those K2 dimensions as -Inf during
122
- // softmax(Q.KT), preemptively padding it with -Inf may cause NaNs during
123
- // matmul of Q.KT.
124
- if (padToMultipleOf[k2Idx] > 1 ) {
125
- return definiteFailureHelper (transformOp, attnOp,
126
- " Padding in K2-dim is currently unsupported "
127
- " until attn_mask lowering is implemented." );
128
- }
129
-
130
121
SmallVector<OpFoldResult> padValues (domainRank, rewriter.getIndexAttr (0 ));
131
122
for (auto [idx, bound] : enumerate(bounds)) {
132
123
if (padToMultipleOf[idx] != 0 ) {
@@ -152,7 +143,16 @@ padAttention(IREE::LinalgExt::AttentionOp attnOp,
152
143
{padValues[batchIdx], padValues[mIdx ], padValues[k1Idx]});
153
144
}
154
145
155
- // Pad K-tensor if any non-K1 dims needs padding.
146
+ // Pad K2-dim of K-tensor by a large negative S.T when used by softmax it will
147
+ // generate the correct numerics.
148
+ if (!isConstantIntValue (padValues[k2Idx], 0 )) {
149
+ Type keyElType = attnOp.getKeyType ().getElementType ();
150
+ auto largeNeg = rewriter.getFloatAttr (keyElType, 0.0 );
151
+ paddedKey = getPaddedValue (rewriter, loc, paddedKey,
152
+ {zero, padValues[k2Idx], zero}, largeNeg);
153
+ }
154
+
155
+ // Pad K-tensor if any non-K2 dims needs padding.
156
156
if (!isConstantIntValue (padValues[batchIdx], 0 ) ||
157
157
!isConstantIntValue (padValues[k1Idx], 0 )) {
158
158
paddedKey = getPaddedValue (rewriter, loc, paddedKey,
@@ -161,9 +161,11 @@ padAttention(IREE::LinalgExt::AttentionOp attnOp,
161
161
162
162
// Pad V-tensor if any of its' dims needs padding.
163
163
if (!isConstantIntValue (padValues[batchIdx], 0 ) ||
164
+ !isConstantIntValue (padValues[k2Idx], 0 ) ||
164
165
!isConstantIntValue (padValues[nIdx], 0 )) {
165
- paddedValue = getPaddedValue (rewriter, loc, paddedValue,
166
- {padValues[batchIdx], zero, padValues[nIdx]});
166
+ paddedValue = getPaddedValue (
167
+ rewriter, loc, paddedValue,
168
+ {padValues[batchIdx], padValues[k2Idx], padValues[nIdx]});
167
169
}
168
170
169
171
// Pad Acc-tensor if any of its' dims needs padding.
@@ -186,11 +188,35 @@ padAttention(IREE::LinalgExt::AttentionOp attnOp,
186
188
}
187
189
}
188
190
191
+ std::optional<Value> mask = attnOp.getMask ();
192
+ if (!isConstantIntValue (padValues[k2Idx], 0 )) {
193
+ if (!mask.has_value ()) {
194
+ SmallVector<OpFoldResult> mixedMaskShape = {
195
+ bounds[batchIdx].size , bounds[mIdx ].size , bounds[k2Idx].size };
196
+ SmallVector<int64_t > staticMaskShape =
197
+ llvm::map_to_vector (mixedMaskShape, [](OpFoldResult dim) {
198
+ std::optional<int64_t > dimVal = getConstantIntValue (dim);
199
+ return dimVal.value_or (ShapedType::kDynamic );
200
+ });
201
+ auto maskElType = rewriter.getI1Type ();
202
+ auto maskType = RankedTensorType::get (staticMaskShape, maskElType);
203
+ auto oneBoolAttr = rewriter.getOneAttr (maskElType);
204
+ mask = rewriter.createOrFold <arith::ConstantOp>(
205
+ loc, SplatElementsAttr::get (llvm::cast<ShapedType>(maskType),
206
+ oneBoolAttr));
207
+ }
208
+ mask = getPaddedValue (rewriter, loc, mask.value (),
209
+ {zero, zero, padValues[k2Idx]});
210
+ }
211
+
212
+ SmallVector<Value> paddedInputs = {paddedQuery, paddedKey, paddedValue,
213
+ scale};
214
+ if (mask.has_value ()) {
215
+ paddedInputs.push_back (mask.value ());
216
+ }
189
217
// Generate padded attention op.
190
218
auto paddedAttnOp = rewriter.create <IREE::LinalgExt::AttentionOp>(
191
- loc, paddedAcc.getType (),
192
- SmallVector<Value>{paddedQuery, paddedKey, paddedValue, scale},
193
- paddedAcc);
219
+ loc, paddedAcc.getType (), paddedInputs, paddedAcc);
194
220
195
221
ops.push_back (paddedAttnOp);
196
222
0 commit comments