Skip to content

Commit e4cc751

Browse files
committed
Merge remote-tracking branch 'xlnx/feature/fused-ops' into bump_to_f8eceb45
2 parents d1726f4 + 5518042 commit e4cc751

File tree

11 files changed

+156
-50
lines changed

11 files changed

+156
-50
lines changed

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ def LinalgElementwiseOpFusionPass : Pass<"linalg-fuse-elementwise-ops"> {
7575
let dependentDialects = [
7676
"affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
7777
];
78+
let options = [
79+
Option<"removeOutsDependency", "remove-outs-dependency", "bool",
80+
/*default=*/"true",
81+
"Replace out by tensor.empty">,
82+
];
7883
}
7984

8085
def LinalgNamedOpConversionPass: Pass<"linalg-named-op-conversion"> {

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1701,7 +1701,8 @@ using ControlFusionFn = std::function<bool(OpOperand *fusedOperand)>;
17011701
/// when both operations are fusable elementwise operations.
17021702
void populateElementwiseOpsFusionPatterns(
17031703
RewritePatternSet &patterns,
1704-
const ControlFusionFn &controlElementwiseOpFusion);
1704+
const ControlFusionFn &controlElementwiseOpFusion,
1705+
bool replaceOutsDependency = true);
17051706

17061707
/// Function type which is used to control propagation of tensor.pack/unpack
17071708
/// ops.

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ class TruncFConversion : public OpConversionPattern<arith::TruncFOp> {
757757
return rewriter.notifyMatchFailure(castOp,
758758
"unsupported cast destination type");
759759

760-
if (!castOp.areCastCompatible(operandType, dstType))
760+
if (!emitc::CastOp::areCastCompatible(operandType, dstType))
761761
return rewriter.notifyMatchFailure(castOp, "cast-incompatible types");
762762

763763
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType,
@@ -787,7 +787,7 @@ class ExtFConversion : public OpConversionPattern<arith::ExtFOp> {
787787
return rewriter.notifyMatchFailure(castOp,
788788
"unsupported cast destination type");
789789

790-
if (!castOp.areCastCompatible(operandType, dstType))
790+
if (!emitc::CastOp::areCastCompatible(operandType, dstType))
791791
return rewriter.notifyMatchFailure(castOp, "cast-incompatible types");
792792

793793
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType,

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,10 @@ LogicalResult emitc::AssignOp::verify() {
313313
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
314314
Type input = inputs.front(), output = outputs.front();
315315

316+
// Opaque types are always allowed
317+
if (isa<emitc::OpaqueType>(input) || isa<emitc::OpaqueType>(output))
318+
return true;
319+
316320
// Cast to array is only possible from an array
317321
if (isa<emitc::ArrayType>(input) != isa<emitc::ArrayType>(output))
318322
return false;

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2134,11 +2134,13 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
21342134

21352135
void mlir::linalg::populateElementwiseOpsFusionPatterns(
21362136
RewritePatternSet &patterns,
2137-
const ControlFusionFn &controlElementwiseOpsFusion) {
2137+
const ControlFusionFn &controlElementwiseOpsFusion,
2138+
bool removeOutsDependency) {
21382139
auto *context = patterns.getContext();
21392140
patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2140-
patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2141-
RemoveOutsDependency>(context);
2141+
patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant>(context);
2142+
if (removeOutsDependency)
2143+
patterns.add<RemoveOutsDependency>(context);
21422144
// Add the patterns that clean up dead operands and results.
21432145
populateEraseUnusedOperandsAndResultsPatterns(patterns);
21442146
}
@@ -2180,7 +2182,8 @@ struct LinalgElementwiseOpFusionPass
21802182
};
21812183

21822184
// Add elementwise op fusion patterns.
2183-
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
2185+
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn,
2186+
removeOutsDependency);
21842187
populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
21852188
tensor::populateBubbleUpExpandShapePatterns(patterns);
21862189

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,19 +1349,13 @@ OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
13491349
return {};
13501350
}
13511351

