@@ -254,8 +254,7 @@ LogicalResult setDmaInputs(Operation *&operandOp,
254
254
template <typename PackOrUnpackOp>
255
255
LogicalResult rewriteAsDma (IRRewriter &rewriter, PackOrUnpackOp op, Value input,
256
256
Value output, llvm::ArrayRef<int64_t > innerTiles,
257
- bool packTransposeOnSource,
258
- bool unpackTransposeOnSource) {
257
+ bool transposeOnSource) {
259
258
if (llvm::any_of (innerTiles,
260
259
[](int64_t size) { return ShapedType::isDynamic (size); })) {
261
260
op->emitError (" has a non-static shape: not yet supported by this pass." );
@@ -269,43 +268,42 @@ LogicalResult rewriteAsDma(IRRewriter &rewriter, PackOrUnpackOp op, Value input,
269
268
270
269
// Prepare source DMA inputs.
271
270
SmallVector<OpFoldResult> srcOffsets;
272
- SmallVector<OpFoldResult> srcBaseStrides ;
271
+ SmallVector<OpFoldResult> srcStrides ;
273
272
SmallVector<OpFoldResult> srcShape;
274
-
275
- if (!succeeded (
276
- setDmaInputs (sourceOp, srcOffsets, srcShape, srcBaseStrides))) {
273
+ if (failed (setDmaInputs (sourceOp, srcOffsets, srcShape, srcStrides))) {
277
274
return failure ();
278
275
}
279
276
280
277
// Prepare destination DMA inputs.
281
278
SmallVector<OpFoldResult> dstOffsets;
282
- SmallVector<OpFoldResult> dstBaseStrides ;
279
+ SmallVector<OpFoldResult> dstStrides ;
283
280
SmallVector<OpFoldResult> dstShape;
284
- if (! succeeded (setDmaInputs (dstOp, dstOffsets, dstShape, dstBaseStrides ))) {
281
+ if (failed (setDmaInputs (dstOp, dstOffsets, dstShape, dstStrides ))) {
285
282
return failure ();
286
283
}
287
284
288
285
// Update dma source or destination addressing based on the side for dma
289
- // transposition and pack/unpack operations.
290
- if (packTransposeOnSource && isa<IREE::LinalgExt::PackOp>(op)) {
291
- if (!succeeded (dmaTransposeOnLowerNumDims (op, srcOffsets, srcShape,
292
- srcBaseStrides)))
293
- return failure ();
294
- } else if (!packTransposeOnSource && isa<IREE::LinalgExt::PackOp>(op)) {
295
- if (!succeeded (dmaTransposeOnHigherNumDims (op, dstOffsets, dstShape,
296
- dstBaseStrides)))
297
- return failure ();
298
- } else if (unpackTransposeOnSource && isa<IREE::LinalgExt::UnPackOp>(op)) {
299
- if (!succeeded (dmaTransposeOnHigherNumDims (op, srcOffsets, srcShape,
300
- srcBaseStrides)))
301
- return failure ();
302
- } else if (!unpackTransposeOnSource && isa<IREE::LinalgExt::UnPackOp>(op)) {
303
- if (!succeeded (dmaTransposeOnLowerNumDims (op, dstOffsets, dstShape,
304
- dstBaseStrides)))
305
- return failure ();
306
- } else {
307
- op->emitError (" unhandled option for dma addressing update." );
308
- return failure ();
286
+ // transposition.
287
+ {
288
+ SmallVector<OpFoldResult> &offsets =
289
+ transposeOnSource ? srcOffsets : dstOffsets;
290
+
291
+ SmallVector<OpFoldResult> &shape = transposeOnSource ? srcShape : dstShape;
292
+
293
+ SmallVector<OpFoldResult> &strides =
294
+ transposeOnSource ? srcStrides : dstStrides;
295
+
296
+ bool sourceIsHigherDim = dstStrides.size () <= srcStrides.size ();
297
+
298
+ if (sourceIsHigherDim == transposeOnSource) {
299
+ if (failed (dmaTransposeOnHigherNumDims (op, offsets, shape, strides))) {
300
+ return failure ();
301
+ }
302
+ } else {
303
+ if (failed (dmaTransposeOnLowerNumDims (op, offsets, shape, strides))) {
304
+ return failure ();
305
+ }
306
+ }
309
307
}
310
308
311
309
// Create logical objectFifos from source and destination memrefs.
@@ -317,27 +315,27 @@ LogicalResult rewriteAsDma(IRRewriter &rewriter, PackOrUnpackOp op, Value input,
317
315
rewriter.setInsertionPointAfter (srcVal.getDefiningOp ());
318
316
auto src = rewriter.create <AMDAIE::LogicalObjectFifoFromMemrefOp>(
319
317
rewriter.getUnknownLoc (), LogicalObjectFifoType::get (srcType), srcVal);
318
+
320
319
rewriter.setInsertionPointAfter (dstVal.getDefiningOp ());
321
320
auto dst = rewriter.create <AMDAIE::LogicalObjectFifoFromMemrefOp>(
322
321
rewriter.getUnknownLoc (), LogicalObjectFifoType::get (dstType), dstVal);
323
322
324
323
rewriter.setInsertionPoint (op);
325
324
rewriter.create <AMDAIE::DmaCpyNdOp>(op->getLoc (), dst, dstOffsets, dstShape,
326
- dstBaseStrides , src, srcOffsets, srcShape,
327
- srcBaseStrides );
325
+ dstStrides , src, srcOffsets, srcShape,
326
+ srcStrides );
328
327
rewriter.eraseOp (op);
329
328
return success ();
330
329
}
331
330
332
331
template <typename PackOrUnpackOp>
333
332
LogicalResult rewriteAsDma (PackOrUnpackOp op, IRRewriter &rewriter,
334
- bool packTransposeOnSource,
335
- bool unpackTransposeOnSource) {
333
+ bool tranposeOnSource) {
336
334
Value input = op.getInput ();
337
335
Value output = op.getOutput ();
338
336
llvm::ArrayRef<int64_t > innerTiles = op.getStaticInnerTiles ();
339
337
return rewriteAsDma (rewriter, op, input, output, innerTiles,
340
- packTransposeOnSource, unpackTransposeOnSource );
338
+ tranposeOnSource );
341
339
}
342
340
343
341
// / Convert a linalg.copy operation on 2 memrefs to an equivalent pack/unpack
@@ -413,23 +411,21 @@ void AMDAIEConvertToDmaPass::runOnOperation() {
413
411
});
414
412
if (convertCopiesWalkResult.wasInterrupted ()) return signalPassFailure ();
415
413
416
- auto walkResult =
417
- getOperation ()->walk ([&](IREE::LinalgExt::PackOp op) -> WalkResult {
418
- if (failed (rewriteAsDma (op, rewriter, packTransposeOnSource,
419
- unpackTransposeOnSource))) {
414
+ WalkResult walkResult =
415
+ getOperation ()->walk ([&](IREE::LinalgExt::PackOp packOp) {
416
+ if (failed (rewriteAsDma (packOp, rewriter, packTransposeOnSource))) {
420
417
return WalkResult::interrupt ();
421
418
}
422
419
return WalkResult::advance ();
423
420
});
424
421
if (walkResult.wasInterrupted ()) signalPassFailure ();
425
- walkResult = getOperation ()->walk (
426
- [&](IREE::LinalgExt::UnPackOp op) -> WalkResult {
427
- if (failed (rewriteAsDma (op, rewriter, packTransposeOnSource,
428
- unpackTransposeOnSource))) {
429
- return WalkResult::interrupt ();
430
- }
431
- return WalkResult::advance ();
432
- });
422
+
423
+ walkResult = getOperation ()->walk ([&](IREE::LinalgExt::UnPackOp unpackOp) {
424
+ if (failed (rewriteAsDma (unpackOp, rewriter, unpackTransposeOnSource))) {
425
+ return WalkResult::interrupt ();
426
+ }
427
+ return WalkResult::advance ();
428
+ });
433
429
if (walkResult.wasInterrupted ()) signalPassFailure ();
434
430
}
435
431
0 commit comments