From e1cd7016425b26a7fb112fda3c021af52e5984b4 Mon Sep 17 00:00:00 2001 From: sdjukicTT Date: Fri, 21 Feb 2025 19:12:48 +0000 Subject: [PATCH] addressed comment --- src/common/module_builder.cc | 2 +- src/common/module_builder.h | 2 ++ src/common/pjrt_implementation/client_instance.cc | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/common/module_builder.cc b/src/common/module_builder.cc index a800d5e9..4490f74c 100644 --- a/src/common/module_builder.cc +++ b/src/common/module_builder.cc @@ -197,7 +197,7 @@ void ModuleBuilder::convertFromTTIRToTTNN( mlir::PassManager ttir_to_ttnn_pm(mlir_module.get()->getName()); mlir::tt::ttnn::TTIRToTTNNBackendPipelineOptions options; - options.systemDescPath = "system_desc.ttsys"; + options.systemDescPath = system_desc_path; mlir::tt::ttnn::createTTIRToTTNNBackendPipeline(ttir_to_ttnn_pm, options); // Run the pass manager. diff --git a/src/common/module_builder.h b/src/common/module_builder.h index 548cc321..27982750 100644 --- a/src/common/module_builder.h +++ b/src/common/module_builder.h @@ -40,6 +40,8 @@ class ModuleBuilder { // code. Currently hardcoded to one, as we only support one-chip execution. size_t getNumAddressableDevices() const { return 1; } + static constexpr char const *system_desc_path = "system_desc.ttsys"; + private: // Creates VHLO module from the input program code. mlir::OwningOpRef diff --git a/src/common/pjrt_implementation/client_instance.cc b/src/common/pjrt_implementation/client_instance.cc index 99624bb6..e5d79070 100644 --- a/src/common/pjrt_implementation/client_instance.cc +++ b/src/common/pjrt_implementation/client_instance.cc @@ -27,6 +27,7 @@ ClientInstance::ClientInstance(std::unique_ptr platform) } ClientInstance::~ClientInstance() { + std::remove(ModuleBuilder::system_desc_path); DLOG_F(LOG_DEBUG, "ClientInstance::~ClientInstance"); } @@ -164,7 +165,7 @@ void ClientInstance::BindApi(PJRT_Api *api) { tt_pjrt_status ClientInstance::PopulateDevices() { DLOG_F(LOG_DEBUG, "ClientInstance::PopulateDevices"); auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc(); - system_desc.store("system_desc.ttsys"); + system_desc.store(ModuleBuilder::system_desc_path); int devices_count = chip_ids.size(); devices_.resize(devices_count);