Skip to content

Commit ab64bca

Browse files
authored
[AMDAIEConvertToDma] Small code simplification (#949)
Simplification to not pass more flags into functions than are actually used (logic for pack vs unpack handled before function call).
1 parent 1c7e406 commit ab64bca

File tree

1 file changed

+41
-45
lines changed

1 file changed

+41
-45
lines changed

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEConvertToDma.cpp

Lines changed: 41 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,7 @@ LogicalResult setDmaInputs(Operation *&operandOp,
254254
template <typename PackOrUnpackOp>
255255
LogicalResult rewriteAsDma(IRRewriter &rewriter, PackOrUnpackOp op, Value input,
256256
Value output, llvm::ArrayRef<int64_t> innerTiles,
257-
bool packTransposeOnSource,
258-
bool unpackTransposeOnSource) {
257+
bool transposeOnSource) {
259258
if (llvm::any_of(innerTiles,
260259
[](int64_t size) { return ShapedType::isDynamic(size); })) {
261260
op->emitError("has a non-static shape: not yet supported by this pass.");
@@ -269,43 +268,42 @@ LogicalResult rewriteAsDma(IRRewriter &rewriter, PackOrUnpackOp op, Value input,
269268

270269
// Prepare source DMA inputs.
271270
SmallVector<OpFoldResult> srcOffsets;
272-
SmallVector<OpFoldResult> srcBaseStrides;
271+
SmallVector<OpFoldResult> srcStrides;
273272
SmallVector<OpFoldResult> srcShape;
274-
275-
if (!succeeded(
276-
setDmaInputs(sourceOp, srcOffsets, srcShape, srcBaseStrides))) {
273+
if (failed(setDmaInputs(sourceOp, srcOffsets, srcShape, srcStrides))) {
277274
return failure();
278275
}
279276

280277
// Prepare destination DMA inputs.
281278
SmallVector<OpFoldResult> dstOffsets;
282-
SmallVector<OpFoldResult> dstBaseStrides;
279+
SmallVector<OpFoldResult> dstStrides;
283280
SmallVector<OpFoldResult> dstShape;
284-
if (!succeeded(setDmaInputs(dstOp, dstOffsets, dstShape, dstBaseStrides))) {
281+
if (failed(setDmaInputs(dstOp, dstOffsets, dstShape, dstStrides))) {
285282
return failure();
286283
}
287284

288285
// Update dma source or destination addressing based on the side for dma
289-
// transposition and pack/unpack operations.
290-
if (packTransposeOnSource && isa<IREE::LinalgExt::PackOp>(op)) {
291-
if (!succeeded(dmaTransposeOnLowerNumDims(op, srcOffsets, srcShape,
292-
srcBaseStrides)))
293-
return failure();
294-
} else if (!packTransposeOnSource && isa<IREE::LinalgExt::PackOp>(op)) {
295-
if (!succeeded(dmaTransposeOnHigherNumDims(op, dstOffsets, dstShape,
296-
dstBaseStrides)))
297-
return failure();
298-
} else if (unpackTransposeOnSource && isa<IREE::LinalgExt::UnPackOp>(op)) {
299-
if (!succeeded(dmaTransposeOnHigherNumDims(op, srcOffsets, srcShape,
300-
srcBaseStrides)))
301-
return failure();
302-
} else if (!unpackTransposeOnSource && isa<IREE::LinalgExt::UnPackOp>(op)) {
303-
if (!succeeded(dmaTransposeOnLowerNumDims(op, dstOffsets, dstShape,
304-
dstBaseStrides)))
305-
return failure();
306-
} else {
307-
op->emitError("unhandled option for dma addressing update.");
308-
return failure();
286+
// transposition.
287+
{
288+
SmallVector<OpFoldResult> &offsets =
289+
transposeOnSource ? srcOffsets : dstOffsets;
290+
291+
SmallVector<OpFoldResult> &shape = transposeOnSource ? srcShape : dstShape;
292+
293+
SmallVector<OpFoldResult> &strides =
294+
transposeOnSource ? srcStrides : dstStrides;
295+
296+
bool sourceIsHigherDim = dstStrides.size() <= srcStrides.size();
297+
298+
if (sourceIsHigherDim == transposeOnSource) {
299+
if (failed(dmaTransposeOnHigherNumDims(op, offsets, shape, strides))) {
300+
return failure();
301+
}
302+
} else {
303+
if (failed(dmaTransposeOnLowerNumDims(op, offsets, shape, strides))) {
304+
return failure();
305+
}
306+
}
309307
}
310308

311309
// Create logical objectFifos from source and destination memrefs.
@@ -317,27 +315,27 @@ LogicalResult rewriteAsDma(IRRewriter &rewriter, PackOrUnpackOp op, Value input,
317315
rewriter.setInsertionPointAfter(srcVal.getDefiningOp());
318316
auto src = rewriter.create<AMDAIE::LogicalObjectFifoFromMemrefOp>(
319317
rewriter.getUnknownLoc(), LogicalObjectFifoType::get(srcType), srcVal);
318+
320319
rewriter.setInsertionPointAfter(dstVal.getDefiningOp());
321320
auto dst = rewriter.create<AMDAIE::LogicalObjectFifoFromMemrefOp>(
322321
rewriter.getUnknownLoc(), LogicalObjectFifoType::get(dstType), dstVal);
323322

324323
rewriter.setInsertionPoint(op);
325324
rewriter.create<AMDAIE::DmaCpyNdOp>(op->getLoc(), dst, dstOffsets, dstShape,
326-
dstBaseStrides, src, srcOffsets, srcShape,
327-
srcBaseStrides);
325+
dstStrides, src, srcOffsets, srcShape,
326+
srcStrides);
328327
rewriter.eraseOp(op);
329328
return success();
330329
}
331330

332331
template <typename PackOrUnpackOp>
333332
LogicalResult rewriteAsDma(PackOrUnpackOp op, IRRewriter &rewriter,
334-
bool packTransposeOnSource,
335-
bool unpackTransposeOnSource) {
333+
bool tranposeOnSource) {
336334
Value input = op.getInput();
337335
Value output = op.getOutput();
338336
llvm::ArrayRef<int64_t> innerTiles = op.getStaticInnerTiles();
339337
return rewriteAsDma(rewriter, op, input, output, innerTiles,
340-
packTransposeOnSource, unpackTransposeOnSource);
338+
tranposeOnSource);
341339
}
342340

343341
/// Convert a linalg.copy operation on 2 memrefs to an equivalent pack/unpack
@@ -413,23 +411,21 @@ void AMDAIEConvertToDmaPass::runOnOperation() {
413411
});
414412
if (convertCopiesWalkResult.wasInterrupted()) return signalPassFailure();
415413

416-
auto walkResult =
417-
getOperation()->walk([&](IREE::LinalgExt::PackOp op) -> WalkResult {
418-
if (failed(rewriteAsDma(op, rewriter, packTransposeOnSource,
419-
unpackTransposeOnSource))) {
414+
WalkResult walkResult =
415+
getOperation()->walk([&](IREE::LinalgExt::PackOp packOp) {
416+
if (failed(rewriteAsDma(packOp, rewriter, packTransposeOnSource))) {
420417
return WalkResult::interrupt();
421418
}
422419
return WalkResult::advance();
423420
});
424421
if (walkResult.wasInterrupted()) signalPassFailure();
425-
walkResult = getOperation()->walk(
426-
[&](IREE::LinalgExt::UnPackOp op) -> WalkResult {
427-
if (failed(rewriteAsDma(op, rewriter, packTransposeOnSource,
428-
unpackTransposeOnSource))) {
429-
return WalkResult::interrupt();
430-
}
431-
return WalkResult::advance();
432-
});
422+
423+
walkResult = getOperation()->walk([&](IREE::LinalgExt::UnPackOp unpackOp) {
424+
if (failed(rewriteAsDma(unpackOp, rewriter, unpackTransposeOnSource))) {
425+
return WalkResult::interrupt();
426+
}
427+
return WalkResult::advance();
428+
});
433429
if (walkResult.wasInterrupted()) signalPassFailure();
434430
}
435431

0 commit comments

Comments
 (0)