Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMDAIEFoldDmaWaits] Fold DMA wait operations across multi columns #986

Merged
merged 11 commits into from
Dec 18, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,46 @@ LogicalResult convertOp(AMDAIE::NpuAddressPatchOp op,
}

LogicalResult convertOp(AMDAIE::NpuDmaWaitOp op, TransactionBuilder &builder) {
for (Value token : op.getAsyncTokens()) {
auto pushToQueueOp =
dyn_cast_if_present<AMDAIE::NpuPushToQueueOp>(token.getDefiningOp());
// Collect all half DMA ops from the async tokens.
SmallVector<AMDAIE::NpuPushToQueueOp> pushToQueueOps;
for (Value asyncToken : op.getAsyncTokens()) {
auto pushToQueueOp = dyn_cast_if_present<AMDAIE::NpuPushToQueueOp>(
asyncToken.getDefiningOp());
if (!pushToQueueOp) {
return op.emitOpError()
<< "should operate on an `amdaie.push_to_queue` op";
<< "should operate on an `amdaie.push_to_queue` op async token";
}
pushToQueueOps.push_back(pushToQueueOp);
}
// Sort the half DMA ops by channel, direction, row, and column.
std::sort(pushToQueueOps.begin(), pushToQueueOps.end(),
[](AMDAIE::NpuPushToQueueOp a, AMDAIE::NpuPushToQueueOp b) {
return std::make_tuple(a.getChannel(), a.getDirection(),
a.getRow(), a.getCol()) <
std::make_tuple(b.getChannel(), b.getDirection(),
b.getRow(), b.getCol());
});
// Batch DMA operations with the same row, channel, and direction into a
// single TCT sync operation, as long as they have consecutive columns.
llvm::MapVector<AMDAIE::NpuPushToQueueOp, uint32_t> columnBatches;
for (auto pushToQueueOp : pushToQueueOps) {
if (!columnBatches.empty()) {
auto &[lastPushOp, lastColNum] = columnBatches.back();
if (lastPushOp.getRow() == pushToQueueOp.getRow() &&
lastPushOp.getCol() + lastColNum == pushToQueueOp.getCol() &&
lastPushOp.getDirection() == pushToQueueOp.getDirection() &&
lastPushOp.getChannel() == pushToQueueOp.getChannel()) {
++lastColNum;
continue;
}
}
columnBatches.insert({pushToQueueOp, 1});
}
// Convert to TCT sync ops.
for (auto &[pushToQueueOp, colNum] : columnBatches) {
if (failed(builder.appendTCTSync(
pushToQueueOp.getCol(), pushToQueueOp.getRow(),
static_cast<uint32_t>(pushToQueueOp.getDirection()), 1, 1,
static_cast<uint32_t>(pushToQueueOp.getDirection()), 1, colNum,
pushToQueueOp.getChannel()))) {
return failure();
}
Expand Down
Loading
Loading