From 0ddb0ed9841dc3051cf8da876d326d06948590f6 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 3 Dec 2024 07:45:36 -0800 Subject: [PATCH 1/4] squash --- .../AMDAIELinalgFunctionOutlining.cpp | 181 ++++++++++++------ .../test/linalg_function_outlining.mlir | 85 +++++--- 2 files changed, 176 insertions(+), 90 deletions(-) diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp index c7304bca9..384248756 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp @@ -4,11 +4,11 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree-amd-aie/IR/AMDAIEOps.h" #include "iree-amd-aie/Transforms/AMDAIEUtils.h" #include "iree-amd-aie/Transforms/Passes.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" @@ -20,31 +20,49 @@ namespace mlir::iree_compiler::AMDAIE { namespace { -/// Utility to check if the linalg op is a known op we know we want to be able -/// to outline. -static bool mustOutline(linalg::LinalgOp linalgOp) { - return isMatmul(linalgOp) || isElementwise(linalgOp); +/// Return true if the strides of `memrefType` are contiguous. +bool isContiguousMemRef(MemRefType memrefType) { + ArrayRef shape = memrefType.getShape(); + SmallVector strides; + int64_t offset; + if (failed(getStridesAndOffset(memrefType, strides, offset))) return false; + int64_t expectedStride = 1; + for (int i = shape.size() - 1; i >= 0; --i) { + if (shape[i] == ShapedType::kDynamic) return false; + if (strides[i] != expectedStride) return false; + expectedStride *= shape[i]; + } + return true; } -/// Utility to check if the linalg op is a known op we know should not be -/// outlined. -static bool mustNotOutline(linalg::LinalgOp linalgOp) { - return isa(linalgOp); +/// If `type` is a contiguous memref, return an equivalent memref without any +/// layout attribute. Otherwise, return nullptr. +Type getIdentityLayoutType(Type type) { + auto memRefType = dyn_cast(type); + if (!memRefType) return {}; + if (!isContiguousMemRef(memRefType)) return {}; + return MemRefType::get(memRefType.getShape(), memRefType.getElementType(), + MemRefLayoutAttrInterface{}, + memRefType.getMemorySpace()); } /// Utility to outline the linalg compute op. -static FailureOr outlinedToAFunction( - IRRewriter &rewriter, ModuleOp moduleOp, linalg::LinalgOp computeOp, - std::string outlineFuncName) { - if (auto outlinedFuncOp = dyn_cast_if_present( - moduleOp.lookupSymbol(outlineFuncName))) { - return outlinedFuncOp; - } - +static FailureOr outline(IRRewriter &rewriter, ModuleOp moduleOp, + linalg::LinalgOp computeOp, + const std::string &outlineFuncName) { // Form outlined FunctionType. SmallVector inputTypes = llvm::map_to_vector( - computeOp.getDpsInputs(), [](Value v) { return v.getType(); }); - for (Value val : computeOp.getDpsInits()) inputTypes.push_back(val.getType()); + computeOp.getDpsInputs(), + [&](Value v) { return getIdentityLayoutType(v.getType()); }); + + for (Value val : computeOp.getDpsInits()) + inputTypes.push_back(getIdentityLayoutType(val.getType())); + + // If any of the input types is not set, return failure. + if (llvm::any_of(inputTypes, [](Type t) { return !t; })) + return computeOp.emitOpError( + "has inputs with types that aren't compatible with outlining"); + auto outlinedFuncType = FunctionType::get(rewriter.getContext(), inputTypes, /*outputTypes=*/{}); @@ -78,75 +96,114 @@ static FailureOr outlinedToAFunction( return outlinedFunc; } +/// Utility to check if the linalg op is one we know should not be outlined. +static bool mustNotOutline(linalg::LinalgOp linalgOp) { + return isa(linalgOp); + // TODO(newling) not all remaining ops should be outlined, not even all + // remaining matmuls: below some threshold on size (m*n*k) it's not worth + // outlining (function call overhead). +}; + class AMDAIELinalgFunctionOutliningPass : public impl::AMDAIELinalgFunctionOutliningBase< AMDAIELinalgFunctionOutliningPass> { public: AMDAIELinalgFunctionOutliningPass() = default; void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnOperation() override; -}; -void AMDAIELinalgFunctionOutliningPass::runOnOperation() { - ModuleOp moduleOp = getOperation(); - MLIRContext *context = &getContext(); - IRRewriter rewriter(context); + private: + // Used for unique-ifing the outlined function names. + unsigned outlineCounter = 0; - unsigned uniqueOutlinedMatmul = 0; - unsigned uniqueOutlinedElementwise = 0; - DenseMap computeOpToOutlinedFuncMap; - SmallVector toBeErased; - moduleOp.walk([&](linalg::LinalgOp computeOp) { - if (mustNotOutline(computeOp)) { - return WalkResult::skip(); - } else if (!mustOutline(computeOp)) { - computeOp->emitOpError() << "unsupported linalg op for outlining"; - return WalkResult::interrupt(); - } - // Form outlined function name for matmul/elementwise compute ops. - std::string outlineFuncName = ""; + DenseMap computeOpToOutlinedFuncMap; + + static std::string getSpecializedName(linalg::LinalgOp computeOp) { + // Will result in a function name like `generic_matmul_2_outlined`: + if (isMatmul(computeOp)) return "_matmul_"; + // Will result in a function name like `generic_elementwise_2_outlined`: + if (isElementwise(computeOp)) return "_elementwise_"; + // Will result in a function name like `generic_2_outlined`: + return "_"; + } + + std::string generateFuncName(linalg::LinalgOp computeOp) { + std::string name = computeOp->getName().stripDialect().str() + + getSpecializedName(computeOp) + + std::to_string(outlineCounter) + "_outlined"; + ++outlineCounter; + return name; + } + + FailureOr retrieveOrCreate(IRRewriter &rewriter, + ModuleOp moduleOp, + linalg::LinalgOp computeOp) { // Check if the compute op is equivalent to a previously outlined compute - // op. If yes, we replace the `outlineFuncName` of the current compute op to - // be same as the previous equivalent outlined compute op in order to lookup - // the Symbol table. - for (auto &[op, funcName] : computeOpToOutlinedFuncMap) { + // op. If it is, retrieve and return the function generated for the previous + // compute op. + for (auto &[op, funcOp] : computeOpToOutlinedFuncMap) { if (OperationEquivalence::isEquivalentTo( computeOp.getOperation(), op, OperationEquivalence::ignoreValueEquivalence, /*flags=*/nullptr, OperationEquivalence::IgnoreLocations)) { - outlineFuncName = funcName; - break; + return funcOp; } } - if (outlineFuncName == "") { - std::string computeName = ""; - if (isMatmul(computeOp)) { - computeName = "_matmul_" + std::to_string(uniqueOutlinedMatmul++); - } else if (isElementwise(computeOp)) { - computeName = - "_elementwise_" + std::to_string(uniqueOutlinedElementwise++); - } else { - return WalkResult::skip(); + + std::string outlineFuncName = generateFuncName(computeOp); + while (moduleOp.lookupSymbol(outlineFuncName)) { + outlineFuncName = generateFuncName(computeOp); + } + + FailureOr maybeFuncOp = + outline(rewriter, moduleOp, computeOp, outlineFuncName); + + if (succeeded(maybeFuncOp)) + computeOpToOutlinedFuncMap[computeOp] = maybeFuncOp.value(); + + return maybeFuncOp; + } +}; + +void AMDAIELinalgFunctionOutliningPass::runOnOperation() { + ModuleOp moduleOp = getOperation(); + MLIRContext *context = &getContext(); + IRRewriter rewriter(context); + + SmallVector toBeErased; + WalkResult walkResult = moduleOp.walk([&](linalg::LinalgOp computeOp) { + if (mustNotOutline(computeOp)) return WalkResult::skip(); + + FailureOr maybeFuncOp = + retrieveOrCreate(rewriter, moduleOp, computeOp); + if (failed(maybeFuncOp)) return WalkResult::interrupt(); + func::FuncOp outlinedFuncOp = maybeFuncOp.value(); + + // Create a call into the outlined function. The operands of the compute op + // might need to be cast to a different type to match the outlined function. + { + SmallVector castOperands; + castOperands.reserve(computeOp->getOperands().size()); + rewriter.setInsertionPoint(computeOp); + Location loc = computeOp.getLoc(); + for (auto iter : llvm::enumerate(computeOp->getOperands())) { + Type type = outlinedFuncOp.getArgumentTypes()[iter.index()]; + Value cast = + rewriter.createOrFold(loc, type, iter.value()); + castOperands.push_back(cast); } - outlineFuncName = - computeOp->getName().stripDialect().str() + computeName + "_outlined"; - computeOpToOutlinedFuncMap[computeOp] = outlineFuncName; + rewriter.create(loc, outlinedFuncOp, castOperands); } - FailureOr outlinedFuncOp = - outlinedToAFunction(rewriter, moduleOp, computeOp, outlineFuncName); - if (failed(outlinedFuncOp)) return WalkResult::interrupt(); - rewriter.setInsertionPoint(computeOp); - rewriter.create(computeOp.getLoc(), *outlinedFuncOp, - computeOp->getOperands()); // We cannot immediately erase the compute op because it'd be used for // equivalence check. toBeErased.push_back(computeOp); return WalkResult::advance(); }); + if (walkResult.wasInterrupted()) return signalPassFailure(); for (Operation *op : toBeErased) { op->dropAllUses(); rewriter.eraseOp(op); diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/linalg_function_outlining.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/linalg_function_outlining.mlir index 7dbf256fe..f7968f3ed 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/linalg_function_outlining.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/linalg_function_outlining.mlir @@ -161,6 +161,63 @@ func.func @elemwise_example(%A: memref<4xf32>, %C: memref<4xbf16>, %B: memref<4x // ----- +// This is an example of a linalg.generic which is not a 'known' op, it is +// a pure reduction operation. This should be supported by the outlining pass. +// Note that the first operand A has got an unknown offset, which is +// supported by the outlining pass (the function signature drops the offset, +// this is necessary for LLVM lowering). +// CHECK: func.func private @generic_0_outlined(%arg0: memref<4x8xbf16>, %arg1: memref) { +// CHECK: linalg.generic +// CHECK-SAME: iterator_types = ["reduction", "reduction"] +// CHECK-SAME: ins(%arg0 : memref<4x8xbf16>) outs(%arg1 : memref) +// CHECK: return +// CHECK: func.func @supported_linalg_op(%arg0: memref<4x8xbf16, strided<[8, 1], offset: ?>>, %arg1: memref) { +// CHECK: %[[CAST:.*]] = memref.cast %arg0 : memref<4x8xbf16, strided<[8, 1], offset: ?>> to memref<4x8xbf16> +// CHECK: func.call @generic_0_outlined(%[[CAST]], %arg1) : (memref<4x8xbf16>, memref) -> () +// CHECK: return +func.func @supported_linalg_op(%A: memref<4x8xbf16, strided<[8,1], offset:?>>, %B: memref) { + %c2 = arith.constant 2 : index + %tile = amdaie.tile(%c2, %c2) + %1 = amdaie.core(%tile, in : [], out : []) { + linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> ()>], + iterator_types = ["reduction", "reduction"] + } ins(%A: memref<4x8xbf16, strided<[8,1], offset:?>>) outs(%B : memref) { + ^bb0(%in: bf16, %out: bf16): + linalg.yield %in : bf16 + } + amdaie.end + } + return +} + + +// ----- + +// This is an example of a linalg.generic which cannot be outlined because +// of the layout of one one of the operands. Strided layout is not supported +// in lowering to llvm (AFAIK). +func.func @unsupported_linalg_op(%A: memref<4x8xbf16, strided<[9,1], offset:?>>, %B: memref) { + %c2 = arith.constant 2 : index + %tile = amdaie.tile(%c2, %c2) + %1 = amdaie.core(%tile, in : [], out : []) { + // expected-error@+1 {{'linalg.generic' op has inputs with types that aren't compatible with outlining}} + linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> ()>], + iterator_types = ["reduction", "reduction"] + } ins(%A: memref<4x8xbf16, strided<[9,1], offset:?>>) outs(%B : memref) { + ^bb0(%in: bf16, %out: bf16): + linalg.yield %in : bf16 + } + amdaie.end + } + return +} + +// ----- + // CHECK-LABEL: @linalg_fill_copy func.func @linalg_fill_copy(%A: memref<4xf32>, %B: memref<4xf32>) { %c2 = arith.constant 2 : index @@ -179,31 +236,3 @@ func.func @linalg_fill_copy(%A: memref<4xf32>, %B: memref<4xf32>) { return } -// ----- - -func.func @unsupported_linalg_op(%A: memref<4x8xbf16>, %B: memref<8x4xbf16>, %C: memref<4x4xf32>) { - %c2 = arith.constant 2 : index - %c1 = arith.constant 1 : index - %tile = amdaie.tile(%c1, %c2) - %1 = amdaie.core(%tile, in : [], out : []) { - // expected-error@+1 {{unsupported linalg op for outlining}} - linalg.generic { - indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ], - iterator_types = ["parallel", "parallel", "reduction"] - } ins(%A, %B : memref<4x8xbf16>, memref<8x4xbf16>) - outs(%C : memref<4x4xf32>) { - ^bb0(%in: bf16, %in_17: bf16, %out: f32): - %1 = arith.extf %in : bf16 to f32 - %2 = arith.extf %in_17 : bf16 to f32 - %3 = arith.mulf %1, %2 : f32 - %4 = arith.addf %out, %3 : f32 - %5 = arith.addf %4, %4 : f32 - linalg.yield %5 : f32 - } - amdaie.end - } - return -} From 01b646117089c7ab8b80a6d5998f7e545e1b421d Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 3 Dec 2024 08:32:28 -0800 Subject: [PATCH 2/4] tweaks --- .../AMDAIELinalgFunctionOutlining.cpp | 106 +++++++++-------- .../test/linalg_function_outlining.mlir | 110 ++++++++++++------ 2 files changed, 125 insertions(+), 91 deletions(-) diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp index 384248756..94b5b9898 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp @@ -20,12 +20,12 @@ namespace mlir::iree_compiler::AMDAIE { namespace { -/// Return true if the strides of `memrefType` are contiguous. -bool isContiguousMemRef(MemRefType memrefType) { - ArrayRef shape = memrefType.getShape(); +/// Return true if the strides of `type` make it a contiguous memref. +bool isContiguousMemRef(MemRefType type) { + ArrayRef shape = type.getShape(); SmallVector strides; - int64_t offset; - if (failed(getStridesAndOffset(memrefType, strides, offset))) return false; + int64_t ignoredOffset; + if (failed(getStridesAndOffset(type, strides, ignoredOffset))) return false; int64_t expectedStride = 1; for (int i = shape.size() - 1; i >= 0; --i) { if (shape[i] == ShapedType::kDynamic) return false; @@ -37,7 +37,7 @@ bool isContiguousMemRef(MemRefType memrefType) { /// If `type` is a contiguous memref, return an equivalent memref without any /// layout attribute. Otherwise, return nullptr. -Type getIdentityLayoutType(Type type) { +Type getTypeWithoutLayout(Type type) { auto memRefType = dyn_cast(type); if (!memRefType) return {}; if (!isContiguousMemRef(memRefType)) return {}; @@ -49,51 +49,47 @@ Type getIdentityLayoutType(Type type) { /// Utility to outline the linalg compute op. static FailureOr outline(IRRewriter &rewriter, ModuleOp moduleOp, linalg::LinalgOp computeOp, - const std::string &outlineFuncName) { - // Form outlined FunctionType. - SmallVector inputTypes = llvm::map_to_vector( - computeOp.getDpsInputs(), - [&](Value v) { return getIdentityLayoutType(v.getType()); }); - - for (Value val : computeOp.getDpsInits()) - inputTypes.push_back(getIdentityLayoutType(val.getType())); - - // If any of the input types is not set, return failure. - if (llvm::any_of(inputTypes, [](Type t) { return !t; })) - return computeOp.emitOpError( - "has inputs with types that aren't compatible with outlining"); - - auto outlinedFuncType = + const std::string &funcName) { + // // Form outlined FunctionType. + SmallVector inputTypes; + inputTypes.reserve(computeOp->getOperands().size()); + for (const auto &operand : computeOp->getOperands()) { + Type withoutLayout = getTypeWithoutLayout(operand.getType()); + if (!withoutLayout) { + return computeOp.emitOpError("has an operand of type ") + << operand.getType() << " that isn't compatible with outlining."; + } + inputTypes.push_back(withoutLayout); + } + auto funcType = FunctionType::get(rewriter.getContext(), inputTypes, /*outputTypes=*/{}); // Form outlined FuncSignature rewriter.setInsertionPointToStart(moduleOp.getBody()); - auto outlinedFunc = rewriter.create( - moduleOp.getLoc(), outlineFuncName, outlinedFuncType); - outlinedFunc.setPrivate(); + auto func = + rewriter.create(moduleOp.getLoc(), funcName, funcType); + func.setPrivate(); // Create an entry func block and map the original operands of the compute // op to the block arguments. - Block *outlinedFuncBody = outlinedFunc.addEntryBlock(); - rewriter.setInsertionPointToStart(outlinedFuncBody); - SmallVector outlinedFuncArgs = llvm::map_to_vector( - outlinedFunc.getArguments(), [&](BlockArgument bbArg) { return bbArg; }); + Block *funcBody = func.addEntryBlock(); + rewriter.setInsertionPointToStart(funcBody); + SmallVector funcArgs = llvm::map_to_vector( + func.getArguments(), [&](BlockArgument bbArg) { return bbArg; }); unsigned bbArgIndex = 0; IRMapping operandMap; - for (Value origOperand : computeOp.getDpsInputs()) - operandMap.map(origOperand, outlinedFuncArgs[bbArgIndex++]); - for (Value origOperand : computeOp.getDpsInits()) - operandMap.map(origOperand, outlinedFuncArgs[bbArgIndex++]); + for (Value origOperand : computeOp->getOperands()) + operandMap.map(origOperand, funcArgs[bbArgIndex++]); // Clone the compute op while mapping the operand to the function block // arguments. Operation *clonedComputeOp = rewriter.clone(*computeOp, operandMap); // Create terminator op returning the cloned compute op's results. - rewriter.setInsertionPointToEnd(outlinedFuncBody); + rewriter.setInsertionPointToEnd(funcBody); rewriter.create(clonedComputeOp->getLoc(), ValueRange({})); - return outlinedFunc; + return func; } /// Utility to check if the linalg op is one we know should not be outlined. @@ -101,7 +97,8 @@ static bool mustNotOutline(linalg::LinalgOp linalgOp) { return isa(linalgOp); // TODO(newling) not all remaining ops should be outlined, not even all // remaining matmuls: below some threshold on size (m*n*k) it's not worth - // outlining (function call overhead). + // outlining (function call overhead). We should extend the blacklist + // here. }; class AMDAIELinalgFunctionOutliningPass @@ -116,17 +113,17 @@ class AMDAIELinalgFunctionOutliningPass void runOnOperation() override; private: - // Used for unique-ifing the outlined function names. + // Used for unique-ifing ID for generating new function names. unsigned outlineCounter = 0; DenseMap computeOpToOutlinedFuncMap; static std::string getSpecializedName(linalg::LinalgOp computeOp) { - // Will result in a function name like `generic_matmul_2_outlined`: + // Will result in a function name like `generic_matmul_0_outlined`: if (isMatmul(computeOp)) return "_matmul_"; - // Will result in a function name like `generic_elementwise_2_outlined`: + // Will result in a function name like `generic_elementwise_0_outlined`: if (isElementwise(computeOp)) return "_elementwise_"; - // Will result in a function name like `generic_2_outlined`: + // Will result in a function name like `generic_0_outlined`: return "_"; } @@ -144,27 +141,28 @@ class AMDAIELinalgFunctionOutliningPass // Check if the compute op is equivalent to a previously outlined compute // op. If it is, retrieve and return the function generated for the previous // compute op. - for (auto &[op, funcOp] : computeOpToOutlinedFuncMap) { + for (auto &[op, func] : computeOpToOutlinedFuncMap) { if (OperationEquivalence::isEquivalentTo( computeOp.getOperation(), op, OperationEquivalence::ignoreValueEquivalence, /*flags=*/nullptr, OperationEquivalence::IgnoreLocations)) { - return funcOp; + return func; } } - std::string outlineFuncName = generateFuncName(computeOp); - while (moduleOp.lookupSymbol(outlineFuncName)) { - outlineFuncName = generateFuncName(computeOp); + std::string funcName = generateFuncName(computeOp); + while (moduleOp.lookupSymbol(funcName)) { + funcName = generateFuncName(computeOp); } - FailureOr maybeFuncOp = - outline(rewriter, moduleOp, computeOp, outlineFuncName); + FailureOr maybeFunc = + outline(rewriter, moduleOp, computeOp, funcName); - if (succeeded(maybeFuncOp)) - computeOpToOutlinedFuncMap[computeOp] = maybeFuncOp.value(); + if (succeeded(maybeFunc)) { + computeOpToOutlinedFuncMap[computeOp] = maybeFunc.value(); + } - return maybeFuncOp; + return maybeFunc; } }; @@ -177,10 +175,10 @@ void AMDAIELinalgFunctionOutliningPass::runOnOperation() { WalkResult walkResult = moduleOp.walk([&](linalg::LinalgOp computeOp) { if (mustNotOutline(computeOp)) return WalkResult::skip(); - FailureOr maybeFuncOp = + FailureOr maybeFunc = retrieveOrCreate(rewriter, moduleOp, computeOp); - if (failed(maybeFuncOp)) return WalkResult::interrupt(); - func::FuncOp outlinedFuncOp = maybeFuncOp.value(); + if (failed(maybeFunc)) return WalkResult::interrupt(); + func::FuncOp func = maybeFunc.value(); // Create a call into the outlined function. The operands of the compute op // might need to be cast to a different type to match the outlined function. @@ -190,12 +188,12 @@ void AMDAIELinalgFunctionOutliningPass::runOnOperation() { rewriter.setInsertionPoint(computeOp); Location loc = computeOp.getLoc(); for (auto iter : llvm::enumerate(computeOp->getOperands())) { - Type type = outlinedFuncOp.getArgumentTypes()[iter.index()]; + Type type = func.getArgumentTypes()[iter.index()]; Value cast = rewriter.createOrFold(loc, type, iter.value()); castOperands.push_back(cast); } - rewriter.create(loc, outlinedFuncOp, castOperands); + rewriter.create(loc, func, castOperands); } // We cannot immediately erase the compute op because it'd be used for diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/linalg_function_outlining.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/linalg_function_outlining.mlir index f7968f3ed..73f3207f7 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/linalg_function_outlining.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/linalg_function_outlining.mlir @@ -161,29 +161,49 @@ func.func @elemwise_example(%A: memref<4xf32>, %C: memref<4xbf16>, %B: memref<4x // ----- -// This is an example of a linalg.generic which is not a 'known' op, it is -// a pure reduction operation. This should be supported by the outlining pass. -// Note that the first operand A has got an unknown offset, which is -// supported by the outlining pass (the function signature drops the offset, -// this is necessary for LLVM lowering). -// CHECK: func.func private @generic_0_outlined(%arg0: memref<4x8xbf16>, %arg1: memref) { -// CHECK: linalg.generic -// CHECK-SAME: iterator_types = ["reduction", "reduction"] -// CHECK-SAME: ins(%arg0 : memref<4x8xbf16>) outs(%arg1 : memref) -// CHECK: return -// CHECK: func.func @supported_linalg_op(%arg0: memref<4x8xbf16, strided<[8, 1], offset: ?>>, %arg1: memref) { -// CHECK: %[[CAST:.*]] = memref.cast %arg0 : memref<4x8xbf16, strided<[8, 1], offset: ?>> to memref<4x8xbf16> -// CHECK: func.call @generic_0_outlined(%[[CAST]], %arg1) : (memref<4x8xbf16>, memref) -> () -// CHECK: return -func.func @supported_linalg_op(%A: memref<4x8xbf16, strided<[8,1], offset:?>>, %B: memref) { +// CHECK-LABEL: @linalg_fill_copy +func.func @linalg_fill_copy(%A: memref<4xf32>, %B: memref<4xf32>) { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0 : f32 + %tile = amdaie.tile(%c1, %c2) + %0 = amdaie.core(%tile, in : [], out : []) { + // CHECK: linalg.fill + // CHECK-NOT: func.call @fill_elementwise_0_outlined + // CHECK: linalg.copy + // CHECK-NOT: func.call @copy_elementwise_1_outlined + linalg.fill ins(%cst : f32) outs(%A : memref<4xf32>) + linalg.copy ins(%A : memref<4xf32>) outs(%B : memref<4xf32>) + amdaie.end + } + return +} + +// ----- + +// Test demonstrating the outlining of a linalg.generic operation other than +// a matmul or elementwise operation. Specifically, one which has not been +// 'blacklisted' like linalg.copy has (see test linalg_fill_copy above). +// CHECK: func.func private @generic_0_outlined +// CHECK-SAME: memref<4xbf16>, +// CHECK-SAME: memref +// CHECK: linalg.generic +// CHECK-SAME: iterator_types = ["reduction"] +// CHECK: return +// CHECK: func.func @reduction +// CHECK-SAME: memref<4xbf16> +// CHECK-SAME: memref +// CHECK: func.call @generic_0_outlined +// CHECK-SAME: (memref<4xbf16>, memref) -> () +// CHECK: return +func.func @reduction(%A: memref<4xbf16>, %B: memref) { %c2 = arith.constant 2 : index %tile = amdaie.tile(%c2, %c2) %1 = amdaie.core(%tile, in : [], out : []) { linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> ()>], - iterator_types = ["reduction", "reduction"] - } ins(%A: memref<4x8xbf16, strided<[8,1], offset:?>>) outs(%B : memref) { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], + iterator_types = ["reduction"] + } ins(%A: memref<4xbf16>) outs(%B : memref) { ^bb0(%in: bf16, %out: bf16): linalg.yield %in : bf16 } @@ -195,19 +215,32 @@ func.func @supported_linalg_op(%A: memref<4x8xbf16, strided<[8,1], offset:?>>, % // ----- -// This is an example of a linalg.generic which cannot be outlined because -// of the layout of one one of the operands. Strided layout is not supported -// in lowering to llvm (AFAIK). -func.func @unsupported_linalg_op(%A: memref<4x8xbf16, strided<[9,1], offset:?>>, %B: memref) { +// Test demonstrating the outlining of a linalg.generic where one +// operand has an unkown offset. The memref is still contiguous, however. +// CHECK: func.func private @generic_0_outlined +// CHECK-SAME: memref<4x8xbf16>, +// CHECK-SAME: memref +// CHECK: linalg.generic +// CHECK-SAME: iterator_types = ["reduction", "reduction"] +// CHECK: return +// CHECK: func.func @supported_linalg_op +// CHECK-SAME: memref<4x8xbf16, strided<[8, 1], offset: ?>> +// CHECK-SAME: memref +// CHECK: %[[CAST:.*]] = memref.cast +// CHECK-SAME: memref<4x8xbf16, strided<[8, 1], offset: ?>> +// CHECK-SAME: to memref<4x8xbf16> +// CHECK: func.call @generic_0_outlined(%[[CAST]], %arg1) : +// CHECK-SAME: (memref<4x8xbf16>, memref) -> () +// CHECK: return +func.func @supported_linalg_op(%A: memref<4x8xbf16, strided<[8,1], offset:?>>, %B: memref) { %c2 = arith.constant 2 : index %tile = amdaie.tile(%c2, %c2) %1 = amdaie.core(%tile, in : [], out : []) { - // expected-error@+1 {{'linalg.generic' op has inputs with types that aren't compatible with outlining}} linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> ()>], iterator_types = ["reduction", "reduction"] - } ins(%A: memref<4x8xbf16, strided<[9,1], offset:?>>) outs(%B : memref) { + } ins(%A: memref<4x8xbf16, strided<[8,1], offset:?>>) outs(%B : memref) { ^bb0(%in: bf16, %out: bf16): linalg.yield %in : bf16 } @@ -216,21 +249,24 @@ func.func @unsupported_linalg_op(%A: memref<4x8xbf16, strided<[9,1], offset:?>>, return } + // ----- -// CHECK-LABEL: @linalg_fill_copy -func.func @linalg_fill_copy(%A: memref<4xf32>, %B: memref<4xf32>) { +// Test illustrating the error message when a linalg.generic operation has an +// operand that is not contiguous. This is currently unsupported. +func.func @unsupported_linalg_op(%A: memref<4x8xbf16, strided<[9,1]>>, %B: memref) { %c2 = arith.constant 2 : index - %c1 = arith.constant 1 : index - %cst = arith.constant 0.0 : f32 - %tile = amdaie.tile(%c1, %c2) - %0 = amdaie.core(%tile, in : [], out : []) { - // CHECK: linalg.fill - // CHECK-NOT: func.call @fill_elementwise_0_outlined - // CHECK: linalg.copy - // CHECK-NOT: func.call @copy_elementwise_1_outlined - linalg.fill ins(%cst : f32) outs(%A : memref<4xf32>) - linalg.copy ins(%A : memref<4xf32>) outs(%B : memref<4xf32>) + %tile = amdaie.tile(%c2, %c2) + %1 = amdaie.core(%tile, in : [], out : []) { + // expected-error@+1 {{'linalg.generic' op has an operand of type 'memref<4x8xbf16, strided<[9, 1]>>' that isn't compatible with outlining.}} + linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> ()>], + iterator_types = ["reduction", "reduction"] + } ins(%A: memref<4x8xbf16, strided<[9,1]>>) outs(%B : memref) { + ^bb0(%in: bf16, %out: bf16): + linalg.yield %in : bf16 + } amdaie.end } return From 5548e5b6605cf25949d4ed93037798ef13eb90c3 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 3 Dec 2024 09:39:02 -0800 Subject: [PATCH 3/4] fix typo --- .../iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp index 94b5b9898..179827677 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp @@ -50,7 +50,7 @@ Type getTypeWithoutLayout(Type type) { static FailureOr outline(IRRewriter &rewriter, ModuleOp moduleOp, linalg::LinalgOp computeOp, const std::string &funcName) { - // // Form outlined FunctionType. + // Form outlined FunctionType. SmallVector inputTypes; inputTypes.reserve(computeOp->getOperands().size()); for (const auto &operand : computeOp->getOperands()) { @@ -64,7 +64,7 @@ static FailureOr outline(IRRewriter &rewriter, ModuleOp moduleOp, auto funcType = FunctionType::get(rewriter.getContext(), inputTypes, /*outputTypes=*/{}); - // Form outlined FuncSignature + // Form outlined FuncSignature. rewriter.setInsertionPointToStart(moduleOp.getBody()); auto func = rewriter.create(moduleOp.getLoc(), funcName, funcType); From db72a6c580d160c6f2a2e0c83a398202a0580715 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 3 Dec 2024 21:53:41 -0800 Subject: [PATCH 4/4] do not error if strides --- .../Transforms/AMDAIELinalgFunctionOutlining.cpp | 10 ++++------ .../Transforms/test/linalg_function_outlining.mlir | 8 +++++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp index 179827677..68d79ec13 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp @@ -55,10 +55,9 @@ static FailureOr outline(IRRewriter &rewriter, ModuleOp moduleOp, inputTypes.reserve(computeOp->getOperands().size()); for (const auto &operand : computeOp->getOperands()) { Type withoutLayout = getTypeWithoutLayout(operand.getType()); - if (!withoutLayout) { - return computeOp.emitOpError("has an operand of type ") - << operand.getType() << " that isn't compatible with outlining."; - } + // The op has an operand with a layout that isn't compatible with outlining. + // We currently don't support strided layouts. + if (!withoutLayout) return failure(); inputTypes.push_back(withoutLayout); } auto funcType = @@ -172,7 +171,7 @@ void AMDAIELinalgFunctionOutliningPass::runOnOperation() { IRRewriter rewriter(context); SmallVector toBeErased; - WalkResult walkResult = moduleOp.walk([&](linalg::LinalgOp computeOp) { + moduleOp.walk([&](linalg::LinalgOp computeOp) { if (mustNotOutline(computeOp)) return WalkResult::skip(); FailureOr maybeFunc = @@ -201,7 +200,6 @@ void AMDAIELinalgFunctionOutliningPass::runOnOperation() { toBeErased.push_back(computeOp); return WalkResult::advance(); }); - if (walkResult.wasInterrupted()) return signalPassFailure(); for (Operation *op : toBeErased) { op->dropAllUses(); rewriter.eraseOp(op); diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/linalg_function_outlining.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/linalg_function_outlining.mlir index 73f3207f7..1d953d67e 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/linalg_function_outlining.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/linalg_function_outlining.mlir @@ -252,13 +252,15 @@ func.func @supported_linalg_op(%A: memref<4x8xbf16, strided<[8,1], offset:?>>, % // ----- -// Test illustrating the error message when a linalg.generic operation has an -// operand that is not contiguous. This is currently unsupported. +// Test illustrating the that when a linalg.generic operation has an +// operand that is not contiguous, it is not outlined. + +// CHECK-COUNT-1: func.func +// CHECK-NOT: func.func func.func @unsupported_linalg_op(%A: memref<4x8xbf16, strided<[9,1]>>, %B: memref) { %c2 = arith.constant 2 : index %tile = amdaie.tile(%c2, %c2) %1 = amdaie.core(%tile, in : [], out : []) { - // expected-error@+1 {{'linalg.generic' op has an operand of type 'memref<4x8xbf16, strided<[9, 1]>>' that isn't compatible with outlining.}} linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> ()>],