Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes needed to enable outlining by default #951

Merged
merged 4 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/Transforms/AMDAIEUtils.h"
#include "iree-amd-aie/Transforms/Passes.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
Expand All @@ -20,128 +20,181 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

/// Utility to check if the linalg op is a known op we know we want to be able
/// to outline.
static bool mustOutline(linalg::LinalgOp linalgOp) {
return isMatmul(linalgOp) || isElementwise(linalgOp);
/// Return true if the strides of `type` make it a contiguous memref.
bool isContiguousMemRef(MemRefType type) {
ArrayRef<int64_t> shape = type.getShape();
SmallVector<int64_t, 4> strides;
int64_t ignoredOffset;
if (failed(getStridesAndOffset(type, strides, ignoredOffset))) return false;
int64_t expectedStride = 1;
for (int i = shape.size() - 1; i >= 0; --i) {
if (shape[i] == ShapedType::kDynamic) return false;
if (strides[i] != expectedStride) return false;
expectedStride *= shape[i];
}
return true;
}

/// Utility to check if the linalg op is a known op we know should not be
/// outlined.
static bool mustNotOutline(linalg::LinalgOp linalgOp) {
return isa<linalg::CopyOp, linalg::FillOp>(linalgOp);
/// If `type` is a contiguous memref, return an equivalent memref without any
/// layout attribute. Otherwise, return nullptr.
Type getTypeWithoutLayout(Type type) {
auto memRefType = dyn_cast<MemRefType>(type);
if (!memRefType) return {};
if (!isContiguousMemRef(memRefType)) return {};
return MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
MemRefLayoutAttrInterface{},
memRefType.getMemorySpace());
}

/// Utility to outline the linalg compute op.
static FailureOr<func::FuncOp> outlinedToAFunction(
IRRewriter &rewriter, ModuleOp moduleOp, linalg::LinalgOp computeOp,
std::string outlineFuncName) {
if (auto outlinedFuncOp = dyn_cast_if_present<func::FuncOp>(
moduleOp.lookupSymbol(outlineFuncName))) {
return outlinedFuncOp;
}

static FailureOr<func::FuncOp> outline(IRRewriter &rewriter, ModuleOp moduleOp,
linalg::LinalgOp computeOp,
const std::string &funcName) {
// Form outlined FunctionType.
SmallVector<Type> inputTypes = llvm::map_to_vector(
computeOp.getDpsInputs(), [](Value v) { return v.getType(); });
for (Value val : computeOp.getDpsInits()) inputTypes.push_back(val.getType());
auto outlinedFuncType =
SmallVector<Type> inputTypes;
inputTypes.reserve(computeOp->getOperands().size());
for (const auto &operand : computeOp->getOperands()) {
Type withoutLayout = getTypeWithoutLayout(operand.getType());
// The op has an operand with a layout that isn't compatible with outlining.
// We currently don't support strided layouts.
if (!withoutLayout) return failure();
inputTypes.push_back(withoutLayout);
}
auto funcType =
FunctionType::get(rewriter.getContext(), inputTypes, /*outputTypes=*/{});

// Form outlined FuncSignature
// Form outlined FuncSignature.
rewriter.setInsertionPointToStart(moduleOp.getBody());
auto outlinedFunc = rewriter.create<func::FuncOp>(
moduleOp.getLoc(), outlineFuncName, outlinedFuncType);
outlinedFunc.setPrivate();
auto func =
rewriter.create<func::FuncOp>(moduleOp.getLoc(), funcName, funcType);
func.setPrivate();

// Create an entry func block and map the original operands of the compute
// op to the block arguments.
Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
rewriter.setInsertionPointToStart(outlinedFuncBody);
SmallVector<BlockArgument> outlinedFuncArgs = llvm::map_to_vector(
outlinedFunc.getArguments(), [&](BlockArgument bbArg) { return bbArg; });
Block *funcBody = func.addEntryBlock();
rewriter.setInsertionPointToStart(funcBody);
SmallVector<BlockArgument> funcArgs = llvm::map_to_vector(
func.getArguments(), [&](BlockArgument bbArg) { return bbArg; });
unsigned bbArgIndex = 0;
IRMapping operandMap;
for (Value origOperand : computeOp.getDpsInputs())
operandMap.map(origOperand, outlinedFuncArgs[bbArgIndex++]);
for (Value origOperand : computeOp.getDpsInits())
operandMap.map(origOperand, outlinedFuncArgs[bbArgIndex++]);
for (Value origOperand : computeOp->getOperands())
operandMap.map(origOperand, funcArgs[bbArgIndex++]);

// Clone the compute op while mapping the operand to the function block
// arguments.
Operation *clonedComputeOp = rewriter.clone(*computeOp, operandMap);

// Create terminator op returning the cloned compute op's results.
rewriter.setInsertionPointToEnd(outlinedFuncBody);
rewriter.setInsertionPointToEnd(funcBody);
rewriter.create<func::ReturnOp>(clonedComputeOp->getLoc(), ValueRange({}));

return outlinedFunc;
return func;
}

/// Utility to check if the linalg op is one we know should not be outlined.
static bool mustNotOutline(linalg::LinalgOp linalgOp) {
return isa<linalg::CopyOp, linalg::FillOp>(linalgOp);
// TODO(newling) not all remaining ops should be outlined, not even all
// remaining matmuls: below some threshold on size (m*n*k) it's not worth
// outlining (function call overhead). We should extend the blacklist
// here.
};

class AMDAIELinalgFunctionOutliningPass
: public impl::AMDAIELinalgFunctionOutliningBase<
AMDAIELinalgFunctionOutliningPass> {
public:
AMDAIELinalgFunctionOutliningPass() = default;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AMDAIEDialect, linalg::LinalgDialect>();
registry.insert<linalg::LinalgDialect>();
}

void runOnOperation() override;
};

void AMDAIELinalgFunctionOutliningPass::runOnOperation() {
ModuleOp moduleOp = getOperation();
MLIRContext *context = &getContext();
IRRewriter rewriter(context);
private:
// Used for unique-ifing ID for generating new function names.
unsigned outlineCounter = 0;

unsigned uniqueOutlinedMatmul = 0;
unsigned uniqueOutlinedElementwise = 0;
DenseMap<Operation *, std::string> computeOpToOutlinedFuncMap;
SmallVector<Operation *> toBeErased;
moduleOp.walk([&](linalg::LinalgOp computeOp) {
if (mustNotOutline(computeOp)) {
return WalkResult::skip();
} else if (!mustOutline(computeOp)) {
computeOp->emitOpError() << "unsupported linalg op for outlining";
return WalkResult::interrupt();
}
// Form outlined function name for matmul/elementwise compute ops.
std::string outlineFuncName = "";
DenseMap<Operation *, func::FuncOp> computeOpToOutlinedFuncMap;

static std::string getSpecializedName(linalg::LinalgOp computeOp) {
// Will result in a function name like `generic_matmul_0_outlined`:
if (isMatmul(computeOp)) return "_matmul_";
// Will result in a function name like `generic_elementwise_0_outlined`:
if (isElementwise(computeOp)) return "_elementwise_";
// Will result in a function name like `generic_0_outlined`:
return "_";
}

std::string generateFuncName(linalg::LinalgOp computeOp) {
std::string name = computeOp->getName().stripDialect().str() +
getSpecializedName(computeOp) +
std::to_string(outlineCounter) + "_outlined";
++outlineCounter;
return name;
}

FailureOr<func::FuncOp> retrieveOrCreate(IRRewriter &rewriter,
ModuleOp moduleOp,
linalg::LinalgOp computeOp) {
// Check if the compute op is equivalent to a previously outlined compute
// op. If yes, we replace the `outlineFuncName` of the current compute op to
// be same as the previous equivalent outlined compute op in order to lookup
// the Symbol table.
for (auto &[op, funcName] : computeOpToOutlinedFuncMap) {
// op. If it is, retrieve and return the function generated for the previous
// compute op.
for (auto &[op, func] : computeOpToOutlinedFuncMap) {
if (OperationEquivalence::isEquivalentTo(
computeOp.getOperation(), op,
OperationEquivalence::ignoreValueEquivalence, /*flags=*/nullptr,
OperationEquivalence::IgnoreLocations)) {
outlineFuncName = funcName;
break;
return func;
}
}
if (outlineFuncName == "") {
std::string computeName = "";
if (isMatmul(computeOp)) {
computeName = "_matmul_" + std::to_string(uniqueOutlinedMatmul++);
} else if (isElementwise(computeOp)) {
computeName =
"_elementwise_" + std::to_string(uniqueOutlinedElementwise++);
} else {
return WalkResult::skip();

std::string funcName = generateFuncName(computeOp);
while (moduleOp.lookupSymbol(funcName)) {
funcName = generateFuncName(computeOp);
}

FailureOr<func::FuncOp> maybeFunc =
outline(rewriter, moduleOp, computeOp, funcName);

if (succeeded(maybeFunc)) {
computeOpToOutlinedFuncMap[computeOp] = maybeFunc.value();
}

return maybeFunc;
}
};

void AMDAIELinalgFunctionOutliningPass::runOnOperation() {
ModuleOp moduleOp = getOperation();
MLIRContext *context = &getContext();
IRRewriter rewriter(context);

SmallVector<Operation *> toBeErased;
moduleOp.walk([&](linalg::LinalgOp computeOp) {
if (mustNotOutline(computeOp)) return WalkResult::skip();

FailureOr<func::FuncOp> maybeFunc =
retrieveOrCreate(rewriter, moduleOp, computeOp);
if (failed(maybeFunc)) return WalkResult::interrupt();
func::FuncOp func = maybeFunc.value();

// Create a call into the outlined function. The operands of the compute op
// might need to be cast to a different type to match the outlined function.
{
SmallVector<Value> castOperands;
castOperands.reserve(computeOp->getOperands().size());
rewriter.setInsertionPoint(computeOp);
Location loc = computeOp.getLoc();
for (auto iter : llvm::enumerate(computeOp->getOperands())) {
Type type = func.getArgumentTypes()[iter.index()];
Value cast =
rewriter.createOrFold<memref::CastOp>(loc, type, iter.value());
castOperands.push_back(cast);
}
outlineFuncName =
computeOp->getName().stripDialect().str() + computeName + "_outlined";
computeOpToOutlinedFuncMap[computeOp] = outlineFuncName;
rewriter.create<func::CallOp>(loc, func, castOperands);
}

FailureOr<func::FuncOp> outlinedFuncOp =
outlinedToAFunction(rewriter, moduleOp, computeOp, outlineFuncName);
if (failed(outlinedFuncOp)) return WalkResult::interrupt();
rewriter.setInsertionPoint(computeOp);
rewriter.create<func::CallOp>(computeOp.getLoc(), *outlinedFuncOp,
computeOp->getOperands());
// We cannot immediately erase the compute op because it'd be used for
// equivalence check.
toBeErased.push_back(computeOp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,29 +181,96 @@ func.func @linalg_fill_copy(%A: memref<4xf32>, %B: memref<4xf32>) {

// -----

func.func @unsupported_linalg_op(%A: memref<4x8xbf16>, %B: memref<8x4xbf16>, %C: memref<4x4xf32>) {
// Test demonstrating the outlining of a linalg.generic operation other than
// a matmul or elementwise operation. Specifically, one which has not been
// 'blacklisted' like linalg.copy has (see test linalg_fill_copy above).
// CHECK: func.func private @generic_0_outlined
// CHECK-SAME: memref<4xbf16>,
// CHECK-SAME: memref<bf16>
// CHECK: linalg.generic
// CHECK-SAME: iterator_types = ["reduction"]
// CHECK: return
// CHECK: func.func @reduction
// CHECK-SAME: memref<4xbf16>
// CHECK-SAME: memref<bf16>
// CHECK: func.call @generic_0_outlined
// CHECK-SAME: (memref<4xbf16>, memref<bf16>) -> ()
// CHECK: return
func.func @reduction(%A: memref<4xbf16>, %B: memref<bf16>) {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%tile = amdaie.tile(%c1, %c2)
%tile = amdaie.tile(%c2, %c2)
%1 = amdaie.core(%tile, in : [], out : []) {
// expected-error@+1 {{unsupported linalg op for outlining}}
linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>
],
iterator_types = ["parallel", "parallel", "reduction"]
} ins(%A, %B : memref<4x8xbf16>, memref<8x4xbf16>)
outs(%C : memref<4x4xf32>) {
^bb0(%in: bf16, %in_17: bf16, %out: f32):
%1 = arith.extf %in : bf16 to f32
%2 = arith.extf %in_17 : bf16 to f32
%3 = arith.mulf %1, %2 : f32
%4 = arith.addf %out, %3 : f32
%5 = arith.addf %4, %4 : f32
linalg.yield %5 : f32
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>],
iterator_types = ["reduction"]
} ins(%A: memref<4xbf16>) outs(%B : memref<bf16>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
amdaie.end
}
return
}


// -----

// Test demonstrating the outlining of a linalg.generic where one
// operand has an unkown offset. The memref is still contiguous, however.
// CHECK: func.func private @generic_0_outlined
// CHECK-SAME: memref<4x8xbf16>,
// CHECK-SAME: memref<bf16>
// CHECK: linalg.generic
// CHECK-SAME: iterator_types = ["reduction", "reduction"]
// CHECK: return
// CHECK: func.func @supported_linalg_op
// CHECK-SAME: memref<4x8xbf16, strided<[8, 1], offset: ?>>
// CHECK-SAME: memref<bf16>
// CHECK: %[[CAST:.*]] = memref.cast
// CHECK-SAME: memref<4x8xbf16, strided<[8, 1], offset: ?>>
// CHECK-SAME: to memref<4x8xbf16>
// CHECK: func.call @generic_0_outlined(%[[CAST]], %arg1) :
// CHECK-SAME: (memref<4x8xbf16>, memref<bf16>) -> ()
// CHECK: return
func.func @supported_linalg_op(%A: memref<4x8xbf16, strided<[8,1], offset:?>>, %B: memref<bf16>) {
%c2 = arith.constant 2 : index
%tile = amdaie.tile(%c2, %c2)
%1 = amdaie.core(%tile, in : [], out : []) {
linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> ()>],
iterator_types = ["reduction", "reduction"]
} ins(%A: memref<4x8xbf16, strided<[8,1], offset:?>>) outs(%B : memref<bf16>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
amdaie.end
}
return
}


// -----

// Test illustrating the that when a linalg.generic operation has an
// operand that is not contiguous, it is not outlined.

// CHECK-COUNT-1: func.func
// CHECK-NOT: func.func
func.func @unsupported_linalg_op(%A: memref<4x8xbf16, strided<[9,1]>>, %B: memref<bf16>) {
%c2 = arith.constant 2 : index
%tile = amdaie.tile(%c2, %c2)
%1 = amdaie.core(%tile, in : [], out : []) {
linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> ()>],
iterator_types = ["reduction", "reduction"]
} ins(%A: memref<4x8xbf16, strided<[9,1]>>) outs(%B : memref<bf16>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
amdaie.end
}
return
}

Loading