Skip to content

Commit 64dfcb2

Browse files
authored
[Tuner] Clean up sample tuner (#781)
Rename it from 'test' to 'simple' to avoid mistaking it for a test: https://github.com/nod-ai/shark-ai/actions/runs/12659028888/job/35277223146#step:6:26 . Also: * Update and improve the README (account for the directory structure) * Make flag naming consistent * Handle previously missing `--stop-after` phases * Add git ignore for temporary files
1 parent 04b1819 commit 64dfcb2

File tree

6 files changed

+36
-28
lines changed

6 files changed

+36
-28
lines changed

tuner/examples/simple/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tmp
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# Example Tuner Test
1+
# Simple Example Tuner
22

3-
Example of tuning a dispatch and full model.
3+
Example of tuning a dispatch and a full model.
44

55
## Environments
66
Follow instructions in [`/tuner/README.md`](../README.md)
@@ -15,27 +15,31 @@ Use the usual `iree-compile` command for your model, add
1515
`--iree-hal-dump-executable-files-to=dump --iree-config-add-tuner-attributes`,
1616
and get the dispatch benchmark that you want to tune. For example:
1717
```shell
18+
mkdir tmp
1819
iree-compile double_mmt.mlir --iree-hal-target-backends=rocm \
19-
--iree-hip-target=gfx942 --iree-hal-dump-executable-files-to=dump \
20+
--iree-hip-target=gfx942 --iree-hal-dump-executable-files-to=tmp/dump \
2021
--iree-config-add-tuner-attributes -o /dev/null
2122

22-
cp dump/module_main_dispatch_0_rocm_hsaco_fb_benchmark.mlir mmt_benchmark.mlir
23+
cp tmp/dump/module_main_dispatch_0_rocm_hsaco_fb_benchmark.mlir tmp/mmt_benchmark.mlir
2324
```
2425

2526
### Recommended Trial Run
2627
For an initial trial to test the tuning loop, use:
2728
```shell
28-
python -m examples.test double_mmt.mlir mmt_benchmark.mlir \
29-
--test_num_dispatch_candidates=5 --test_num_model_candidates=3 \
30-
--num-candidates=30
29+
cd ../../
30+
python -m examples.simple examples/simple/double_mmt.mlir \
31+
examples/simple/tmp/mmt_benchmark.mlir \
32+
--devices=hip://0 --num-candidates=30 \
33+
--simple-num-dispatch-candidates=5 --simple-num-model-candidates=3 \
3134
```
3235

3336
### Basic Usage
3437
```shell
35-
python -m examples.test <model_file_path> <benchmark_file_path> \
36-
--test_num_dispatch_candidates=<num_dispatch_candidates> \
37-
--test_num_model_candidates=<num_model_candidates> \
38-
--test_hip_target=<hip_target> \
38+
python -m examples.simple <model_file_path> <benchmark_file_path> \
39+
--devices=hip://0 --num-candidates=1024 \
40+
--test-num-dispatch-candidates=<num_dispatch_candidates> \
41+
--test-num-model-candidates=<num_model_candidates> \
42+
--test-hip-target=<hip_target> \
3943
--num-candidates=<num_generated_candidates> \
4044
--codegen-pipeline=<codegen_pipeline>
4145
```
File renamed without changes.

tuner/examples/test/__main__.py tuner/examples/simple/__main__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7-
from . import tuner_test
7+
from . import simple_tuner
88

9-
tuner_test.main()
9+
simple_tuner.main()
File renamed without changes.

tuner/examples/test/tuner_test.py tuner/examples/simple/simple_tuner.py

+18-15
Original file line numberDiff line numberDiff line change
@@ -65,22 +65,22 @@ def main():
6565
parser = argparse.ArgumentParser(description="Autotune test script")
6666
test_args = parser.add_argument_group("Example Test Options")
6767
test_args.add_argument(
68-
"test_model_file", type=Path, help="Path to the model file to tune (.mlir)"
68+
"simple_model_file", type=Path, help="Path to the model file to tune (.mlir)"
6969
)
7070
test_args.add_argument(
71-
"--test_num_dispatch_candidates",
71+
"--simple-num-dispatch-candidates",
7272
type=int,
7373
default=None,
7474
help="Number of dispatch candidates to keep for model benchmarks.",
7575
)
7676
test_args.add_argument(
77-
"--test_num_model_candidates",
77+
"--simple-num-model-candidates",
7878
type=int,
7979
default=None,
8080
help="Number of model candidates to produce after tuning.",
8181
)
8282
test_args.add_argument(
83-
"--test_hip_target",
83+
"--simple-hip-target",
8484
type=str,
8585
default="gfx942",
8686
help="Hip target for tuning.",
@@ -98,51 +98,54 @@ def main():
9898
libtuner.setup_logging(args, path_config)
9999
print(path_config.run_log, end="\n\n")
100100

101-
# TODO(Max191): Some bug seems to be causing OOM errors in benchmarking
102-
# when device validation happens, so this is commented for now. Uncomment
103-
# when the bug is fixed.
104101
if not args.dry_run:
105102
print("Validating devices")
106103
libtuner.validate_devices(args.devices)
107104
print("Validation successful!\n")
108105

109-
print("Generating candidates...")
106+
print("Generating candidate tuning specs...")
110107
test_tuner = TestTuner()
111108
candidates = libtuner.generate_candidate_specs(
112109
args, path_config, candidate_trackers, test_tuner
113110
)
114-
print(f"Stored candidate specs in {path_config.specs_dir}\n")
111+
print(f"Stored candidate tuning specs in {path_config.specs_dir}\n")
115112
if stop_after_phase == libtuner.ExecutionPhases.generate_candidates:
116113
return
117114

118-
print("Compiling candidates...")
115+
print("Compiling dispatch candidates...")
119116
compiled_candidates = libtuner.compile(
120117
args, path_config, candidates, candidate_trackers, test_tuner
121118
)
119+
if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches:
120+
return
122121

123-
print("Benchmarking compiled candidates...")
122+
print("Benchmarking compiled dispatch candidates...")
124123
top_candidates = libtuner.benchmark(
125124
args,
126125
path_config,
127126
compiled_candidates,
128127
candidate_trackers,
129128
test_tuner,
130-
args.test_num_dispatch_candidates,
129+
args.simple_num_dispatch_candidates,
131130
)
131+
if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches:
132+
return
132133

133134
print("Compiling models with top candidates...")
134135
test_tuner.compile_flags = [
135136
"--iree-hal-target-backends=rocm",
136-
f"--iree-hip-target={args.test_hip_target}",
137+
f"--iree-hip-target={args.simple_hip_target}",
137138
]
138139
compiled_model_candidates = libtuner.compile(
139140
args,
140141
path_config,
141142
top_candidates,
142143
candidate_trackers,
143144
test_tuner,
144-
args.test_model_file,
145+
args.simple_model_file,
145146
)
147+
if stop_after_phase == libtuner.ExecutionPhases.compile_models:
148+
return
146149

147150
print("Benchmarking compiled model candidates...")
148151
test_tuner.benchmark_flags = [
@@ -156,7 +159,7 @@ def main():
156159
compiled_model_candidates,
157160
candidate_trackers,
158161
test_tuner,
159-
args.test_num_model_candidates,
162+
args.simple_num_model_candidates,
160163
)
161164

162165
print(f"Top model candidates: {top_model_candidates}")

0 commit comments

Comments
 (0)