1352-
static bool hasZeroSize(Type ty) {
1353-
auto ranked = dyn_cast<RankedTensorType>(ty);
1354-
if (!ranked)
1355-
return false;
1356-
return any_of(ranked.getShape(), [](auto d) { return d == 0; });
1357-
}
1358-
13591352
OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
13601353
/// Remove operands that have zero elements.
13611354
bool changed = false;
13621355
for (size_t i = 0; i < getInput1().size(); ) {
1363-
auto input = getInput1()[i];
1364-
if (hasZeroSize(input.getType())) {
1356+
auto input = cast<RankedTensorType>(getInput1()[i].getType());
1357+
// Ensure that we have at least one operand left.
1358+
if (input.getDimSize(getAxis()) == 0 && getInput1().size() > 1) {
13651359
getInput1Mutable().erase(i);
13661360
changed = true;
13671361
} else {

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,16 @@ struct CppEmitter {
254254
return operandExpression == emittedExpression;
255255
};
256256

257+
/// Determine whether expression \p expressionOp should be emitted inline,
258+
/// i.e. as part of its user. This function recommends inlining of any
259+
/// expressions that can be inlined unless it is used by another expression,
260+
/// under the assumption that any expression fusion/re-materialization was
261+
/// taken care of by transformations run by the backend.
262+
bool shouldBeInlined(ExpressionOp expressionOp);
263+
264+
/// This emitter will only emit translation units whos id matches this value.
265+
StringRef willOnlyEmitTu() { return onlyTu; }
266+
257267
private:
258268
using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
259269
using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
@@ -297,21 +307,22 @@ struct CppEmitter {
297307
return lowestPrecedence();
298308
return emittedExpressionPrecedence.back();
299309
}
310+
311+
/// Determine whether expression \p op should be emitted in a deferred way.
312+
bool hasDeferredEmission(Operation *op);
300313
};
301314
} // namespace
302315

303-
/// Determine whether expression \p op should be emitted in a deferred way.
304-
static bool hasDeferredEmission(Operation *op) {
316+
bool CppEmitter::hasDeferredEmission(Operation *op) {
317+
if (llvm::isa_and_nonnull<emitc::ConstantOp>(op)) {
318+
return !shouldUseConstantsAsVariables();
319+
}
320+
305321
return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp,
306322
emitc::MemberOfPtrOp, emitc::SubscriptOp>(op);
307323
}
308324

309-
/// Determine whether expression \p expressionOp should be emitted inline, i.e.
310-
/// as part of its user. This function recommends inlining of any expressions
311-
/// that can be inlined unless it is used by another expression, under the
312-
/// assumption that any expression fusion/re-materialization was taken care of
313-
/// by transformations run by the backend.
314-
static bool shouldBeInlined(ExpressionOp expressionOp) {
325+
bool CppEmitter::shouldBeInlined(ExpressionOp expressionOp) {
315326
// Do not inline if expression is marked as such.
316327
if (expressionOp.getDoNotInline())
317328
return false;
@@ -373,6 +384,25 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
373384
static LogicalResult printOperation(CppEmitter &emitter,
374385
emitc::ConstantOp constantOp) {
375386
if (!emitter.shouldUseConstantsAsVariables()) {
387+
std::string out;
388+
llvm::raw_string_ostream ss(out);
389+
390+
/// Temporary emitter object that writes to our stream instead of the output
391+
/// allowing for the capture and caching of the produced string.
392+
CppEmitter sniffer = CppEmitter(ss, emitter.shouldDeclareVariablesAtTop(),
393+
emitter.willOnlyEmitTu(),
394+
emitter.shouldUseConstantsAsVariables());
395+
396+
ss << "(";
397+
if (failed(sniffer.emitType(constantOp.getLoc(), constantOp.getType())))
398+
return failure();
399+
ss << ") ";
400+
401+
if (failed(
402+
sniffer.emitAttribute(constantOp.getLoc(), constantOp.getValue())))
403+
return failure();
404+
405+
emitter.cacheDeferredOpResult(constantOp.getResult(), out);
376406
return success();
377407
}
378408

@@ -838,7 +868,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
838868

839869
static LogicalResult printOperation(CppEmitter &emitter,
840870
emitc::ExpressionOp expressionOp) {
841-
if (shouldBeInlined(expressionOp))
871+
if (emitter.shouldBeInlined(expressionOp))
842872
return success();
843873

844874
Operation &op = *expressionOp.getOperation();
@@ -892,7 +922,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
892922
dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
893923
if (!expressionOp)
894924
return false;
895-
return shouldBeInlined(expressionOp);
925+
return emitter.shouldBeInlined(expressionOp);
896926
};
897927

898928
os << "for (";
@@ -1114,7 +1144,7 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
11141144
functionOp->walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
11151145
if (isa<emitc::ExpressionOp>(op->getParentOp()) ||
11161146
(isa<emitc::ExpressionOp>(op) &&
1117-
shouldBeInlined(cast<emitc::ExpressionOp>(op))))
1147+
emitter.shouldBeInlined(cast<emitc::ExpressionOp>(op))))
11181148
return WalkResult::skip();
11191149
for (OpResult result : op->getResults()) {
11201150
if (failed(emitter.emitVariableDeclaration(
@@ -1494,22 +1524,6 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {
14941524

14951525
LogicalResult CppEmitter::emitOperand(Value value) {
14961526
Operation *def = value.getDefiningOp();
1497-
if (!shouldUseConstantsAsVariables()) {
1498-
if (auto constant = dyn_cast_if_present<ConstantOp>(def)) {
1499-
os << "((";
1500-
1501-
if (failed(emitType(constant.getLoc(), constant.getType()))) {
1502-
return failure();
1503-
}
1504-
os << ") ";
1505-
1506-
if (failed(emitAttribute(constant.getLoc(), constant.getValue()))) {
1507-
return failure();
1508-
}
1509-
os << ")";
1510-
return success();
1511-
}
1512-
}
15131527

15141528
if (isPartOfCurrentExpression(value)) {
15151529
assert(def && "Expected operand to be defined by an operation");
@@ -1721,11 +1735,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
17211735
cacheDeferredOpResult(op.getResult(), op.getValue());
17221736
return success();
17231737
})
1724-
.Case<emitc::MemberOp>([&](auto op) {
1725-
cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
1726-
return success();
1727-
})
1728-
.Case<emitc::MemberOfPtrOp>([&](auto op) {
1738+
.Case<emitc::MemberOp, emitc::MemberOfPtrOp>([&](auto op) {
17291739
cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
17301740
return success();
17311741
})

mlir/test/Dialect/EmitC/ops.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,15 @@ emitc.func private @extern(i32) attributes {specifiers = ["extern"]}
3636

3737
func.func @cast(%arg0: i32) {
3838
%1 = emitc.cast %arg0: i32 to f32
39+
%2 = emitc.cast %1: f32 to !emitc.opaque<"some type">
40+
%3 = emitc.cast %2: !emitc.opaque<"some type"> to !emitc.size_t
3941
return
4042
}
4143

4244
func.func @cast_array(%arg : !emitc.array<4xf32>) {
4345
%1 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.array<4xf32> ref
46+
%2 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.opaque<"some type">
47+
%3 = emitc.cast %2: !emitc.opaque<"some type"> to !emitc.array<4xf32> ref
4448
return
4549
}
4650

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: mlir-opt %s -p 'builtin.module(func.func(linalg-fuse-elementwise-ops{remove-outs-dependency=0}))' -split-input-file | FileCheck %s
2+
3+
#identity = affine_map<(d0) -> (d0)>
4+
5+
func.func @keep_outs_dependency(%arg: tensor<4xf32>) -> tensor<4xf32> {
6+
// CHECK-NOT: tensor.empty
7+
%1 = linalg.generic {indexing_maps = [#identity, #identity], iterator_types = ["parallel"] } ins(%arg: tensor<4xf32>) outs(%arg: tensor<4xf32>) {
8+
^bb0(%in: f32, %out: f32):
9+
%exp = arith.negf %in: f32
10+
linalg.yield %exp : f32
11+
} -> tensor<4xf32>
12+
return %1 : tensor<4xf32>
13+
}

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,32 @@ func.func @concat_fold_zero(%arg0: tensor<?x0xf32>, %arg1: tensor<?x1xf32>, %arg
204204
%0 = tosa.concat %arg0, %arg1, %arg2 {axis = 1 : i32}: (tensor<?x0xf32>, tensor<?x1xf32>, tensor<?x2xf32>) -> tensor<?x3xf32>
205205
return %0 : tensor<?x3xf32>
206206
}
207+
// -----
208+
209+
// CHECK-LABEL: @concat_fold_zero
210+
func.func @concat_fold_zero_all(%arg0: tensor<?x0xf32>, %arg1: tensor<?x0xf32>) -> tensor<?x0xf32> {
211+
// CHECK: return %arg1
212+
%0 = tosa.concat %arg0, %arg1 {axis = 1 : i32}: (tensor<?x0xf32>, tensor<?x0xf32>) -> tensor<?x0xf32>
213+
return %0 : tensor<?x0xf32>
214+
}
215+
216+
// -----
217+
218+
// CHECK-LABEL: @concat_fold_zero
219+
func.func @concat_fold_zero_different_axis(%arg0: tensor<0x2xf32>, %arg1: tensor<0x4xf32>) -> tensor<0x6xf32> {
220+
// CHECK: tosa.concat %arg0, %arg1
221+
%0 = tosa.concat %arg0, %arg1 {axis = 1 : i32}: (tensor<0x2xf32>, tensor<0x4xf32>) -> tensor<0x6xf32>
222+
return %0 : tensor<0x6xf32>
223+
}
224+
225+
// -----
226+
227+
// CHECK-LABEL: @concat_fold_zero_size
228+
func.func @concat_fold_zero_size(%arg0: tensor<?x0xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x2xf32>) -> tensor<?x3xf32> {
229+
// CHECK: tosa.concat %arg1, %arg2 {axis = 1 : i32}
230+
%0 = tosa.concat %arg0, %arg1, %arg2 {axis = 1 : i32}: (tensor<?x0xf32>, tensor<?x1xf32>, tensor<?x2xf32>) -> tensor<?x3xf32>
231+
return %0 : tensor<?x3xf32>
232+
}
207233

208234
// -----
209235

mlir/test/Target/Cpp/emitc-constants-as-variables.mlir

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,55 @@ func.func @test() {
1111

1212
return
1313
}
14+
// CPP-DEFAULT-LABEL: void test() {
15+
// CPP-DEFAULT-NEXT: for (size_t v1 = (size_t) 0; v1 < (size_t) 10; v1 += (size_t) 1) {
16+
// CPP-DEFAULT-NEXT: }
17+
// CPP-DEFAULT-NEXT: return;
18+
// CPP-DEFAULT-NEXT: }
19+
20+
// -----
21+
22+
func.func @test_subscript(%arg0: !emitc.array<4xf32>) -> (!emitc.lvalue<f32>) {
23+
%c0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
24+
%0 = emitc.subscript %arg0[%c0] : (!emitc.array<4xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
25+
return %0 : !emitc.lvalue<f32>
26+
}
27+
// CPP-DEFAULT-LABEL: float test_subscript(float v1[4]) {
28+
// CPP-DEFAULT-NEXT: return v1[(size_t) 0];
29+
// CPP-DEFAULT-NEXT: }
30+
31+
// -----
1432

15-
// CPP-DEFAULT: void test() {
16-
// CPP-DEFAULT-NEXT: for (size_t v1 = ((size_t) 0); v1 < ((size_t) 10); v1 += ((size_t) 1)) {
33+
func.func @emitc_switch_ui64() {
34+
%0 = "emitc.constant"(){value = 1 : ui64} : () -> ui64
35+
36+
emitc.switch %0 : ui64
37+
default {
38+
emitc.call_opaque "func2" (%0) : (ui64) -> ()
39+
emitc.yield
40+
}
41+
return
42+
}
43+
// CPP-DEFAULT-LABEL: void emitc_switch_ui64() {
44+
// CPP-DEFAULT: switch ((uint64_t) 1) {
45+
// CPP-DEFAULT-NEXT: default: {
46+
// CPP-DEFAULT-NEXT: func2((uint64_t) 1);
47+
// CPP-DEFAULT-NEXT: break;
1748
// CPP-DEFAULT-NEXT: }
1849
// CPP-DEFAULT-NEXT: return;
1950
// CPP-DEFAULT-NEXT: }
51+
52+
// -----
53+
54+
func.func @negative_values() {
55+
%1 = "emitc.constant"() <{value = 10 : index}> : () -> !emitc.size_t
56+
%2 = "emitc.constant"() <{value = -3000000000 : index}> : () -> !emitc.ssize_t
57+
58+
%3 = emitc.add %1, %2 : (!emitc.size_t, !emitc.ssize_t) -> !emitc.ssize_t
59+
60+
return
61+
}
62+
// CPP-DEFAULT-LABEL: void negative_values() {
63+
// CPP-DEFAULT-NEXT: ssize_t v1 = (size_t) 10 + (ssize_t) -3000000000;
64+
// CPP-DEFAULT-NEXT: return;
65+
// CPP-DEFAULT-NEXT: }

0 commit comments

Comments
 (0)