Skip to content

Commit

Permalink
[tuner] Add timeout for compilation in TuningClient (#798)
Browse files Browse the repository at this point in the history
This PR adds a timeout for compilation to prevent the tuner from hanging
when iree-compile hangs. The `iree.compiler.compile_file` binding is no
longer used, since it does not support a timeout. Instead, the command
is run as a regular subprocess.

---------

Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 authored Jan 9, 2025
1 parent c2ad576 commit 1e23a74
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 17 deletions.
5 changes: 5 additions & 0 deletions tuner/examples/simple/simple_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@ def __init__(self, tuner_context: libtuner.TunerContext):
super().__init__(tuner_context)
self.compile_flags: list[str] = []
self.benchmark_flags: list[str] = []
self.compile_timeout: int = 10

def get_iree_compile_flags(self) -> list[str]:
return self.compile_flags

def get_iree_compile_timeout_s(self) -> int:
return self.compile_timeout

def get_iree_benchmark_module_flags(self) -> list[str]:
return self.benchmark_flags

Expand Down Expand Up @@ -123,6 +127,7 @@ def main():

print("Compiling models with top candidates...")
simple_tuner.compile_flags = compile_flags
simple_tuner.compile_timeout = 60
compiled_model_candidates = libtuner.compile(
args,
path_config,
Expand Down
40 changes: 23 additions & 17 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ def __init__(self, tuner_context: TunerContext):
def get_iree_compile_flags(self) -> list[str]:
pass

@abstractmethod
def get_iree_compile_timeout_s(self) -> int:
pass

@abstractmethod
def get_iree_benchmark_module_flags(self) -> list[str]:
pass
Expand All @@ -122,6 +126,7 @@ def get_benchmark_timeout_s(self) -> int:
@dataclass
class CompilePack:
iree_compile_flags: list[str]
iree_compile_timeout: int
candidate_tracker: CandidateTracker


Expand Down Expand Up @@ -440,30 +445,29 @@ def run_iree_compile_command(compile_pack: CompilePack) -> Optional[int]:
logging.debug(
f"Compiling candidate {candidate_tracker.candidate_id} with spec: {td_spec_path}"
)
extra_flags = [
f"--iree-codegen-tuning-spec-path={td_spec_path}",
]
extra_flags += compile_pack.iree_compile_flags
assert candidate_tracker.compiled_vmfb_path, "expected output vmfb path"
output_path = candidate_tracker.compiled_vmfb_path.as_posix()
crash_dump_path = f"{output_path}.crash_report.mlir"
assert candidate_tracker.mlir_path, "expected input mlir file path"
input_file = candidate_tracker.mlir_path.as_posix()
# TODO(Max191): Make the device in `traget_backends` a command line option
# instead of hardcoding in ireec.compile_str.
try:
ireec.compile_file(
input_file=input_file,
target_backends=["rocm"],
output_file=output_path,
extra_args=extra_flags,
crash_reproducer_path=crash_dump_path,
iree_compile = ireec.binaries.find_tool("iree-compile")
compile_command = [
iree_compile,
input_file,
f"-o={output_path}",
f"--mlir-pass-pipeline-crash-reproducer={crash_dump_path}",
f"--iree-codegen-tuning-spec-path={td_spec_path}",
]
compile_command += compile_pack.iree_compile_flags
result = candidate_gen.run_command(
candidate_gen.RunPack(
command=compile_command,
check=False,
timeout_seconds=compile_pack.iree_compile_timeout,
)
except ireec.CompilerToolError as e:
logging.info(f"Compilation returned non-zero exit status.")
logging.debug(e)
)
if result.process_res is None or result.is_timeout:
return None

return candidate_tracker.candidate_id


Expand Down Expand Up @@ -775,6 +779,7 @@ def compile(
task_list = [
CompilePack(
iree_compile_flags=tuning_client.get_iree_compile_flags(),
iree_compile_timeout=tuning_client.get_iree_compile_timeout_s(),
candidate_tracker=candidate_trackers[i],
)
for i in candidates
Expand All @@ -783,6 +788,7 @@ def compile(
task_list.append(
CompilePack(
iree_compile_flags=tuning_client.get_iree_compile_flags(),
iree_compile_timeout=tuning_client.get_iree_compile_timeout_s(),
candidate_tracker=candidate_trackers[0],
)
)
Expand Down

0 comments on commit 1e23a74

Please sign in to comment.