Skip to content

Commit

Permalink
#2377: Added support for collective permute operation with all suppor…
Browse files Browse the repository at this point in the history
…ted multi device configurations
  • Loading branch information
tapspatel committed Mar 8, 2025
1 parent 1bb1440 commit 1702f76
Show file tree
Hide file tree
Showing 23 changed files with 1,168 additions and 24 deletions.
17 changes: 17 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2028,6 +2028,23 @@ def TTIR_ReduceScatterOp : TTIR_DPSOp<"reduce_scatter"> {
let hasVerifier = 1;
}

def TTIR_CollectivePermuteOp : TTIR_DPSOp<"collective_permute"> {
let summary = "Collective permute operation.";
let description = [{
Collective permute op.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
I64ElementsAttr:$source_target_pairs);

let results = (outs AnyRankedTensor:$result);
let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];
let hasVerifier = 1;
}

def TTIR_MeshShardOp : TTIR_NamedOp<"mesh_shard"> {
let summary = "Mesh shard operation.";
let description = [{
Expand Down
15 changes: 15 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,21 @@ def TTNN_AllReduceOp: TTNN_Op<"all_reduce"> {
let hasVerifier = 1;
}

def TTNN_CollectivePermuteOp: TTNN_Op<"collective_permute"> {
let summary = "Collective permute op.";
let description = [{
Tensor Collective Permute operation
}];

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device,
I64ElementsAttr:$source_target_pairs);

let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
}

def TTNN_MeshShardOp: TTNN_Op<"mesh_shard"> {
let summary = "Mesh shard op.";
let description = [{
Expand Down
8 changes: 8 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,13 @@ table ReduceScatterOp {
num_links: uint32;
}

table CollectivePermuteOp {
in: tt.target.ttnn.TensorRef;
out: tt.target.ttnn.TensorRef;
device: tt.target.DeviceRef;
source_target_pairs: [int64];
}

table MeshShardOp {
in: tt.target.ttnn.TensorRef;
out: tt.target.ttnn.TensorRef;
Expand Down Expand Up @@ -475,6 +482,7 @@ union OpType {
AllGatherOp,
ReduceScatterOp,
MeshShardOp,
CollectivePermuteOp,
ArangeOp,
UpdateCacheOp,
FillCacheOp,
Expand Down
39 changes: 39 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1740,6 +1740,43 @@ class StableHLOToTTIRAllGatherOpConversionPattern
};
} // namespace

namespace {
class StableHLOToTTIRCollectivePermuteOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::CollectivePermuteOp> {
using OpConversionPattern<
mlir::stablehlo::CollectivePermuteOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::stablehlo::CollectivePermuteOp srcOp,
mlir::stablehlo::CollectivePermuteOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Create the output tensor type based on inputs
auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult().getType()));

if (auto srcChannelHandleAttr = adaptor.getChannelHandleAttr()) {
// channelType is supposed to be DEVICE_TO_DEVICE or Invalid for CCL ops.
// Currently, we ensure if it is DEVICE_TO_DEVICE commmuincaiton.
// Consider preserving this information in the future if the attribute
// is non-DEVICE_TO_DEVICE values.
auto channelType =
static_cast<StableHLOChannelType>(srcChannelHandleAttr.getType());
if (channelType != StableHLOChannelType::kChannelTypeDeviceToDevice &&
channelType != StableHLOChannelType::kChannelTypeInvalid) {
return failure();
}
}

ttmlir::utils::replaceOpWithNewDPSOp<mlir::tt::ttir::CollectivePermuteOp>(
rewriter, srcOp, outputType, adaptor.getOperand(),
adaptor.getSourceTargetPairs());

return success();
}
};
} // namespace

namespace {
class StableHLOToTTIRCustomCallOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::CustomCallOp> {
Expand Down Expand Up @@ -2413,6 +2450,8 @@ static void addCCLOpsConversionPattern(MLIRContext *ctx,
patterns.add<StableHLOToTTIRAllGatherOpConversionPattern>(typeConverter, ctx);
patterns.add<StableHLOToTTIRReduceScatterOpConversionPattern>(typeConverter,
ctx);
patterns.add<StableHLOToTTIRCollectivePermuteOpConversionPattern>(
typeConverter, ctx);
patterns.add<StableHLOToTTIRCustomCallOpConversionPattern>(typeConverter,
ctx);
}
Expand Down
23 changes: 23 additions & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1425,6 +1425,28 @@ class AllGatherOpConversionPattern
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), device, adaptor.getAllGatherDim(),
static_cast<uint32_t>(adaptor.getClusterAxis()));

return success();
}
};
} // namespace

namespace {
class CollectivePermuteOpConversionPattern
: public OpConversionPattern<ttir::CollectivePermuteOp> {
public:
using OpConversionPattern<ttir::CollectivePermuteOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::CollectivePermuteOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op);

rewriter.replaceOpWithNewOp<ttnn::CollectivePermuteOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), device, adaptor.getSourceTargetPairs());

