Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMDAIEFoldDmaWaits] Fold DMA wait operations across multi columns #986

Merged
merged 11 commits into from
Dec 18, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,46 @@ LogicalResult convertOp(AMDAIE::NpuAddressPatchOp op,
}

LogicalResult convertOp(AMDAIE::NpuDmaWaitOp op, TransactionBuilder &builder) {
for (Value token : op.getAsyncTokens()) {
auto pushToQueueOp =
dyn_cast_if_present<AMDAIE::NpuPushToQueueOp>(token.getDefiningOp());
// Collect all half DMA ops from the async tokens.
SmallVector<AMDAIE::NpuPushToQueueOp> pushToQueueOps;
for (Value asyncToken : op.getAsyncTokens()) {
auto pushToQueueOp = dyn_cast_if_present<AMDAIE::NpuPushToQueueOp>(
asyncToken.getDefiningOp());
if (!pushToQueueOp) {
return op.emitOpError()
<< "should operate on an `amdaie.push_to_queue` op";
<< "should operate on an `amdaie.push_to_queue` op async token";
}
pushToQueueOps.push_back(pushToQueueOp);
}
// Sort the half DMA ops by channel, direction, row, and column.
std::sort(pushToQueueOps.begin(), pushToQueueOps.end(),
[](AMDAIE::NpuPushToQueueOp a, AMDAIE::NpuPushToQueueOp b) {
return std::make_tuple(a.getChannel(), a.getDirection(),
a.getRow(), a.getCol()) <
std::make_tuple(b.getChannel(), b.getDirection(),
b.getRow(), b.getCol());
});
// Batch DMA operations with the same row, channel, and direction into a
// single TCT sync operation, as long as they have consecutive columns.
llvm::MapVector<AMDAIE::NpuPushToQueueOp, uint32_t> columnBatches;
for (auto pushToQueueOp : pushToQueueOps) {
if (!columnBatches.empty()) {
auto &[lastPushOp, lastColNum] = columnBatches.back();
if (lastPushOp.getRow() == pushToQueueOp.getRow() &&
lastPushOp.getCol() + lastColNum == pushToQueueOp.getCol() &&
lastPushOp.getDirection() == pushToQueueOp.getDirection() &&
lastPushOp.getChannel() == pushToQueueOp.getChannel()) {
++lastColNum;
continue;
}
}
columnBatches.insert({pushToQueueOp, 1});
}
// Convert to TCT sync ops.
for (auto &[pushToQueueOp, colNum] : columnBatches) {
if (failed(builder.appendTCTSync(
pushToQueueOp.getCol(), pushToQueueOp.getRow(),
static_cast<uint32_t>(pushToQueueOp.getDirection()), 1, 1,
static_cast<uint32_t>(pushToQueueOp.getDirection()), 1, colNum,
pushToQueueOp.getChannel()))) {
return failure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

using DmaBdIdKey = std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>;

/// Utility function to determine whether a DMA wait op can be folded based on
/// its half DMA copy operation.
FailureOr<bool> canFoldBasedOnHalfDmaCpy(
FailureOr<bool> canFoldByQueue(
const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::NpuHalfDmaCpyNdOp &npuHalfDmaCpyNdOp,
DenseMap<std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>,
SmallVector<uint32_t>> &tileConnectToBdIdQueue) {
DenseMap<DmaBdIdKey, SmallVector<uint32_t>> &tileConnectToBdIdQueue) {
// Retrieve the connection op.
std::optional<AMDAIE::ConnectionOp> maybeConnectionOp =
npuHalfDmaCpyNdOp.getConnectionOp();
Expand Down Expand Up @@ -101,13 +102,11 @@ FailureOr<bool> canFoldBasedOnHalfDmaCpy(
/// Reverse traversal simplifies handling duplicate BD IDs, preventing
/// the need to revisit and modify earlier operations after processing later
/// ones.
LogicalResult foldDmaWaits(const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::ControlCodeOp controlCodeOp) {
LogicalResult foldDmaWaitsByQueue(const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::ControlCodeOp controlCodeOp) {
IRRewriter rewriter(controlCodeOp->getContext());
std::vector<AMDAIE::NpuDmaWaitOp> waitOpsToErase;
DenseMap<std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>,
SmallVector<uint32_t>>
tileConnectToBdIdQueue;
DenseMap<DmaBdIdKey, SmallVector<uint32_t>> tileConnectToBdIdQueue;
// Traverse the control code in reverse.
WalkResult res = controlCodeOp->walk<WalkOrder::PostOrder, ReverseIterator>(
[&](AMDAIE::NpuDmaWaitOp waitOp) {
Expand All @@ -116,7 +115,7 @@ LogicalResult foldDmaWaits(const AMDAIE::AMDAIEDeviceModel &deviceModel,
if (auto npuHalfDmaCpyNdOp =
dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
FailureOr<bool> result = canFoldBasedOnHalfDmaCpy(
FailureOr<bool> result = canFoldByQueue(
deviceModel, npuHalfDmaCpyNdOp, tileConnectToBdIdQueue);
if (failed(result)) return WalkResult::interrupt();
toErase &= *result;
Expand Down Expand Up @@ -152,6 +151,179 @@ LogicalResult foldDmaWaits(const AMDAIE::AMDAIEDeviceModel &deviceModel,
return success();
}

/// For each batch, combine the async tokens into a single NpuDmaWaitOp.
LogicalResult eraseBatchOperations(IRRewriter &rewriter,
SmallVector<AMDAIE::NpuDmaWaitOp> &waitOps) {
// Skip if there are less than two DMA wait operations.
if (waitOps.size() < 2) return success();

SmallVector<Value> asyncTokens;
Operation *parentOp = waitOps[0]->getParentOp();
for (AMDAIE::NpuDmaWaitOp waitOp : waitOps) {
if (waitOp->getParentOp() != parentOp) {
return waitOp.emitError(
"DMA operations to be batched must belong to the same scope");
}
asyncTokens.append(waitOp.getAsyncTokens().begin(),
waitOp.getAsyncTokens().end());
}

rewriter.setInsertionPointAfter(waitOps.back());
rewriter.create<AMDAIE::NpuDmaWaitOp>(waitOps.back().getLoc(), asyncTokens);
for (AMDAIE::NpuDmaWaitOp waitOp : waitOps) rewriter.eraseOp(waitOp);
return success();
}

/// Utility function to determine if a DMA wait operation can be folded into a
/// a batch based on its half DMA copy operation.
FailureOr<bool> canFoldByBatch(
const Operation *batchParentOp,
const DenseSet<AMDAIE::ConnectionOp> &connectionOps,
const DenseMap<DmaBdIdKey, DenseSet<uint32_t>> &dmaBdIdsMap,
DmaBdIdKey &currBdIdKey, uint32_t &currBdIdVal,
AMDAIE::NpuHalfDmaCpyNdOp currHalfDmaCpyNdOp) {
// Check if the current operation is in the same scope as the rest of the
// batch.
bool isSameScope = currHalfDmaCpyNdOp->getParentOp() == batchParentOp;

// Retrieve the connection op.
std::optional<AMDAIE::ConnectionOp> maybeConnectionOp =
currHalfDmaCpyNdOp.getConnectionOp();
if (!maybeConnectionOp) {
return currHalfDmaCpyNdOp.emitOpError()
<< "expected to operate on an `amdaie.connection`";
}
AMDAIE::ConnectionOp connectionOp = maybeConnectionOp.value();
bool isDuplicateConnection = connectionOps.contains(connectionOp);

// Retrieve the flow op.
std::optional<AMDAIE::FlowOp> maybeFlowOp = connectionOp.getFlowOp();
if (!maybeFlowOp) {
return connectionOp.emitOpError()
<< "expected to operate on an `amdaie.flow`";
}
AMDAIE::FlowOp flowOp = maybeFlowOp.value();
bool isPacketFlow = flowOp.getIsPacketFlow();

// Retrieve the BD ID op.
std::optional<AMDAIE::BdIdOp> maybeBdIdOp = currHalfDmaCpyNdOp.getBdIdOp();
if (!maybeBdIdOp) {
return currHalfDmaCpyNdOp.emitOpError()
<< "must have a BD ID op to lower to "
"`amdaie.npu.write_bd`";
}
AMDAIE::BdIdOp bdIdOp = maybeBdIdOp.value();
currBdIdVal = getConstantIndexOrAssert(bdIdOp.getValue());

// Retrieve the tile op.
AMDAIE::TileOp tileOp =
dyn_cast_if_present<AMDAIE::TileOp>(bdIdOp.getTile().getDefiningOp());
if (!tileOp) {
return bdIdOp.emitOpError() << "must operate on an `amdaie.tile`";
}
currBdIdKey = {tileOp, connectionOp};

bool isDuplicateBdId = llvm::any_of(dmaBdIdsMap, [&](const auto &entry) {
return entry.first.first == tileOp && entry.second.contains(currBdIdVal);
});

// Can't fold wait op if:
// (1) the current connection op already occurs in the batch, or
// (2) the current BD ID on the same tile already occurs in the batch, or
// (3) the current operation is a packet flow, or
// (4) the batch is empty, or
// (5) the current operation is not in the same scope as the batch.
return !(isDuplicateConnection || isDuplicateBdId || isPacketFlow ||
connectionOps.empty() || !isSameScope);
}

/// Traverses the control code in reverse, ensuring that only one DMA wait op is
/// retained for every batch of DMA copy operations.
///
/// Example Input:
/// %0 = dma_cpy_nd(connection0)
/// dma_wait(%0)
/// %1 = dma_cpy_nd(connection1)
/// %2 = dma_cpy_nd(connection2)
/// %3 = dma_cpy_nd(connection3)
/// dma_wait(%1)
/// dma_wait(%2)
/// dma_wait(%3)
/// Example Output:
/// %0 = dma_cpy_nd(connection0)
/// %1 = dma_cpy_nd(connection1)
/// %2 = dma_cpy_nd(connection2)
/// %3 = dma_cpy_nd(connection3)
/// dma_wait(%0, %1, %2, %3)
/// Reverse traversal simplifies handling duplicate connections, preventing
/// the need to revisit and modify earlier operations after processing later
/// ones.
LogicalResult foldDmaWaitsByBatch(AMDAIE::ControlCodeOp controlCodeOp) {
IRRewriter rewriter(controlCodeOp->getContext());
SmallVector<AMDAIE::NpuDmaWaitOp> waitOps;
DenseSet<AMDAIE::ConnectionOp> connectionOps;
DenseMap<DmaBdIdKey, DenseSet<uint32_t>> dmaBdIdsMap;

auto updateWithCurrBdId =
[&](bool canFold, DenseSet<AMDAIE::ConnectionOp> &connectionOps,
DenseMap<DmaBdIdKey, DenseSet<uint32_t>> &dmaBdIdsMap,
DmaBdIdKey &currBdIdKey, uint32_t currBdIdVal) {
assert(currBdIdKey.first && "TileOp must not be null");
assert(currBdIdKey.second && "ConnectionOp must not be null");
if (!canFold) {
// Clear the BD IDs for all the connections in the batch.
for (auto &entry : dmaBdIdsMap) {
ConnectionOp connectionOp = entry.first.second;
DenseSet<uint32_t> &bdIds = entry.second;
if (connectionOps.contains(connectionOp)) bdIds.clear();
}
connectionOps.clear();
}
connectionOps.insert(currBdIdKey.second);
dmaBdIdsMap[currBdIdKey].insert(currBdIdVal);
};

// Traverse the control code in reverse.
WalkResult res = controlCodeOp->walk<WalkOrder::PostOrder, ReverseIterator>(
[&](AMDAIE::NpuDmaWaitOp waitOp) {
bool toBatch = true;
Operation *batchParentOp =
waitOps.empty() ? waitOp->getParentOp() : waitOps[0]->getParentOp();
for (Value token : waitOp.getAsyncTokens()) {
if (auto npuHalfDmaCpyNdOp =
dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
DmaBdIdKey currBdIdKey = {nullptr, nullptr};
uint32_t currBdIdVal = 0;
FailureOr<bool> result =
canFoldByBatch(batchParentOp, connectionOps, dmaBdIdsMap,
currBdIdKey, currBdIdVal, npuHalfDmaCpyNdOp);
if (failed(result)) return WalkResult::interrupt();
toBatch &= *result;
updateWithCurrBdId(*result, connectionOps, dmaBdIdsMap, currBdIdKey,
currBdIdVal);
}
}
// Process the previous batch of wait ops, and start a new batch.
if (!toBatch) {
// Since the controlcode is traversed in reverse order, we need to
// restore the original order of the DMA operations.
std::reverse(waitOps.begin(), waitOps.end());
if (failed(eraseBatchOperations(rewriter, waitOps)))
return WalkResult::interrupt();
waitOps.clear();
}
waitOps.push_back(waitOp);
return WalkResult::advance();
});

if (res.wasInterrupted()) return failure();
// Process the remaining wait ops.
std::reverse(waitOps.begin(), waitOps.end());
if (failed(eraseBatchOperations(rewriter, waitOps))) return failure();
return success();
}

class AMDAIEFoldDmaWaitsPass
: public impl::AMDAIEFoldDmaWaitsBase<AMDAIEFoldDmaWaitsPass> {
public:
Expand Down Expand Up @@ -181,7 +353,10 @@ void AMDAIEFoldDmaWaitsPass::runOnOperation() {

WalkResult res = parentOp->walk([&](AMDAIE::WorkgroupOp workgroupOp) {
AMDAIE::ControlCodeOp controlCodeOp = workgroupOp.getControlCode();
if (failed(foldDmaWaits(deviceModel, controlCodeOp))) {
if (failed(foldDmaWaitsByQueue(deviceModel, controlCodeOp))) {
return WalkResult::interrupt();
}
if (failed(foldDmaWaitsByBatch(controlCodeOp))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
Expand Down
Loading
Loading