Skip to content

Commit

Permalink
update aie test
Browse files Browse the repository at this point in the history
  • Loading branch information
newling committed Feb 19, 2025
1 parent d2ccc89 commit 314497f
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 97 deletions.
22 changes: 21 additions & 1 deletion compiler/plugins/target/AMD-AIE/aie/AMDAIECoreToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "AIEDialect.h"
#include "Passes.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
Expand Down Expand Up @@ -192,6 +193,22 @@ struct AMDAIECoreToStandardPass
return true;
}

// Ensure that all aie.core ops are isolated from above, i.e. that all
// operands of ops within an aie.core are produced inside the aie.core (or are
// block arguments of the core). The expection is ops in the aie dialect --
// operands produced by for example an aie.buffer may be outside the core.
static void isolateCores(ModuleOp m) {
IRRewriter rewriter(m->getContext());
auto notAieDialect = [](Operation *op) -> bool {
StringRef dialect = op->getDialect()->getNamespace();
if (dialect == AIEDialect::getDialectNamespace()) return false;
return true;
};
m->walk([&](CoreOp coreOp) {
sinkInto(coreOp.getRegion(), rewriter, notAieDialect);
});
}

void runOnOperation() override {
ModuleOp m = getOperation();

Expand Down Expand Up @@ -222,8 +239,11 @@ struct AMDAIECoreToStandardPass
m->setAttr(LLVM::LLVMDialect::getTargetTripleAttrName(),
rewriter.getStringAttr(targetArchStr));

if (failed(lockToStd(rewriter, m, targetArchStr)))
isolateCores(m);

if (failed(lockToStd(rewriter, m, targetArchStr))) {
return signalPassFailure();
}

m.walk([&](BufferOp buffer) { bufferToStd(m, buffer, rewriter); });

Expand Down
2 changes: 2 additions & 0 deletions compiler/plugins/target/AMD-AIE/aie/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ iree_cc_library(
MLIRMemRefDialect
MLIRIR
MLIREmitCDialect
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
)

###############################################################################
Expand Down
62 changes: 43 additions & 19 deletions compiler/plugins/target/AMD-AIE/aie/test/lower_buffer.mlir
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
// RUN: iree-opt --amdaie-standard-lowering %s | FileCheck %s
// RUN: iree-opt --amdaie-standard-lowering --split-input-file %s | FileCheck %s

// CHECK: memref.global "public" @a : memref<4xi32>
// CHECK-LABEL: func.func @core_4_3() {
// CHECK-LABEL: @basic_test
// CHECK-DAG: memref.global "public" @a : memref<4xi32>
// CHECK-DAG: memref.global "public" @b : memref<4xi32>
// CHECK: func.func @core_3_4() {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_0:.*]] = memref.get_global @a : memref<4xi32>
// CHECK: %[[VAL_0:.*]] = memref.get_global @b : memref<4xi32>
// CHECK: memref.assume_alignment %[[VAL_0]], 32 : memref<4xi32>
// CHECK: %[[VAL_1:.*]] = memref.load %[[VAL_0]]{{\[}}%[[C0]]] : memref<4xi32>
// CHECK: return
Expand All @@ -18,22 +20,44 @@
// CHECK: return
// CHECK: }

module @codegen1 {
aie.device(xcvc1902) {
%t33 = aie.tile(3, 3)
%a = aie.buffer(%t33) { sym_name = "a" } : memref<4xi32>
%core33 = aie.core(%t33) {
%0 = arith.constant 0 : index
%377 = arith.constant 377 : i32
memref.store %377, %a[%0] : memref<4xi32>
aie.end
module @basic_test {
aie.device(xcvc1902) {
%tile_3_3 = aie.tile(3, 3)
%buffer_3_3 = aie.buffer(%tile_3_3) {sym_name = "a"} : memref<4xi32>
%core_3_3 = aie.core(%tile_3_3) {
%c0 = arith.constant 0 : index
%c377_i32 = arith.constant 377 : i32
memref.store %c377_i32, %buffer_3_3[%c0] : memref<4xi32>
aie.end
}
%tile_3_4 = aie.tile(3, 4)
%buffer_3_4 = aie.buffer(%tile_3_4) {sym_name = "b"} : memref<4xi32>
%core_3_4 = aie.core(%tile_3_4) {
%c0 = arith.constant 0 : index
%0 = memref.load %buffer_3_4[%c0] : memref<4xi32>
aie.end
}
}
%t34 = aie.tile(4, 3)
}

// -----

// CHECK: func.func @core_4_3() {
// CHECK-DAG: %[[C44:.*]] = arith.constant 44 : index
// CHECK-DAG: %[[VAL_0:.*]] = memref.get_global @a : memref<4xi32>
// CHECK: %[[VAL_1:.*]] = memref.load %[[VAL_0]]{{\[}}%[[C44]]] : memref<4xi32>
// CHECK: return
// CHECK: }

%core34 = aie.core(%t34) {
%0 = arith.constant 0 : index
%1 = memref.load %a[%0] : memref<4xi32>
aie.end
// Check that the constant 44 is hoisted into the core/function.
module @isolation_test {
aie.device(xcvc1902) {
%tile_4_3 = aie.tile(4, 3)
%c44 = arith.constant 44 : index
%buffer_4_3 = aie.buffer(%tile_4_3) {sym_name = "a"} : memref<4xi32>
%core_4_3 = aie.core(%tile_4_3) {
%0 = memref.load %buffer_4_3[%c44] : memref<4xi32>
aie.end
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@
#include "iree-amd-aie/IR/AMDAIEDialect.h"
#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/Transforms/Passes.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"

#define DEBUG_TYPE "iree-amdaie-sink-into-core"
Expand All @@ -23,62 +22,6 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

bool sinkInto(AMDAIE::CoreOp coreOp, PatternRewriter &rewriter) {
// Record if any ops are sunk into the core during this iteration.
bool changed = false;

// Collect all ops in the amdaie.core op
SmallVector<Operation *> opsInCore;
coreOp->walk([&](Operation *op) {
if (op == coreOp) return WalkResult::advance();
opsInCore.push_back(op);
return WalkResult::advance();
});

for (auto opInCore : opsInCore) {
for (Value operand : opInCore->getOperands()) {
if (!operand || !operand.getDefiningOp()) continue;
Operation *dependencyOp = operand.getDefiningOp();

// Skip if the dependency is already in the core.
if (coreOp->isAncestor(dependencyOp)) {
continue;
}

// Ops in the amdaie dialect are probably related to data movement
// and should not be sunk into the core. This might need adjustment
// later.
if (dependencyOp->getDialect()->getNamespace() ==
AMDAIE::AMDAIEDialect::getDialectNamespace()) {
continue;
}

// Create a clone of the dependency op in the core region.
Region &r = coreOp->getRegion(0);
assert(r.getBlocks().size() == 1 && "expected single block region");
rewriter.setInsertionPointToStart(&r.front());
Operation *sunkOp = rewriter.clone(*dependencyOp);

// Replace uses of the dependency op inside the core.
dependencyOp->replaceUsesWithIf(sunkOp, [&](OpOperand &use) {
return coreOp->isAncestor(use.getOwner());
});
changed = true;
}
}
return changed;
}

class SinkingPattern : public OpRewritePattern<AMDAIE::CoreOp> {
public:
using OpRewritePattern<AMDAIE::CoreOp>::OpRewritePattern;

LogicalResult matchAndRewrite(AMDAIE::CoreOp coreOp,
PatternRewriter &rewriter) const override {
return success(sinkInto(coreOp, rewriter));
}
};

class AMDAIESinkIntoCorePass
: public impl::AMDAIESinkIntoCoreBase<AMDAIESinkIntoCorePass> {
public:
Expand All @@ -87,10 +30,23 @@ class AMDAIESinkIntoCorePass
xilinx::AIE::AIEDialect, AMDAIE::AMDAIEDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.insert<SinkingPattern>(&getContext());
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
auto shouldSink = [&](Operation *op) -> bool {
// Ops in the amdaie dialect are probably related to data movement
// and should not be sunk into the core. This might need adjustment
// later.
if (op->getDialect()->getNamespace() ==
AMDAIE::AMDAIEDialect::getDialectNamespace()) {
return false;
}
return true;
};
IRRewriter rewriter(getOperation());
SmallVector<AMDAIE::CoreOp> coreOps;

getOperation()->walk(
[&](AMDAIE::CoreOp coreOp) { coreOps.push_back(coreOp); });
for (auto coreOp : coreOps) {
sinkInto(coreOp.getRegion(), rewriter, shouldSink);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,4 +397,47 @@ int detail::findLargestFactor(int num, int max, int multiple) {
return factor ? factor : detail::findLargestFactor(num, max);
}

bool sinkInto(Region &region, IRRewriter &rewriter,
std::function<bool(Operation *)> shouldSink) {
Operation *parentOfRegion = region.getParentOp();
assert(parentOfRegion && "Region has no parent operation");
if (region.getBlocks().empty()) return false;
bool regionChanged = false;
for (Block &block : region.getBlocks()) {
// Collect all ops in the block.
SmallVector<Operation *> ops;
SmallVector<Operation *> nextIterationOps;
block.walk([&](Operation *op) { ops.push_back(op); });
while (!ops.empty()) {
for (Operation *op : ops) {
for (Value operand : op->getOperands()) {
if (!operand || !operand.getDefiningOp()) continue;
Operation *dependencyOp = operand.getDefiningOp();
// Skip if the dependency is already in the core.
if (parentOfRegion->isAncestor(dependencyOp)) continue;
if (!shouldSink(dependencyOp)) continue;
rewriter.setInsertionPointToStart(&block);
Operation *sunkOp = rewriter.clone(*dependencyOp);
nextIterationOps.push_back(sunkOp);
// Replace uses of the dependency op inside the block. Specifically,
// if `use` is in `block` then replace its operand with `sunkOp`.
auto isInBlock = [&block](OpOperand &use) {
auto op = use.getOwner();
while (op) {
if (op->getBlock() == &block) return true;
op = op->getParentOp();
}
return false;
};
dependencyOp->replaceUsesWithIf(sunkOp, isInBlock);
regionChanged = true;
}
}
std::swap(ops, nextIterationOps);
nextIterationOps.clear();
}
}
return regionChanged;
}

} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,20 @@ std::string getArrayString(ArrayRef<T> vs) {
/// "[not constant integers]".
std::string getConstantIntValuesString(ArrayRef<OpFoldResult> opFoldResults);

/// Consider all operations in the region, recursively. If the operation
/// has an operand that is not in the region, and the `shouldSink` function
/// returns true for that operand's producer, then replace all uses of the
/// operand inside the region with a clone of the operand in the block.
///
/// If `shouldSink` returns true for all operations, then this function will
/// make the region isolated from above. So this function essentially makes
/// the region isolated from above with respect to the set of operation types
/// defined by `shouldSink`.
///
/// \return true if the region was changed.
bool sinkInto(Region &, IRRewriter &,
std::function<bool(Operation *)> shouldSink);

} // namespace mlir::iree_compiler::AMDAIE

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
module {
// CHECK-LABEL: func @sink_into_single_core
func.func @sink_into_single_core(%arg0: index) {
// CHECK-NOT: arith.constant 3 : index
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%0 = arith.addi %arg0, %c3 : index
%tile = amdaie.tile(%c0, %c2)
// CHECK: amdaie.core
%1 = amdaie.core(%tile, in : [], out : []) {
// CHECK: arith.constant 3 : index
// CHECK: arith.addi
// CHECK: linalg.fill
// CHECK-NEXT: %[[C3:.*]] = arith.constant 3 : index
// CHECK-NEXT: %[[ADD:.*]] = arith.addi %arg0, %[[C3]] : index
// CHECK: linalg.fill ins(%[[ADD]] : index)
%alloc = memref.alloc() : memref<2x2xindex>
linalg.fill ins(%0 : index) outs(%alloc : memref<2x2xindex>)
amdaie.end
Expand All @@ -25,16 +24,8 @@ module {
// -----

module {
// Constants 0 and 1 are cloned into the cores, but not removed, because
// they are still used outside of the cores. Constants 2 and 3 are used only
// inside the cores, so they are cloned into the cores but then removed from
// the outer function.
// CHECK-LABEL: func @sink_into_pair_of_cores
func.func @sink_into_pair_of_cores(%arg0 : index) {
// CHECK-NOT: arith.constant 3 : index
// CHECK-NOT: arith.constant 2 : index
// CHECK-DAG: arith.constant 1 : index
// CHECK-DAG: arith.constant 0 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
Expand All @@ -43,9 +34,14 @@ module {
%tile_0 = amdaie.tile(%c0, %c1)
// CHECK: amdaie.core
%0 = amdaie.core(%tile, in : [], out : []) {
// CHECK-DAG: arith.constant 3 : index
// CHECK-DAG: arith.constant 2 : index
// CHECK-DAG: arith.constant 1 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[A0:.*]] = arith.addi %arg0, %[[C1]] : index
// CHECK: %[[A1:.*]] = arith.addi %[[C1]], %[[A0]] : index
// CHECK: %[[A2:.*]] = arith.addi %[[A1]], %[[C2]] : index
// CHECK: %[[A3:.*]] = arith.addi %[[A2]], %[[C3]] : index
// CHECK: linalg.fill ins(%[[A3]] : index)
%1 = arith.addi %arg0, %c1 : index
%2 = arith.addi %c1, %1 : index
%3 = arith.addi %2, %c2 : index
Expand Down

0 comments on commit 314497f

Please sign in to comment.