Skip to content

Commit

Permalink
Switched to using logging in Explorer (#1920)
Browse files Browse the repository at this point in the history
- Changed ModelRunner.log to have a `severity` param that defaults to
`logging.info`.
- Changed print to `logging.info`.
- This closes #1500
  • Loading branch information
vprajapati-tt authored Feb 10, 2025
1 parent a4bdb21 commit 7eefddf
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
4 changes: 2 additions & 2 deletions tools/explorer/tt_adapter/src/tt_adapter/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import model_explorer
from . import runner, utils, mlir
import dataclasses
import enum
import logging
from ttmlir import optimizer_overrides

OPTIMIZER_DISABLED_POLICY = "Optimizer Disabled"
Expand Down Expand Up @@ -92,7 +92,7 @@ def convert(
if optimized_model_path := self.model_runner.get_optimized_model_path(
model_path
):
print(f"Using optimized model: {optimized_model_path}")
logging.info(f"Using optimized model: {optimized_model_path}")
# Get performance results.
perf_trace = self.model_runner.get_perf_trace(model_path)

Expand Down
26 changes: 14 additions & 12 deletions tools/explorer/tt_adapter/src/tt_adapter/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
import subprocess
import os
import tempfile
import logging

# TODO(odjuricic) Cleaner to implement ttrt --quiet flag.
# os.environ["TTRT_LOGGER_LEVEL"] = "ERROR"
Expand Down Expand Up @@ -56,14 +56,14 @@ class ModelRunner:

def __new__(cls, *args, **kwargs):
if not cls._instance:
print("Creating a new ModelRunner instance.")
logging.info("Creating a new ModelRunner instance.")
cls._instance = super(ModelRunner, cls).__new__(cls, *args, **kwargs)
cls._instance.initialize()
return cls._instance

def initialize(self):
# Initialize machine to generate SystemDesc and load up functionality to begin
print("Running ttrt initialization.")
logging.info("Running ttrt initialization.")
ttrt.initialize_apis()

if "TT_MLIR_HOME" not in os.environ:
Expand All @@ -84,7 +84,7 @@ def initialize(self):
}
)()

print("ModelRunner initialized.")
logging.info("ModelRunner initialized.")

def get_optimized_model_path(self, model_path):
if model_path in self.model_state:
Expand Down Expand Up @@ -127,8 +127,8 @@ def reset_state(self, model_path):
if model_path in self.model_state:
del self.model_state[model_path]

def log(self, message):
print(message)
def log(self, message, severity=logging.info):
severity(message)
self.log_queue.put(message)

def get_perf_trace(self, model_path):
Expand Down Expand Up @@ -166,7 +166,7 @@ def compile_and_run_wrapper(self, model_path, overrides_string):
raise e
except Exception as e:
self.runner_error = "An unexpected error occurred: " + str(e)
self.log(self.runner_error)
self.log(self.runner_error, severity=logging.error)
raise e
finally:
self.progress = 100
Expand Down Expand Up @@ -224,7 +224,7 @@ def compile_and_run(self, model_path, overrides_string):
compile_process = self.run_in_subprocess(compile_command)
if compile_process.returncode != 0:
error = "Error running compile TTIR to TTNN Backend Pipeline"
self.log(error)
self.log(error, severity=logging.error)
raise ExplorerRunException(error)
self.progress = 20

Expand Down Expand Up @@ -288,7 +288,7 @@ def compile_and_run(self, model_path, overrides_string):
translate_process = self.run_in_subprocess(to_flatbuffer_command)
if translate_process.returncode != 0:
error = "Error while running TTNN to Flatbuffer File"
self.log(error)
self.log(error, severtity=logging.error)
raise ExplorerRunException(error)

self.progress = 30
Expand All @@ -306,7 +306,7 @@ def compile_and_run(self, model_path, overrides_string):

if ttrt_process.returncode != 0:
error = "Error while running TTRT perf"
self.log(error)
self.log(error, severity=logging.error)
raise ExplorerRunException(error)

perf = self.get_perf_trace(model_path)
Expand All @@ -319,9 +319,11 @@ def compile_and_run(self, model_path, overrides_string):
"LOC",
]
perf = perf[columns]
print(perf)
logging.info(perf)

print("Total device duration: ", perf["DEVICE FW DURATION [ns]"].sum(), "ns")
logging.info(
"Total device duration: ", perf["DEVICE FW DURATION [ns]"].sum(), "ns"
)

# TTNN_IR_FILE from flatbuffer is still relevant since model_path is the FB with golden data and it will rented optimized_model_path instead
state.optimized_model_path = ttnn_ir_file
Expand Down

0 comments on commit 7eefddf

Please sign in to comment.