diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 5afbc2739c..f9eb6e4f5a 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -2028,6 +2028,23 @@ def TTIR_ReduceScatterOp : TTIR_DPSOp<"reduce_scatter"> { let hasVerifier = 1; } +def TTIR_CollectivePermuteOp : TTIR_DPSOp<"collective_permute"> { + let summary = "Collective permute operation."; + let description = [{ + Collective permute op. + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$output, + I64ElementsAttr:$source_target_pairs); + + let results = (outs AnyRankedTensor:$result); + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + let hasVerifier = 1; +} + def TTIR_MeshShardOp : TTIR_NamedOp<"mesh_shard"> { let summary = "Mesh shard operation."; let description = [{ diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 4b49e7a88a..792076f4c5 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -1502,6 +1502,21 @@ def TTNN_AllReduceOp: TTNN_Op<"all_reduce"> { let hasVerifier = 1; } +def TTNN_CollectivePermuteOp: TTNN_Op<"collective_permute"> { + let summary = "Collective permute op."; + let description = [{ + Tensor Collective Permute operation + }]; + + let arguments = (ins AnyRankedTensor:$input, + TT_Device:$device, + I64ElementsAttr:$source_target_pairs); + + let results = (outs AnyRankedTensor:$result); + + let hasVerifier = 1; +} + def TTNN_MeshShardOp: TTNN_Op<"mesh_shard"> { let summary = "Mesh shard op."; let description = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index b618a59949..e2fbf48814 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -398,6 +398,13 @@ table ReduceScatterOp { num_links: uint32; } +table CollectivePermuteOp { + in: tt.target.ttnn.TensorRef; + out: tt.target.ttnn.TensorRef; + device: tt.target.DeviceRef; + source_target_pairs: [int64]; +} + table MeshShardOp { in: tt.target.ttnn.TensorRef; out: tt.target.ttnn.TensorRef; @@ -475,6 +482,7 @@ union OpType { AllGatherOp, ReduceScatterOp, MeshShardOp, + CollectivePermuteOp, ArangeOp, UpdateCacheOp, FillCacheOp, diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 3d790fbfa8..5aa0d91b8f 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -1740,6 +1740,43 @@ class StableHLOToTTIRAllGatherOpConversionPattern }; } // namespace +namespace { +class StableHLOToTTIRCollectivePermuteOpConversionPattern + : public OpConversionPattern { + using OpConversionPattern< + mlir::stablehlo::CollectivePermuteOp>::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(mlir::stablehlo::CollectivePermuteOp srcOp, + mlir::stablehlo::CollectivePermuteOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Create the output tensor type based on inputs + auto outputType = mlir::cast( + getTypeConverter()->convertType(srcOp.getResult().getType())); + + if (auto srcChannelHandleAttr = adaptor.getChannelHandleAttr()) { + // channelType is supposed to be DEVICE_TO_DEVICE or Invalid for CCL ops. + // Currently, we ensure if it is DEVICE_TO_DEVICE commmuincaiton. + // Consider preserving this information in the future if the attribute + // is non-DEVICE_TO_DEVICE values. + auto channelType = + static_cast(srcChannelHandleAttr.getType()); + if (channelType != StableHLOChannelType::kChannelTypeDeviceToDevice && + channelType != StableHLOChannelType::kChannelTypeInvalid) { + return failure(); + } + } + + ttmlir::utils::replaceOpWithNewDPSOp( + rewriter, srcOp, outputType, adaptor.getOperand(), + adaptor.getSourceTargetPairs()); + + return success(); + } +}; +} // namespace + namespace { class StableHLOToTTIRCustomCallOpConversionPattern : public OpConversionPattern { @@ -2413,6 +2450,8 @@ static void addCCLOpsConversionPattern(MLIRContext *ctx, patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add( + typeConverter, ctx); patterns.add(typeConverter, ctx); } diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 673c347444..43a350618d 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -1425,6 +1425,28 @@ class AllGatherOpConversionPattern op, this->getTypeConverter()->convertType(op.getType()), adaptor.getInput(), device, adaptor.getAllGatherDim(), static_cast(adaptor.getClusterAxis())); + + return success(); + } +}; +} // namespace + +namespace { +class CollectivePermuteOpConversionPattern + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::CollectivePermuteOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op); + + rewriter.replaceOpWithNewOp( + op, this->getTypeConverter()->convertType(op.getType()), + adaptor.getInput(), device, adaptor.getSourceTargetPairs()); + return success(); } }; @@ -1616,6 +1638,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, AllReduceOpConversionPattern, AllGatherOpConversionPattern, ReduceScatterOpConversionPattern, + CollectivePermuteOpConversionPattern, ArangeOpConversionPattern, UpdateCacheOpConversionPattern, FillCacheOpConversionPattern, diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 2498eff1e6..2df8710c91 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -1215,6 +1215,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, ctx); patterns.add>( typeConverter, ctx); + patterns.add>( + typeConverter, ctx); patterns.add>(typeConverter, ctx); diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 061d199b8d..0d1ac8fb26 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -2412,6 +2412,55 @@ ::mlir::LogicalResult mlir::tt::ttir::ReduceScatterOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// CollectivePermuteOp +//===----------------------------------------------------------------------===// + +// CollectivePermuteOp verification +::mlir::LogicalResult mlir::tt::ttir::CollectivePermuteOp::verify() { + auto sourceTargetPairs = getSourceTargetPairs().getValues(); + + // Check that the rank of sourceTargetPairs is 2D + llvm::ArrayRef sourceTargetPairsShape = + getSourceTargetPairs().getType().getShape(); + const size_t sourceTargetPairsRank = sourceTargetPairsShape.size(); + + if (sourceTargetPairsRank != 2) { + return emitOpError("The rank of source target pairs must be 2, got rank = ") + << sourceTargetPairsRank; + } + + /* Check that the 'src' values and 'dest' values in sourceTargetPairs is + unique Given a 2D rank tensor of source target pairs eg. [['src', 'target'], + ['src', 'target'] ...], we need to ensure that each 'src' is unique and each + 'target' is unique. + */ + auto areElementsUnique = [](const auto &sourceTargetPairs) -> bool { + for (size_t i = 0; i < sourceTargetPairs.size(); i++) { + int count = 0; + int target = sourceTargetPairs[i]; + for (size_t j = i; j < sourceTargetPairs.size(); j += 2) { + if (sourceTargetPairs[j] == target) { + count++; + } + } + + if (count != 1) { + return false; + } + } + + return true; + }; + + if (!areElementsUnique(sourceTargetPairs)) { + return emitOpError( + "There are duplicate 'src' or 'dest' devices in source target pairs"); + } + + return success(); +} + //===----------------------------------------------------------------------===// // MeshShardOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index 5ac0d679a5..5b30e48de4 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -1437,6 +1437,55 @@ ::mlir::LogicalResult AllReduceOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// CollectivePermuteOp +//===----------------------------------------------------------------------===// + +// CollectivePermuteOp verification +::mlir::LogicalResult CollectivePermuteOp::verify() { + auto sourceTargetPairs = getSourceTargetPairs().getValues(); + + // Check that the rank of sourceTargetPairs is 2D + llvm::ArrayRef sourceTargetPairsShape = + getSourceTargetPairs().getType().getShape(); + const size_t sourceTargetPairsRank = sourceTargetPairsShape.size(); + + if (sourceTargetPairsRank != 2) { + return emitOpError("The rank of source target pairs must be 2, got rank = ") + << sourceTargetPairsRank; + } + + /* Check that the 'src' values and 'dest' values in sourceTargetPairs is + unique Given a 2D rank tensor of source target pairs eg. [['src', 'target'], + ['src', 'target'] ...], we need to ensure that each 'src' is unique and each + 'target' is unique. + */ + auto areElementsUnique = [](const auto &sourceTargetPairs) -> bool { + for (size_t i = 0; i < sourceTargetPairs.size(); i++) { + int count = 0; + int target = sourceTargetPairs[i]; + for (size_t j = i; j < sourceTargetPairs.size(); j += 2) { + if (sourceTargetPairs[j] == target) { + count++; + } + } + + if (count != 1) { + return false; + } + } + + return true; + }; + + if (!areElementsUnique(sourceTargetPairs)) { + return emitOpError( + "There are duplicate 'src' or 'dest' devices in source target pairs"); + } + + return success(); +} + //===----------------------------------------------------------------------===// // MeshShardOp //===----------------------------------------------------------------------===// diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 6ee92ab4ad..b300e060f2 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -796,6 +796,21 @@ createOp(FlatbufferObjectCache &cache, ReduceScatterOp op) { op.getClusterAxis(), op.getNumLinks()); } +::flatbuffers::Offset<::tt::target::ttnn::CollectivePermuteOp> +createOp(FlatbufferObjectCache &cache, CollectivePermuteOp op) { + auto input = cache.at<::tt::target::ttnn::TensorRef>( + getOperandThroughDPSOps(op.getInput())); + auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, + kHostAllocatedSize); + auto device = getOperandThroughDPSOps(op.getDevice()); + auto sourceTargetPairs = op.getSourceTargetPairs().getValues(); + std::vector sourceTargetPairsVec(sourceTargetPairs.begin(), + sourceTargetPairs.end()); + return ::tt::target::ttnn::CreateCollectivePermuteOp( + *cache.fbb, input, output, cache.at<::tt::target::DeviceRef>(device), + cache.fbb->CreateVector(sourceTargetPairsVec)); +} + ::flatbuffers::Offset<::tt::target::ttnn::MeshShardOp> createOp(FlatbufferObjectCache &cache, MeshShardOp op) { auto input = cache.at<::tt::target::ttnn::TensorRef>( @@ -1689,6 +1704,11 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, return createOperation(cache, createOp(cache, reduceScatterOp), debugString, locInfo); } + if (auto collectivePermuteOp = dyn_cast(op); + collectivePermuteOp) { + return createOperation(cache, createOp(cache, collectivePermuteOp), + debugString, locInfo); + } if (auto meshShardOp = dyn_cast(op); meshShardOp) { return createOperation(cache, createOp(cache, meshShardOp), debugString, locInfo); diff --git a/runtime/lib/ttnn/operations/CMakeLists.txt b/runtime/lib/ttnn/operations/CMakeLists.txt index beb54aea32..7d3cf62ef5 100644 --- a/runtime/lib/ttnn/operations/CMakeLists.txt +++ b/runtime/lib/ttnn/operations/CMakeLists.txt @@ -10,6 +10,7 @@ set(TTNN_OPS_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl/reduce_scatter.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ccl/collective_permute.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl/mesh_shard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/creation/arange.cpp diff --git a/runtime/lib/ttnn/operations/ccl/collective_permute.cpp b/runtime/lib/ttnn/operations/ccl/collective_permute.cpp new file mode 100644 index 0000000000..394456302f --- /dev/null +++ b/runtime/lib/ttnn/operations/ccl/collective_permute.cpp @@ -0,0 +1,122 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "operations/ccl/collective_permute.h" +#include "tt/runtime/detail/logger.h" +#include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/debug_apis.h" +#include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" +#include "ttnn/operations/ccl/ccl_host_types.hpp" + +/* +Currently, TTNN does not support collective permute as a first class API. +Nor do they have send/recv point to point communication support. +Therefore, this algorithm uses the host as a fallback to do the mapping. +The collective permute operation takes a list of source_target_pairs that define +how tensor shards currently living in a src device should move to the dest +device. For example, for a 1x2 mesh system, you could have [0, 1], [1, 0] +source_target_pairs list. This indicates that the device shard living in device +0 should move to device 1, and the device shard living in device 1 should move +to device 0. In the situation where you have incomplete devices as the 'dest', +those devices will acquire a device shard with all values set to 0 +*/ +namespace tt::runtime::ttnn::operations::ccl { +void run(const ::tt::target::ttnn::CollectivePermuteOp *op, + ProgramContext &context) { + ProgramTensorPool &tensorPool = context.getTensorPool(); + + const ::ttnn::Tensor &input = tensorPool.getAndValidate(op->in()); + + const auto *fbSourceTargetPairs = op->source_target_pairs(); + std::vector sourceTargetPairs(fbSourceTargetPairs->begin(), + fbSourceTargetPairs->end()); + + LOG_ASSERT(input.storage_type() == ::ttnn::StorageType::MULTI_DEVICE, + "Input of collective_permute must be multidevice storage. id:", + op->in()->global_id()); + + // Get list of individual per device tensors. + std::vector<::ttnn::Tensor> originalDeviceTensors = + ::ttnn::distributed::get_tensors_from_multi_device_storage(input); + + // Iterate through originalDeviceTensors and create mapping of device_id : + // owned_storage. Also store device id to IDevice mapping. + std::unordered_map mappedOwnedStorageTensors; + std::unordered_map mappedDeviceIds; + + for (const auto &tensor : originalDeviceTensors) { + auto *tensorDevice = tensor.device(); + auto deviceId = tensorDevice->id(); + ::ttnn::Tensor hostTensor = ::ttnn::from_device(tensor); + mappedOwnedStorageTensors[deviceId] = hostTensor; + mappedDeviceIds[deviceId] = tensorDevice; + } + + // Iterate through sourceTargetPairs and for each pair, get the source tensor + // from the map and convert to device storage with dest device. + std::vector foundDestDevices(originalDeviceTensors.size(), false); + std::vector<::ttnn::Tensor> newDeviceTensors(originalDeviceTensors.size(), + ::ttnn::Tensor()); + + for (size_t i = 0; i < sourceTargetPairs.size(); i += 2) { + int64_t src = sourceTargetPairs[i]; + int64_t dest = sourceTargetPairs[i + 1]; + + auto srcHostTensorIt = mappedOwnedStorageTensors.find(src); + LOG_ASSERT(srcHostTensorIt != mappedOwnedStorageTensors.end(), + "Could not find device id in owned storage tensor map!"); + auto srcHostTensor = srcHostTensorIt->second; + + auto deviceIt = mappedDeviceIds.find(dest); + LOG_ASSERT(deviceIt != mappedDeviceIds.end(), + "Could not find device id in device map!"); + auto *device = deviceIt->second; + + std::optional<::ttnn::MemoryConfig> memoryConfig = + srcHostTensor.memory_config(); + newDeviceTensors[dest] = + ::ttnn::to_device(srcHostTensor, device, memoryConfig); + foundDestDevices[dest] = true; + } + + // Loop through all the devices that did not participate in the swaping and + // set their tensor device shard values to 0. + for (size_t i = 0; i < foundDestDevices.size(); i++) { + if (foundDestDevices[i]) { + continue; + } + + auto srcHostTensorIt = mappedOwnedStorageTensors.find(i); + LOG_ASSERT(srcHostTensorIt != mappedOwnedStorageTensors.end(), + "Could not find device id in owned storage tensor map!"); + auto srcHostTensor = srcHostTensorIt->second; + + auto deviceIt = mappedDeviceIds.find(i); + LOG_ASSERT(deviceIt != mappedDeviceIds.end(), + "Could not find device id in device map!"); + auto *device = deviceIt->second; + + // We need to memset this tensor value to 0 based on collective permute + // operation semantics + void *dstPtr = ::tt::tt_metal::get_raw_host_data_ptr(srcHostTensor); + size_t size = srcHostTensor.volume() * srcHostTensor.element_size(); + std::memset(dstPtr, 0, size); + + std::optional<::ttnn::MemoryConfig> memoryConfig = + srcHostTensor.memory_config(); + newDeviceTensors[i] = + ::ttnn::to_device(srcHostTensor, device, memoryConfig); + foundDestDevices[i] = true; + } + + // Combine all device tensor shards into a single multi device tensor with + // multi device storage type. + ::ttnn::Tensor out = ::ttnn::distributed::create_multi_device_tensor( + newDeviceTensors, ::ttnn::StorageType::MULTI_DEVICE, + ::ttnn::distributed::get_distributed_tensor_config_from_tensor(input)); + + tensorPool.insertAndValidate(op->out(), out); +} +} // namespace tt::runtime::ttnn::operations::ccl diff --git a/runtime/lib/ttnn/operations/ccl/collective_permute.h b/runtime/lib/ttnn/operations/ccl/collective_permute.h new file mode 100644 index 0000000000..71e159339d --- /dev/null +++ b/runtime/lib/ttnn/operations/ccl/collective_permute.h @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef RUNTIME_LIB_TTNN_OPERATIONS_CCL_COLLECTIVE_PERMUTE_H +#define RUNTIME_LIB_TTNN_OPERATIONS_CCL_COLLECTIVE_PERMUTE_H + +#include "tt/runtime/ttnn/types.h" +#include "ttmlir/Target/TTNN/program_generated.h" + +namespace tt::runtime::ttnn::operations::ccl { +void run(const ::tt::target::ttnn::CollectivePermuteOp *op, + ProgramContext &context); +} // namespace tt::runtime::ttnn::operations::ccl + +#endif diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index ea4b4c59bb..d766cfa2b2 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -2,6 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 #include "operations/ccl/all_gather.h" +#include "operations/ccl/collective_permute.h" #include "operations/ccl/mesh_shard.h" #include "operations/ccl/reduce_scatter.h" #include "operations/context/get_device.h" @@ -268,6 +269,9 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) { case ::tt::target::ttnn::OpType::ReduceScatterOp: { return operations::ccl::run(op->type_as_ReduceScatterOp(), context); } + case ::tt::target::ttnn::OpType::CollectivePermuteOp: { + return operations::ccl::run(op->type_as_CollectivePermuteOp(), context); + } case ::tt::target::ttnn::OpType::MeshShardOp: { return operations::ccl::run(op->type_as_MeshShardOp(), context); } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_gspmd.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_gspmd.mlir index 6abccf1c93..856dfa090e 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_gspmd.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_gspmd.mlir @@ -1847,3 +1847,506 @@ module @jit_negative_basic attributes {mhlo.num_partitions = 2 : i32, mhlo.num_r return %1 : tensor<256x128xf32> } } + +// ----- + +module @jit_collective_permute_1x2_rank_4_cluster_1 attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,1,2]<=[2]}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x256xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + %2 = call @shmap_body(%1) : (tensor<1x1x8192x256xf32>) -> tensor<1x1x8192x256xf32> + %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x256xf32>) -> tensor<1x1x8192x256xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,1,2]<=[2]}"} : (tensor<1x1x8192x256xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + return %4 : tensor<1x1x8192x512xf32> + } + func.func private @shmap_body(%arg0: tensor<1x1x8192x256xf32>) -> (tensor<1x1x8192x256xf32> {jax.result_info = "[None, None, ('batch',), ('pipeline',)]"}) { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 1], [1, 0]]> : tensor<2x2xi64>}> : (tensor<1x1x8192x256xf32>) -> tensor<1x1x8192x256xf32> + // CHECK: "ttir.collective_permute" + // CHECK-SAME: source_target_pairs = dense + // CHECK-SAME: [0, 1], [1, 0] + return %0 : tensor<1x1x8192x256xf32> + } +} + +// ----- + +module @jit_collective_permute_1x2_rank_4_cluster_1_partial_target_pairs attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,1,2]<=[2]}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x256xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + %2 = call @shmap_body(%1) : (tensor<1x1x8192x256xf32>) -> tensor<1x1x8192x256xf32> + %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x256xf32>) -> tensor<1x1x8192x256xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,1,2]<=[2]}"} : (tensor<1x1x8192x256xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + return %4 : tensor<1x1x8192x512xf32> + } + func.func private @shmap_body(%arg0: tensor<1x1x8192x256xf32>) -> (tensor<1x1x8192x256xf32> {jax.result_info = "[None, None, ('batch',), ('pipeline',)]"}) { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 1]]> : tensor<1x2xi64>}> : (tensor<1x1x8192x256xf32>) -> tensor<1x1x8192x256xf32> + // CHECK: "ttir.collective_permute" + // CHECK-SAME: source_target_pairs = dense + // CHECK-SAME: [0, 1] + return %0 : tensor<1x1x8192x256xf32> + } +} + +// ----- + +module @jit_collective_permute_1x2_rank_4_cluster_0 attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,2,1]<=[2]}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x4096x512xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + %2 = call @shmap_body(%1) : (tensor<1x1x4096x512xf32>) -> tensor<1x1x4096x512xf32> + %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x4096x512xf32>) -> tensor<1x1x4096x512xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,2,1]<=[2]}"} : (tensor<1x1x4096x512xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + return %4 : tensor<1x1x8192x512xf32> + } + func.func private @shmap_body(%arg0: tensor<1x1x4096x512xf32>) -> (tensor<1x1x4096x512xf32> {jax.result_info = "[None, None, ('batch',), ('pipeline',)]"}) { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<> : tensor<0x2xi64>}> : (tensor<1x1x4096x512xf32>) -> tensor<1x1x4096x512xf32> + // CHECK: "ttir.collective_permute" + // CHECK-SAME: source_target_pairs = dense + return %0 : tensor<1x1x4096x512xf32> + } +} + +// ----- + +module @jit_collective_permute_1x32_rank_4_cluster_1 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,1,32]<=[32]}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x16xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + %2 = call @shmap_body(%1) : (tensor<1x1x8192x16xf32>) -> tensor<1x1x8192x16xf32> + %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x16xf32>) -> tensor<1x1x8192x16xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,1,32]<=[32]}"} : (tensor<1x1x8192x16xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + return %4 : tensor<1x1x8192x512xf32> + } + func.func private @shmap_body(%arg0: tensor<1x1x8192x16xf32>) -> (tensor<1x1x8192x16xf32> {jax.result_info = "[None, None, ('batch',), ('pipeline',)]"}) { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13], [13, 14], [14, 15], [15, 16], [16, 17], [17, 18], [18, 19], [19, 20], [20, 21], [21, 22], [22, 23], [23, 24], [24, 25], [25, 26], [26, 27], [27, 28], [28, 29], [29, 30], [30, 31], [31, 0]]> : tensor<32x2xi64>}> : (tensor<1x1x8192x16xf32>) -> tensor<1x1x8192x16xf32> + // CHECK: "ttir.collective_permute" + // CHECK-SAME: source_target_pairs = dense + // CHECK-SAME: [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13], [13, 14], [14, 15], [15, 16], [16, 17], [17, 18], [18, 19], [19, 20], [20, 21], [21, 22], [22, 23], [23, 24], [24, 25], [25, 26], [26, 27], [27, 28], [28, 29], [29, 30], [30, 31], [31, 0] + return %0 : tensor<1x1x8192x16xf32> + } +} + +// ----- + +module @jit_collective_permute_1x32_rank_4_cluster_1_partial_target_pairs attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,1,32]<=[32]}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x16xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + %2 = call @shmap_body(%1) : (tensor<1x1x8192x16xf32>) -> tensor<1x1x8192x16xf32> + %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x16xf32>) -> tensor<1x1x8192x16xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,1,32]<=[32]}"} : (tensor<1x1x8192x16xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + return %4 : tensor<1x1x8192x512xf32> + } + func.func private @shmap_body(%arg0: tensor<1x1x8192x16xf32>) -> (tensor<1x1x8192x16xf32> {jax.result_info = "[None, None, ('batch',), ('pipeline',)]"}) { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15], [16, 17], [18, 19], [20, 21], [22, 23], [24, 25], [26, 27], [28, 29], [30, 31]]> : tensor<16x2xi64>}> : (tensor<1x1x8192x16xf32>) -> tensor<1x1x8192x16xf32> + // CHECK: "ttir.collective_permute" + // CHECK-SAME: source_target_pairs = dense + // CHECK-SAME: [0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15], [16, 17], [18, 19], [20, 21], [22, 23], [24, 25], [26, 27], [28, 29], [30, 31] + return %0 : tensor<1x1x8192x16xf32> + } +} + +// ----- + +module @jit_collective_permute_1x32_rank_4_cluster_0 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,32,1]<=[32]}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x256x512xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + %2 = call @shmap_body(%1) : (tensor<1x1x256x512xf32>) -> tensor<1x1x256x512xf32> + %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x256x512xf32>) -> tensor<1x1x256x512xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,32,1]<=[32]}"} : (tensor<1x1x256x512xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + return %4 : tensor<1x1x8192x512xf32> + } + func.func private @shmap_body(%arg0: tensor<1x1x256x512xf32>) -> (tensor<1x1x256x512xf32> {jax.result_info = "[None, None, ('batch',), ('pipeline',)]"}) { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<> : tensor<0x2xi64>}> : (tensor<1x1x256x512xf32>) -> tensor<1x1x256x512xf32> + // CHECK: "ttir.collective_permute" + // CHECK-SAME: source_target_pairs = dense + return %0 : tensor<1x1x256x512xf32> + } +} + +// ----- + +module @jit_collective_permute_1x8_rank_4_cluster_1 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,1,8]<=[8]}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x64xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + %2 = call @shmap_body(%1) : (tensor<1x1x8192x64xf32>) -> tensor<1x1x8192x64xf32> + %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x64xf32>) -> tensor<1x1x8192x64xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,1,8]<=[8]}"} : (tensor<1x1x8192x64xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + return %4 : tensor<1x1x8192x512xf32> + } + func.func private @shmap_body(%arg0: tensor<1x1x8192x64xf32>) -> (tensor<1x1x8192x64xf32> {jax.result_info = "[None, None, ('batch',), ('pipeline',)]"}) { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 0]]> : tensor<8x2xi64>}> : (tensor<1x1x8192x64xf32>) -> tensor<1x1x8192x64xf32> + // CHECK: "ttir.collective_permute" + // CHECK-SAME: source_target_pairs = dense + // CHECK-SAME: [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 0] + return %0 : tensor<1x1x8192x64xf32> + } +} + +// ----- + +module @jit_collective_permute_1x8_rank_4_cluster_1_partial_target_pairs attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,1,8]<=[8]}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x64xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + %2 = call @shmap_body(%1) : (tensor<1x1x8192x64xf32>) -> tensor<1x1x8192x64xf32> + %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x64xf32>) -> tensor<1x1x8192x64xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,1,8]<=[8]}"} : (tensor<1x1x8192x64xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + return %4 : tensor<1x1x8192x512xf32> + } + func.func private @shmap_body(%arg0: tensor<1x1x8192x64xf32>) -> (tensor<1x1x8192x64xf32> {jax.result_info = "[None, None, ('batch',), ('pipeline',)]"}) { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xi64>}> : (tensor<1x1x8192x64xf32>) -> tensor<1x1x8192x64xf32> + // CHECK: "ttir.collective_permute" + // CHECK-SAME: source_target_pairs = dense + // CHECK-SAME: [0, 1], [2, 3], [4, 5], [6, 7] + return %0 : tensor<1x1x8192x64xf32> + } +} + +// ----- + +module @jit_collective_permute_1x8_rank_4_cluster_0 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,8,1]<=[8]}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x1024x512xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + %2 = call @shmap_body(%1) : (tensor<1x1x1024x512xf32>) -> tensor<1x1x1024x512xf32> + %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x1024x512xf32>) -> tensor<1x1x1024x512xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,8,1]<=[8]}"} : (tensor<1x1x1024x512xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + return %4 : tensor<1x1x8192x512xf32> + } + func.func private @shmap_body(%arg0: tensor<1x1x1024x512xf32>) -> (tensor<1x1x1024x512xf32> {jax.result_info = "[None, None, ('batch',), ('pipeline',)]"}) { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<> : tensor<0x2xi64>}> : (tensor<1x1x1024x512xf32>) -> tensor<1x1x1024x512xf32> + // CHECK: "ttir.collective_permute" + // CHECK-SAME: source_target_pairs = dense + return %0 : tensor<1x1x1024x512xf32> + } +} + +// ----- + +module @jit_collective_permute_2x4_rank_4_cluster_1 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,2,4]<=[8]}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x4096x128xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + %2 = call @shmap_body(%1) : (tensor<1x1x4096x128xf32>) -> tensor<1x1x4096x128xf32> + %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x4096x128xf32>) -> tensor<1x1x4096x128xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,2,4]<=[8]}"} : (tensor<1x1x4096x128xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + return %4 : tensor<1x1x8192x512xf32> + } + func.func private @shmap_body(%arg0: tensor<1x1x4096x128xf32>) -> (tensor<1x1x4096x128xf32> {jax.result_info = "[None, None, ('batch',), ('pipeline',)]"}) { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], [7, 4]]> : tensor<8x2xi64>}> : (tensor<1x1x4096x128xf32>) -> tensor<1x1x4096x128xf32> + // CHECK: "ttir.collective_permute" + // CHECK-SAME: source_target_pairs = dense + // CHECK-SAME: [0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], [7, 4] + return %0 : tensor<1x1x4096x128xf32> + } +} + +// ----- + +module @jit_collective_permute_2x4_rank_4_cluster_1_partial_target_pairs attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,2,4]<=[8]}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x4096x128xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + %2 = call @shmap_body(%1) : (tensor<1x1x4096x128xf32>) -> tensor<1x1x4096x128xf32> + %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x4096x128xf32>) -> tensor<1x1x4096x128xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,2,4]<=[8]}"} : (tensor<1x1x4096x128xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + return %4 : tensor<1x1x8192x512xf32> + } + func.func private @shmap_body(%arg0: tensor<1x1x4096x128xf32>) -> (tensor<1x1x4096x128xf32> {jax.result_info = "[None, None, ('batch',), ('pipeline',)]"}) { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xi64>}> : (tensor<1x1x4096x128xf32>) -> tensor<1x1x4096x128xf32> + // CHECK: "ttir.collective_permute" + // CHECK-SAME: source_target_pairs = dense + // CHECK-SAME: [0, 1], [2, 3], [4, 5], [6, 7] + return %0 : tensor<1x1x4096x128xf32> + } +} + +// ----- + +module @jit_collective_permute_2x4_rank_4_cluster_0 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,4,2]<=[2,4]T(1,0)}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x2048x256xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + %2 = call @shmap_body(%1) : (tensor<1x1x2048x256xf32>) -> tensor<1x1x2048x256xf32> + %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x2048x256xf32>) -> tensor<1x1x2048x256xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,4,2]<=[2,4]T(1,0)}"} : (tensor<1x1x2048x256xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + return %4 : tensor<1x1x8192x512xf32> + } + func.func private @shmap_body(%arg0: tensor<1x1x2048x256xf32>) -> (tensor<1x1x2048x256xf32> {jax.result_info = "[None, None, ('batch',), ('pipeline',)]"}) { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<1x1x2048x256xf32>) -> tensor<1x1x2048x256xf32> + // CHECK: "ttir.collective_permute" + // CHECK-SAME: source_target_pairs = dense + return %0 : tensor<1x1x2048x256xf32> + } +} + +// ----- + +module @jit_collective_permute_2x4_rank_4_cluster_0_partial_target_pairs attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,4,2]<=[2,4]T(1,0)}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x2048x256xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + %2 = call @shmap_body(%1) : (tensor<1x1x2048x256xf32>) -> tensor<1x1x2048x256xf32> + %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x2048x256xf32>) -> tensor<1x1x2048x256xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,4,2]<=[2,4]T(1,0)}"} : (tensor<1x1x2048x256xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + return %4 : tensor<1x1x8192x512xf32> + } + func.func private @shmap_body(%arg0: tensor<1x1x2048x256xf32>) -> (tensor<1x1x2048x256xf32> {jax.result_info = "[None, None, ('batch',), ('pipeline',)]"}) { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>}> : (tensor<1x1x2048x256xf32>) -> tensor<1x1x2048x256xf32> + // CHECK: "ttir.collective_permute" + // CHECK-SAME: source_target_pairs = dense + return %0 : tensor<1x1x2048x256xf32> + } +} + +// ----- + +module @jit_collective_permute_8x4_rank_4_cluster_1 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,8,4]<=[32]}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x1024x128xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + %2 = call @shmap_body(%1) : (tensor<1x1x1024x128xf32>) -> tensor<1x1x1024x128xf32> + %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x1024x128xf32>) -> tensor<1x1x1024x128xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,8,4]<=[32]}"} : (tensor<1x1x1024x128xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + return %4 : tensor<1x1x8192x512xf32> + } + func.func private @shmap_body(%arg0: tensor<1x1x1024x128xf32>) -> (tensor<1x1x1024x128xf32> {jax.result_info = "[None, None, ('batch',), ('pipeline',)]"}) { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], [7, 4], [8, 9], [9, 10], [10, 11], [11, 8], [12, 13], [13, 14], [14, 15], [15, 12], [16, 17], [17, 18], [18, 19], [19, 16], [20, 21], [21, 22], [22, 23], [23, 20], [24, 25], [25, 26], [26, 27], [27, 24], [28, 29], [29, 30], [30, 31], [31, 28]]> : tensor<32x2xi64>}> : (tensor<1x1x1024x128xf32>) -> tensor<1x1x1024x128xf32> + // CHECK: "ttir.collective_permute" + // CHECK-SAME: source_target_pairs = dense + // CHECK-SAME: [0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], [7, 4], [8, 9], [9, 10], [10, 11], [11, 8], [12, 13], [13, 14], [14, 15], [15, 12], [16, 17], [17, 18], [18, 19], [19, 16], [20, 21], [21, 22], [22, 23], [23, 20], [24, 25], [25, 26], [26, 27], [27, 24], [28, 29], [29, 30], [30, 31], [31, 28] + return %0 : tensor<1x1x1024x128xf32> + } +} + +// ----- + +module @jit_collective_permute_8x4_rank_4_cluster_1_partial_target_pairs attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,8,4]<=[32]}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x1024x128xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + %2 = call @shmap_body(%1) : (tensor<1x1x1024x128xf32>) -> tensor<1x1x1024x128xf32> + %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x1024x128xf32>) -> tensor<1x1x1024x128xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,8,4]<=[32]}"} : (tensor<1x1x1024x128xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + return %4 : tensor<1x1x8192x512xf32> + } + func.func private @shmap_body(%arg0: tensor<1x1x1024x128xf32>) -> (tensor<1x1x1024x128xf32> {jax.result_info = "[None, None, ('batch',), ('pipeline',)]"}) { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15], [16, 17], [18, 19], [20, 21], [22, 23], [24, 25], [26, 27], [28, 29], [30, 31]]> : tensor<16x2xi64>}> : (tensor<1x1x1024x128xf32>) -> tensor<1x1x1024x128xf32> + // CHECK: "ttir.collective_permute" + // CHECK-SAME: source_target_pairs = dense + // CHECK-SAME: [0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15], [16, 17], [18, 19], [20, 21], [22, 23], [24, 25], [26, 27], [28, 29], [30, 31] + return %0 : tensor<1x1x1024x128xf32> + } +} + +// ----- + +module @jit_collective_permute_8x4_rank_4_cluster_0 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,4,8]<=[8,4]T(1,0)}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x2048x64xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + %2 = call @shmap_body(%1) : (tensor<1x1x2048x64xf32>) -> tensor<1x1x2048x64xf32> + %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x2048x64xf32>) -> tensor<1x1x2048x64xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,4,8]<=[8,4]T(1,0)}"} : (tensor<1x1x2048x64xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + return %4 : tensor<1x1x8192x512xf32> + } + func.func private @shmap_body(%arg0: tensor<1x1x2048x64xf32>) -> (tensor<1x1x2048x64xf32> {jax.result_info = "[None, None, ('batch',), ('pipeline',)]"}) { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 8], [8, 12], [12, 16], [16, 20], [20, 24], [24, 28], [28, 0], [1, 5], [5, 9], [9, 13], [13, 17], [17, 21], [21, 25], [25, 29], [29, 1], [2, 6], [6, 10], [10, 14], [14, 18], [18, 22], [22, 26], [26, 30], [30, 2], [3, 7], [7, 11], [11, 15], [15, 19], [19, 23], [23, 27], [27, 31], [31, 3]]> : tensor<32x2xi64>}> : (tensor<1x1x2048x64xf32>) -> tensor<1x1x2048x64xf32> + // CHECK: "ttir.collective_permute" + // CHECK-SAME: source_target_pairs = dense + return %0 : tensor<1x1x2048x64xf32> + } +} + +// ----- + +module @jit_collective_permute_8x4_rank_4_cluster_0_partial_target_pairs attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,4,8]<=[8,4]T(1,0)}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x512xf32>) -> tensor<1x1x2048x64xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + %2 = call @shmap_body(%1) : (tensor<1x1x2048x64xf32>) -> tensor<1x1x2048x64xf32> + %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x2048x64xf32>) -> tensor<1x1x2048x64xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,4,8]<=[8,4]T(1,0)}"} : (tensor<1x1x2048x64xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttir.mesh_shard" + // CHECK-SAME: shard_dims = array + // CHECK-SAME: shard_direction = #tt.shard_direction + // CHECK-SAME: shard_shape = array + // CHECK-SAME: shard_type = #tt.shard_type + return %4 : tensor<1x1x8192x512xf32> + } + func.func private @shmap_body(%arg0: tensor<1x1x2048x64xf32>) -> (tensor<1x1x2048x64xf32> {jax.result_info = "[None, None, ('batch',), ('pipeline',)]"}) { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [8, 12], [16, 20], [24, 28], [1, 5], [9, 13], [17, 21], [25, 29], [2, 6], [10, 14], [18, 22], [26, 30], [3, 7], [11, 15], [19, 23], [27, 31]]> : tensor<16x2xi64>}> : (tensor<1x1x2048x64xf32>) -> tensor<1x1x2048x64xf32> + // CHECK: "ttir.collective_permute" + // CHECK-SAME: source_target_pairs = dense + return %0 : tensor<1x1x2048x64xf32> + } +} diff --git a/test/ttmlir/Dialect/TTIR/ccl/collective_permute/collective_permute_negative.mlir b/test/ttmlir/Dialect/TTIR/ccl/collective_permute/collective_permute_negative.mlir new file mode 100644 index 0000000000..c3420ffa55 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/ccl/collective_permute/collective_permute_negative.mlir @@ -0,0 +1,36 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Unit tests for ttir collective_permute op + +// ----- + +module attributes {} { + func.func public @collective_permute_invalid_source_target_pair_rank(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1x1x8192x512xf32> + %1 = "ttir.collective_permute"(%arg0, %0) <{source_target_pairs = dense<[0]> : tensor<1xi64>}> : (tensor<1x1x8192x512xf32>, tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + return %1 : tensor<1x1x8192x512xf32> + } +} +// CHECK: error: 'ttir.collective_permute' op The rank of source target pairs must be 2, got rank = 1 + +// ----- + +module attributes {} { + func.func public @collective_permute_invalid_duplicate_sources(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1x1x8192x512xf32> + %1 = "ttir.collective_permute"(%arg0, %0) <{source_target_pairs = dense<[[0, 1], [0, 2]]> : tensor<2x2xi64>}> : (tensor<1x1x8192x512xf32>, tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + return %1 : tensor<1x1x8192x512xf32> + } +} +// CHECK: error: 'ttir.collective_permute' op There are duplicate 'src' or 'dest' devices in source target pairs + + +// ----- + +module attributes {} { + func.func public @collective_permute_invalid_duplicate_targets(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1x1x8192x512xf32> + %1 = "ttir.collective_permute"(%arg0, %0) <{source_target_pairs = dense<[[0, 2], [1, 2]]> : tensor<2x2xi64>}> : (tensor<1x1x8192x512xf32>, tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + return %1 : tensor<1x1x8192x512xf32> + } +} +// CHECK: error: 'ttir.collective_permute' op There are duplicate 'src' or 'dest' devices in source target pairs diff --git a/test/ttmlir/Dialect/TTNN/ccl/collective_permute/collective_permute_negative.mlir b/test/ttmlir/Dialect/TTNN/ccl/collective_permute/collective_permute_negative.mlir new file mode 100644 index 0000000000..392f7ffb55 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/ccl/collective_permute/collective_permute_negative.mlir @@ -0,0 +1,44 @@ +// RUN: not ttmlir-opt --split-input-file --ttir-load-system-desc="path=%system_desc_path%" --ttir-implicit-device="force-reload=true" %s 2>&1 | FileCheck %s +// Unit tests for ttnn collective_permute op + +// ----- + +#device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> +#dram = #ttnn.buffer_type +#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<128x512x!tt.tile<32x32, f32>, #dram>, > +module @collective_permute_invalid_source_target_pair_rank attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32, tt.device = #device} { + func.func public @main(%arg0: tensor<4096x16384xf32, #ttnn_layout1>) -> (tensor<4096x16384xf32, #ttnn_layout1> {}) { + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + %1 = "ttnn.collective_permute"(%arg0, %0) <{source_target_pairs = dense<[0]> : tensor<1xi64>}> : (tensor<4096x16384xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<4096x16384xf32, #ttnn_layout1> + return %1 : tensor<4096x16384xf32, #ttnn_layout1> + } +} +// CHECK: error: 'ttnn.collective_permute' op The rank of source target pairs must be 2, got rank = 1 + +// ----- + +#device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> +#dram = #ttnn.buffer_type +#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<128x512x!tt.tile<32x32, f32>, #dram>, > +module @collective_permute_invalid_source_target_pair_rank attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32, tt.device = #device} { + func.func public @main(%arg0: tensor<4096x16384xf32, #ttnn_layout1>) -> (tensor<4096x16384xf32, #ttnn_layout1> {}) { + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + %1 = "ttnn.collective_permute"(%arg0, %0) <{source_target_pairs = dense<[[0, 1], [0, 2]]> : tensor<2x2xi64>}> : (tensor<4096x16384xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<4096x16384xf32, #ttnn_layout1> + return %1 : tensor<4096x16384xf32, #ttnn_layout1> + } +} +// CHECK: error: 'ttnn.collective_permute' op There are duplicate 'src' or 'dest' devices in source target pairs + +// ----- + +#device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> +#dram = #ttnn.buffer_type +#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<128x512x!tt.tile<32x32, f32>, #dram>, > +module @collective_permute_invalid_source_target_pair_rank attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32, tt.device = #device} { + func.func public @main(%arg0: tensor<4096x16384xf32, #ttnn_layout1>) -> (tensor<4096x16384xf32, #ttnn_layout1> {}) { + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + %1 = "ttnn.collective_permute"(%arg0, %0) <{source_target_pairs = dense<[[0, 2], [1, 2]]> : tensor<2x2xi64>}> : (tensor<4096x16384xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<4096x16384xf32, #ttnn_layout1> + return %1 : tensor<4096x16384xf32, #ttnn_layout1> + } +} +// CHECK: error: 'ttnn.collective_permute' op There are duplicate 'src' or 'dest' devices in source target pairs diff --git a/test/ttmlir/Dialect/TTNN/ccl/collective_permute/collective_permute_positive.mlir b/test/ttmlir/Dialect/TTNN/ccl/collective_permute/collective_permute_positive.mlir new file mode 100644 index 0000000000..92772a7248 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/ccl/collective_permute/collective_permute_positive.mlir @@ -0,0 +1,14 @@ +// RUN: ttmlir-opt --split-input-file --ttir-to-ttnn-backend-pipeline="mesh-shape=1,1" %s | FileCheck %s +// Unit tests for ttnn collective_permute op + +// ----- + +// Verify lowering of ttir collective_permute to ttnn ops +module attributes {} { + func.func public @collective_permute_invalid_duplicate_sources(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1x1x8192x512xf32> + %1 = "ttir.collective_permute"(%arg0, %0) <{source_target_pairs = dense<[[0, 1], [1, 0]]> : tensor<2x2xi64>}> : (tensor<1x1x8192x512xf32>, tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + return %1 : tensor<1x1x8192x512xf32> + } +} +// CHECK: "ttnn.collective_permute" diff --git a/test/ttmlir/Silicon/TTNN/llmbox/ccl/ccl_1x8.mlir b/test/ttmlir/Silicon/TTNN/llmbox/ccl/ccl_1x8.mlir index 252f548b37..f2b0742b5c 100644 --- a/test/ttmlir/Silicon/TTNN/llmbox/ccl/ccl_1x8.mlir +++ b/test/ttmlir/Silicon/TTNN/llmbox/ccl/ccl_1x8.mlir @@ -41,3 +41,29 @@ func.func @reduce_scatter_cluster1(%arg0: tensor<1x1x8192x2048xf32>) -> (tensor< // CHECK: "ttnn.mesh_shard" return %5 : tensor<1x1x8192x256xf32> } + +func.func public @collective_permute_cluster_1(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1x1x8192x64xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x8192x512xf32>, tensor<1x1x8192x64xf32>) -> tensor<1x1x8192x64xf32> + // CHECK: "ttnn.mesh_shard" + %2 = tensor.empty() : tensor<1x1x8192x64xf32> + %3 = "ttir.collective_permute"(%1, %2) <{source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 0]]> : tensor<8x2xi64>}> : (tensor<1x1x8192x64xf32>, tensor<1x1x8192x64xf32>) -> tensor<1x1x8192x64xf32> + // CHECK: "ttnn.collective_permute" + %4 = tensor.empty() : tensor<1x1x8192x512xf32> + %5 = "ttir.mesh_shard"(%3, %4) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x8192x64xf32>, tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttnn.mesh_shard" + return %5 : tensor<1x1x8192x512xf32> +} + +func.func public @collective_permute_cluster_1_partial_target_pairs(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1x1x8192x64xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x8192x512xf32>, tensor<1x1x8192x64xf32>) -> tensor<1x1x8192x64xf32> + // CHECK: "ttnn.mesh_shard" + %2 = tensor.empty() : tensor<1x1x8192x64xf32> + %3 = "ttir.collective_permute"(%1, %2) <{source_target_pairs = dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xi64>}> : (tensor<1x1x8192x64xf32>, tensor<1x1x8192x64xf32>) -> tensor<1x1x8192x64xf32> + // CHECK: "ttnn.collective_permute" + %4 = tensor.empty() : tensor<1x1x8192x512xf32> + %5 = "ttir.mesh_shard"(%3, %4) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x8192x64xf32>, tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttnn.mesh_shard" + return %5 : tensor<1x1x8192x512xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/llmbox/ccl/ccl_2x4.mlir b/test/ttmlir/Silicon/TTNN/llmbox/ccl/ccl_2x4.mlir index af92d07145..b6acda2053 100644 --- a/test/ttmlir/Silicon/TTNN/llmbox/ccl/ccl_2x4.mlir +++ b/test/ttmlir/Silicon/TTNN/llmbox/ccl/ccl_2x4.mlir @@ -81,3 +81,55 @@ func.func @reduce_scatter_cluster1(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1 // CHECK: "ttnn.mesh_shard" return %5 : tensor<1x1x8192x128xf32> } + +func.func public @collective_permute_cluster_1(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1x1x4096x128xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x8192x512xf32>, tensor<1x1x4096x128xf32>) -> tensor<1x1x4096x128xf32> + // CHECK: "ttnn.mesh_shard" + %2 = tensor.empty() : tensor<1x1x4096x128xf32> + %3 = "ttir.collective_permute"(%1, %2) <{source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], [7, 4]]> : tensor<8x2xi64>}> : (tensor<1x1x4096x128xf32>, tensor<1x1x4096x128xf32>) -> tensor<1x1x4096x128xf32> + // CHECK: "ttnn.collective_permute" + %4 = tensor.empty() : tensor<1x1x8192x512xf32> + %5 = "ttir.mesh_shard"(%3, %4) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x4096x128xf32>, tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttnn.mesh_shard" + return %5 : tensor<1x1x8192x512xf32> +} + +func.func public @collective_permute_cluster_1_partial_target_pairs(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1x1x4096x128xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x8192x512xf32>, tensor<1x1x4096x128xf32>) -> tensor<1x1x4096x128xf32> + // CHECK: "ttnn.mesh_shard" + %2 = tensor.empty() : tensor<1x1x4096x128xf32> + %3 = "ttir.collective_permute"(%1, %2) <{source_target_pairs = dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xi64>}> : (tensor<1x1x4096x128xf32>, tensor<1x1x4096x128xf32>) -> tensor<1x1x4096x128xf32> + // CHECK: "ttnn.collective_permute" + %4 = tensor.empty() : tensor<1x1x8192x512xf32> + %5 = "ttir.mesh_shard"(%3, %4) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x4096x128xf32>, tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttnn.mesh_shard" + return %5 : tensor<1x1x8192x512xf32> +} + +func.func public @collective_permute_cluster_0(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1x1x2048x256xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x8192x512xf32>, tensor<1x1x2048x256xf32>) -> tensor<1x1x2048x256xf32> + // CHECK: "ttnn.mesh_shard" + %2 = tensor.empty() : tensor<1x1x2048x256xf32> + %3 = "ttir.collective_permute"(%1, %2) <{source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<1x1x2048x256xf32>, tensor<1x1x2048x256xf32>) -> tensor<1x1x2048x256xf32> + // CHECK: "ttnn.collective_permute" + %4 = tensor.empty() : tensor<1x1x8192x512xf32> + %5 = "ttir.mesh_shard"(%3, %4) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x2048x256xf32>, tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttnn.mesh_shard" + return %5 : tensor<1x1x8192x512xf32> +} + +func.func public @collective_permute_cluster_0_partial_target_pairs(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1x1x2048x256xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x8192x512xf32>, tensor<1x1x2048x256xf32>) -> tensor<1x1x2048x256xf32> + // CHECK: "ttnn.mesh_shard" + %2 = tensor.empty() : tensor<1x1x2048x256xf32> + %3 = "ttir.collective_permute"(%1, %2) <{source_target_pairs = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>}> : (tensor<1x1x2048x256xf32>, tensor<1x1x2048x256xf32>) -> tensor<1x1x2048x256xf32> + // CHECK: "ttnn.collective_permute" + %4 = tensor.empty() : tensor<1x1x8192x512xf32> + %5 = "ttir.mesh_shard"(%3, %4) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x2048x256xf32>, tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttnn.mesh_shard" + return %5 : tensor<1x1x8192x512xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/n300/ccl/ccl_1x2.mlir b/test/ttmlir/Silicon/TTNN/n300/ccl/ccl_1x2.mlir index d59105bd9c..7dc49c1989 100644 --- a/test/ttmlir/Silicon/TTNN/n300/ccl/ccl_1x2.mlir +++ b/test/ttmlir/Silicon/TTNN/n300/ccl/ccl_1x2.mlir @@ -42,6 +42,32 @@ func.func @reduce_scatter_cluster1(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1 return %5 : tensor<1x1x8192x256xf32> } +func.func @collective_permute_cluster_1(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1x1x8192x256xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x8192x512xf32>, tensor<1x1x8192x256xf32>) -> tensor<1x1x8192x256xf32> + // CHECK: "ttnn.mesh_shard" + %2 = tensor.empty() : tensor<1x1x8192x256xf32> + %3 = "ttir.collective_permute"(%1, %2) <{source_target_pairs = dense<[[0, 1], [1, 0]]> : tensor<2x2xi64>}> : (tensor<1x1x8192x256xf32>, tensor<1x1x8192x256xf32>) -> tensor<1x1x8192x256xf32> + // CHECK: "ttnn.collective_permute" + %4 = tensor.empty() : tensor<1x1x8192x512xf32> + %5 = "ttir.mesh_shard"(%3, %4) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x8192x256xf32>, tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttnn.mesh_shard" + return %5 : tensor<1x1x8192x512xf32> +} + +func.func public @collective_permute_cluster_1_partial_target_pairs(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1x1x8192x256xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x8192x512xf32>, tensor<1x1x8192x256xf32>) -> tensor<1x1x8192x256xf32> + // CHECK: "ttnn.mesh_shard" + %2 = tensor.empty() : tensor<1x1x8192x256xf32> + %3 = "ttir.collective_permute"(%1, %2) <{source_target_pairs = dense<[[0, 1]]> : tensor<1x2xi64>}> : (tensor<1x1x8192x256xf32>, tensor<1x1x8192x256xf32>) -> tensor<1x1x8192x256xf32> + // CHECK: "ttnn.collective_permute" + %4 = tensor.empty() : tensor<1x1x8192x512xf32> + %5 = "ttir.mesh_shard"(%3, %4) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x8192x256xf32>, tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttnn.mesh_shard" + return %5 : tensor<1x1x8192x512xf32> +} + func.func public @main(%arg0: tensor<8192x784xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<784x16384xf32> {mhlo.layout_mode = "default"}) -> (tensor<8192x16384xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { %0 = tensor.empty() : tensor<8192x392xf32> %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<8192x784xf32>, tensor<8192x392xf32>) -> tensor<8192x392xf32> diff --git a/test/ttmlir/Silicon/TTNN/tg/ccl/ccl_1x32.mlir b/test/ttmlir/Silicon/TTNN/tg/ccl/ccl_1x32.mlir index 24bd989e26..d813109e2f 100644 --- a/test/ttmlir/Silicon/TTNN/tg/ccl/ccl_1x32.mlir +++ b/test/ttmlir/Silicon/TTNN/tg/ccl/ccl_1x32.mlir @@ -38,3 +38,29 @@ func.func @reduce_scatter_cluster1(%arg0: tensor<1x1x512x8192xf32>) -> (tensor<1 // CHECK: "ttnn.mesh_shard" return %5 : tensor<1x1x512x256xf32> } + +func.func public @collective_permute_cluster_1(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1x1x8192x16xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x8192x512xf32>, tensor<1x1x8192x16xf32>) -> tensor<1x1x8192x16xf32> + // CHECK: "ttnn.mesh_shard" + %2 = tensor.empty() : tensor<1x1x8192x16xf32> + %3 = "ttir.collective_permute"(%1, %2) <{source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13], [13, 14], [14, 15], [15, 16], [16, 17], [17, 18], [18, 19], [19, 20], [20, 21], [21, 22], [22, 23], [23, 24], [24, 25], [25, 26], [26, 27], [27, 28], [28, 29], [29, 30], [30, 31], [31, 0]]> : tensor<32x2xi64>}> : (tensor<1x1x8192x16xf32>, tensor<1x1x8192x16xf32>) -> tensor<1x1x8192x16xf32> + // CHECK: "ttnn.collective_permute" + %4 = tensor.empty() : tensor<1x1x8192x512xf32> + %5 = "ttir.mesh_shard"(%3, %4) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x8192x16xf32>, tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttnn.mesh_shard" + return %5 : tensor<1x1x8192x512xf32> +} + +func.func public @collective_permute_cluster_1_partial_target_pairs(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1x1x8192x16xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x8192x512xf32>, tensor<1x1x8192x16xf32>) -> tensor<1x1x8192x16xf32> + // CHECK: "ttnn.mesh_shard" + %2 = tensor.empty() : tensor<1x1x8192x16xf32> + %3 = "ttir.collective_permute"(%1, %2) <{source_target_pairs = dense<[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15], [16, 17], [18, 19], [20, 21], [22, 23], [24, 25], [26, 27], [28, 29], [30, 31]]> : tensor<16x2xi64>}> : (tensor<1x1x8192x16xf32>, tensor<1x1x8192x16xf32>) -> tensor<1x1x8192x16xf32> + // CHECK: "ttnn.collective_permute" + %4 = tensor.empty() : tensor<1x1x8192x512xf32> + %5 = "ttir.mesh_shard"(%3, %4) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x8192x16xf32>, tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttnn.mesh_shard" + return %5 : tensor<1x1x8192x512xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/tg/ccl/ccl_8x4.mlir b/test/ttmlir/Silicon/TTNN/tg/ccl/ccl_8x4.mlir index b8618cf083..97eeaac849 100644 --- a/test/ttmlir/Silicon/TTNN/tg/ccl/ccl_8x4.mlir +++ b/test/ttmlir/Silicon/TTNN/tg/ccl/ccl_8x4.mlir @@ -82,29 +82,54 @@ func.func @reduce_scatter_cluster1(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1 return %5 : tensor<1x1x8192x128xf32> } -func.func public @jit_data_tensor_parallel_tg(%arg0: tensor<64x1x1024x2048xf32>, %arg1: tensor<1x1x2048x512xf32>) -> (tensor<64x1x1024x512xf32> {jax.result_info = ""}) { - %0 = tensor.empty() : tensor<8x1x1024x512xf32> - %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<64x1x1024x2048xf32>, tensor<8x1x1024x512xf32>) -> tensor<8x1x1024x512xf32> - // CHECK: "ttnn.mesh_shard" - %2 = tensor.empty() : tensor<1x1x512x512xf32> - %3 = "ttir.mesh_shard"(%arg1, %2) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x2048x512xf32>, tensor<1x1x512x512xf32>) -> tensor<1x1x512x512xf32> - // CHECK: "ttnn.mesh_shard" - %4 = tensor.empty() : tensor<8x1024x512xf32> - %5 = "ttir.reshape"(%1, %4) <{shape = [8 : i32, 1024 : i32, 512 : i32]}> : (tensor<8x1x1024x512xf32>, tensor<8x1024x512xf32>) -> tensor<8x1024x512xf32> - // CHECK: = "ttnn.reshape" - %6 = tensor.empty() : tensor<1x512x512xf32> - %7 = "ttir.reshape"(%3, %6) <{shape = [1 : i32, 512 : i32, 512 : i32]}> : (tensor<1x1x512x512xf32>, tensor<1x512x512xf32>) -> tensor<1x512x512xf32> - // CHECK: = "ttnn.reshape" - %8 = "ttir.dot_general"(%5, %7) <{batch_dims_lhs = array, batch_dims_rhs = array, contract_dims_lhs = array, contract_dims_rhs = array}> : (tensor<8x1024x512xf32>, tensor<1x512x512xf32>) -> tensor<8x1024x1x512xf32> - // CHECK: "ttir.matmul" - %9 = tensor.empty() : tensor<8x1x1024x512xf32> - %10 = "ttir.permute"(%8, %9) <{permutation = array}> : (tensor<8x1024x1x512xf32>, tensor<8x1x1024x512xf32>) -> tensor<8x1x1024x512xf32> - // CHECK: "ttnn.permute" - %11 = tensor.empty() : tensor<8x1x256x512xf32> - %12 = "ttir.reduce_scatter"(%10, %11) <{cluster_axis = 1 : ui32, reduce_type = #tt.reduce_type, scatter_dim = 2 : si32}> : (tensor<8x1x1024x512xf32>, tensor<8x1x256x512xf32>) -> tensor<8x1x256x512xf32> - // CHECK: "ttnn.reduce_scatter" - %13 = tensor.empty() : tensor<64x1x1024x512xf32> - %14 = "ttir.mesh_shard"(%12, %13) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<8x1x256x512xf32>, tensor<64x1x1024x512xf32>) -> tensor<64x1x1024x512xf32> +func.func public @collective_permute_cluster_1(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1x1x1024x128xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x8192x512xf32>, tensor<1x1x1024x128xf32>) -> tensor<1x1x1024x128xf32> + // CHECK: "ttnn.mesh_shard" + %2 = tensor.empty() : tensor<1x1x1024x128xf32> + %3 = "ttir.collective_permute"(%1, %2) <{source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], [7, 4], [8, 9], [9, 10], [10, 11], [11, 8], [12, 13], [13, 14], [14, 15], [15, 12], [16, 17], [17, 18], [18, 19], [19, 16], [20, 21], [21, 22], [22, 23], [23, 20], [24, 25], [25, 26], [26, 27], [27, 24], [28, 29], [29, 30], [30, 31], [31, 28]]> : tensor<32x2xi64>}> : (tensor<1x1x1024x128xf32>, tensor<1x1x1024x128xf32>) -> tensor<1x1x1024x128xf32> + // CHECK: "ttnn.collective_permute" + %4 = tensor.empty() : tensor<1x1x8192x512xf32> + %5 = "ttir.mesh_shard"(%3, %4) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x1024x128xf32>, tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttnn.mesh_shard" + return %5 : tensor<1x1x8192x512xf32> +} + +func.func public @collective_permute_cluster_1_partial_target_pairs(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1x1x1024x128xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x8192x512xf32>, tensor<1x1x1024x128xf32>) -> tensor<1x1x1024x128xf32> + // CHECK: "ttnn.mesh_shard" + %2 = tensor.empty() : tensor<1x1x1024x128xf32> + %3 = "ttir.collective_permute"(%1, %2) <{source_target_pairs = dense<[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15], [16, 17], [18, 19], [20, 21], [22, 23], [24, 25], [26, 27], [28, 29], [30, 31]]> : tensor<16x2xi64>}> : (tensor<1x1x1024x128xf32>, tensor<1x1x1024x128xf32>) -> tensor<1x1x1024x128xf32> + // CHECK: "ttnn.collective_permute" + %4 = tensor.empty() : tensor<1x1x8192x512xf32> + %5 = "ttir.mesh_shard"(%3, %4) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x1024x128xf32>, tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttnn.mesh_shard" + return %5 : tensor<1x1x8192x512xf32> +} + +func.func public @collective_permute_cluster_0(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1x1x2048x64xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x8192x512xf32>, tensor<1x1x2048x64xf32>) -> tensor<1x1x2048x64xf32> + // CHECK: "ttnn.mesh_shard" + %2 = tensor.empty() : tensor<1x1x2048x64xf32> + %3 = "ttir.collective_permute"(%1, %2) <{source_target_pairs = dense<[[0, 4], [4, 8], [8, 12], [12, 16], [16, 20], [20, 24], [24, 28], [28, 0], [1, 5], [5, 9], [9, 13], [13, 17], [17, 21], [21, 25], [25, 29], [29, 1], [2, 6], [6, 10], [10, 14], [14, 18], [18, 22], [22, 26], [26, 30], [30, 2], [3, 7], [7, 11], [11, 15], [15, 19], [19, 23], [23, 27], [27, 31], [31, 3]]> : tensor<32x2xi64>}> : (tensor<1x1x2048x64xf32>, tensor<1x1x2048x64xf32>) -> tensor<1x1x2048x64xf32> + // CHECK: "ttnn.collective_permute" + %4 = tensor.empty() : tensor<1x1x8192x512xf32> + %5 = "ttir.mesh_shard"(%3, %4) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x2048x64xf32>, tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> + // CHECK: "ttnn.mesh_shard" + return %5 : tensor<1x1x8192x512xf32> +} + +func.func public @collective_permute_cluster_0_partial_target_pairs(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1x1x8192x512xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1x1x2048x64xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x8192x512xf32>, tensor<1x1x2048x64xf32>) -> tensor<1x1x2048x64xf32> + // CHECK: "ttnn.mesh_shard" + %2 = tensor.empty() : tensor<1x1x2048x64xf32> + %3 = "ttir.collective_permute"(%1, %2) <{source_target_pairs = dense<[[0, 4], [8, 12], [16, 20], [24, 28], [1, 5], [9, 13], [17, 21], [25, 29], [2, 6], [10, 14], [18, 22], [26, 30], [3, 7], [11, 15], [19, 23], [27, 31]]> : tensor<16x2xi64>}> : (tensor<1x1x2048x64xf32>, tensor<1x1x2048x64xf32>) -> tensor<1x1x2048x64xf32> + // CHECK: "ttnn.collective_permute" + %4 = tensor.empty() : tensor<1x1x8192x512xf32> + %5 = "ttir.mesh_shard"(%3, %4) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x2048x64xf32>, tensor<1x1x8192x512xf32>) -> tensor<1x1x8192x512xf32> // CHECK: "ttnn.mesh_shard" - return %14 : tensor<64x1x1024x512xf32> + return %5 : tensor<1x1x8192x512xf32> } diff --git a/test/ttmlir/Silicon/TTNN/tg/device_parallel/device_parallel_8x4.mlir b/test/ttmlir/Silicon/TTNN/tg/device_parallel/device_parallel_8x4.mlir index 9d086dbfa6..d0ee54e6d3 100644 --- a/test/ttmlir/Silicon/TTNN/tg/device_parallel/device_parallel_8x4.mlir +++ b/test/ttmlir/Silicon/TTNN/tg/device_parallel/device_parallel_8x4.mlir @@ -28,3 +28,30 @@ func.func public @jit_data_tensor_parallel_tg(%arg0: tensor<64x1x1024x2048xf32>, // CHECK: "ttnn.mesh_shard" return %14 : tensor<64x1x1024x512xf32> } + +func.func public @jit_data_tensor_parallel_tg(%arg0: tensor<64x1x1024x2048xf32>, %arg1: tensor<1x1x2048x512xf32>) -> (tensor<64x1x1024x512xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<8x1x1024x512xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<64x1x1024x2048xf32>, tensor<8x1x1024x512xf32>) -> tensor<8x1x1024x512xf32> + // CHECK: "ttnn.mesh_shard" + %2 = tensor.empty() : tensor<1x1x512x512xf32> + %3 = "ttir.mesh_shard"(%arg1, %2) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x2048x512xf32>, tensor<1x1x512x512xf32>) -> tensor<1x1x512x512xf32> + // CHECK: "ttnn.mesh_shard" + %4 = tensor.empty() : tensor<8x1024x512xf32> + %5 = "ttir.reshape"(%1, %4) <{shape = [8 : i32, 1024 : i32, 512 : i32]}> : (tensor<8x1x1024x512xf32>, tensor<8x1024x512xf32>) -> tensor<8x1024x512xf32> + // CHECK: = "ttnn.reshape" + %6 = tensor.empty() : tensor<1x512x512xf32> + %7 = "ttir.reshape"(%3, %6) <{shape = [1 : i32, 512 : i32, 512 : i32]}> : (tensor<1x1x512x512xf32>, tensor<1x512x512xf32>) -> tensor<1x512x512xf32> + // CHECK: = "ttnn.reshape" + %8 = "ttir.dot_general"(%5, %7) <{batch_dims_lhs = array, batch_dims_rhs = array, contract_dims_lhs = array, contract_dims_rhs = array}> : (tensor<8x1024x512xf32>, tensor<1x512x512xf32>) -> tensor<8x1024x1x512xf32> + // CHECK: "ttir.matmul" + %9 = tensor.empty() : tensor<8x1x1024x512xf32> + %10 = "ttir.permute"(%8, %9) <{permutation = array}> : (tensor<8x1024x1x512xf32>, tensor<8x1x1024x512xf32>) -> tensor<8x1x1024x512xf32> + // CHECK: "ttnn.permute" + %11 = tensor.empty() : tensor<8x1x256x512xf32> + %12 = "ttir.reduce_scatter"(%10, %11) <{cluster_axis = 1 : ui32, reduce_type = #tt.reduce_type, scatter_dim = 2 : si32}> : (tensor<8x1x1024x512xf32>, tensor<8x1x256x512xf32>) -> tensor<8x1x256x512xf32> + // CHECK: "ttnn.reduce_scatter" + %13 = tensor.empty() : tensor<64x1x1024x512xf32> + %14 = "ttir.mesh_shard"(%12, %13) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<8x1x256x512xf32>, tensor<64x1x1024x512xf32>) -> tensor<64x1x1024x512xf32> + // CHECK: "ttnn.mesh_shard" + return %14 : tensor<64x1x1024x512xf32> +}