Skip to content

Commit

Permalink
Merge branch 'main' into zhewen_bdassign
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu-Zhewen authored Dec 3, 2024
2 parents e0e5e30 + ab64bca commit 05afbbb
Showing 1 changed file with 41 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,7 @@ LogicalResult setDmaInputs(Operation *&operandOp,
template <typename PackOrUnpackOp>
LogicalResult rewriteAsDma(IRRewriter &rewriter, PackOrUnpackOp op, Value input,
Value output, llvm::ArrayRef<int64_t> innerTiles,
bool packTransposeOnSource,
bool unpackTransposeOnSource) {
bool transposeOnSource) {
if (llvm::any_of(innerTiles,
[](int64_t size) { return ShapedType::isDynamic(size); })) {
op->emitError("has a non-static shape: not yet supported by this pass.");
Expand All @@ -269,43 +268,42 @@ LogicalResult rewriteAsDma(IRRewriter &rewriter, PackOrUnpackOp op, Value input,

// Prepare source DMA inputs.
SmallVector<OpFoldResult> srcOffsets;
SmallVector<OpFoldResult> srcBaseStrides;
SmallVector<OpFoldResult> srcStrides;
SmallVector<OpFoldResult> srcShape;

if (!succeeded(
setDmaInputs(sourceOp, srcOffsets, srcShape, srcBaseStrides))) {
if (failed(setDmaInputs(sourceOp, srcOffsets, srcShape, srcStrides))) {
return failure();
}

// Prepare destination DMA inputs.
SmallVector<OpFoldResult> dstOffsets;
SmallVector<OpFoldResult> dstBaseStrides;
SmallVector<OpFoldResult> dstStrides;
SmallVector<OpFoldResult> dstShape;
if (!succeeded(setDmaInputs(dstOp, dstOffsets, dstShape, dstBaseStrides))) {
if (failed(setDmaInputs(dstOp, dstOffsets, dstShape, dstStrides))) {
return failure();
}

// Update dma source or destination addressing based on the side for dma
// transposition and pack/unpack operations.
if (packTransposeOnSource && isa<IREE::LinalgExt::PackOp>(op)) {
if (!succeeded(dmaTransposeOnLowerNumDims(op, srcOffsets, srcShape,
srcBaseStrides)))
return failure();
} else if (!packTransposeOnSource && isa<IREE::LinalgExt::PackOp>(op)) {
if (!succeeded(dmaTransposeOnHigherNumDims(op, dstOffsets, dstShape,
dstBaseStrides)))
return failure();
} else if (unpackTransposeOnSource && isa<IREE::LinalgExt::UnPackOp>(op)) {
if (!succeeded(dmaTransposeOnHigherNumDims(op, srcOffsets, srcShape,
srcBaseStrides)))
return failure();
} else if (!unpackTransposeOnSource && isa<IREE::LinalgExt::UnPackOp>(op)) {
if (!succeeded(dmaTransposeOnLowerNumDims(op, dstOffsets, dstShape,
dstBaseStrides)))
return failure();
} else {
op->emitError("unhandled option for dma addressing update.");
return failure();
// transposition.
{
SmallVector<OpFoldResult> &offsets =
transposeOnSource ? srcOffsets : dstOffsets;

SmallVector<OpFoldResult> &shape = transposeOnSource ? srcShape : dstShape;

SmallVector<OpFoldResult> &strides =
transposeOnSource ? srcStrides : dstStrides;

bool sourceIsHigherDim = dstStrides.size() <= srcStrides.size();

if (sourceIsHigherDim == transposeOnSource) {
if (failed(dmaTransposeOnHigherNumDims(op, offsets, shape, strides))) {
return failure();
}
} else {
if (failed(dmaTransposeOnLowerNumDims(op, offsets, shape, strides))) {
return failure();
}
}
}

// Create logical objectFifos from source and destination memrefs.
Expand All @@ -317,27 +315,27 @@ LogicalResult rewriteAsDma(IRRewriter &rewriter, PackOrUnpackOp op, Value input,
rewriter.setInsertionPointAfter(srcVal.getDefiningOp());
auto src = rewriter.create<AMDAIE::LogicalObjectFifoFromMemrefOp>(
rewriter.getUnknownLoc(), LogicalObjectFifoType::get(srcType), srcVal);

rewriter.setInsertionPointAfter(dstVal.getDefiningOp());
auto dst = rewriter.create<AMDAIE::LogicalObjectFifoFromMemrefOp>(
rewriter.getUnknownLoc(), LogicalObjectFifoType::get(dstType), dstVal);

rewriter.setInsertionPoint(op);
rewriter.create<AMDAIE::DmaCpyNdOp>(op->getLoc(), dst, dstOffsets, dstShape,
dstBaseStrides, src, srcOffsets, srcShape,
srcBaseStrides);
dstStrides, src, srcOffsets, srcShape,
srcStrides);
rewriter.eraseOp(op);
return success();
}

template <typename PackOrUnpackOp>
LogicalResult rewriteAsDma(PackOrUnpackOp op, IRRewriter &rewriter,
bool packTransposeOnSource,
bool unpackTransposeOnSource) {
bool tranposeOnSource) {
Value input = op.getInput();
Value output = op.getOutput();
llvm::ArrayRef<int64_t> innerTiles = op.getStaticInnerTiles();
return rewriteAsDma(rewriter, op, input, output, innerTiles,
packTransposeOnSource, unpackTransposeOnSource);
tranposeOnSource);
}

/// Convert a linalg.copy operation on 2 memrefs to an equivalent pack/unpack
Expand Down Expand Up @@ -413,23 +411,21 @@ void AMDAIEConvertToDmaPass::runOnOperation() {
});
if (convertCopiesWalkResult.wasInterrupted()) return signalPassFailure();

auto walkResult =
getOperation()->walk([&](IREE::LinalgExt::PackOp op) -> WalkResult {
if (failed(rewriteAsDma(op, rewriter, packTransposeOnSource,
unpackTransposeOnSource))) {
WalkResult walkResult =
getOperation()->walk([&](IREE::LinalgExt::PackOp packOp) {
if (failed(rewriteAsDma(packOp, rewriter, packTransposeOnSource))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted()) signalPassFailure();
walkResult = getOperation()->walk(
[&](IREE::LinalgExt::UnPackOp op) -> WalkResult {
if (failed(rewriteAsDma(op, rewriter, packTransposeOnSource,
unpackTransposeOnSource))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});

walkResult = getOperation()->walk([&](IREE::LinalgExt::UnPackOp unpackOp) {
if (failed(rewriteAsDma(unpackOp, rewriter, unpackTransposeOnSource))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted()) signalPassFailure();
}

Expand Down

0 comments on commit 05afbbb

Please sign in to comment.