Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AssignChannels] Prioritize channel assignment for control packets #1106

Merged
merged 3 commits into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,62 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

/// Initializes channel generators for tiles by detecting DMA channels
/// previously assigned by other passes (e.g., for control packets) and
/// registering them to prevent conflicts.
LogicalResult initializeChannelsGenerators(
AMDAIE::WorkgroupOp workgroupOp, const AMDAIEDeviceModel &deviceModel,
DenseMap<Value, ChannelGenerator> &tileToGeneratorMap) {
// Get the number of producer and consumer channels for each tile.
workgroupOp.walk([&](AMDAIE::TileOp tileOp) {
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
uint32_t row = getConstantIndexOrAssert(tileOp.getRow());
AMDAIETileType tileType = deviceModel.getTileType(col, row);
uint8_t numDmaChannels =
deviceModel.getDmaProp<uint8_t>(tileType, AMDAIEDmaProp::NumChannels);
tileToGeneratorMap[tileOp.getResult()] =
ChannelGenerator(numDmaChannels, numDmaChannels);
});

WalkResult res = workgroupOp.walk([&](AMDAIE::ConnectionOp connectionOp) {
ChannelAssignmentMode mode =
(connectionOp.getConnectionType() == AMDAIE::ConnectionType::Packet)
? ChannelAssignmentMode::RoundRobinPacketFlow
: ChannelAssignmentMode::FirstAvailableCircuitFlow;
// Check source DMA channels previously assigned by other passes,
// and register them in `ChannelGenerator` using `assignProducerDMAChannel`.
for (Value source : connectionOp.getSourceChannels()) {
auto channelOp = dyn_cast<AMDAIE::ChannelOp>(source.getDefiningOp());
if (!channelOp) {
connectionOp.emitOpError() << "expected a `amdaie.channel` op source";
return WalkResult::interrupt();
}
if (channelOp.getPortType() == StrmSwPortType::DMA) {
Value tile = channelOp.getTileOp().getResult();
tileToGeneratorMap[tile].assignProducerDMAChannel(channelOp.getValue(),
mode);
}
}
// Check target DMA channels previously assigned by other passes,
// and register them in `ChannelGenerator` using `assignConsumerDMAChannel`.
for (Value target : connectionOp.getTargetChannels()) {
auto channelOp = dyn_cast<AMDAIE::ChannelOp>(target.getDefiningOp());
if (!channelOp) {
connectionOp.emitOpError() << "expected a `amdaie.channel` op target";
return WalkResult::interrupt();
}
if (channelOp.getPortType() == StrmSwPortType::DMA) {
Value tile = channelOp.getTileOp().getResult();
tileToGeneratorMap[tile].assignConsumerDMAChannel(channelOp.getValue(),
mode);
}
}
return WalkResult::advance();
});
if (res.wasInterrupted()) return failure();
return success();
}

