From 53a2df1bdd0f989244015d3f5bcbecb549666362 Mon Sep 17 00:00:00 2001 From: asaigal Date: Fri, 28 Feb 2025 18:19:04 +0000 Subject: [PATCH] Add asserts for APIs not supported during MeshTrace capture --- tests/tt_metal/distributed/test_mesh_trace.cpp | 16 ++++++++++++++++ tt_metal/distributed/mesh_command_queue.cpp | 4 ++++ 2 files changed, 20 insertions(+) diff --git a/tests/tt_metal/distributed/test_mesh_trace.cpp b/tests/tt_metal/distributed/test_mesh_trace.cpp index 5142599b455..379f94e592c 100644 --- a/tests/tt_metal/distributed/test_mesh_trace.cpp +++ b/tests/tt_metal/distributed/test_mesh_trace.cpp @@ -580,5 +580,21 @@ TEST_F(MeshTraceTestSuite, DataCopyOnSubDevicesTrace) { ReleaseTrace(mesh_device_.get(), trace_id); } +TEST_F(MeshTraceTestSuite, MeshTraceAsserts) { + auto random_seed = 10; + uint32_t seed = tt::parse_env("TT_METAL_SEED", random_seed); + log_info(tt::LogTest, "Using Test Seed: {}", seed); + srand(seed); + MeshCoordinateRange all_devices(mesh_device_->shape()); + auto workload = std::make_shared(); + auto programs = tt::tt_metal::distributed::test::utils::create_random_programs( + 1, mesh_device_->compute_with_storage_grid_size(), seed); + AddProgramToMeshWorkload(*workload, std::move(*programs[0]), all_devices); + auto trace_id = BeginTraceCapture(mesh_device_.get(), 0); + EXPECT_THROW(EnqueueMeshWorkload(mesh_device_->mesh_command_queue(), *workload, true), std::runtime_error); + EXPECT_THROW(Finish(mesh_device_->mesh_command_queue()), std::runtime_error); + EndTraceCapture(mesh_device_.get(), 0, trace_id); +} + } // namespace } // namespace tt::tt_metal::distributed::test diff --git a/tt_metal/distributed/mesh_command_queue.cpp b/tt_metal/distributed/mesh_command_queue.cpp index 0902cd8f69a..c431a0a3b30 100644 --- a/tt_metal/distributed/mesh_command_queue.cpp +++ b/tt_metal/distributed/mesh_command_queue.cpp @@ -235,6 +235,7 @@ void MeshCommandQueue::write_shard_to_device( const void* src, const BufferRegion& region, tt::stl::Span sub_device_ids) { + TT_FATAL(!trace_id_.has_value(), "Writes are not supported during trace capture."); auto device = shard_view->device(); sub_device_ids = buffer_dispatch::select_sub_device_ids(mesh_device_, sub_device_ids); buffer_dispatch::write_to_device_buffer( @@ -247,6 +248,7 @@ void MeshCommandQueue::read_shard_from_device( const BufferRegion& region, std::unordered_map& num_txns_per_device, tt::stl::Span sub_device_ids) { + TT_FATAL(!trace_id_.has_value(), "Reads are not supported during trace capture."); auto device = shard_view->device(); sub_device_ids = buffer_dispatch::select_sub_device_ids(mesh_device_, sub_device_ids); @@ -529,6 +531,7 @@ MeshEvent MeshCommandQueue::enqueue_record_event_helper( tt::stl::Span sub_device_ids, bool notify_host, const std::optional& device_range) { + TT_FATAL(!trace_id_.has_value(), "Event Synchronization is not supported during trace capture."); auto& sysmem_manager = this->reference_sysmem_manager(); auto event = MeshEvent( sysmem_manager.get_next_event(id_), @@ -567,6 +570,7 @@ MeshEvent MeshCommandQueue::enqueue_record_event_to_host( } void MeshCommandQueue::enqueue_wait_for_event(const MeshEvent& sync_event) { + TT_FATAL(!trace_id_.has_value(), "Event Synchronization is not supported during trace capture."); for (const auto& coord : sync_event.device_range()) { event_dispatch::issue_wait_for_event_commands( id_, sync_event.mesh_cq_id(), mesh_device_->get_device(coord)->sysmem_manager(), sync_event.id());