Skip to content

Commit 5518042

Browse files
ElementwiseOpFusion: option for disable empty (#430)
* ElementwiseOpFusion: option for disable empty --------- Co-authored-by: Matthias Gehre <[email protected]>
1 parent 992dad3 commit 5518042

File tree

4 files changed

+27
-5
lines changed

4 files changed

+27
-5
lines changed

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ def LinalgElementwiseOpFusionPass : Pass<"linalg-fuse-elementwise-ops"> {
7575
let dependentDialects = [
7676
"affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
7777
];
78+
let options = [
79+
Option<"removeOutsDependency", "remove-outs-dependency", "bool",
80+
/*default=*/"true",
81+
"Replace out by tensor.empty">,
82+
];
7883
}
7984

8085
def LinalgNamedOpConversionPass: Pass<"linalg-named-op-conversion"> {

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1701,7 +1701,8 @@ using ControlFusionFn = std::function<bool(OpOperand *fusedOperand)>;
17011701
/// when both operations are fusable elementwise operations.
17021702
void populateElementwiseOpsFusionPatterns(
17031703
RewritePatternSet &patterns,
1704-
const ControlFusionFn &controlElementwiseOpFusion);
1704+
const ControlFusionFn &controlElementwiseOpFusion,
1705+
bool replaceOutsDependency = true);
17051706

17061707
/// Function type which is used to control propagation of tensor.pack/unpack
17071708
/// ops.

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2134,11 +2134,13 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
21342134

21352135
void mlir::linalg::populateElementwiseOpsFusionPatterns(
21362136
RewritePatternSet &patterns,
2137-
const ControlFusionFn &controlElementwiseOpsFusion) {
2137+
const ControlFusionFn &controlElementwiseOpsFusion,
2138+
bool removeOutsDependency) {
21382139
auto *context = patterns.getContext();
21392140
patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2140-
patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2141-
RemoveOutsDependency>(context);
2141+
patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant>(context);
2142+
if (removeOutsDependency)
2143+
patterns.add<RemoveOutsDependency>(context);
21422144
// Add the patterns that clean up dead operands and results.
21432145
populateEraseUnusedOperandsAndResultsPatterns(patterns);
21442146
}
@@ -2180,7 +2182,8 @@ struct LinalgElementwiseOpFusionPass
21802182
};
21812183

21822184
// Add elementwise op fusion patterns.
2183-
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
2185+
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn,
2186+
removeOutsDependency);
21842187
populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
21852188
tensor::populateBubbleUpExpandShapePatterns(patterns);
21862189

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: mlir-opt %s -p 'builtin.module(func.func(linalg-fuse-elementwise-ops{remove-outs-dependency=0}))' -split-input-file | FileCheck %s
2+
3+
#identity = affine_map<(d0) -> (d0)>
4+
5+
func.func @keep_outs_dependency(%arg: tensor<4xf32>) -> tensor<4xf32> {
6+
// CHECK-NOT: tensor.empty
7+
%1 = linalg.generic {indexing_maps = [#identity, #identity], iterator_types = ["parallel"] } ins(%arg: tensor<4xf32>) outs(%arg: tensor<4xf32>) {
8+
^bb0(%in: f32, %out: f32):
9+
%exp = arith.negf %in: f32
10+
linalg.yield %exp : f32
11+
} -> tensor<4xf32>
12+
return %1 : tensor<4xf32>
13+
}

0 commit comments

Comments
 (0)