Skip to content

Commit b920696

Browse files
authored
[Tuner] Fix large model benchmarking (#808)
* Add model-specific benchmark timeout * Fix benchmark argument parsing to allow for `=` in command line argument values * Don't print candidate trackers at the very end (too much noise) * Always promote operands
1 parent 45322d4 commit b920696

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

tuner/examples/simple/simple_tuner.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ def __init__(self, tuner_context: libtuner.TunerContext):
1515
super().__init__(tuner_context)
1616
self.compile_flags: list[str] = []
1717
self.benchmark_flags: list[str] = []
18-
self.compile_timeout: int = 10
18+
self.compile_timeout: int = 16
19+
self.benchmark_timeout: int = 16
1920

2021
def get_iree_compile_flags(self) -> list[str]:
2122
return self.compile_flags
@@ -27,7 +28,7 @@ def get_iree_benchmark_module_flags(self) -> list[str]:
2728
return self.benchmark_flags
2829

2930
def get_benchmark_timeout_s(self) -> int:
30-
return 10
31+
return self.benchmark_timeout
3132

3233

3334
def read_flags_file(flags_file: str) -> list[str]:
@@ -127,7 +128,7 @@ def main():
127128

128129
print("Compiling models with top candidates...")
129130
simple_tuner.compile_flags = compile_flags
130-
simple_tuner.compile_timeout = 60
131+
simple_tuner.compile_timeout = 120
131132
compiled_model_candidates = libtuner.compile(
132133
args,
133134
path_config,
@@ -141,6 +142,7 @@ def main():
141142

142143
print("Benchmarking compiled model candidates...")
143144
simple_tuner.benchmark_flags = model_benchmark_flags
145+
simple_tuner.benchmark_timeout = 60
144146
top_model_candidates = libtuner.benchmark(
145147
args,
146148
path_config,
@@ -154,6 +156,3 @@ def main():
154156

155157
print("Check the detailed execution logs in:")
156158
print(path_config.run_log.resolve())
157-
158-
for candidate in candidate_trackers:
159-
libtuner.logging.debug(candidate)

tuner/tuner/dispatch_constraints.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,11 @@ def generate_compilation_infos(
309309
"reduction": reduction_tile_sizes,
310310
"subgroup_m_count": subgroup_m_count,
311311
"subgroup_n_count": subgroup_n_count,
312+
"promote_operands": [0, 1],
312313
}
313314
if codegen_pipeline == iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse:
314315
lowering_config_args["subgroup"] = subgroup_tile_sizes
315-
lowering_config_args["promote_operands"] = [0, 1]
316+
316317
lowering_config = get_lowering_config(**lowering_config_args)
317318

318319
# Create the TranslationInfoAttr

tuner/tuner/libtuner.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -497,10 +497,10 @@ def run_iree_benchmark_module_command(benchmark_pack: BenchmarkPack):
497497
assert flag[:2] == "--", "iree_benchmark_module_flags should begin with '--'"
498498
split_key_value = flag[2:].split("=")
499499
assert (
500-
len(split_key_value) == 2
500+
len(split_key_value) >= 1
501501
), "iree_benchmark_module_flags should have the format --<key>=<value>"
502502
key = split_key_value[0]
503-
value = split_key_value[1]
503+
value = "=".join(split_key_value[1:])
504504
# Allow the tuning client to pass `--function=@func_name`.
505505
if key == "function":
506506
func_name = value

0 commit comments

Comments
 (0)