Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AssignPacketIds] Prioritize ID assignment for control packets #1092

Merged
merged 3 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,52 @@ std::optional<FlowOp> ConnectionOp::getFlowOp() {
// AMDAIE_FlowOp
//===----------------------------------------------------------------------===//

FailureOr<AMDAIE::ChannelOp> FlowOp::getSourceChannelOp() {
SmallVector<Value> sourceChannels = getSources();
int numSources = sourceChannels.size();
if (numSources == 0)
return emitOpError() << "with no source channel is unsupported";
if (numSources > 1)
return emitOpError() << "with multiple source channels is unsupported";
auto sourceChannelOp =
dyn_cast_if_present<AMDAIE::ChannelOp>(sourceChannels[0].getDefiningOp());
if (!sourceChannelOp)
return emitOpError() << "source should be an `amdaie.channel` op";
return sourceChannelOp;
}

FailureOr<SmallVector<AMDAIE::ChannelOp>> FlowOp::getTargetChannelOps() {
SmallVector<Value> targetChannels = getTargets();
SmallVector<AMDAIE::ChannelOp> targetChannelOps;
if (targetChannels.size() == 0)
return emitOpError() << "with no target channel is unsupported";
for (Value targetChannel : targetChannels) {
auto targetChannelOp =
dyn_cast_if_present<AMDAIE::ChannelOp>(targetChannel.getDefiningOp());
if (!targetChannelOp)
return emitOpError() << "target should be an `amdaie.channel` op";
targetChannelOps.push_back(targetChannelOp);
}
return targetChannelOps;
}

FailureOr<bool> FlowOp::isControlFlow() {
// Fetch source channel.
auto maybeSourceChannelOp = getSourceChannelOp();
if (failed(maybeSourceChannelOp)) return failure();
AMDAIE::ChannelOp sourceChannelOp = *maybeSourceChannelOp;
// Fetch target channels.
auto maybeTargetChannelOps = getTargetChannelOps();
if (failed(maybeTargetChannelOps)) return failure();
// Check source port type first.
if (sourceChannelOp.getPortType() == StrmSwPortType::CTRL) return true;
// Check if any target port type is `CTRL`.
return llvm::any_of(
*maybeTargetChannelOps, [](AMDAIE::ChannelOp targetChannelOp) {
return targetChannelOp.getPortType() == StrmSwPortType::CTRL;
});
}

LogicalResult FlowOp::verify() {
if (getSources().size() > 1 && getTargets().size() > 1) {
return emitOpError()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ def AMDAIE_FlowOp: AMDAIE_Op<"flow", [AttrSizedOperandSegments]>,

let assemblyFormat = [{ `(` `{` $sources `}` `->` `{` $targets `}` `)` attr-dict }];
let hasVerifier = 1;

let extraClassDeclaration = [{
FailureOr<AMDAIE::ChannelOp> getSourceChannelOp();
FailureOr<SmallVector<AMDAIE::ChannelOp>> getTargetChannelOps();
FailureOr<bool> isControlFlow();
}];
}

def AMDAIE_TileOp: AMDAIE_Op<"tile", [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "iree-amd-aie/Transforms/Passes.h"
#include "iree-amd-aie/Transforms/Transforms.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h"
#include "llvm/ADT/STLExtras.h"

#define DEBUG_TYPE "iree-amdaie-assign-packet-ids"

Expand Down Expand Up @@ -41,45 +42,50 @@ void AMDAIEAssignPacketIdsPass::runOnOperation() {
auto ui8ty =
IntegerType::get(rewriter.getContext(), 8, IntegerType::Unsigned);

// Collect all packet flow operations and categorize them into control and
// normal flows. Control packet flows will be prioritized for packet ID
// assignment.
SmallVector<AMDAIE::FlowOp> ctrlPktFlowOps;
SmallVector<AMDAIE::FlowOp> dataPktFlowOps;
WalkResult res = parentOp->walk([&](AMDAIE::FlowOp flowOp) {
if (!flowOp.getIsPacketFlow()) return WalkResult::advance();
FailureOr<bool> maybeIsControlFlow = flowOp.isControlFlow();
if (failed(maybeIsControlFlow)) return WalkResult::interrupt();
if (*maybeIsControlFlow) {
ctrlPktFlowOps.push_back(flowOp);
} else {
dataPktFlowOps.push_back(flowOp);
}
return WalkResult::advance();
});
if (res.wasInterrupted()) return signalPassFailure();
SmallVector<AMDAIE::FlowOp> allPktFlowOps = std::move(ctrlPktFlowOps);
allPktFlowOps.append(std::make_move_iterator(dataPktFlowOps.begin()),
std::make_move_iterator(dataPktFlowOps.end()));

// Perform assignment of packet IDs based on the source channels of the flow
// ops. I.e. `amdaie.flow` ops with the same source channel will get a
// different packet IDs assigned to accommodate multiple data packets being
// routed through the same ports.
DenseMap<AMDAIE::ChannelOp, size_t> channelToPktFlowIndex;
WalkResult res =
parentOp->walk([&](AMDAIE::FlowOp flowOp) {
if (!flowOp.getIsPacketFlow()) return WalkResult::advance();
SmallVector<Value> sourceChannels = flowOp.getSources();
if (sourceChannels.size() == 0) {
flowOp.emitOpError() << "with no source channel is unsupported";
return WalkResult::interrupt();
}
if (sourceChannels.size() > 1) {
flowOp.emitOpError()
<< "with multiple source channels is unsupported";
return WalkResult::interrupt();
}
auto sourceChannelOp = dyn_cast_if_present<AMDAIE::ChannelOp>(
sourceChannels[0].getDefiningOp());
if (!sourceChannelOp) {
flowOp.emitOpError() << "source should be an `amdaie.channel` op";
return WalkResult::interrupt();
}
size_t pktFlowIndex = channelToPktFlowIndex[sourceChannelOp];
if (pktFlowIndex > deviceModel.getPacketIdMaxIdx()) {
flowOp.emitOpError()
<< "ran out of packet IDs to assign for source channel";
return WalkResult::interrupt();
}
IntegerAttr pktIdAttr = IntegerAttr::get(ui8ty, pktFlowIndex);
rewriter.setInsertionPoint(flowOp);
rewriter.replaceOpWithNewOp<AMDAIE::FlowOp>(
flowOp, flowOp.getSources(), flowOp.getTargets(),
flowOp.getIsPacketFlow(), pktIdAttr);
channelToPktFlowIndex[sourceChannelOp]++;
return WalkResult::advance();
});
if (res.wasInterrupted()) return signalPassFailure();
for (AMDAIE::FlowOp flowOp : allPktFlowOps) {
FailureOr<AMDAIE::ChannelOp> maybeSourceChannelOp =
flowOp.getSourceChannelOp();
if (failed(maybeSourceChannelOp)) return signalPassFailure();
AMDAIE::ChannelOp sourceChannelOp = *maybeSourceChannelOp;
size_t pktFlowIndex = channelToPktFlowIndex[sourceChannelOp];
if (pktFlowIndex > deviceModel.getPacketIdMaxIdx()) {
flowOp.emitOpError()
<< "ran out of packet IDs to assign for source channel";
return signalPassFailure();
}
IntegerAttr pktIdAttr = IntegerAttr::get(ui8ty, pktFlowIndex);
rewriter.setInsertionPoint(flowOp);
rewriter.replaceOpWithNewOp<AMDAIE::FlowOp>(
flowOp, flowOp.getSources(), flowOp.getTargets(),
flowOp.getIsPacketFlow(), pktIdAttr);
channelToPktFlowIndex[sourceChannelOp]++;
}
}

} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,55 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
return
}
}


