Skip to content

Commit

Permalink
Add asserts for APIs not supported during MeshTrace capture
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-asaigal committed Mar 10, 2025
1 parent 7c1bd85 commit 53a2df1
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
16 changes: 16 additions & 0 deletions tests/tt_metal/distributed/test_mesh_trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,5 +580,21 @@ TEST_F(MeshTraceTestSuite, DataCopyOnSubDevicesTrace) {
ReleaseTrace(mesh_device_.get(), trace_id);
}

TEST_F(MeshTraceTestSuite, MeshTraceAsserts) {
auto random_seed = 10;
uint32_t seed = tt::parse_env("TT_METAL_SEED", random_seed);
log_info(tt::LogTest, "Using Test Seed: {}", seed);
srand(seed);
MeshCoordinateRange all_devices(mesh_device_->shape());
auto workload = std::make_shared<MeshWorkload>();
auto programs = tt::tt_metal::distributed::test::utils::create_random_programs(
1, mesh_device_->compute_with_storage_grid_size(), seed);
AddProgramToMeshWorkload(*workload, std::move(*programs[0]), all_devices);
auto trace_id = BeginTraceCapture(mesh_device_.get(), 0);
EXPECT_THROW(EnqueueMeshWorkload(mesh_device_->mesh_command_queue(), *workload, true), std::runtime_error);
EXPECT_THROW(Finish(mesh_device_->mesh_command_queue()), std::runtime_error);
EndTraceCapture(mesh_device_.get(), 0, trace_id);
}

} // namespace
} // namespace tt::tt_metal::distributed::test
4 changes: 4 additions & 0 deletions tt_metal/distributed/mesh_command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ void MeshCommandQueue::write_shard_to_device(
const void* src,
const BufferRegion& region,
tt::stl::Span<const SubDeviceId> sub_device_ids) {
TT_FATAL(!trace_id_.has_value(), "Writes are not supported during trace capture.");
auto device = shard_view->device();
sub_device_ids = buffer_dispatch::select_sub_device_ids(mesh_device_, sub_device_ids);
buffer_dispatch::write_to_device_buffer(
Expand All @@ -247,6 +248,7 @@ void MeshCommandQueue::read_shard_from_device(
const BufferRegion& region,
std::unordered_map<IDevice*, uint32_t>& num_txns_per_device,
tt::stl::Span<const SubDeviceId> sub_device_ids) {
TT_FATAL(!trace_id_.has_value(), "Reads are not supported during trace capture.");
auto device = shard_view->device();
sub_device_ids = buffer_dispatch::select_sub_device_ids(mesh_device_, sub_device_ids);

Expand Down Expand Up @@ -529,6 +531,7 @@ MeshEvent MeshCommandQueue::enqueue_record_event_helper(
tt::stl::Span<const SubDeviceId> sub_device_ids,
bool notify_host,
const std::optional<MeshCoordinateRange>& device_range) {
TT_FATAL(!trace_id_.has_value(), "Event Synchronization is not supported during trace capture.");
auto& sysmem_manager = this->reference_sysmem_manager();
auto event = MeshEvent(
sysmem_manager.get_next_event(id_),
Expand Down Expand Up @@ -567,6 +570,7 @@ MeshEvent MeshCommandQueue::enqueue_record_event_to_host(
}

void MeshCommandQueue::enqueue_wait_for_event(const MeshEvent& sync_event) {
TT_FATAL(!trace_id_.has_value(), "Event Synchronization is not supported during trace capture.");
for (const auto& coord : sync_event.device_range()) {
event_dispatch::issue_wait_for_event_commands(
id_, sync_event.mesh_cq_id(), mesh_device_->get_device(coord)->sysmem_manager(), sync_event.id());
Expand Down

0 comments on commit 53a2df1

Please sign in to comment.