Skip to content

Commit

Permalink
#2362: Added workaround to bug in full to shard with shard type repli…
Browse files Browse the repository at this point in the history
…cate
  • Loading branch information
tapspatel committed Mar 9, 2025
1 parent c16d622 commit 928f62e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
10 changes: 9 additions & 1 deletion runtime/lib/ttnn/operations/ccl/mesh_shard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,16 @@ void FullToShardShape(const ::ttnn::Tensor &input, ::ttnn::Tensor &out,
const std::vector<int64_t> &shardShape,
const std::vector<int64_t> &shardDims) {
if (shardType == ::tt::target::ttnn::MeshShardType::Replicate) {

// todo: (tapspatel) - metal issue #18842. Currently ReplicateTensorToMesh
// map API cannot take as input a borrowed storage tensor, so we need to
// create a deep copy of this tensor and convert to owned storage.
// https://github.com/tenstorrent/tt-metal/issues/18842
auto copiedTensorChunk = ::ttnn::experimental::xtensor::chunk(input, 1, 0);
auto copiedTensor =
::ttnn::experimental::xtensor::concat(copiedTensorChunk, 0);
out = ::ttnn::distributed::distribute_tensor(
input,
copiedTensor,
*::ttnn::distributed::replicate_tensor_to_mesh_mapper(meshDevice));
} else {
DEBUG_ASSERT(input.get_logical_shape().rank() > 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,27 @@ func.func public @jit_tensor_parallel_n300(%arg0: tensor<64x1x1024x2048xf32>, %a
// CHECK: "ttnn.mesh_shard"
return %14 : tensor<64x1x1024x512xf32>
}

func.func public @jit_data_parallel_n300(%arg0: tensor<64x1x1024x2048xf32>, %arg1: tensor<1x1x2048x512xf32>) -> (tensor<64x1x1024x512xf32> {jax.result_info = ""}) {
%0 = tensor.empty() : tensor<32x1x1024x2048xf32>
%1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array<i64: -1, 0>, shard_direction = #tt.shard_direction<full_to_shard>, shard_shape = array<i64: 2, 1, 1, 1>, shard_type = #tt.shard_type<devices>}> : (tensor<64x1x1024x2048xf32>, tensor<32x1x1024x2048xf32>) -> tensor<32x1x1024x2048xf32>
// CHECK: "ttnn.mesh_shard"
%2 = tensor.empty() : tensor<1x1x2048x512xf32>
%3 = "ttir.mesh_shard"(%arg1, %2) <{shard_dims = array<i64: -1>, shard_direction = #tt.shard_direction<full_to_shard>, shard_shape = array<i64: 1>, shard_type = #tt.shard_type<replicate>}> : (tensor<1x1x2048x512xf32>, tensor<1x1x2048x512xf32>) -> tensor<1x1x2048x512xf32>
// CHECK: "ttnn.mesh_shard"
%4 = tensor.empty() : tensor<32x1024x2048xf32>
%5 = "ttir.reshape"(%1, %4) <{shape = [32 : i32, 1024 : i32, 2048 : i32]}> : (tensor<32x1x1024x2048xf32>, tensor<32x1024x2048xf32>) -> tensor<32x1024x2048xf32>
// CHECK: = "ttnn.reshape"
%6 = tensor.empty() : tensor<1x2048x512xf32>
%7 = "ttir.reshape"(%3, %6) <{shape = [1 : i32, 2048 : i32, 512 : i32]}> : (tensor<1x1x2048x512xf32>, tensor<1x2048x512xf32>) -> tensor<1x2048x512xf32>
// CHECK: = "ttnn.reshape"
%8 = "ttir.dot_general"(%5, %7) <{batch_dims_lhs = array<i64>, batch_dims_rhs = array<i64>, contract_dims_lhs = array<i64: 2>, contract_dims_rhs = array<i64: 1>}> : (tensor<32x1024x2048xf32>, tensor<1x2048x512xf32>) -> tensor<32x1024x1x512xf32>
// CHECK: "ttnn.matmul"
%9 = tensor.empty() : tensor<32x1x1024x512xf32>
%10 = "ttir.permute"(%8, %9) <{permutation = array<i64: 0, 2, 1, 3>}> : (tensor<32x1024x1x512xf32>, tensor<32x1x1024x512xf32>) -> tensor<32x1x1024x512xf32>
// CHECK: "ttnn.permute"
%11 = tensor.empty() : tensor<64x1x1024x512xf32>
%12 = "ttir.mesh_shard"(%10, %11) <{shard_dims = array<i64: -1, 0>, shard_direction = #tt.shard_direction<shard_to_full>, shard_shape = array<i64: 2, 1, 1, 1>, shard_type = #tt.shard_type<devices>}> : (tensor<32x1x1024x512xf32>, tensor<64x1x1024x512xf32>) -> tensor<64x1x1024x512xf32>
// CHECK: "ttnn.mesh_shard"
return %12 : tensor<64x1x1024x512xf32>
}

0 comments on commit 928f62e

Please sign in to comment.