Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMDAIEFuseFillIntoForall] Handle case where fill output is not sliced #976

Merged
merged 6 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "iree-amd-aie/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"

#define DEBUG_TYPE "iree-amdaie-fuse-fill-into-forall"
Expand All @@ -29,60 +30,87 @@ class AMDAIEFuseFillIntoForallPass

void AMDAIEFuseFillIntoForallPass::runOnOperation() {
MLIRContext *context = &getContext();
mlir::FunctionOpInterface funcOp = getOperation();
IRRewriter rewriter(context);

// Find the producer op, in this case is linalg.fill.
TilingInterface tileableProducer;
funcOp->walk([&](TilingInterface op) {
if (isa<linalg::FillOp>(op)) {
tileableProducer = op;
return WalkResult::interrupt();
}
return WalkResult::advance();
});

if (!tileableProducer) {
LLVM_DEBUG(llvm::dbgs() << "There is no producer op to be fused.\n");
// Find a unique FillOp with a single output, or return.
SmallVector<linalg::FillOp> fillOps;
getOperation()->walk(
[&](linalg::FillOp fillOp) { fillOps.push_back(fillOp); });
if (fillOps.size() != 1) {
LLVM_DEBUG(llvm::dbgs() << "Expected exactly 1 fill op, but found "
<< fillOps.size() << ".\n");
return;
}

// Search the first use by a scf::ForallOp user.
scf::ForallOp forallOp;
auto itProducerUses =
llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
return forallOp;
});
linalg::FillOp fillOp = fillOps[0];
if (fillOp.getResults().size() != 1) {
LLVM_DEBUG(llvm::dbgs() << "Expected fill op to have exactly 1 result, but "
<< "found " << fillOp.getResults().size() << ".\n");

return;
};

// Confirm that there is a unique user that is a forall, and match
// the block argument that is used by the fill op, or return.
if (!fillOp->hasOneUse()) {
LLVM_DEBUG(llvm::dbgs()
<< "Expected exactly 1 use of fill op, but found 0 or 2+.");
return;
}
OpOperand &fillUse = *fillOp->getUses().begin();
auto forallOp = dyn_cast<scf::ForallOp>(fillUse.getOwner());
if (!forallOp) {
LLVM_DEBUG(llvm::dbgs() << "There is no forall Op.\n");
LLVM_DEBUG(llvm::dbgs() << "Expected fill op to be used by a forall op, "
<< "but unique user is "
<< fillUse.getOwner()->getName() << ".\n");
return;
}

// Search the producer slices accessed within the Forall op.
OpOperand *pUse = &(*itProducerUses);
BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);

auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
return sliceOp;
});
if (itBBArgUsers == bbArg.getUsers().end()) {
funcOp->emitOpError("There is no extract tensor slice.");
return signalPassFailure();
BlockArgument bbArg = forallOp.getTiedBlockArgument(&fillUse);

// Find 0 or 1 ExtractSliceOps that use the fill result, or return.
tensor::ExtractSliceOp extractSliceOp;
for (Operation *user : bbArg.getUsers()) {
if (auto nxt = dyn_cast<tensor::ExtractSliceOp>(user)) {
if (extractSliceOp) {
LLVM_DEBUG(llvm::dbgs()
<< "Expected at most 1 extract_slice op, but found 2+.\n");
return;
}
extractSliceOp = nxt;
}
}
auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);

LoopLikeOpInterface loops =
cast<LoopLikeOpInterface>(forallOp.getOperation());

// Materialize the slice of the producer in place.
std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
scf::tileAndFuseProducerOfSlice(rewriter, sliceOpToTile,
MutableArrayRef(&loops, 1));
if (!fusedProducer) {
funcOp->emitOpError("Failed to fuse fill op into forall loop.");
return signalPassFailure();

if (extractSliceOp) {
LoopLikeOpInterface loops =
cast<LoopLikeOpInterface>(forallOp.getOperation());

// Materialize the slice of the producer in place.
std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
scf::tileAndFuseProducerOfSlice(rewriter, extractSliceOp,
MutableArrayRef(&loops, 1));
if (!fusedProducer) {
fillOp->emitOpError("could not be fused into forall");
return signalPassFailure();
}
} else {
// In the case where there are no extract_slice ops, we manually create the
// fill at the beginning of the forall body. This situation might arise
// if the extract_slice has been folded, for example if the forall is
// over a grid if size 1.
rewriter.setInsertionPointToStart(forallOp.getBody());
auto fusedFill =
rewriter.create<linalg::FillOp>(fillOp.getLoc(), fillOp.value(), bbArg);
rewriter.replaceUsesWithIf(
bbArg, fusedFill.getResult(0), [&](OpOperand &operand) {
Operation *owner = operand.getOwner();
if (owner == fusedFill || isa<tensor::ParallelInsertSliceOp>(owner)) {
return false;
}
return true;
});

// Do not use the result of the old fill.
rewriter.replaceAllUsesWith(fillOp.getResults()[0], fillOp.getOutputs()[0]);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-amdaie-fuse-fill-into-forall))' %s | FileCheck %s
// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-amdaie-fuse-fill-into-forall))' %s | FileCheck %s

#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
Expand Down Expand Up @@ -29,3 +29,38 @@ func.func @fuse_fill_into_forall(%arg0: tensor<1x4x16x64xi8>, %arg1 : tensor<4x1
// CHECK: linalg.fill
// CHECK: linalg.generic
// CHECK: }

// -----

#map = affine_map<(d0) -> (d0)>
func.func @fuse_without_slice(%arg0: tensor<8xi8>) -> tensor<8xi8> {
%c7_i8 = arith.constant 7 : i8
%c3_i8 = arith.constant 3 : i8
%0 = linalg.fill ins(%c7_i8 : i8) outs(%arg0 : tensor<8xi8>) -> tensor<8xi8>
%1 = tensor.empty() : tensor<8xi8>
%2 = scf.forall (%arg1) in (1) shared_outs(%arg2 = %0) -> (tensor<8xi8>) {
%3 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<8xi8>) outs(%1 : tensor<8xi8>) {
^bb0(%in: i8, %out: i8):
%4 = arith.addi %in, %c3_i8 : i8
linalg.yield %4 : i8
} -> tensor<8xi8>
scf.forall.in_parallel {
tensor.parallel_insert_slice %3 into %arg2[0] [8] [1] : tensor<8xi8> into tensor<8xi8>
}
} {mapping = [#gpu.thread<y>]}
return %2 : tensor<8xi8>
}

// CHECK: @fuse_without_slice(%[[FUNCARG:.*]]: tensor<8xi8>) -> tensor<8xi8> {
// check that the operand of scf.forall is not the filled tensor, because the
// fill will take place inside the scf.forall:
// CHECK: %[[FORALL:.*]] = scf.forall (%[[ARG1:.*]]) in (1)
// CHECK-SAME: shared_outs(%[[ARG2:.*]] = %[[FUNCARG]])
// check for the new fill:
// CHECK: %[[NEWFILL:.*]] = linalg.fill
// CHECK-SAME: outs(%[[ARG2]] : tensor<8xi8>) -> tensor<8xi8>
// CHECK: linalg.generic
// check the the parallel_insert_slice still happens on arg2, not the filled
// tensor. This is because it must match the shared_outs of the scf.forall:
// CHECK: tensor.parallel_insert_slice
// CHECK-SAME: into %[[ARG2]]
Loading