Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The following data parallel test fails on n300 #2362

Open
tapspatel opened this issue Mar 4, 2025 · 0 comments · May be fixed by #2409
Open

The following data parallel test fails on n300 #2362

tapspatel opened this issue Mar 4, 2025 · 0 comments · May be fixed by #2409
Assignees
Labels
bug Something isn't working

Comments

@tapspatel
Copy link
Collaborator

Following test fails with error signature

2025-03-04 20:35:53,820 - ERROR - ERROR: test=/code/tt-mlir/build/test/ttmlir/Silicon/TTNN/n300/device_parallel/Output/device_parallel_1x2.mlir.tmp.ttnn experienced an error with exception=std::get: wrong index for variant
func.func public @jit_data_parallel_n300(%arg0: tensor<64x1x1024x2048xf32>, %arg1: tensor<1x1x2048x512xf32>) -> (tensor<64x1x1024x512xf32> {jax.result_info = ""}) {
  %0 = tensor.empty() : tensor<32x1x1024x2048xf32>
  %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array<i64: -1, 0>, shard_direction = #tt.shard_direction<full_to_shard>, shard_shape = array<i64: 2, 1, 1, 1>, shard_type = #tt.shard_type<devices>}> : (tensor<64x1x1024x2048xf32>, tensor<32x1x1024x2048xf32>) -> tensor<32x1x1024x2048xf32>
  // CHECK: "ttnn.mesh_shard"
  %2 = tensor.empty() : tensor<1x1x2048x512xf32>
  %3 = "ttir.mesh_shard"(%arg1, %2) <{shard_dims = array<i64: -1>, shard_direction = #tt.shard_direction<full_to_shard>, shard_shape = array<i64: 1>, shard_type = #tt.shard_type<replicate>}> : (tensor<1x1x2048x512xf32>, tensor<1x1x2048x512xf32>) -> tensor<1x1x2048x512xf32>
  // CHECK: "ttnn.mesh_shard"
  %4 = tensor.empty() : tensor<32x1024x2048xf32>
  %5 = "ttir.reshape"(%1, %4) <{shape = [32 : i32, 1024 : i32, 2048 : i32]}> : (tensor<32x1x1024x2048xf32>, tensor<32x1024x2048xf32>) -> tensor<32x1024x2048xf32>
  // CHECK: = "ttnn.reshape"
  %6 = tensor.empty() : tensor<1x2048x512xf32>
  %7 = "ttir.reshape"(%3, %6) <{shape = [1 : i32, 2048 : i32, 512 : i32]}> : (tensor<1x1x2048x512xf32>, tensor<1x2048x512xf32>) -> tensor<1x2048x512xf32>
  // CHECK: = "ttnn.reshape"
  %8 = "ttir.dot_general"(%5, %7) <{batch_dims_lhs = array<i64>, batch_dims_rhs = array<i64>, contract_dims_lhs = array<i64: 2>, contract_dims_rhs = array<i64: 1>}> : (tensor<32x1024x2048xf32>, tensor<1x2048x512xf32>) -> tensor<32x1024x1x512xf32>
  // CHECK: "ttnn.matmul"
  %9 = tensor.empty() : tensor<32x1x1024x512xf32>
  %10 = "ttir.permute"(%8, %9) <{permutation = array<i64: 0, 2, 1, 3>}> : (tensor<32x1024x1x512xf32>, tensor<32x1x1024x512xf32>) -> tensor<32x1x1024x512xf32>
  // CHECK: "ttnn.permute"
  %11 = tensor.empty() : tensor<64x1x1024x512xf32>
  %12 = "ttir.mesh_shard"(%10, %11) <{shard_dims = array<i64: -1, 0>, shard_direction = #tt.shard_direction<shard_to_full>, shard_shape = array<i64: 2, 1, 1, 1>, shard_type = #tt.shard_type<devices>}> : (tensor<32x1x1024x512xf32>, tensor<64x1x1024x512xf32>) -> tensor<64x1x1024x512xf32>
  // CHECK: "ttnn.mesh_shard"
  return %12 : tensor<64x1x1024x512xf32>
}
@tapspatel tapspatel added the bug Something isn't working label Mar 4, 2025
@tapspatel tapspatel added this to the [Multi Device 1] milestone Mar 4, 2025
@tapspatel tapspatel self-assigned this Mar 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant