@@ -62,39 +62,52 @@ struct PadAttentionPass : public PadAttentionBase<PadAttentionPass> {
62
62
void runOnOperation () override ;
63
63
};
64
64
65
+ static DiagnosedSilenceableFailure definiteFailureHelper (
66
+ std::optional<transform::TransformOpInterface> transformOp,
67
+ Operation *target, const Twine &message) {
68
+ if (transformOp.has_value ())
69
+ return transformOp->emitDefiniteFailure () << message;
70
+ return emitDefiniteFailure (target, message);
71
+ }
72
+
65
73
} // namespace
66
74
67
75
// / Pads iree_linalg_ext.attention.
68
- LogicalResult padAttention (IREE::LinalgExt::AttentionOp attnOp,
69
- SmallVectorImpl<Operation *> &ops,
70
- RewriterBase &rewriter,
71
- ArrayRef<int64_t > padToMultipleOf) {
76
+ DiagnosedSilenceableFailure
77
+ padAttention (IREE::LinalgExt::AttentionOp attnOp,
78
+ SmallVectorImpl<Operation *> &ops, RewriterBase &rewriter,
79
+ std::optional<transform::TransformOpInterface> transformOp,
80
+ ArrayRef<int64_t > padToMultipleOf) {
72
81
SmallVector<AffineMap> maps = attnOp.getIndexingMapsArray ();
73
82
FailureOr<IREE::LinalgExt::AttentionOpDetail> maybeOpInfo =
74
83
IREE::LinalgExt::AttentionOpDetail::get (maps);
75
84
if (failed (maybeOpInfo)) {
76
85
// failed to infer attention dims
77
- return failure ();
86
+ return definiteFailureHelper (transformOp, attnOp,
87
+ " Failed to infer attention dims." );
78
88
}
79
89
auto opInfo = maybeOpInfo.value ();
80
90
Location loc = attnOp.getLoc ();
81
91
rewriter.setInsertionPoint (attnOp);
82
92
83
93
int64_t domainRank = maps[0 ].getNumDims ();
84
94
if (domainRank != 5 ) {
85
- // Currently only support base-case of attention dims.
86
- return failure ();
95
+ return definiteFailureHelper (
96
+ transformOp, attnOp,
97
+ " Currently only support base-case of attention dims." );
87
98
}
88
99
if (padToMultipleOf.size () != domainRank) {
89
- // Expects pad_to_multiple to have same rank as dimensions of attention.
90
- return failure ();
100
+ return definiteFailureHelper (transformOp, attnOp,
101
+ " Expects pad_to_multiple to have same rank as "
102
+ " dimensions of attention." );
91
103
}
92
104
93
105
bool hasValidPadding = llvm::none_of (
94
106
padToMultipleOf, [](int64_t padMultiple) { return padMultiple < 0 ; });
95
107
if (!hasValidPadding) {
96
- // pad-multiple-of cannot be a negative value.
97
- return failure ();
108
+ return definiteFailureHelper (transformOp, attnOp,
109
+ " Expects pad_to_multiple to have same rank as "
110
+ " dimensions of attention." );
98
111
}
99
112
100
113
SmallVector<Range> bounds = attnOp.getIterationDomain (rewriter);
@@ -109,7 +122,9 @@ LogicalResult padAttention(IREE::LinalgExt::AttentionOp attnOp,
109
122
// softmax(Q.KT), preemptively padding it with -Inf may cause NaNs during
110
123
// matmul of Q.KT.
111
124
if (padToMultipleOf[k2Idx] > 1 ) {
112
- return failure ();
125
+ return definiteFailureHelper (transformOp, attnOp,
126
+ " Padding in K2-dim is currently unsupported "
127
+ " until attn_mask lowering is implemented." );
113
128
}
114
129
115
130
SmallVector<OpFoldResult> padValues (domainRank, rewriter.getIndexAttr (0 ));
@@ -191,15 +206,15 @@ LogicalResult padAttention(IREE::LinalgExt::AttentionOp attnOp,
191
206
192
207
rewriter.replaceOp (attnOp, extracted);
193
208
194
- return success ();
209
+ return DiagnosedSilenceableFailure:: success ();
195
210
}
196
211
197
212
void PadAttentionPass::runOnOperation () {
198
213
MLIRContext *context = &getContext ();
199
214
IRRewriter rewriter (context);
200
215
getOperation ().walk ([&](AttentionOp attnOp) {
201
216
SmallVector<Operation *> ops;
202
- (void )padAttention (attnOp, ops, rewriter, padToMultipleOf);
217
+ (void )padAttention (attnOp, ops, rewriter, std::nullopt, padToMultipleOf);
203
218
});
204
219
}
205
220
0 commit comments