// -----

// Test that control packets should take priority (by getting `packet_id=0`) in the ID assignment.
// CHECK-LABEL: @assign_ctrl_packet_ids_in_priority
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: amdaie.workgroup
// CHECK: %[[TILE_0_0:.*]] = amdaie.tile(%[[C0]], %[[C0]])
// CHECK: %[[TILE_0_1:.*]] = amdaie.tile(%[[C0]], %[[C1]])
// CHECK: %[[TILE_1_0:.*]] = amdaie.tile(%[[C1]], %[[C0]])
// CHECK: %[[TILE_1_2:.*]] = amdaie.tile(%[[C1]], %[[C2]])
// CHECK: %[[CHANNEL:.*]] = amdaie.channel(%[[TILE_0_0]], 0, port_type = DMA, direction = MM2S)
// CHECK: %[[CHANNEL_1:.*]] = amdaie.channel(%[[TILE_0_1]], 0, port_type = DMA, direction = S2MM)
// CHECK: %[[CHANNEL_2:.*]] = amdaie.channel(%[[TILE_0_1]], 0, port_type = CTRL, direction = S2MM)
// CHECK: %[[CHANNEL_3:.*]] = amdaie.channel(%[[TILE_1_0]], 1, port_type = DMA, direction = MM2S)
// CHECK: %[[CHANNEL_4:.*]] = amdaie.channel(%[[TILE_1_2]], 1, port_type = DMA, direction = S2MM)
// CHECK: %[[CHANNEL_5:.*]] = amdaie.channel(%[[TILE_1_2]], 0, port_type = CTRL, direction = S2MM)
// CHECK: amdaie.flow({%[[CHANNEL]]} -> {%[[CHANNEL_1]]}) {is_packet_flow = true, packet_id = 1 : ui8}
// CHECK: amdaie.flow({%[[CHANNEL]]} -> {%[[CHANNEL_2]]}) {is_packet_flow = true, packet_id = 0 : ui8}
// CHECK: amdaie.flow({%[[CHANNEL_3]]} -> {%[[CHANNEL_4]]}) {is_packet_flow = true, packet_id = 1 : ui8}
// CHECK: amdaie.flow({%[[CHANNEL_3]]} -> {%[[CHANNEL_5]]}) {is_packet_flow = true, packet_id = 0 : ui8}
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @assign_ctrl_packet_ids_in_priority() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
amdaie.workgroup {
%tile_0_0 = amdaie.tile(%c0, %c0)
%tile_0_1 = amdaie.tile(%c0, %c1)
%tile_1_0 = amdaie.tile(%c1, %c0)
%tile_1_2 = amdaie.tile(%c1, %c2)
%channel = amdaie.channel(%tile_0_0, 0, port_type = DMA, direction = MM2S)
%channel_1 = amdaie.channel(%tile_0_1, 0, port_type = DMA, direction = S2MM)
%channel_2 = amdaie.channel(%tile_0_1, 0, port_type = CTRL, direction = S2MM)
%channel_3 = amdaie.channel(%tile_1_0, 1, port_type = DMA, direction = MM2S)
%channel_4 = amdaie.channel(%tile_1_2, 1, port_type = DMA, direction = S2MM)
%channel_5 = amdaie.channel(%tile_1_2, 0, port_type = CTRL, direction = S2MM)
%0 = amdaie.flow({%channel} -> {%channel_1}) {is_packet_flow = true}
%1 = amdaie.flow({%channel} -> {%channel_2}) {is_packet_flow = true}
%2 = amdaie.flow({%channel_3} -> {%channel_4}) {is_packet_flow = true}
%3 = amdaie.flow({%channel_3} -> {%channel_5}) {is_packet_flow = true}
amdaie.controlcode {
amdaie.end
}
}
return
}
}
Loading