return success();
}
};
Expand Down Expand Up @@ -1616,6 +1638,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
AllReduceOpConversionPattern,
AllGatherOpConversionPattern,
ReduceScatterOpConversionPattern,
CollectivePermuteOpConversionPattern,
ArangeOpConversionPattern,
UpdateCacheOpConversionPattern,
FillCacheOpConversionPattern,
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
ctx);
patterns.add<DefaultOpConversionPattern<tt::ttnn::ReduceScatterOp>>(
typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<tt::ttnn::CollectivePermuteOp>>(
typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<tt::ttnn::MeshShardOp>>(typeConverter,
ctx);

Expand Down
49 changes: 49 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2412,6 +2412,55 @@ ::mlir::LogicalResult mlir::tt::ttir::ReduceScatterOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// CollectivePermuteOp
//===----------------------------------------------------------------------===//

// CollectivePermuteOp verification
::mlir::LogicalResult mlir::tt::ttir::CollectivePermuteOp::verify() {
auto sourceTargetPairs = getSourceTargetPairs().getValues<int64_t>();

// Check that the rank of sourceTargetPairs is 2D
llvm::ArrayRef<int64_t> sourceTargetPairsShape =
getSourceTargetPairs().getType().getShape();
const size_t sourceTargetPairsRank = sourceTargetPairsShape.size();

if (sourceTargetPairsRank != 2) {
return emitOpError("The rank of source target pairs must be 2, got rank = ")
<< sourceTargetPairsRank;
}

/* Check that the 'src' values and 'dest' values in sourceTargetPairs is
unique Given a 2D rank tensor of source target pairs eg. [['src', 'target'],
['src', 'target'] ...], we need to ensure that each 'src' is unique and each
'target' is unique.
*/
auto areElementsUnique = [](const auto &sourceTargetPairs) -> bool {
for (size_t i = 0; i < sourceTargetPairs.size(); i++) {
int count = 0;
int target = sourceTargetPairs[i];
for (size_t j = i; j < sourceTargetPairs.size(); j += 2) {
if (sourceTargetPairs[j] == target) {
count++;
}
}

if (count != 1) {
return false;
}
}

return true;
};

if (!areElementsUnique(sourceTargetPairs)) {
return emitOpError(
"There are duplicate 'src' or 'dest' devices in source target pairs");
}

return success();
}

//===----------------------------------------------------------------------===//
// MeshShardOp
//===----------------------------------------------------------------------===//
Expand Down
49 changes: 49 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,55 @@ ::mlir::LogicalResult AllReduceOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// CollectivePermuteOp
//===----------------------------------------------------------------------===//

// CollectivePermuteOp verification
::mlir::LogicalResult CollectivePermuteOp::verify() {
auto sourceTargetPairs = getSourceTargetPairs().getValues<int64_t>();

// Check that the rank of sourceTargetPairs is 2D
llvm::ArrayRef<int64_t> sourceTargetPairsShape =
getSourceTargetPairs().getType().getShape();
const size_t sourceTargetPairsRank = sourceTargetPairsShape.size();

if (sourceTargetPairsRank != 2) {
return emitOpError("The rank of source target pairs must be 2, got rank = ")
<< sourceTargetPairsRank;
}

/* Check that the 'src' values and 'dest' values in sourceTargetPairs is
unique Given a 2D rank tensor of source target pairs eg. [['src', 'target'],
['src', 'target'] ...], we need to ensure that each 'src' is unique and each
'target' is unique.
*/
auto areElementsUnique = [](const auto &sourceTargetPairs) -> bool {
for (size_t i = 0; i < sourceTargetPairs.size(); i++) {
int count = 0;
int target = sourceTargetPairs[i];
for (size_t j = i; j < sourceTargetPairs.size(); j += 2) {
if (sourceTargetPairs[j] == target) {
count++;
}
}

if (count != 1) {
return false;
}
}

return true;
};

if (!areElementsUnique(sourceTargetPairs)) {
return emitOpError(
"There are duplicate 'src' or 'dest' devices in source target pairs");
}

return success();
}

//===----------------------------------------------------------------------===//
// MeshShardOp
//===----------------------------------------------------------------------===//
Expand Down
20 changes: 20 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,21 @@ createOp(FlatbufferObjectCache &cache, ReduceScatterOp op) {
op.getClusterAxis(), op.getNumLinks());
}

::flatbuffers::Offset<::tt::target::ttnn::CollectivePermuteOp>
createOp(FlatbufferObjectCache &cache, CollectivePermuteOp op) {
auto input = cache.at<::tt::target::ttnn::TensorRef>(
getOperandThroughDPSOps(op.getInput()));
auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
kHostAllocatedSize);
auto device = getOperandThroughDPSOps(op.getDevice());
auto sourceTargetPairs = op.getSourceTargetPairs().getValues<int64_t>();
std::vector<int64_t> sourceTargetPairsVec(sourceTargetPairs.begin(),
sourceTargetPairs.end());
return ::tt::target::ttnn::CreateCollectivePermuteOp(
*cache.fbb, input, output, cache.at<::tt::target::DeviceRef>(device),
cache.fbb->CreateVector<int64_t>(sourceTargetPairsVec));
}

::flatbuffers::Offset<::tt::target::ttnn::MeshShardOp>
createOp(FlatbufferObjectCache &cache, MeshShardOp op) {
auto input = cache.at<::tt::target::ttnn::TensorRef>(
Expand Down Expand Up @@ -1689,6 +1704,11 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createOp(cache, reduceScatterOp), debugString,
locInfo);
}
if (auto collectivePermuteOp = dyn_cast<CollectivePermuteOp>(op);
collectivePermuteOp) {
return createOperation(cache, createOp(cache, collectivePermuteOp),
debugString, locInfo);
}
if (auto meshShardOp = dyn_cast<MeshShardOp>(op); meshShardOp) {
return createOperation(cache, createOp(cache, meshShardOp), debugString,
locInfo);
Expand Down
1 change: 1 addition & 0 deletions runtime/lib/ttnn/operations/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ set(TTNN_OPS_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/reduce_scatter.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/collective_permute.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/mesh_shard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/creation/arange.cpp
Expand Down
Loading

0 comments on commit 1702f76

Please sign in to comment.