13
13
class SimpleTuner (libtuner .TuningClient ):
14
14
def __init__ (self , tuner_context : libtuner .TunerContext ):
15
15
super ().__init__ (tuner_context )
16
- self .compile_flags = ["--compile-from=executable-sources" ]
17
- self .benchmark_flags = [ "--benchmark_repetitions=3" , "--input=1" ]
16
+ self .compile_flags : list [ str ] = []
17
+ self .benchmark_flags : list [ str ] = [ ]
18
18
19
19
def get_iree_compile_flags (self ) -> list [str ]:
20
20
return self .compile_flags
@@ -26,6 +26,14 @@ def get_benchmark_timeout_s(self) -> int:
26
26
return 10
27
27
28
28
29
+ def read_flags_file (flags_file : str ) -> list [str ]:
30
+ if not flags_file :
31
+ return []
32
+
33
+ with open (flags_file ) as file :
34
+ return file .read ().splitlines ()
35
+
36
+
29
37
def main ():
30
38
# Custom arguments for the example tuner file.
31
39
parser = argparse .ArgumentParser (description = "Autotune sample script" )
@@ -46,10 +54,16 @@ def main():
46
54
help = "Number of model candidates to produce after tuning." ,
47
55
)
48
56
client_args .add_argument (
49
- "--simple-hip-target" ,
57
+ "--simple-compile-flags-file" ,
58
+ type = str ,
59
+ default = "" ,
60
+ help = "Path to the flags file for iree-compile." ,
61
+ )
62
+ client_args .add_argument (
63
+ "--simple-model-benchmark-flags-file" ,
50
64
type = str ,
51
- default = "gfx942 " ,
52
- help = "Hip target for tuning ." ,
65
+ default = "" ,
66
+ help = "Path to the flags file for iree-benchmark-module for model benchmarking ." ,
53
67
)
54
68
# Remaining arguments come from libtuner
55
69
args = libtuner .parse_arguments (parser )
@@ -69,6 +83,11 @@ def main():
69
83
libtuner .validate_devices (args .devices )
70
84
print ("Validation successful!\n " )
71
85
86
+ compile_flags : list [str ] = read_flags_file (args .simple_compile_flags_file )
87
+ model_benchmark_flags : list [str ] = read_flags_file (
88
+ args .simple_model_benchmark_flags_file
89
+ )
90
+
72
91
print ("Generating candidate tuning specs..." )
73
92
with TunerContext () as tuner_context :
74
93
simple_tuner = SimpleTuner (tuner_context )
@@ -80,13 +99,17 @@ def main():
80
99
return
81
100
82
101
print ("Compiling dispatch candidates..." )
102
+ simple_tuner .compile_flags = compile_flags + [
103
+ "--compile-from=executable-sources"
104
+ ]
83
105
compiled_candidates = libtuner .compile (
84
106
args , path_config , candidates , candidate_trackers , simple_tuner
85
107
)
86
108
if stop_after_phase == libtuner .ExecutionPhases .compile_dispatches :
87
109
return
88
110
89
111
print ("Benchmarking compiled dispatch candidates..." )
112
+ simple_tuner .benchmark_flags = ["--input=1" , "--benchmark_repetitions=3" ]
90
113
top_candidates = libtuner .benchmark (
91
114
args ,
92
115
path_config ,
@@ -99,10 +122,7 @@ def main():
99
122
return
100
123
101
124
print ("Compiling models with top candidates..." )
102
- simple_tuner .compile_flags = [
103
- "--iree-hal-target-backends=rocm" ,
104
- f"--iree-hip-target={ args .simple_hip_target } " ,
105
- ]
125
+ simple_tuner .compile_flags = compile_flags
106
126
compiled_model_candidates = libtuner .compile (
107
127
args ,
108
128
path_config ,
@@ -115,11 +135,7 @@ def main():
115
135
return
116
136
117
137
print ("Benchmarking compiled model candidates..." )
118
- simple_tuner .benchmark_flags = [
119
- "--benchmark_repetitions=3" ,
120
- "--input=2048x2048xf16" ,
121
- "--input=2048x2048xf16" ,
122
- ]
138
+ simple_tuner .benchmark_flags = model_benchmark_flags
123
139
top_model_candidates = libtuner .benchmark (
124
140
args ,
125
141
path_config ,
0 commit comments