@@ -65,22 +65,22 @@ def main():
65
65
parser = argparse .ArgumentParser (description = "Autotune test script" )
66
66
test_args = parser .add_argument_group ("Example Test Options" )
67
67
test_args .add_argument (
68
- "test_model_file " , type = Path , help = "Path to the model file to tune (.mlir)"
68
+ "simple_model_file " , type = Path , help = "Path to the model file to tune (.mlir)"
69
69
)
70
70
test_args .add_argument (
71
- "--test_num_dispatch_candidates " ,
71
+ "--simple-num-dispatch-candidates " ,
72
72
type = int ,
73
73
default = None ,
74
74
help = "Number of dispatch candidates to keep for model benchmarks." ,
75
75
)
76
76
test_args .add_argument (
77
- "--test_num_model_candidates " ,
77
+ "--simple-num-model-candidates " ,
78
78
type = int ,
79
79
default = None ,
80
80
help = "Number of model candidates to produce after tuning." ,
81
81
)
82
82
test_args .add_argument (
83
- "--test_hip_target " ,
83
+ "--simple-hip-target " ,
84
84
type = str ,
85
85
default = "gfx942" ,
86
86
help = "Hip target for tuning." ,
@@ -98,51 +98,54 @@ def main():
98
98
libtuner .setup_logging (args , path_config )
99
99
print (path_config .run_log , end = "\n \n " )
100
100
101
- # TODO(Max191): Some bug seems to be causing OOM errors in benchmarking
102
- # when device validation happens, so this is commented for now. Uncomment
103
- # when the bug is fixed.
104
101
if not args .dry_run :
105
102
print ("Validating devices" )
106
103
libtuner .validate_devices (args .devices )
107
104
print ("Validation successful!\n " )
108
105
109
- print ("Generating candidates ..." )
106
+ print ("Generating candidate tuning specs ..." )
110
107
test_tuner = TestTuner ()
111
108
candidates = libtuner .generate_candidate_specs (
112
109
args , path_config , candidate_trackers , test_tuner
113
110
)
114
- print (f"Stored candidate specs in { path_config .specs_dir } \n " )
111
+ print (f"Stored candidate tuning specs in { path_config .specs_dir } \n " )
115
112
if stop_after_phase == libtuner .ExecutionPhases .generate_candidates :
116
113
return
117
114
118
- print ("Compiling candidates..." )
115
+ print ("Compiling dispatch candidates..." )
119
116
compiled_candidates = libtuner .compile (
120
117
args , path_config , candidates , candidate_trackers , test_tuner
121
118
)
119
+ if stop_after_phase == libtuner .ExecutionPhases .compile_dispatches :
120
+ return
122
121
123
- print ("Benchmarking compiled candidates..." )
122
+ print ("Benchmarking compiled dispatch candidates..." )
124
123
top_candidates = libtuner .benchmark (
125
124
args ,
126
125
path_config ,
127
126
compiled_candidates ,
128
127
candidate_trackers ,
129
128
test_tuner ,
130
- args .test_num_dispatch_candidates ,
129
+ args .simple_num_dispatch_candidates ,
131
130
)
131
+ if stop_after_phase == libtuner .ExecutionPhases .benchmark_dispatches :
132
+ return
132
133
133
134
print ("Compiling models with top candidates..." )
134
135
test_tuner .compile_flags = [
135
136
"--iree-hal-target-backends=rocm" ,
136
- f"--iree-hip-target={ args .test_hip_target } " ,
137
+ f"--iree-hip-target={ args .simple_hip_target } " ,
137
138
]
138
139
compiled_model_candidates = libtuner .compile (
139
140
args ,
140
141
path_config ,
141
142
top_candidates ,
142
143
candidate_trackers ,
143
144
test_tuner ,
144
- args .test_model_file ,
145
+ args .simple_model_file ,
145
146
)
147
+ if stop_after_phase == libtuner .ExecutionPhases .compile_models :
148
+ return
146
149
147
150
print ("Benchmarking compiled model candidates..." )
148
151
test_tuner .benchmark_flags = [
@@ -156,7 +159,7 @@ def main():
156
159
compiled_model_candidates ,
157
160
candidate_trackers ,
158
161
test_tuner ,
159
- args .test_num_model_candidates ,
162
+ args .simple_num_model_candidates ,
160
163
)
161
164
162
165
print (f"Top model candidates: { top_model_candidates } " )
0 commit comments