From 928f62e555211dab0cfaef81fe289260fb37abe7 Mon Sep 17 00:00:00 2001 From: Tapasvi Patel Date: Sun, 9 Mar 2025 12:50:03 +0000 Subject: [PATCH] #2362: Added workaround to bug in full to shard with shard type replicate --- .../lib/ttnn/operations/ccl/mesh_shard.cpp | 10 +++++++- .../device_parallel/device_parallel_1x2.mlir | 24 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp b/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp index f633f24c0c..efe6fbb223 100644 --- a/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp +++ b/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp @@ -18,8 +18,16 @@ void FullToShardShape(const ::ttnn::Tensor &input, ::ttnn::Tensor &out, const std::vector &shardShape, const std::vector &shardDims) { if (shardType == ::tt::target::ttnn::MeshShardType::Replicate) { + + // todo: (tapspatel) - metal issue #18842. Currently ReplicateTensorToMesh + // map API cannot take as input a borrowed storage tensor, so we need to + // create a deep copy of this tensor and convert to owned storage. + // https://github.com/tenstorrent/tt-metal/issues/18842 + auto copiedTensorChunk = ::ttnn::experimental::xtensor::chunk(input, 1, 0); + auto copiedTensor = + ::ttnn::experimental::xtensor::concat(copiedTensorChunk, 0); out = ::ttnn::distributed::distribute_tensor( - input, + copiedTensor, *::ttnn::distributed::replicate_tensor_to_mesh_mapper(meshDevice)); } else { DEBUG_ASSERT(input.get_logical_shape().rank() > 1, diff --git a/test/ttmlir/Silicon/TTNN/n300/device_parallel/device_parallel_1x2.mlir b/test/ttmlir/Silicon/TTNN/n300/device_parallel/device_parallel_1x2.mlir index afbc3573df..0077811adc 100644 --- a/test/ttmlir/Silicon/TTNN/n300/device_parallel/device_parallel_1x2.mlir +++ b/test/ttmlir/Silicon/TTNN/n300/device_parallel/device_parallel_1x2.mlir @@ -28,3 +28,27 @@ func.func public @jit_tensor_parallel_n300(%arg0: tensor<64x1x1024x2048xf32>, %a // CHECK: "ttnn.mesh_shard" return %14 : tensor<64x1x1024x512xf32> } + +func.func public @jit_data_parallel_n300(%arg0: tensor<64x1x1024x2048xf32>, %arg1: tensor<1x1x2048x512xf32>) -> (tensor<64x1x1024x512xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<32x1x1024x2048xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<64x1x1024x2048xf32>, tensor<32x1x1024x2048xf32>) -> tensor<32x1x1024x2048xf32> + // CHECK: "ttnn.mesh_shard" + %2 = tensor.empty() : tensor<1x1x2048x512xf32> + %3 = "ttir.mesh_shard"(%arg1, %2) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<1x1x2048x512xf32>, tensor<1x1x2048x512xf32>) -> tensor<1x1x2048x512xf32> + // CHECK: "ttnn.mesh_shard" + %4 = tensor.empty() : tensor<32x1024x2048xf32> + %5 = "ttir.reshape"(%1, %4) <{shape = [32 : i32, 1024 : i32, 2048 : i32]}> : (tensor<32x1x1024x2048xf32>, tensor<32x1024x2048xf32>) -> tensor<32x1024x2048xf32> + // CHECK: = "ttnn.reshape" + %6 = tensor.empty() : tensor<1x2048x512xf32> + %7 = "ttir.reshape"(%3, %6) <{shape = [1 : i32, 2048 : i32, 512 : i32]}> : (tensor<1x1x2048x512xf32>, tensor<1x2048x512xf32>) -> tensor<1x2048x512xf32> + // CHECK: = "ttnn.reshape" + %8 = "ttir.dot_general"(%5, %7) <{batch_dims_lhs = array, batch_dims_rhs = array, contract_dims_lhs = array, contract_dims_rhs = array}> : (tensor<32x1024x2048xf32>, tensor<1x2048x512xf32>) -> tensor<32x1024x1x512xf32> + // CHECK: "ttnn.matmul" + %9 = tensor.empty() : tensor<32x1x1024x512xf32> + %10 = "ttir.permute"(%8, %9) <{permutation = array}> : (tensor<32x1024x1x512xf32>, tensor<32x1x1024x512xf32>) -> tensor<32x1x1024x512xf32> + // CHECK: "ttnn.permute" + %11 = tensor.empty() : tensor<64x1x1024x512xf32> + %12 = "ttir.mesh_shard"(%10, %11) <{shard_dims = array, shard_direction = #tt.shard_direction, shard_shape = array, shard_type = #tt.shard_type}> : (tensor<32x1x1024x512xf32>, tensor<64x1x1024x512xf32>) -> tensor<64x1x1024x512xf32> + // CHECK: "ttnn.mesh_shard" + return %12 : tensor<64x1x1024x512xf32> +}