Skip to content

Commit 2ea0acc

Browse files
committed
Merge remote-tracking branch 'origin/feature/fused-ops' into bump_to_f4943464
2 parents a333ea1 + 7f76ec9 commit 2ea0acc

30 files changed

+521
-296
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H
1616

1717
#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
18+
#include "mlir/Dialect/Affine/IR/AffineTraits.h"
1819
#include "mlir/Dialect/Arith/IR/Arith.h"
1920
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2021
#include "mlir/IR/AffineMap.h"

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef AFFINE_OPS
1414
#define AFFINE_OPS
1515

16+
include "mlir/Dialect/Affine/IR/AffineTraits.td"
1617
include "mlir/Dialect/Arith/IR/ArithBase.td"
1718
include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td"
1819
include "mlir/Interfaces/ControlFlowInterfaces.td"
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- AffineTraits.h - MLIR Affine Traits --------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines traits brought in by the Affine dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
#ifndef AFFINE_TRAITS_H
13+
#define AFFINE_TRAITS_H
14+
15+
#include "mlir/IR/OpDefinition.h"
16+
17+
namespace mlir::OpTrait {
18+
19+
template <typename ConcreteType>
20+
class AffineDim : public TraitBase<ConcreteType, AffineDim> {
21+
public:
22+
static LogicalResult verifyTrait(Operation *op) { return success(); }
23+
};
24+
25+
} // namespace mlir::OpTrait
26+
27+
#endif // AFFINE_TRAITS_H
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//===- AffineTraits.td - Affine dialect traits -------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Defines traits brought in by the MLIR Affine dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
#ifndef AFFINE_TRAITS
13+
#define AFFINE_TRAITS
14+
15+
include "mlir/IR/OpBase.td"
16+
17+
// Trait to declare that an op result is an affine dimension identifier.
18+
// Prevents the result from being seen as a symbol into AffineMaps
19+
// and IntegerSets.
20+
// This is a deviation from upstream to consider linalg.index as
21+
// a dimension rather than a symbol. See this PR:
22+
// https://github.com/Xilinx/llvm-project/pull/537
23+
def AffineDim : NativeOpTrait<"AffineDim">;
24+
25+
#endif // AFFINE_TRAITS

mlir/include/mlir/Dialect/Linalg/IR/Linalg.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_DIALECT_LINALG_IR_LINALG_H
1111

1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
13+
#include "mlir/Dialect/Affine/IR/AffineTraits.h"
1314
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
1415
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1516
#include "mlir/IR/AffineExpr.h"

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef LINALG_OPS
1414
#define LINALG_OPS
1515

16+
include "mlir/Dialect/Affine/IR/AffineTraits.td"
1617
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
1718
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
1819
include "mlir/Interfaces/ControlFlowInterfaces.td"
@@ -46,7 +47,7 @@ def Linalg_YieldOp : Linalg_Op<"yield", [Pure, ReturnLike, Terminator]>,
4647
let hasVerifier = 1;
4748
}
4849

