Skip to content

Commit

Permalink
#18724: temp workaround on TG resnet trace+2cq hang (#18750)
Browse files Browse the repository at this point in the history
P0 temp fix for resnet50, which is hanging ND on TG for trace+2cq, due
to the extra rt arg being send.
IT removes the extra args for sharded case.

#18724 (comment)
### Checklist
- [x] [All post commit]
https://github.com/tenstorrent/tt-metal/actions/runs/13722113079
- [x] TGG
https://github.com/tenstorrent/tt-metal/actions/runs/13722119282
- [x] TG
https://github.com/tenstorrent/tt-metal/actions/runs/13722130588
  • Loading branch information
yugaoTT authored Mar 8, 2025
1 parent a00e06e commit 15db9cc
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_unet_trace_perf(
"batch, groups, iterations, expected_compile_time, expected_throughput, use_async_mode",
(
(1, 2, 128, 25.0, 1450.0, True),
(1, 2, 128, 25.0, 1660.0, False),
(1, 2, 128, 25.0, 1650.0, False),
),
)
def test_unet_trace_perf_multi_device(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ void kernel_main() {
uint32_t out_tensor_start_tile_id = get_arg_val<uint32_t>(rt_args_idx++);

// padding args (WRITER)
const uint32_t last_num_blocks_h_dim = get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t last_num_blocks_w_dim = get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t out_num_nonzero_subblocks_h = get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t out_last_num_nonzero_subblocks_h = get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t out_last_subblock_h = get_arg_val<uint32_t>(rt_args_idx++);
Expand All @@ -33,6 +31,11 @@ void kernel_main() {
const uint32_t padded_subblock_tiles_addr_skip = get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t padded_block_tiles_w_skip = get_arg_val<uint32_t>(rt_args_idx++);

#ifndef OUT_SHARDED
const uint32_t last_num_blocks_h_dim = get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t last_num_blocks_w_dim = get_arg_val<uint32_t>(rt_args_idx++);
#endif

// COMPILE TIME ARGS
// interleaved accessor args
constexpr bool out_is_dram = get_compile_time_arg_val(0) == 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ void kernel_main() {
// padding args (READER)
const uint32_t last_block_w = get_arg_val<uint32_t>(rt_args_idx++);
// padding args (WRITER)
const uint32_t last_num_blocks_w_dim = get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t out_num_nonzero_subblocks_h = get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t out_last_subblock_h = get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t padded_block_tiles_h_skip = get_arg_val<uint32_t>(rt_args_idx++);
Expand Down Expand Up @@ -108,6 +107,9 @@ void kernel_main() {
#else
rt_args_idx += 2; // Skip over placeholders
#endif
#ifndef OUT_SHARDED
const uint32_t last_num_blocks_w_dim = get_arg_val<uint32_t>(rt_args_idx++);
#endif

constexpr bool fuse_op = (bool)get_compile_time_arg_val(31);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0(
mm_in1_sender_writer_args.push_back(last_out_block_w);

// padding args (WRITER)
mm_in1_sender_writer_args.push_back(last_out_num_blocks_w);
mm_in1_sender_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_sender_writer_args.push_back(out_subblock_h);
mm_in1_sender_writer_args.push_back(0);
Expand All @@ -852,7 +851,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0(
mm_in1_sender_writer_args.push_back(out_block_w);

// padding args (WRITER)
mm_in1_sender_writer_args.push_back(out_num_blocks_x);
mm_in1_sender_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_sender_writer_args.push_back(out_subblock_h);
mm_in1_sender_writer_args.push_back(0);
Expand All @@ -871,6 +869,13 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0(
mm_in1_sender_writer_args.push_back(0);
mm_in1_sender_writer_args.push_back(0);
}
if (!output_is_sharded) {
if (output_idx_x == num_blocks_x - 1) {
mm_in1_sender_writer_args.push_back(last_out_num_blocks_w);
} else {
mm_in1_sender_writer_args.push_back(out_num_blocks_x);
}
}

if (fuse_op) {
fused_op_signaler->push_matmul_fused_op_rt_args(mm_in1_sender_writer_args, true);
Expand Down Expand Up @@ -943,7 +948,7 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0(
writer_runtime_args[0] = src_buffer_b->address();
writer_runtime_args[6] = dst_buffer->address();
if (bias_tensor.has_value()) {
writer_runtime_args[18] = (*bias_buffer)->address();
writer_runtime_args[17] = (*bias_buffer)->address();
}
}

Expand Down Expand Up @@ -1531,7 +1536,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in1(
// padding args (READER)
(std::uint32_t)out_block_w, // last_block_w
// padding args (WRITER)
(std::uint32_t)out_num_blocks_x,
(std::uint32_t)out_block_h / out_subblock_h,
(std::uint32_t)out_subblock_h,
(std::uint32_t)0,
Expand All @@ -1545,6 +1549,12 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in1(
mm_in1_sender_writer_args.push_back((std::uint32_t)bias_buffer->address());
mm_in1_sender_writer_args.push_back(
(std::uint32_t)per_core_N * output_idx_x); // in3_tensor_start_tile_id
} else {
mm_in1_sender_writer_args.push_back(0);
mm_in1_sender_writer_args.push_back(0);
}
if (!output_is_sharded) {
mm_in1_sender_writer_args.push_back(out_num_blocks_x);
}

tt_metal::SetRuntimeArgs(
Expand All @@ -1566,8 +1576,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in1(

if (output_idx_y == num_blocks_y - 1) {
// padding args (WRITER)
mm_in1_receiver_writer_args.push_back(last_out_num_blocks_h);
mm_in1_receiver_writer_args.push_back(out_num_blocks_x);
mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_receiver_writer_args.push_back(last_block_num_nonzero_subblocks_h);
mm_in1_receiver_writer_args.push_back(last_subblock_of_last_block_h);
Expand All @@ -1579,8 +1587,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in1(
mm_in1_receiver_writer_args.push_back(0);
} else {
// padding args (WRITER)
mm_in1_receiver_writer_args.push_back(out_num_blocks_y);
mm_in1_receiver_writer_args.push_back(out_num_blocks_x);
mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_receiver_writer_args.push_back(out_subblock_h);
Expand All @@ -1591,6 +1597,15 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in1(
mm_in1_receiver_writer_args.push_back(0);
mm_in1_receiver_writer_args.push_back(0);
}
if (!output_is_sharded) {
if (output_idx_y == num_blocks_y - 1) {
mm_in1_receiver_writer_args.push_back(last_out_num_blocks_h);
mm_in1_receiver_writer_args.push_back(out_num_blocks_x);
} else {
mm_in1_receiver_writer_args.push_back(out_num_blocks_y);
mm_in1_receiver_writer_args.push_back(out_num_blocks_x);
}
}

tt_metal::SetRuntimeArgs(
program,
Expand Down Expand Up @@ -1659,7 +1674,7 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in1(
sender_writer_runtime_args[0] = src_buffer_b->address();
sender_writer_runtime_args[6] = dst_buffer->address();
if (bias_tensor.has_value()) {
sender_writer_runtime_args[18] = (*bias_buffer)->address();
sender_writer_runtime_args[17] = (*bias_buffer)->address();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0_in1(
mm_in1_sender_writer_args.push_back(last_out_block_w);

// padding args (WRITER)
mm_in1_sender_writer_args.push_back(last_out_num_blocks_w);
mm_in1_sender_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_sender_writer_args.push_back(out_subblock_h);
mm_in1_sender_writer_args.push_back(0);
Expand All @@ -1043,7 +1042,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0_in1(
mm_in1_sender_writer_args.push_back(out_block_w);

// padding args (WRITER)
mm_in1_sender_writer_args.push_back(out_num_blocks_x);
mm_in1_sender_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_sender_writer_args.push_back(out_subblock_h);
mm_in1_sender_writer_args.push_back(0);
Expand All @@ -1062,6 +1060,13 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0_in1(
mm_in1_sender_writer_args.push_back(0); // Placeholder; not used
mm_in1_sender_writer_args.push_back(0); // Placeholder; not used
}
if (!output_is_sharded) {
if (in1_idx == in1_end_idx) { // right cores when no transpose_mcast
mm_in1_sender_writer_args.push_back(last_out_num_blocks_w);
} else {
mm_in1_sender_writer_args.push_back(out_num_blocks_x);
}
}

if (in1_is_sharded and in1_is_dram) { // in1 is dram sharded
uint32_t num_iter_index = mm_in1_sender_writer_args.size() + 1;
Expand Down Expand Up @@ -1147,8 +1152,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0_in1(

if (in1_idx == in1_end_idx and in0_idx == in0_end_idx) { // bottom-right core when no transpose_mcast
// padding args (WRITER)
mm_in1_receiver_writer_args.push_back(last_out_num_blocks_h);
mm_in1_receiver_writer_args.push_back(last_out_num_blocks_w);
mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_receiver_writer_args.push_back(last_block_num_nonzero_subblocks_h);
mm_in1_receiver_writer_args.push_back(last_subblock_of_last_block_h);
Expand All @@ -1160,8 +1163,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0_in1(
mm_in1_receiver_writer_args.push_back(last_block_padded_block_tiles_w_skip);
} else if (in0_idx == in0_end_idx) { // bottom cores except bottom-right when no transpose_mcast
// padding args (WRITER)
mm_in1_receiver_writer_args.push_back(last_out_num_blocks_h);
mm_in1_receiver_writer_args.push_back(out_num_blocks_x);
mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_receiver_writer_args.push_back(last_block_num_nonzero_subblocks_h);
mm_in1_receiver_writer_args.push_back(last_subblock_of_last_block_h);
Expand All @@ -1173,8 +1174,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0_in1(
mm_in1_receiver_writer_args.push_back(0);
} else if (in1_idx == in1_end_idx) { // right cores except bottom when no transpose_mcast
// padding args (WRITER)
mm_in1_receiver_writer_args.push_back(out_num_blocks_y);
mm_in1_receiver_writer_args.push_back(last_out_num_blocks_w);
mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_receiver_writer_args.push_back(out_subblock_h);
Expand All @@ -1186,8 +1185,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0_in1(
mm_in1_receiver_writer_args.push_back(last_block_padded_block_tiles_w_skip);
} else {
// padding args (WRITER)
mm_in1_receiver_writer_args.push_back(out_num_blocks_y);
mm_in1_receiver_writer_args.push_back(out_num_blocks_x);
mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_receiver_writer_args.push_back(out_subblock_h);
Expand All @@ -1198,6 +1195,22 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0_in1(
mm_in1_receiver_writer_args.push_back(0);
mm_in1_receiver_writer_args.push_back(0);
}
if (!output_is_sharded) {
if (in1_idx == in1_end_idx and
in0_idx == in0_end_idx) { // bottom-right core when no transpose_mcast
mm_in1_receiver_writer_args.push_back(last_out_num_blocks_h);
mm_in1_receiver_writer_args.push_back(last_out_num_blocks_w);
} else if (in0_idx == in0_end_idx) { // bottom cores except bottom-right when no transpose_mcast
mm_in1_receiver_writer_args.push_back(last_out_num_blocks_h);
mm_in1_receiver_writer_args.push_back(out_num_blocks_x);
} else if (in1_idx == in1_end_idx) { // right cores except bottom when no transpose_mcast
mm_in1_receiver_writer_args.push_back(out_num_blocks_y);
mm_in1_receiver_writer_args.push_back(last_out_num_blocks_w);
} else {
mm_in1_receiver_writer_args.push_back(out_num_blocks_y);
mm_in1_receiver_writer_args.push_back(out_num_blocks_x);
}
}

// left half
if (core.x <= half_core || (transpose_mcast and core.y == start_core_y)) {
Expand Down Expand Up @@ -1270,7 +1283,7 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0_in1(
writer_runtime_args[0] = src_buffer_b->address();
writer_runtime_args[6] = dst_buffer->address();
if (bias_tensor.has_value()) {
writer_runtime_args[18] = (*bias_buffer)->address();
writer_runtime_args[17] = (*bias_buffer)->address();
}
}

Expand Down

0 comments on commit 15db9cc

Please sign in to comment.