Skip to content

Commit 5888624

Browse files
authored
[Tuner] Add support for flag files in the simple tuner (#786)
This generalizes the simple tuner such that it can be used for real-world models without having to modify the python code.
1 parent 3bef80f commit 5888624

File tree

4 files changed

+42
-18
lines changed

4 files changed

+42
-18
lines changed

tuner/examples/simple/README.md

+7-4
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,20 @@ For an initial trial to test the tuning loop, use:
2929
cd ../../
3030
python -m examples.simple examples/simple/double_mmt.mlir \
3131
examples/simple/tmp/mmt_benchmark.mlir \
32+
--simple-compile-flags-file=examples/simple/compile_flags.txt \
33+
--simple-model-benchmark-flags-file=examples/simple/model_benchmark_flags.txt \
3234
--devices=hip://0 --num-candidates=30 \
3335
--simple-num-dispatch-candidates=5 --simple-num-model-candidates=3 \
3436
```
3537

3638
### Basic Usage
3739
```shell
3840
python -m examples.simple <model_file_path> <benchmark_file_path> \
39-
--devices=hip://0 --num-candidates=1024 \
40-
--test-num-dispatch-candidates=<num_dispatch_candidates> \
41-
--test-num-model-candidates=<num_model_candidates> \
42-
--test-hip-target=<hip_target> \
41+
--devices=hip://0,hip://1 --num-candidates=1024 \
42+
--simple-compile-flags-file=<compile_flags_path> \
43+
--simple-model-benchmark-flags-file=<model_benchmark_flags_path> \
44+
--simple-num-dispatch-candidates=<num_dispatch_candidates> \
45+
--simple-num-model-candidates=<num_model_candidates> \
4346
--num-candidates=<num_generated_candidates> \
4447
--codegen-pipeline=<codegen_pipeline>
4548
```
+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
--iree-hal-target-backends=rocm
2+
--iree-hip-target=gfx942
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
--device_allocator=caching
2+
--input=2048x2048xf16
3+
--input=2048x2048xf16

tuner/examples/simple/simple_tuner.py

+30-14
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
class SimpleTuner(libtuner.TuningClient):
1414
def __init__(self, tuner_context: libtuner.TunerContext):
1515
super().__init__(tuner_context)
16-
self.compile_flags = ["--compile-from=executable-sources"]
17-
self.benchmark_flags = ["--benchmark_repetitions=3", "--input=1"]
16+
self.compile_flags: list[str] = []
17+
self.benchmark_flags: list[str] = []
1818

1919
def get_iree_compile_flags(self) -> list[str]:
2020
return self.compile_flags
@@ -26,6 +26,14 @@ def get_benchmark_timeout_s(self) -> int:
2626
return 10
2727

2828

29+
def read_flags_file(flags_file: str) -> list[str]:
30+
if not flags_file:
31+
return []
32+
33+
with open(flags_file) as file:
34+
return file.read().splitlines()
35+
36+
2937
def main():
3038
# Custom arguments for the example tuner file.
3139
parser = argparse.ArgumentParser(description="Autotune sample script")
@@ -46,10 +54,16 @@ def main():
4654
help="Number of model candidates to produce after tuning.",
4755
)
4856
client_args.add_argument(
49-
"--simple-hip-target",
57+
"--simple-compile-flags-file",
58+
type=str,
59+
default="",
60+
help="Path to the flags file for iree-compile.",
61+
)
62+
client_args.add_argument(
63+
"--simple-model-benchmark-flags-file",
5064
type=str,
51-
default="gfx942",
52-
help="Hip target for tuning.",
65+
default="",
66+
help="Path to the flags file for iree-benchmark-module for model benchmarking.",
5367
)
5468
# Remaining arguments come from libtuner
5569
args = libtuner.parse_arguments(parser)
@@ -69,6 +83,11 @@ def main():
6983
libtuner.validate_devices(args.devices)
7084
print("Validation successful!\n")
7185

86+
compile_flags: list[str] = read_flags_file(args.simple_compile_flags_file)
87+
model_benchmark_flags: list[str] = read_flags_file(
88+
args.simple_model_benchmark_flags_file
89+
)
90+
7291
print("Generating candidate tuning specs...")
7392
with TunerContext() as tuner_context:
7493
simple_tuner = SimpleTuner(tuner_context)
@@ -80,13 +99,17 @@ def main():
8099
return
81100

82101
print("Compiling dispatch candidates...")
102+
simple_tuner.compile_flags = compile_flags + [
103+
"--compile-from=executable-sources"
104+
]
83105
compiled_candidates = libtuner.compile(
84106
args, path_config, candidates, candidate_trackers, simple_tuner
85107
)
86108
if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches:
87109
return
88110

89111
print("Benchmarking compiled dispatch candidates...")
112+
simple_tuner.benchmark_flags = ["--input=1", "--benchmark_repetitions=3"]
90113
top_candidates = libtuner.benchmark(
91114
args,
92115
path_config,
@@ -99,10 +122,7 @@ def main():
99122
return
100123

101124
print("Compiling models with top candidates...")
102-
simple_tuner.compile_flags = [
103-
"--iree-hal-target-backends=rocm",
104-
f"--iree-hip-target={args.simple_hip_target}",
105-
]
125+
simple_tuner.compile_flags = compile_flags
106126
compiled_model_candidates = libtuner.compile(
107127
args,
108128
path_config,
@@ -115,11 +135,7 @@ def main():
115135
return
116136

117137
print("Benchmarking compiled model candidates...")
118-
simple_tuner.benchmark_flags = [
119-
"--benchmark_repetitions=3",
120-
"--input=2048x2048xf16",
121-
"--input=2048x2048xf16",
122-
]
138+
simple_tuner.benchmark_flags = model_benchmark_flags
123139
top_model_candidates = libtuner.benchmark(
124140
args,
125141
path_config,

0 commit comments

Comments
 (0)