/// Assign channels to `amdaie.connection` ops.
LogicalResult assignChannels(AMDAIE::WorkgroupOp workgroupOp) {
IRRewriter rewriter(workgroupOp->getContext());
Expand All @@ -27,19 +83,13 @@ LogicalResult assignChannels(AMDAIE::WorkgroupOp workgroupOp) {
<< "could not find an AMDAIEDevice attribute";
}
AMDAIEDeviceModel deviceModel = AMDAIE::getDeviceModel(device.value());

// Get the number of producer and consumer channels for each tile.
// Initialize channel generators for tiles.
DenseMap<Value, ChannelGenerator> tileToGeneratorMap;
workgroupOp.walk([&](AMDAIE::TileOp tileOp) {
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
uint32_t row = getConstantIndexOrAssert(tileOp.getRow());
AMDAIETileType tileType = deviceModel.getTileType(col, row);
uint8_t numDmaChannels =
deviceModel.getDmaProp<uint8_t>(tileType, AMDAIEDmaProp::NumChannels);
tileToGeneratorMap[tileOp.getResult()] =
ChannelGenerator(numDmaChannels, numDmaChannels);
});

if (failed(initializeChannelsGenerators(workgroupOp, deviceModel,
tileToGeneratorMap))) {
return failure();
}
// Get all `amdaie.connection` ops.
SmallVector<AMDAIE::ConnectionOp> connectionOps;
workgroupOp->walk([&](AMDAIE::ConnectionOp connectionOp) {
connectionOps.push_back(connectionOp);
Expand All @@ -59,48 +109,49 @@ LogicalResult assignChannels(AMDAIE::WorkgroupOp workgroupOp) {
return connectionOp.emitOpError()
<< "expected a `LogicalObjFifoOpInterface` target";
}
std::optional<AMDAIE::ConnectionType> connectionType =
connectionOp.getConnectionType();
bool isPacketFlow = connectionType && connectionType.value() ==
AMDAIE::ConnectionType::Packet;

ChannelAssignmentMode mode =
(connectionOp.getConnectionType() == AMDAIE::ConnectionType::Packet)
? ChannelAssignmentMode::RoundRobinPacketFlow
: ChannelAssignmentMode::FirstAvailableCircuitFlow;
rewriter.setInsertionPoint(connectionOp);
SmallVector<Value> sourceChannels;
for (Value tile : sourceLogicalObjFifo.getTiles()) {
assert(tileToGeneratorMap.contains(tile) &&
"no channel generator found for tile");
std::optional<uint8_t> maybeChannel =
tileToGeneratorMap[tile].getProducerDMAChannel();
if (!maybeChannel) {
return connectionOp.emitOpError()
<< "no producer DMA channel available";
SmallVector<Value> sourceChannels = connectionOp.getSourceChannels();
// Assign source (producer) DMA channels if not already assigned.
if (sourceChannels.empty()) {
for (Value tile : sourceLogicalObjFifo.getTiles()) {
assert(tileToGeneratorMap.contains(tile) &&
"no channel generator found for tile");
std::optional<uint8_t> maybeChannel =
tileToGeneratorMap[tile].getAndAssignProducerDMAChannel(mode);
if (!maybeChannel) {
return connectionOp.emitOpError()
<< "no producer DMA channel available";
}
auto channelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), tile, maybeChannel.value(),
StrmSwPortType::DMA, AMDAIE::DMAChannelDir::MM2S);
sourceChannels.push_back(channelOp.getResult());
}
// Only assign the channel if it is for circuit flow.
if (!isPacketFlow)
tileToGeneratorMap[tile].assignProducerDMAChannel(maybeChannel.value());
auto channelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), tile, maybeChannel.value(),
StrmSwPortType::DMA, AMDAIE::DMAChannelDir::MM2S);
sourceChannels.push_back(channelOp.getResult());
}
SmallVector<Value> targetChannels;
for (Value tile : targetLogicalObjFifo.getTiles()) {
assert(tileToGeneratorMap.contains(tile) &&
"no channel generator found for tile");
std::optional<uint8_t> maybeChannel =
tileToGeneratorMap[tile].getConsumerDMAChannel();
if (!maybeChannel) {
return connectionOp.emitOpError()
<< "no consumer DMA channel available";
// Assign target (consumer) DMA channels if not already assigned.
SmallVector<Value> targetChannels = connectionOp.getTargetChannels();
if (targetChannels.empty()) {
for (Value tile : targetLogicalObjFifo.getTiles()) {
assert(tileToGeneratorMap.contains(tile) &&
"no channel generator found for tile");
std::optional<uint8_t> maybeChannel =
tileToGeneratorMap[tile].getAndAssignConsumerDMAChannel(mode);
if (!maybeChannel) {
return connectionOp.emitOpError()
<< "no consumer DMA channel available";
}
auto channelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), tile, maybeChannel.value(),
StrmSwPortType::DMA, AMDAIE::DMAChannelDir::S2MM);
targetChannels.push_back(channelOp.getResult());
}
// Only assign the channel if it is for circuit flow.
if (!isPacketFlow)
tileToGeneratorMap[tile].assignConsumerDMAChannel(maybeChannel.value());
auto channelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), tile, maybeChannel.value(),
StrmSwPortType::DMA, AMDAIE::DMAChannelDir::S2MM);
targetChannels.push_back(channelOp.getResult());
}
// Replace the `amdaie.connection` op with newly assigned `sourceChannels`
// and `targetChannels`.
rewriter.replaceOpWithNewOp<AMDAIE::ConnectionOp>(
connectionOp, connectionOp.getTarget(), targetChannels,
connectionOp.getSource(), sourceChannels,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

/// Initializes the channel generators for the shim tiles, excluding any
/// channels that are already in use by existing circuit-mode connections.
/// Initializes channel generators for shim tiles, ensuring that no shim DMA
/// MM2S channels have been assigned before. This guarantees priority for the
/// control overlay.
LogicalResult initializeChannelsGenerators(
AMDAIE::WorkgroupOp workgroupOp, const AMDAIEDeviceModel &deviceModel,
const DenseSet<TileOp> &shimTileOps,
Expand All @@ -29,40 +30,19 @@ LogicalResult initializeChannelsGenerators(
shimTileToGeneratorMap[shimTileOp.getResult()] =
ChannelGenerator(numShimDmaChannels, numShimDmaChannels);
});
// Exclude those channels that are already used by a circuit-mode connection.
workgroupOp->walk([&](AMDAIE::ConnectionOp connectionOp) {
std::optional<AMDAIE::ConnectionType> connectionType =
connectionOp.getConnectionType();
bool isPacketFlow = connectionType && connectionType.value() ==
AMDAIE::ConnectionType::Packet;
if (isPacketFlow) return WalkResult::advance();
SmallVector<AMDAIE::ChannelOp> sourceChannels;
for (Value source : connectionOp.getSourceChannels()) {
if (auto channelOp =
dyn_cast<AMDAIE::ChannelOp>(source.getDefiningOp())) {
sourceChannels.push_back(channelOp);
}
}
for (AMDAIE::ChannelOp channelOp : sourceChannels) {
AMDAIE::TileOp tileOp = channelOp.getTileOp();
uint8_t channel = channelOp.getValue();
StrmSwPortType portType = channelOp.getPortType();
AMDAIE::DMAChannelDir direction = channelOp.getDirection();
if (shimTileOps.contains(tileOp) && portType == StrmSwPortType::DMA) {
// Assign to exclude.
if (direction == AMDAIE::DMAChannelDir::MM2S) {
shimTileToGeneratorMap[tileOp.getResult()].assignProducerDMAChannel(
channel);
} else if (direction == AMDAIE::DMAChannelDir::S2MM) {
shimTileToGeneratorMap[tileOp.getResult()].assignConsumerDMAChannel(
channel);
} else {
assert(false && "unexpected DMA channel direction");
}
}
// Ensure that shim DMA MM2S channels are not already assigned.
WalkResult res = workgroupOp->walk([&](AMDAIE::ChannelOp channelOp) {
if (shimTileOps.contains(channelOp.getTileOp()) &&
channelOp.getPortType() == StrmSwPortType::DMA &&
channelOp.getDirection() == AMDAIE::DMAChannelDir::MM2S) {
channelOp.emitOpError()
<< "shim DMA MM2S channel must remain unassigned before "
"control overlay generation.";
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (res.wasInterrupted()) return failure();
return success();
}

Expand Down Expand Up @@ -114,11 +94,12 @@ LogicalResult generateControlOverlay(AMDAIE::WorkgroupOp workgroupOp,
WalkResult res = workgroupOp->walk([&](AMDAIE::TileOp tileOp) {
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
TileOp shimTileOp = columnToShimTile[col];
// Get the available channel, but do not assign it. Allow it to be
// shared across multiple packet-mode connections as needed.
// Get the available DMA channel for the shim tile, and assign it for the
// packet flow.
std::optional<uint8_t> maybeChannel =
shimTileToGeneratorMap[shimTileOp.getResult()]
.getProducerDMAChannel();
.getAndAssignProducerDMAChannel(
ChannelAssignmentMode::RoundRobinPacketFlow);
if (!maybeChannel) {
shimTileOp.emitOpError() << "no producer DMA channel available";
return WalkResult::interrupt();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,8 @@ void addAMDAIEObjectFifoLoweringPasses(
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createAMDAIEDmaCSEPass());

passManager.addPass(createAMDAIEGenerateControlOverlayPass());

passManager.addPass(createAMDAIEAssignChannelsPass());
passManager.addPass(createCSEPass());
passManager.addPass(createCanonicalizerPass());
Expand All @@ -881,8 +883,6 @@ void addAMDAIEObjectFifoLoweringPasses(
passManager.addPass(createAMDAIEObjFifoBufferizationPass());
passManager.addPass(createAMDAIETemporaryAllocBufferizationPass());

passManager.addPass(createAMDAIEGenerateControlOverlayPass());

passManager.addPass(createAMDAIEConnectionToFlowPass());
passManager.addPass(createAMDAIEAssignPacketIdsPass());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,40 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
return
}
}

// -----

// For tile (0,0), its producer (MM2S) channel 0 is already assigned
// to a control packet flow. Therefore, channel 1 is used to connect to tile (0,1).
// CHECK-LABEL: @previously_assigned
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: amdaie.workgroup
// CHECK: %[[tile_0_0:.+]] = amdaie.tile(%[[C0]], %[[C0]])
// CHECK: %[[tile_0_1:.+]] = amdaie.tile(%[[C0]], %[[C1]])
// CHECK: %[[CHANNEL_0:.+]] = amdaie.channel(%[[tile_0_0]], 1, port_type = DMA, direction = MM2S)
// CHECK: %[[CHANNEL_1:.+]] = amdaie.channel(%[[tile_0_1]], 0, port_type = DMA, direction = S2MM)
// CHECK: amdaie.connection(%{{.+}} {%[[CHANNEL_1]]}, %{{.+}} {%[[CHANNEL_0]]})
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @previously_assigned(%arg0: memref<1x1x8x16xi32, 1>, %arg1: memref<8x16xi32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
amdaie.workgroup {
%tile_0_0 = amdaie.tile(%c0, %c0)
%tile_0_1 = amdaie.tile(%c0, %c1)
%0 = amdaie.logicalobjectfifo.from_memref %arg0, {%tile_0_1} : memref<1x1x8x16xi32, 1> -> !amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>
%1 = amdaie.logicalobjectfifo.from_memref %arg1, {%tile_0_0} : memref<8x16xi32> -> !amdaie.logicalobjectfifo<memref<8x16xi32>>
%2 = amdaie.connection(%0, %1) : (!amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, !amdaie.logicalobjectfifo<memref<8x16xi32>>)
%channel = amdaie.channel(%tile_0_0, 0, port_type = DMA, direction = MM2S)
%channel_0 = amdaie.channel(%tile_0_0, 0, port_type = CTRL, direction = S2MM)
%3 = amdaie.logicalobjectfifo.placeholder{%tile_0_0} : !amdaie.logicalobjectfifo<memref<?xi32>>
%4 = amdaie.logicalobjectfifo.placeholder{%tile_0_0} : !amdaie.logicalobjectfifo<memref<?xi32>>
%5 = amdaie.connection(%4 {%channel_0}, %3 {%channel}) {connection_type = #amdaie<connection_type Packet>} : (!amdaie.logicalobjectfifo<memref<?xi32>>, !amdaie.logicalobjectfifo<memref<?xi32>>)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test as well with a pre-existing circuit connection or update this one to include both?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the test at:

func.func @previously_assigned_circuit(%arg0: memref<1x1x8x16xi32, 1>, %arg1: memref<8x16xi32>) {

amdaie.controlcode {
amdaie.end
}
}
return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,23 @@ module {

// -----

// Shim tile (0, 0) has two producer (MM2S) channels,
// both of which are already utilized by existing circuit-mode connections.
// No producer DMA channel is available for route-shim-to-tile-ctrl.
/// No shim DMA channel can be assigned before control overlay generation.
/// This ensures that control packets have priority in resource allocation
/// and makes control packet routing static.
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @no_available_channel() {
func.func @priority_check(%arg0: memref<8x16xi32>, %arg1: memref<1x1x8x16xi32, 1>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
amdaie.workgroup {
// expected-error @+1 {{no producer DMA channel available}}
%tile_0_0 = amdaie.tile(%c0, %c0)
%tile_0_1 = amdaie.tile(%c0, %c1)
%0 = amdaie.logicalobjectfifo.placeholder{%tile_0_0} : !amdaie.logicalobjectfifo<memref<32xi32>>
%1 = amdaie.logicalobjectfifo.placeholder{%tile_0_1} : !amdaie.logicalobjectfifo<memref<32xi32>>
%2 = amdaie.logicalobjectfifo.placeholder{%tile_0_0} : !amdaie.logicalobjectfifo<memref<32xi32>>
%3 = amdaie.logicalobjectfifo.placeholder{%tile_0_1} : !amdaie.logicalobjectfifo<memref<32xi32>>
%0 = amdaie.logicalobjectfifo.from_memref %arg0, {%tile_0_0} : memref<8x16xi32> -> !amdaie.logicalobjectfifo<memref<8x16xi32>>
%1 = amdaie.logicalobjectfifo.from_memref %arg1, {%tile_0_1} : memref<1x1x8x16xi32, 1> -> !amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>
// expected-error @+1 {{shim DMA MM2S channel must remain unassigned before control overlay generation}}
%channel_0 = amdaie.channel(%tile_0_0, 0, port_type = DMA, direction = MM2S)
%channel_1 = amdaie.channel(%tile_0_1, 0, port_type = DMA, direction = S2MM)
%connection_0 = amdaie.connection(%1 {%channel_1}, %0 {%channel_0}) : (!amdaie.logicalobjectfifo<memref<32xi32>>, !amdaie.logicalobjectfifo<memref<32xi32>>)
%channel_2 = amdaie.channel(%tile_0_0, 1, port_type = DMA, direction = MM2S)
%channel_3 = amdaie.channel(%tile_0_1, 1, port_type = DMA, direction = S2MM)
%connection_1 = amdaie.connection(%3 {%channel_3}, %2 {%channel_2}) : (!amdaie.logicalobjectfifo<memref<32xi32>>, !amdaie.logicalobjectfifo<memref<32xi32>>)
%connection_0 = amdaie.connection(%0 {%channel_0}, %1 {%channel_1}) : (!amdaie.logicalobjectfifo<memref<8x16xi32>>, !amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>)
amdaie.controlcode {
amdaie.end
}
Expand Down
Loading
Loading