49-
def Linalg_IndexOp : Linalg_Op<"index", [Pure]>,
50+
def Linalg_IndexOp : Linalg_Op<"index", [Pure, AffineDim]>,
5051
Arguments<(ins ConfinedAttr<I64Attr, [IntMinValue<0>]>:$dim)>,
5152
Results<(outs Index:$result)> {
5253
let summary = "linalg index operation";
@@ -100,7 +101,8 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
100101
["getIterationDomain",
101102
"getLoopIteratorTypes",
102103
"getResultTilePosition",
103-
"getTiledImplementation"]>]> {
104+
"getTiledImplementation",
105+
"generateResultTileValue"]>]> {
104106
let summary = "Softmax operator";
105107
let description = [{
106108
linalg.softmax computes a numerically stable version of softmax.

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
1111

1212
#include "mlir/Dialect/SCF/IR/SCF.h"
13+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1314
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1415
#include "mlir/IR/PatternMatch.h"
1516
#include "mlir/Interfaces/LoopLikeInterface.h"
@@ -194,6 +195,21 @@ struct SCFTileAndFuseOptions {
194195
/// before fusion. This will track deleted and newly inserted
195196
/// `tensor.extract_slice` ops and update the worklist.
196197
std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt;
198+
199+
/// A function to insert a tilable node into a list of nodes to be tiled.
200+
/// This controls the order in which tiling and fusion happen.
201+
using WorklistInsertFnTy = std::function<void(
202+
tensor::ExtractSliceOp op, std::deque<tensor::ExtractSliceOp> &worklist)>;
203+
/// By default, simply append the op at the end of the queue.
204+
WorklistInsertFnTy worklistInsertFn =
205+
[](tensor::ExtractSliceOp op,
206+
std::deque<tensor::ExtractSliceOp> &worklist) {
207+
worklist.push_back(op);
208+
};
209+
SCFTileAndFuseOptions &setWorklistInsertFn(WorklistInsertFnTy insertFn) {
210+
worklistInsertFn = insertFn;
211+
return *this;
212+
}
197213
};
198214

199215
/// Fuse the producer of the source of `candidateSliceOp` by computing the

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ namespace tosa {
2626

2727
// Expose Rewrite Functions that decompose TOSA Ops into further TOSA Ops.
2828
// The rewrites can be selectively added to a conversion pass.
29-
void populateTosaDecomposeConv2D(MLIRContext *ctx, RewritePatternSet &patterns);
3029
void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
3130
RewritePatternSet &patterns);
3231
void populateTosaDecomposeDepthwise(MLIRContext *ctx,

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1010
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
11+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1112
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1213
#include "mlir/Dialect/UB/IR/UBOps.h"
1314
#include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -312,6 +313,10 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
312313
return isa<AffineForOp, AffineParallelOp>(parentOp);
313314
}
314315

316+
// Remove me: linalg.index ops are valid affine dim identifiers
317+
if (op->hasTrait<OpTrait::AffineDim>())
318+
return true;
319+
315320
// Affine apply operation is ok if all of its operands are ok.
316321
if (auto applyOp = dyn_cast<AffineApplyOp>(op))
317322
return applyOp.isValidDim(region);
@@ -439,6 +444,10 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) {
439444
return false;
440445
}
441446

447+
// Remove me: linalg.index ops are not valid affine symbols
448+
if (defOp->hasTrait<OpTrait::AffineDim>())
449+
return false;
450+
442451
// Constant operation is ok.
443452
Attribute operandCst;
444453
if (matchPattern(defOp, m_Constant(&operandCst)))

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2648,39 +2648,56 @@ SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
26482648
return iteratorTypes;
26492649
}
26502650

2651-
FailureOr<TilingResult>
2652-
SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2653-
ArrayRef<OpFoldResult> offsets,
2654-
ArrayRef<OpFoldResult> sizes) {
2655-
int64_t rank = getInputOperandRank();
2651+
static FailureOr<TilingResult>
2652+
implementTiledSoftMax(SoftmaxOp &op, OpBuilder &builder,
2653+
ArrayRef<OpFoldResult> offsets,
2654+
ArrayRef<OpFoldResult> sizes) {
2655+
int64_t rank = op.getInputOperandRank();
26562656
auto oneAttr = builder.getI64IntegerAttr(1);
26572657
SmallVector<OpFoldResult> strides(rank, oneAttr);
26582658
SmallVector<Value> tiledOperands;
26592659
Operation *inputSlice =
2660-
getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2660+
getSlice(builder, op.getLoc(), op.getInput(), offsets, sizes, strides);
26612661
if (!inputSlice) {
2662-
return emitOpError("failed to compute input slice");
2662+
return op.emitOpError("failed to compute input slice");
26632663
}
26642664
tiledOperands.emplace_back(inputSlice->getResult(0));
26652665
Operation *outputSlice =
2666-
getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2666+
getSlice(builder, op.getLoc(), op.getOutput(), offsets, sizes, strides);
26672667
if (!outputSlice) {
2668-
return emitOpError("failed to compute output slice");
2668+
return op.emitOpError("failed to compute output slice");
26692669
}
26702670
tiledOperands.emplace_back(outputSlice->getResult(0));
26712671

26722672
SmallVector<Type, 4> resultTypes;
2673-
if (hasPureTensorSemantics())
2673+
if (op.hasPureTensorSemantics())
26742674
resultTypes.push_back(tiledOperands[1].getType());
26752675
Operation *tiledOp =
2676-
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2676+
mlir::clone(builder, op.getOperation(), resultTypes, tiledOperands);
26772677

26782678
return TilingResult{
26792679
{tiledOp},
26802680
SmallVector<Value>(tiledOp->getResults()),
26812681
llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
26822682
}
26832683

2684+
FailureOr<TilingResult>
2685+
SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2686+
ArrayRef<OpFoldResult> offsets,
2687+
ArrayRef<OpFoldResult> sizes) {
2688+
return implementTiledSoftMax(*this, builder, offsets, sizes);
2689+
}
2690+
2691+
FailureOr<TilingResult>
2692+
SoftmaxOp::generateResultTileValue(OpBuilder &builder, unsigned resultNumber,
2693+
ArrayRef<OpFoldResult> offsets,
2694+
ArrayRef<OpFoldResult> sizes) {
2695+
if (resultNumber != 0)
2696+
return failure();
2697+
2698+
return implementTiledSoftMax(*this, builder, offsets, sizes);
2699+
}
2700+
26842701
LogicalResult SoftmaxOp::getResultTilePosition(
26852702
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
26862703
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,

0 commit comments

Comments
 (0)