Skip to content

Commit 7523e8c

Browse files
raikonenfnumonorimet
authored andcommitted
Refactor to use DiagnosedSilenceableFailure
Signed-off-by: stanley-nod <[email protected]>
1 parent 3a218bb commit 7523e8c

File tree

3 files changed

+40
-22
lines changed

3 files changed

+40
-22
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,11 @@ LinalgExt::PadAttentionOp::applyToOne(transform::TransformRewriter &rewriter,
7272
extractFromIntegerArrayAttr<int64_t>(getPadToMultipleOf());
7373

7474
SmallVector<Operation *> ops;
75-
if (failed(LinalgExt::padAttention(attentionOp, ops, rewriter,
76-
padToMultipleOf))) {
77-
return emitSilenceableFailure(this->getOperation(),
78-
"Failed to pad attentionOp.");
75+
auto transformOp = cast<transform::TransformOpInterface>(getOperation());
76+
auto result = LinalgExt::padAttention(attentionOp, ops, rewriter, transformOp,
77+
padToMultipleOf);
78+
if (!result.succeeded()) {
79+
return result;
7980
}
8081
for (auto op : ops) {
8182
results.push_back(op);

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Padding.cpp

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,39 +62,52 @@ struct PadAttentionPass : public PadAttentionBase<PadAttentionPass> {
6262
void runOnOperation() override;
6363
};
6464

65+
static DiagnosedSilenceableFailure definiteFailureHelper(
66+
std::optional<transform::TransformOpInterface> transformOp,
67+
Operation *target, const Twine &message) {
68+
if (transformOp.has_value())
69+
return transformOp->emitDefiniteFailure() << message;
70+
return emitDefiniteFailure(target, message);
71+
}
72+
6573
} // namespace
6674

6775
/// Pads iree_linalg_ext.attention.
68-
LogicalResult padAttention(IREE::LinalgExt::AttentionOp attnOp,
69-
SmallVectorImpl<Operation *> &ops,
70-
RewriterBase &rewriter,
71-
ArrayRef<int64_t> padToMultipleOf) {
76+
DiagnosedSilenceableFailure
77+
padAttention(IREE::LinalgExt::AttentionOp attnOp,
78+
SmallVectorImpl<Operation *> &ops, RewriterBase &rewriter,
79+
std::optional<transform::TransformOpInterface> transformOp,
80+
ArrayRef<int64_t> padToMultipleOf) {
7281
SmallVector<AffineMap> maps = attnOp.getIndexingMapsArray();
7382
FailureOr<IREE::LinalgExt::AttentionOpDetail> maybeOpInfo =
7483
IREE::LinalgExt::AttentionOpDetail::get(maps);
7584
if (failed(maybeOpInfo)) {
7685
// failed to infer attention dims
77-
return failure();
86+
return definiteFailureHelper(transformOp, attnOp,
87+
"Failed to infer attention dims.");
7888
}
7989
auto opInfo = maybeOpInfo.value();
8090
Location loc = attnOp.getLoc();
8191
rewriter.setInsertionPoint(attnOp);
8292

8393
int64_t domainRank = maps[0].getNumDims();
8494
if (domainRank != 5) {
85-
// Currently only support base-case of attention dims.
86-
return failure();
95+
return definiteFailureHelper(
96+
transformOp, attnOp,
97+
"Currently only support base-case of attention dims.");
8798
}
8899
if (padToMultipleOf.size() != domainRank) {
89-
// Expects pad_to_multiple to have same rank as dimensions of attention.
90-
return failure();
100+
return definiteFailureHelper(transformOp, attnOp,
101+
"Expects pad_to_multiple to have same rank as "
102+
"dimensions of attention.");
91103
}
92104

93105
bool hasValidPadding = llvm::none_of(
94106
padToMultipleOf, [](int64_t padMultiple) { return padMultiple < 0; });
95107
if (!hasValidPadding) {
96-
// pad-multiple-of cannot be a negative value.
97-
return failure();
108+
return definiteFailureHelper(transformOp, attnOp,
109+
"Expects pad_to_multiple to have same rank as "
110+
"dimensions of attention.");
98111
}
99112

100113
SmallVector<Range> bounds = attnOp.getIterationDomain(rewriter);
@@ -109,7 +122,9 @@ LogicalResult padAttention(IREE::LinalgExt::AttentionOp attnOp,
109122
// softmax(Q.KT), preemptively padding it with -Inf may cause NaNs during
110123
// matmul of Q.KT.
111124
if (padToMultipleOf[k2Idx] > 1) {
112-
return failure();
125+
return definiteFailureHelper(transformOp, attnOp,
126+
"Padding in K2-dim is currently unsupported "
127+
"until attn_mask lowering is implemented.");
113128
}
114129

115130
SmallVector<OpFoldResult> padValues(domainRank, rewriter.getIndexAttr(0));
@@ -191,15 +206,15 @@ LogicalResult padAttention(IREE::LinalgExt::AttentionOp attnOp,
191206

192207
rewriter.replaceOp(attnOp, extracted);
193208

194-
return success();
209+
return DiagnosedSilenceableFailure::success();
195210
}
196211

197212
void PadAttentionPass::runOnOperation() {
198213
MLIRContext *context = &getContext();
199214
IRRewriter rewriter(context);
200215
getOperation().walk([&](AttentionOp attnOp) {
201216
SmallVector<Operation *> ops;
202-
(void)padAttention(attnOp, ops, rewriter, padToMultipleOf);
217+
(void)padAttention(attnOp, ops, rewriter, std::nullopt, padToMultipleOf);
203218
});
204219
}
205220

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
1212
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1313
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14+
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
1415
#include "mlir/Interfaces/FunctionInterfaces.h"
1516
#include "mlir/Pass/Pass.h"
1617

@@ -51,10 +52,11 @@ tileAttention(IREE::LinalgExt::AttentionOp attnOp,
5152
SmallVectorImpl<Operation *> &ops, RewriterBase &rewriter,
5253
std::optional<uint64_t> tileSize = std::nullopt);
5354

54-
LogicalResult padAttention(IREE::LinalgExt::AttentionOp attnOp,
55-
SmallVectorImpl<Operation *> &ops,
56-
RewriterBase &rewriter,
57-
ArrayRef<int64_t> padToMultipleOf);
55+
DiagnosedSilenceableFailure
56+
padAttention(IREE::LinalgExt::AttentionOp attnOp,
57+
SmallVectorImpl<Operation *> &ops, RewriterBase &rewriter,
58+
std::optional<transform::TransformOpInterface> transformOp,
59+
ArrayRef<int64_t> padToMultipleOf);
5860

5961
void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
6062
SmallVectorImpl<Operation *> &ops,

0 commit comments

Comments
 (0)