Skip to content

Commit

Permalink
[AMDAIEConvertToDma] Small code simplification (#949)
Browse files Browse the repository at this point in the history
Simplification to not pass more flags into functions than are actually
used (logic for pack vs unpack handled before function call).
  • Loading branch information
newling authored Dec 3, 2024
1 parent 1c7e406 commit ab64bca
Showing 1 changed file with 41 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,7 @@ LogicalResult setDmaInputs(Operation *&operandOp,
template <typename PackOrUnpackOp>
LogicalResult rewriteAsDma(IRRewriter &rewriter, PackOrUnpackOp op, Value input,
Value output, llvm::ArrayRef<int64_t> innerTiles,
bool packTransposeOnSource,
bool unpackTransposeOnSource) {
bool transposeOnSource) {
if (llvm::any_of(innerTiles,
[](int64_t size) { return ShapedType::isDynamic(size); })) {
op->emitError("has a non-static shape: not yet supported by this pass.");
Expand All @@ -269,43 +268,42 @@ LogicalResult rewriteAsDma(IRRewriter &rewriter, PackOrUnpackOp op, Value input,

// Prepare source DMA inputs.
SmallVector<OpFoldResult> srcOffsets;
SmallVector<OpFoldResult> srcBaseStrides;
SmallVector<OpFoldResult> srcStrides;
SmallVector<OpFoldResult> srcShape;

if (!succeeded(
setDmaInputs(sourceOp, srcOffsets, srcShape, srcBaseStrides))) {
if (failed(setDmaInputs(sourceOp, srcOffsets, srcShape, srcStrides))) {
return failure();
}

// Prepare destination DMA inputs.
SmallVector<OpFoldResult> dstOffsets;
SmallVector<OpFoldResult> dstBaseStrides;
SmallVector<OpFoldResult> dstStrides;
SmallVector<OpFoldResult> dstShape;
if (!succeeded(setDmaInputs(dstOp, dstOffsets, dstShape, dstBaseStrides))) {
if (failed(setDmaInputs(dstOp, dstOffsets, dstShape, dstStrides))) {
return failure();
}

// Update dma source or destination addressing based on the side for dma
// transposition and pack/unpack operations.
if (packTransposeOnSource && isa<IREE::LinalgExt::PackOp>(op)) {
if (!succeeded(dmaTransposeOnLowerNumDims(op, srcOffsets, srcShape,
srcBaseStrides)))
return failure();
} else if (!packTransposeOnSource && isa<IREE::LinalgExt::PackOp>(op)) {
if (!succeeded(dmaTransposeOnHigherNumDims(op, dstOffsets, dstShape,
dstBaseStrides)))
return failure();
} else if (unpackTransposeOnSource && isa<IREE::LinalgExt::UnPackOp>(op)) {
if (!succeeded(dmaTransposeOnHigherNumDims(op, srcOffsets, srcShape,
srcBaseStrides)))
return failure();
} else if (!unpackTransposeOnSource && isa<IREE::LinalgExt::UnPackOp>(op)) {
if (!succeeded(dmaTransposeOnLowerNumDims(op, dstOffsets, dstShape,
dstBaseStrides)))
return failure();
} else {
op->emitError("unhandled option for dma addressing update.");
return failure();
// transposition.
{
SmallVector<OpFoldResult> &offsets =
transposeOnSource ? srcOffsets : dstOffsets;

SmallVector<OpFoldResult> &shape = transposeOnSource ? srcShape : dstShape;

SmallVector<OpFoldResult> &strides =
transposeOnSource ? srcStrides : dstStrides;

bool sourceIsHigherDim = dstStrides.size() <= srcStrides.size();

if (sourceIsHigherDim == transposeOnSource) {
if (failed(dmaTransposeOnHigherNumDims(op, offsets, shape, strides))) {
return failure();
}
} else {
if (failed(dmaTransposeOnLowerNumDims(op, offsets, shape, strides))) {
return failure();
}
}
}

// Create logical objectFifos from source and destination memrefs.
Expand All @@ -317,27 +315,27 @@ LogicalResult rewriteAsDma(IRRewriter &rewriter, PackOrUnpackOp op, Value input,
rewriter.setInsertionPointAfter(srcVal.getDefiningOp());
auto src = rewriter.create<AMDAIE::LogicalObjectFifoFromMemrefOp>(
rewriter.getUnknownLoc(), LogicalObjectFifoType::get(srcType), srcVal);

rewriter.setInsertionPointAfter(dstVal.getDefiningOp());
auto dst = rewriter.create<AMDAIE::LogicalObjectFifoFromMemrefOp>(
rewriter.getUnknownLoc(), LogicalObjectFifoType::get(dstType), dstVal);

rewriter.setInsertionPoint(op);
rewriter.create<AMDAIE::DmaCpyNdOp>(op->getLoc(), dst, dstOffsets, dstShape,
dstBaseStrides, src, srcOffsets, srcShape,
srcBaseStrides);
dstStrides, src, srcOffsets, srcShape,
srcStrides);
rewriter.eraseOp(op);
return success();
}

template <typename PackOrUnpackOp>
LogicalResult rewriteAsDma(PackOrUnpackOp op, IRRewriter &rewriter,
bool packTransposeOnSource,
bool unpackTransposeOnSource) {
bool tranposeOnSource) {
Value input = op.getInput();
Value output = op.getOutput();
llvm::ArrayRef<int64_t> innerTiles = op.getStaticInnerTiles();
return rewriteAsDma(rewriter, op, input, output, innerTiles,
packTransposeOnSource, unpackTransposeOnSource);
tranposeOnSource);
}

/// Convert a linalg.copy operation on 2 memrefs to an equivalent pack/unpack
Expand Down Expand Up @@ -413,23 +411,21 @@ void AMDAIEConvertToDmaPass::runOnOperation() {
});
if (convertCopiesWalkResult.wasInterrupted()) return signalPassFailure();

auto walkResult =
getOperation()->walk([&](IREE::LinalgExt::PackOp op) -> WalkResult {
if (failed(rewriteAsDma(op, rewriter, packTransposeOnSource,
unpackTransposeOnSource))) {
WalkResult walkResult =
getOperation()->walk([&](IREE::LinalgExt::PackOp packOp) {
if (failed(rewriteAsDma(packOp, rewriter, packTransposeOnSource))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted()) signalPassFailure();
walkResult = getOperation()->walk(
[&](IREE::LinalgExt::UnPackOp op) -> WalkResult {
if (failed(rewriteAsDma(op, rewriter, packTransposeOnSource,
unpackTransposeOnSource))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});

walkResult = getOperation()->walk([&](IREE::LinalgExt::UnPackOp unpackOp) {
if (failed(rewriteAsDma(unpackOp, rewriter, unpackTransposeOnSource))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted()) signalPassFailure();
}

Expand Down

0 comments on commit ab64bca

Please sign in to comment.