Skip to content

Commit b6c9927

Browse files
authored
[tuner] add the output file (nod-ai#815)
This PR addresses the task described in nod-ai#806. It adds a separate output file to the tuner to summarize the most important information, such as the top candidates (dispatch and model) and paths to their specifications. Also, it fixes the issue with incorrect reporting of the compilation success rate in the log. --------- Signed-off-by: Bangtian Liu <[email protected]>
1 parent 04d383b commit b6c9927

File tree

6 files changed

+62
-21
lines changed

6 files changed

+62
-21
lines changed

tuner/examples/simple/simple_tuner.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def main():
8080
stop_after_phase: str = args.stop_after
8181

8282
print("Setup logging")
83-
libtuner.setup_logging(args, path_config)
83+
root_logger = libtuner.setup_logging(args, path_config)
8484
print(path_config.run_log, end="\n\n")
8585

8686
if not args.dry_run:
@@ -93,8 +93,15 @@ def main():
9393
args.simple_model_benchmark_flags_file
9494
)
9595

96+
summary_log_file = path_config.base_dir / "summary.log"
97+
summary_handler = logging.FileHandler(summary_log_file)
98+
summary_handler.setLevel(logging.INFO)
99+
summary_handler.setFormatter(
100+
logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
101+
)
96102
print("Generating candidate tuning specs...")
97-
with TunerContext() as tuner_context:
103+
with TunerContext(logger=root_logger) as tuner_context:
104+
tuner_context.logger.addHandler(summary_handler)
98105
simple_tuner = SimpleTuner(tuner_context)
99106
candidates = libtuner.generate_candidate_specs(
100107
args, path_config, candidate_trackers, simple_tuner
@@ -113,7 +120,9 @@ def main():
113120
if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches:
114121
return
115122

116-
print("Benchmarking compiled dispatch candidates...")
123+
message = "Benchmarking compiled dispatch candidates..."
124+
print(message)
125+
logging.info(message)
117126
simple_tuner.benchmark_flags = ["--input=1", "--benchmark_repetitions=3"]
118127
top_candidates = libtuner.benchmark(
119128
args,
@@ -123,6 +132,9 @@ def main():
123132
simple_tuner,
124133
args.simple_num_dispatch_candidates,
125134
)
135+
logging.info(f"Top dispatch candidates: {top_candidates}")
136+
for id in top_candidates:
137+
logging.info(f"{candidate_trackers[id].spec_path.resolve()}")
126138
if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches:
127139
return
128140

@@ -140,7 +152,9 @@ def main():
140152
if stop_after_phase == libtuner.ExecutionPhases.compile_models:
141153
return
142154

143-
print("Benchmarking compiled model candidates...")
155+
message = "Benchmarking compiled model candidates..."
156+
print(message)
157+
logging.info(message)
144158
simple_tuner.benchmark_flags = model_benchmark_flags
145159
simple_tuner.benchmark_timeout = 60
146160
top_model_candidates = libtuner.benchmark(
@@ -151,8 +165,12 @@ def main():
151165
simple_tuner,
152166
args.simple_num_model_candidates,
153167
)
154-
168+
logging.info(f"Top model candidates: {top_model_candidates}")
169+
for id in top_model_candidates:
170+
logging.info(f"{candidate_trackers[id].spec_path.resolve()}")
155171
print(f"Top model candidates: {top_model_candidates}")
156172

157173
print("Check the detailed execution logs in:")
158174
print(path_config.run_log.resolve())
175+
print("Check the summary in:")
176+
print(summary_log_file.resolve())

tuner/tuner/candidate_gen.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,12 @@ def generate_configs_and_td_specs(
195195
):
196196
if i >= limit:
197197
break
198-
tune_logger.info(f"Solution #{i+1}: {config}")
198+
tune_logger.debug(f"Solution #{i+1}: {config}")
199199
td_spec_module = dispatch_tuner.get_td_spec(input_module, config)
200200
assert td_spec_module, "Failed to generate transform dialect spec"
201201
config_specs.append(td_spec_module)
202202

203-
tune_logger.info(f"Generated {len(config_specs)} tuning specs")
203+
tune_logger.debug(f"Generated {len(config_specs)} tuning specs")
204204
return config_specs
205205

206206

tuner/tuner/dispatch_constraints.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def generate_solutions(
376376
codegen_pipeline,
377377
)
378378
M, N, K = problem_size.MNK
379-
tuner_ctx.logger.info(f"{M},{N},{K}")
379+
tuner_ctx.logger.debug(f"{M},{N},{K}")
380380
m_vars = [z3.Int(f"m{i}") for i in range(len(M))]
381381
n_vars = [z3.Int(f"n{i}") for i in range(len(N))]
382382
k_vars = [z3.Int(f"k{i}") for i in range(len(K))]

tuner/tuner/dispatch_parser.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def parse_mlir(mlir_text: str, ctx: TunerContext) -> ir.Module:
1717
mlir_module = None
1818
try:
1919
mlir_module = ir.Module.parse(mlir_text, ctx.mlir_ctx)
20-
ctx.logger.info("MLIR parsing successful!")
20+
ctx.logger.debug("MLIR parsing successful!")
2121
except ir.MLIRError as e:
2222
ctx.logger.error(f"Error parsing MLIR: {e}")
2323
raise RuntimeError(f"Error parsing MLIR: {e}")

tuner/tuner/libtuner.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _name_base_dir(self) -> Path:
9292
base_dir = Path(f"./tuning_{timestamp}")
9393
return base_dir
9494

95-
def _set_run_log(self, run_log: Path):
95+
def set_run_log(self, run_log: Path):
9696
object.__setattr__(self, "run_log", run_log)
9797

9898
def get_candidate_spec_filename(self, candidate_id: int) -> str:
@@ -334,10 +334,10 @@ def parse_arguments(
334334
return parser.parse_args()
335335

336336

337-
def setup_logging(args: argparse.Namespace, path_config: PathConfig):
337+
def setup_logging(args: argparse.Namespace, path_config: PathConfig) -> logging.Logger:
338338
log_file_name = f"autotune_{args.input_file.stem}.log"
339339
run_log_path = path_config.base_dir / log_file_name
340-
path_config._set_run_log(run_log_path)
340+
path_config.set_run_log(run_log_path)
341341

342342
# Create file handler for logging to a file
343343
if path_config.run_log is None:
@@ -384,7 +384,9 @@ def format(self, record):
384384
# Log all arguments
385385
logging.debug(f"Input Arguments:")
386386
for arg, value in vars(args).items():
387-
tune_logger.info(f"{arg}: {value}")
387+
logging.debug(f"{arg}: {value}")
388+
389+
return logging.getLogger()
388390

389391

390392
def handle_error(
@@ -717,10 +719,18 @@ def generate_candidate_specs(
717719
tune_logger.exception("Error in candidate_gen.py:")
718720
raise
719721

720-
logging.info(f"Generated [{len(candidates) - 1}] candidates")
722+
logging.debug(f"Generated [{len(candidates) - 1}] candidates")
721723
return candidates
722724

723725

726+
def get_compilation_success_rate(compiled_candiates: list[Optional[int]]) -> float:
727+
if not compiled_candiates:
728+
return 0.0
729+
successful_candidates = [c for c in compiled_candiates if c is not None]
730+
success_rate = float(len(successful_candidates)) / float(len(compiled_candiates))
731+
return success_rate
732+
733+
724734
def collision_handler(index_hash_list: list[tuple[int, str]]) -> tuple[bool, list[int]]:
725735
"""If a collision is found, generate a list of new indexes. If no collision, `unique_indexes = []`"""
726736
# Check if candidate produces tbe same .vmfb
@@ -800,11 +810,11 @@ def compile(
800810
compiled_candidates = multiprocess_progress_wrapper(
801811
num_worker=num_worker, task_list=task_list, function=run_iree_compile_command
802812
)
803-
compiled_candidates = [c for c in compiled_candidates if c is not None]
804-
success_rate = float(len(compiled_candidates)) / float(len(candidates))
805-
logging.info(
813+
success_rate = get_compilation_success_rate(compiled_candidates)
814+
logging.debug(
806815
f"Successfully compiled [{len(compiled_candidates)}] candidates. Success rate: {success_rate:.2f}"
807816
)
817+
compiled_candidates = [c for c in compiled_candidates if c is not None]
808818

809819
# Remove duplicate vmfbs from the candidate list.
810820
compiled_candidate_hashes = []
@@ -818,7 +828,7 @@ def compile(
818828
if collision_detected:
819829
compiled_candidates = unique_compiled_candidates
820830

821-
logging.info(f"Produced [{len(compiled_candidates)}] unique vmfbs")
831+
logging.debug(f"Produced [{len(compiled_candidates)}] unique vmfbs")
822832
return compiled_candidates
823833

824834

@@ -875,7 +885,8 @@ def get_speedup(result: BenchmarkResult) -> float:
875885
speedup = f"{round(get_speedup(r) * 100, 2)}% of baseline"
876886
else:
877887
speedup = "baseline unavailable"
878-
logging.info(f"Candidate {r.candidate_id} time: {r.time:.2f} ({speedup})")
888+
result = f"Candidate {r.candidate_id} time: {r.time:.2f} ms ({speedup})"
889+
logging.info(result)
879890
return best_results
880891

881892

tuner/tuner/libtuner_test.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
import argparse
88
import math
9-
import pytest
10-
import json
119
from subprocess import CompletedProcess
1210
from unittest.mock import call, patch, MagicMock
1311
from . import libtuner
@@ -176,6 +174,20 @@ def test_validate_devices_with_invalid_device() -> None:
176174
assert expected_call in mock_handle_error.call_args_list
177175

178176

177+
def test_get_compilation_success_rate():
178+
compiled_candidates = [0, None, 2, None, 4]
179+
assert libtuner.get_compilation_success_rate(compiled_candidates) == 3.0 / 5.0
180+
181+
compiled_candidates = [0, 1, 2, 3, 4]
182+
assert libtuner.get_compilation_success_rate(compiled_candidates) == 1.0
183+
184+
compiled_candidates = [None, None, None]
185+
assert libtuner.get_compilation_success_rate(compiled_candidates) == 0.0
186+
187+
compiled_candidates = []
188+
assert libtuner.get_compilation_success_rate(compiled_candidates) == 0.0
189+
190+
179191
def test_select_best_benchmark_results() -> None:
180192
candidate_results = [
181193
libtuner.BenchmarkResult(1, 0.5, "hip://0"),

0 commit comments

Comments
 (0)