Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMDAIEFoldDmaWaits] Fold DMA wait operations across multi columns #986

Merged
merged 11 commits into from
Dec 18, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,46 @@ LogicalResult convertOp(AMDAIE::NpuAddressPatchOp op,
}

LogicalResult convertOp(AMDAIE::NpuDmaWaitOp op, TransactionBuilder &builder) {
for (Value token : op.getAsyncTokens()) {
auto pushToQueueOp =
dyn_cast_if_present<AMDAIE::NpuPushToQueueOp>(token.getDefiningOp());
// Collect all half DMA ops from the async tokens.
SmallVector<AMDAIE::NpuPushToQueueOp> pushToQueueOps;
for (Value asyncToken : op.getAsyncTokens()) {
auto pushToQueueOp = dyn_cast_if_present<AMDAIE::NpuPushToQueueOp>(
asyncToken.getDefiningOp());
if (!pushToQueueOp) {
return op.emitOpError()
<< "should operate on an `amdaie.push_to_queue` op";
<< "should operate on an `amdaie.push_to_queue` op async token";
}
pushToQueueOps.push_back(pushToQueueOp);
}
// Sort the half DMA ops by channel, direction, row, and column.
std::sort(pushToQueueOps.begin(), pushToQueueOps.end(),
[](AMDAIE::NpuPushToQueueOp a, AMDAIE::NpuPushToQueueOp b) {
return std::make_tuple(a.getChannel(), a.getDirection(),
a.getRow(), a.getCol()) <
std::make_tuple(b.getChannel(), b.getDirection(),
b.getRow(), b.getCol());
});
// Batch DMA operations with the same row, channel, and direction into a
// single TCT sync operation, as long as they have consecutive columns.
llvm::MapVector<AMDAIE::NpuPushToQueueOp, uint32_t> columnBatches;
for (auto pushToQueueOp : pushToQueueOps) {
if (!columnBatches.empty()) {
auto &[lastPushOp, lastColNum] = columnBatches.back();
if (lastPushOp.getRow() == pushToQueueOp.getRow() &&
lastPushOp.getCol() + lastColNum == pushToQueueOp.getCol() &&
lastPushOp.getDirection() == pushToQueueOp.getDirection() &&
lastPushOp.getChannel() == pushToQueueOp.getChannel()) {
++lastColNum;
continue;
}
}
columnBatches.insert({pushToQueueOp, 1});
}
// Convert to TCT sync ops.
for (auto &[pushToQueueOp, colNum] : columnBatches) {
if (failed(builder.appendTCTSync(
pushToQueueOp.getCol(), pushToQueueOp.getRow(),
static_cast<uint32_t>(pushToQueueOp.getDirection()), 1, 1,
static_cast<uint32_t>(pushToQueueOp.getDirection()), 1, colNum,
pushToQueueOp.getChannel()))) {
return failure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,49 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

using DmaBdIdKey = std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>;
using DmaBdIdPair = std::pair<DmaBdIdKey, uint32_t>;

/// Utility function to retrieve TileOp, ConnectionOp, and BD ID from a given
/// half DMA copy operation.
FailureOr<DmaBdIdPair> retrieveDmaBdIdPair(
AMDAIE::NpuHalfDmaCpyNdOp &npuHalfDmaCpyNdOp) {
// Retrieve the connection op.
std::optional<AMDAIE::ConnectionOp> maybeConnectionOp =
npuHalfDmaCpyNdOp.getConnectionOp();
if (!maybeConnectionOp) {
return npuHalfDmaCpyNdOp.emitOpError()
<< "expected to operate on an `amdaie.connection`";
}
AMDAIE::ConnectionOp connectionOp = maybeConnectionOp.value();

// Retrieve the BD ID op.
std::optional<AMDAIE::BdIdOp> maybeBdIdOp = npuHalfDmaCpyNdOp.getBdIdOp();
if (!maybeBdIdOp) {
return npuHalfDmaCpyNdOp.emitOpError()
<< "must have a BD ID op to lower to "
"`amdaie.npu.write_bd`";
}
AMDAIE::BdIdOp bdIdOp = maybeBdIdOp.value();
uint32_t currBdIdVal = getConstantIndexOrAssert(bdIdOp.getValue());

// Retrieve the tile op.
AMDAIE::TileOp tileOp =
dyn_cast_if_present<AMDAIE::TileOp>(bdIdOp.getTile().getDefiningOp());
if (!tileOp) {
return bdIdOp.emitOpError() << "must operate on an `amdaie.tile`";
}

DmaBdIdKey currBdIdKey = {tileOp, connectionOp};
return DmaBdIdPair{currBdIdKey, currBdIdVal};
}

/// Utility function to determine whether a DMA wait op can be folded based on
/// its half DMA copy operation.
FailureOr<bool> canFoldBasedOnHalfDmaCpy(
FailureOr<bool> canFoldByQueue(
const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::NpuHalfDmaCpyNdOp &npuHalfDmaCpyNdOp,
DenseMap<std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>,
SmallVector<uint32_t>> &tileConnectToBdIdQueue) {
DenseMap<DmaBdIdKey, SmallVector<uint32_t>> &tileConnectToBdIdQueue) {
// Retrieve the connection op.
std::optional<AMDAIE::ConnectionOp> maybeConnectionOp =
npuHalfDmaCpyNdOp.getConnectionOp();
Expand Down Expand Up @@ -101,13 +137,11 @@ FailureOr<bool> canFoldBasedOnHalfDmaCpy(
/// Reverse traversal simplifies handling duplicate BD IDs, preventing
/// the need to revisit and modify earlier operations after processing later
/// ones.
LogicalResult foldDmaWaits(const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::ControlCodeOp controlCodeOp) {
LogicalResult foldDmaWaitsByQueue(const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::ControlCodeOp controlCodeOp) {
IRRewriter rewriter(controlCodeOp->getContext());
std::vector<AMDAIE::NpuDmaWaitOp> waitOpsToErase;
DenseMap<std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>,
SmallVector<uint32_t>>
tileConnectToBdIdQueue;
DenseMap<DmaBdIdKey, SmallVector<uint32_t>> tileConnectToBdIdQueue;
// Traverse the control code in reverse.
WalkResult res = controlCodeOp->walk<WalkOrder::PostOrder, ReverseIterator>(
[&](AMDAIE::NpuDmaWaitOp waitOp) {
Expand All @@ -116,7 +150,7 @@ LogicalResult foldDmaWaits(const AMDAIE::AMDAIEDeviceModel &deviceModel,
if (auto npuHalfDmaCpyNdOp =
dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
FailureOr<bool> result = canFoldBasedOnHalfDmaCpy(
FailureOr<bool> result = canFoldByQueue(
deviceModel, npuHalfDmaCpyNdOp, tileConnectToBdIdQueue);
if (failed(result)) return WalkResult::interrupt();
toErase &= *result;
Expand Down Expand Up @@ -152,6 +186,162 @@ LogicalResult foldDmaWaits(const AMDAIE::AMDAIEDeviceModel &deviceModel,
return success();
}

/// For each batch, combine the async tokens into a single NpuDmaWaitOp.
LogicalResult eraseBatchOperations(IRRewriter &rewriter,
SmallVector<AMDAIE::NpuDmaWaitOp> &waitOps) {
// Skip if there are less than two DMA wait operations.
if (waitOps.size() < 2) return success();

SmallVector<Value> asyncTokens;
Operation *parentOp = waitOps[0]->getParentOp();
for (AMDAIE::NpuDmaWaitOp waitOp : waitOps) {
if (waitOp->getParentOp() != parentOp) {
return waitOp.emitError(
"DMA operations to be batched must belong to the same scope");
}
asyncTokens.append(waitOp.getAsyncTokens().begin(),
waitOp.getAsyncTokens().end());
}

rewriter.setInsertionPointAfter(waitOps.back());
rewriter.create<AMDAIE::NpuDmaWaitOp>(waitOps.back().getLoc(), asyncTokens);
for (AMDAIE::NpuDmaWaitOp waitOp : waitOps) rewriter.eraseOp(waitOp);
return success();
}

/// Utility function to determine if a DMA wait operation can be folded into a
/// a batch based on its half DMA copy operation.
/// Can't fold wait op if:
/// (1) the current operation is not in the same scope as the batch, or
/// (2) the current connection op already occurs in the batch, or
/// (3) the batch is empty, or
/// (4) the current operation is a packet flow, or
/// (5) the current BD ID on the same tile already occurs in the batch.
FailureOr<bool> canFoldByBatch(
const Operation *batchParentOp,
const DenseSet<AMDAIE::ConnectionOp> &connectionOps,
const DenseMap<DmaBdIdKey, DenseSet<uint32_t>> &dmaBdIdsMap,
AMDAIE::NpuHalfDmaCpyNdOp currHalfDmaCpyNdOp, DmaBdIdPair currBdIdPair) {
// Not in the same scope? Can't fold.
if (currHalfDmaCpyNdOp->getParentOp() != batchParentOp) return false;

// Connection op already in the batch, or an empty batch? Can't fold.
AMDAIE::ConnectionOp connectionOp = currBdIdPair.first.second;
if (connectionOps.contains(connectionOp) || connectionOps.empty())
return false;

// Packet flow? Can't fold.
std::optional<AMDAIE::FlowOp> maybeFlowOp = connectionOp.getFlowOp();
if (!maybeFlowOp) {
return connectionOp.emitOpError()
<< "expected to operate on an `amdaie.flow`";
}
AMDAIE::FlowOp flowOp = maybeFlowOp.value();
if (flowOp.getIsPacketFlow()) return false;

// Duplicate BD ID on the same tile? Can't fold.
AMDAIE::TileOp tileOp = currBdIdPair.first.first;
uint32_t currBdIdVal = currBdIdPair.second;
bool isDuplicateBdId = llvm::any_of(dmaBdIdsMap, [&](const auto &entry) {
return entry.first.first == tileOp && entry.second.contains(currBdIdVal);
});
if (isDuplicateBdId) return false;

// Can fold.
return true;
}

/// Traverses the control code in reverse, ensuring that only one DMA wait op is
/// retained for every batch of DMA copy operations.
///
/// Example Input:
/// %0 = dma_cpy_nd(connection0)
/// dma_wait(%0)
/// %1 = dma_cpy_nd(connection1)
/// %2 = dma_cpy_nd(connection2)
/// %3 = dma_cpy_nd(connection3)
/// dma_wait(%1)
/// dma_wait(%2)
/// dma_wait(%3)
/// Example Output:
/// %0 = dma_cpy_nd(connection0)
/// %1 = dma_cpy_nd(connection1)
/// %2 = dma_cpy_nd(connection2)
/// %3 = dma_cpy_nd(connection3)
/// dma_wait(%0, %1, %2, %3)
/// Reverse traversal simplifies handling duplicate connections, preventing
/// the need to revisit and modify earlier operations after processing later
/// ones.
LogicalResult foldDmaWaitsByBatch(AMDAIE::ControlCodeOp controlCodeOp) {
IRRewriter rewriter(controlCodeOp->getContext());
SmallVector<AMDAIE::NpuDmaWaitOp> waitOps;
DenseSet<AMDAIE::ConnectionOp> connectionOps;
DenseMap<DmaBdIdKey, DenseSet<uint32_t>> dmaBdIdsMap;

auto updateWithCurrBdId =
[&](bool canFold, DmaBdIdPair currBdIdPair,
DenseSet<AMDAIE::ConnectionOp> &connectionOps,
DenseMap<DmaBdIdKey, DenseSet<uint32_t>> &dmaBdIdsMap) {
DmaBdIdKey currBdIdKey = currBdIdPair.first;
uint32_t currBdIdVal = currBdIdPair.second;
if (!canFold) {
// Clear the BD IDs for all the connections in the batch.
for (auto &entry : dmaBdIdsMap) {
ConnectionOp connectionOp = entry.first.second;
DenseSet<uint32_t> &bdIds = entry.second;
if (connectionOps.contains(connectionOp)) bdIds.clear();
}
connectionOps.clear();
}
connectionOps.insert(currBdIdKey.second);
dmaBdIdsMap[currBdIdKey].insert(currBdIdVal);
};

// Traverse the control code in reverse.
WalkResult res = controlCodeOp->walk<WalkOrder::PostOrder, ReverseIterator>(
[&](AMDAIE::NpuDmaWaitOp waitOp) {
bool toBatch = true;
Operation *batchParentOp =
waitOps.empty() ? waitOp->getParentOp() : waitOps[0]->getParentOp();
for (Value token : waitOp.getAsyncTokens()) {
if (auto npuHalfDmaCpyNdOp =
dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
// Retrieve the TileOp, ConnectionOp, and BD ID.
FailureOr<DmaBdIdPair> currBdIdPair =
retrieveDmaBdIdPair(npuHalfDmaCpyNdOp);
if (failed(currBdIdPair)) return WalkResult::interrupt();
// Check if the current DMA wait op can be folded into the batch.
FailureOr<bool> canFold =
canFoldByBatch(batchParentOp, connectionOps, dmaBdIdsMap,
npuHalfDmaCpyNdOp, *currBdIdPair);
if (failed(canFold)) return WalkResult::interrupt();
// Update the `connectionOps` and `dmaBdIdsMap`.
updateWithCurrBdId(*canFold, *currBdIdPair, connectionOps,
dmaBdIdsMap);
toBatch &= *canFold;
}
}
// Process the previous batch of wait ops, and start a new batch.
if (!toBatch) {
// Since the controlcode is traversed in reverse order, we need to
// restore the original order of the DMA operations.
std::reverse(waitOps.begin(), waitOps.end());
if (failed(eraseBatchOperations(rewriter, waitOps)))
return WalkResult::interrupt();
waitOps.clear();
}
waitOps.push_back(waitOp);
return WalkResult::advance();
});

if (res.wasInterrupted()) return failure();
// Process the remaining wait ops.
std::reverse(waitOps.begin(), waitOps.end());
if (failed(eraseBatchOperations(rewriter, waitOps))) return failure();
return success();
}

class AMDAIEFoldDmaWaitsPass
: public impl::AMDAIEFoldDmaWaitsBase<AMDAIEFoldDmaWaitsPass> {
public:
Expand Down Expand Up @@ -181,7 +371,10 @@ void AMDAIEFoldDmaWaitsPass::runOnOperation() {

WalkResult res = parentOp->walk([&](AMDAIE::WorkgroupOp workgroupOp) {
AMDAIE::ControlCodeOp controlCodeOp = workgroupOp.getControlCode();
if (failed(foldDmaWaits(deviceModel, controlCodeOp))) {
if (failed(foldDmaWaitsByQueue(deviceModel, controlCodeOp))) {
return WalkResult::interrupt();
}
if (failed(foldDmaWaitsByBatch(controlCodeOp))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
Expand Down
Loading
Loading