Skip to content

Commit d279aff

Browse files
authored
[tuner]: use compilation_info binding (#678)
This PR is relevant to the task in #453 : use IREE bindings for compilation info (incl., lowering_config and translation_info). Retire data class `configuration` and use the `compilation_info` from IREE python binding. Signed-off-by: Bangtian Liu <[email protected]>
1 parent ffb870f commit d279aff

File tree

8 files changed

+157
-129
lines changed

8 files changed

+157
-129
lines changed

tuner/tuner/candidate_gen.py

Lines changed: 50 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -41,23 +41,23 @@
4141

4242
def apply_configuration(
4343
template: list[str],
44-
configuration: Configuration,
44+
compilation_info: iree_codegen.CompilationInfoAttr,
4545
) -> str:
46-
lowering_config = configuration.lowering_config
46+
lowering_config = compilation_info.lowering_config
4747
intrinsic = lowering_config.mma_kind
4848
(
4949
subgroup_m_count,
5050
subgroup_n_count,
5151
) = lowering_config.subgroup_count_mn
5252
workgroup_sizes = lowering_config.workgroup_tile_sizes
5353
reduction_sizes = lowering_config.reduction_tile_sizes
54-
gpu_pipeline_options = configuration.translation_info.configuration[
54+
gpu_pipeline_options = compilation_info.translation_info.configuration[
5555
GPU_PIPELINE_OPTIONS_KEY
5656
]
57-
waves_per_eu = configuration.translation_info.configuration[LLVM_FUNC_ATTRS_KEY][
57+
waves_per_eu = compilation_info.translation_info.configuration[LLVM_FUNC_ATTRS_KEY][
5858
WAVES_PER_EU_KEY
5959
]
60-
tune_logger.info(f"Applying: {configuration}")
60+
tune_logger.info(f"Applying: {compilation_info}")
6161
expr0 = re.compile(
6262
r"<intrinsic = #iree_gpu\.mma_layout<(.+)>, subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>"
6363
)
@@ -69,7 +69,7 @@ def apply_configuration(
6969
expr4 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>")
7070
expr5 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"")
7171
repl0 = f"<intrinsic = {intrinsic}, subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>"
72-
repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.translation_info.workgroup_size))}] subgroup_size = {configuration.translation_info.subgroup_size},'
72+
repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, compilation_info.translation_info.workgroup_size))}] subgroup_size = {compilation_info.translation_info.subgroup_size},'
7373
repl2 = f"workgroup = {workgroup_sizes}"
7474
repl3 = f"reduction = {reduction_sizes}"
7575
repl4 = f"gpu_pipeline_options = {gpu_pipeline_options}"
@@ -101,7 +101,7 @@ def apply_params(
101101
self,
102102
problem_size: ProblemSize,
103103
template: list[str],
104-
configuration: Configuration,
104+
compilation_info: iree_codegen.CompilationInfoAttr,
105105
) -> MLIRTransformation:
106106
"""Apply parameter transformations to the operation."""
107107
pass
@@ -132,7 +132,10 @@ def find_handler(self, op_name: str) -> DispatchTuner:
132132

133133
class MmtTuner(DispatchTuner, MmtParser):
134134
def get_transform_function_mmt(
135-
self, problem_size: ProblemSize, functionName: str, configuration: Configuration
135+
self,
136+
problem_size: ProblemSize,
137+
functionName: str,
138+
compilation_info: iree_codegen.CompilationInfoAttr,
136139
) -> str:
137140
return f"""
138141
transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{
@@ -141,10 +144,7 @@ def get_transform_function_mmt(
141144
%rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value
142145
transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value
143146
transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value
144-
%config = transform.param.constant #iree_codegen.compilation_info<
145-
lowering_config = {configuration.lowering_config},
146-
translation_info = {configuration.translation_info}
147-
> -> !transform.any_param
147+
%config = transform.param.constant {compilation_info} -> !transform.any_param
148148
transform.yield %matmul, %config : !transform.any_op, !transform.any_param
149149
}}
150150
"""
@@ -153,29 +153,34 @@ def apply_params(
153153
self,
154154
problem_size: ProblemSize,
155155
template: list[str],
156-
configuration: Configuration,
156+
compilation_info: iree_codegen.CompilationInfoAttr,
157157
) -> MLIRTransformation:
158158
M, N, K = problem_size.MNK
159159
modified = indent(
160160
self.get_transform_function_mmt(
161-
problem_size, f"match_mmt_{M}x{N}x{K}", configuration
161+
problem_size, f"match_mmt_{M}x{N}x{K}", compilation_info
162162
),
163163
"// ",
164164
)
165165
modified += apply_configuration(
166166
template,
167-
configuration,
167+
compilation_info,
168168
)
169169
embeddable = indent(
170-
self.get_transform_function_mmt(problem_size, f"match_op", configuration),
170+
self.get_transform_function_mmt(
171+
problem_size, f"match_op", compilation_info
172+
),
171173
" ",
172174
)
173175
return MLIRTransformation(template, modified, embeddable)
174176

175177

176178
class ConvTuner(DispatchTuner, ConvParser):
177179
def get_transform_function_conv(
178-
self, problem_size: ProblemSize, functionName: str, configuration: Configuration
180+
self,
181+
problem_size: ProblemSize,
182+
functionName: str,
183+
compilation_info: iree_codegen.CompilationInfoAttr,
179184
) -> str:
180185
dynamic_batch_input_ty = problem_size.lhs_type
181186
dynamic_batch_input_ty.shape = dynamic_batch_input_ty.shape.copy()
@@ -198,10 +203,7 @@ def get_transform_function_conv(
198203
ins(%lhs, %rhs : {input}, {filter})
199204
outs(%out : {output}) -> {output}
200205
}} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
201-
%config = transform.param.constant #iree_codegen.compilation_info<
202-
lowering_config = {configuration.lowering_config},
203-
translation_info = {configuration.translation_info}
204-
> -> !transform.any_param
206+
%config = transform.param.constant {compilation_info} -> !transform.any_param
205207
transform.yield %conv, %config : !transform.any_op, !transform.any_param
206208
}}
207209
"""
@@ -210,23 +212,25 @@ def apply_params(
210212
self,
211213
problem_size: ProblemSize,
212214
template: list[str],
213-
configuration: Configuration,
215+
compilation_info: iree_codegen.CompilationInfoAttr,
214216
) -> MLIRTransformation:
215217
conv_dims = ConvDimInfo.from_problem_size(problem_size)
216218
modified = indent(
217219
self.get_transform_function_conv(
218220
problem_size,
219221
f"match_conv_2d_nhwc_hwcf_Bx{conv_dims.oh}x{conv_dims.ow}x{conv_dims.oc}x{conv_dims.fh}x{conv_dims.fw}x{conv_dims.ic}",
220-
configuration,
222+
compilation_info,
221223
),
222224
"// ",
223225
)
224226
modified += apply_configuration(
225227
template,
226-
configuration,
228+
compilation_info,
227229
)
228230
embeddable = indent(
229-
self.get_transform_function_conv(problem_size, f"match_op", configuration),
231+
self.get_transform_function_conv(
232+
problem_size, f"match_op", compilation_info
233+
),
230234
" ",
231235
)
232236
return MLIRTransformation(template, modified, embeddable)
@@ -237,7 +241,7 @@ def get_transform_function_broadcast_rhs_mmt(
237241
self,
238242
problem_size: ProblemSize,
239243
functionName: str,
240-
configuration: Configuration,
244+
compilation_info: iree_codegen.CompilationInfoAttr,
241245
) -> str:
242246
lhs_dynamic_batch = problem_size.lhs_type
243247
lhs_dynamic_batch.shape = lhs_dynamic_batch.shape.copy()
@@ -250,10 +254,7 @@ def get_transform_function_broadcast_rhs_mmt(
250254
%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value
251255
transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value
252256
transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value
253-
%config = transform.param.constant #iree_codegen.compilation_info<
254-
lowering_config = {configuration.lowering_config},
255-
translation_info = {configuration.translation_info}
256-
> -> !transform.any_param
257+
%config = transform.param.constant {compilation_info} -> !transform.any_param
257258
transform.yield %generic, %config : !transform.any_op, !transform.any_param
258259
}}
259260
"""
@@ -262,23 +263,23 @@ def apply_params_broadcast_rhs_mmt(
262263
self,
263264
problem_size: ProblemSize,
264265
template: list[str],
265-
configuration: Configuration,
266+
compilation_info: iree_codegen.CompilationInfoAttr,
266267
) -> MLIRTransformation:
267268
M, N, K = problem_size.MNK
268269
modified = indent(
269270
self.get_transform_function_broadcast_rhs_mmt(
270-
problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", configuration
271+
problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", compilation_info
271272
),
272273
"// ",
273274
)
274275
modified += apply_configuration(
275276
template,
276-
configuration,
277+
compilation_info,
277278
)
278279

279280
embeddable = indent(
280281
self.get_transform_function_broadcast_rhs_mmt(
281-
problem_size, f"match_op", configuration
282+
problem_size, f"match_op", compilation_info
282283
),
283284
" ",
284285
)
@@ -288,19 +289,19 @@ def apply_params(
288289
self,
289290
problem_size: ProblemSize,
290291
template: list[str],
291-
configuration: Configuration,
292+
compilation_info: iree_codegen.CompilationInfoAttr,
292293
) -> MLIRTransformation:
293294
if self.is_broadcast_rhs_mmt(template):
294295
return self.apply_params_broadcast_rhs_mmt(
295-
problem_size, template, configuration
296+
problem_size, template, compilation_info
296297
)
297298

298299
# TODO: Generate transform function.
299300
return MLIRTransformation(
300301
template,
301302
apply_configuration(
302303
template,
303-
configuration,
304+
compilation_info,
304305
),
305306
"",
306307
)
@@ -311,7 +312,7 @@ def get_transform_function_batch_mmt(
311312
self,
312313
problem_size: ProblemSize,
313314
functionName: str,
314-
configuration: Configuration,
315+
compilation_info: iree_codegen.CompilationInfoAttr,
315316
) -> str:
316317
return f"""
317318
transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{
@@ -320,10 +321,7 @@ def get_transform_function_batch_mmt(
320321
%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value
321322
transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value
322323
transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value
323-
%config = transform.param.constant #iree_codegen.compilation_info<
324-
lowering_config = {configuration.lowering_config},
325-
translation_info = {configuration.translation_info}
326-
> -> !transform.any_param
324+
%config = transform.param.constant {compilation_info} -> !transform.any_param
327325
transform.yield %generic, %config : !transform.any_op, !transform.any_param
328326
}}
329327
"""
@@ -332,24 +330,24 @@ def apply_params(
332330
self,
333331
problem_size: ProblemSize,
334332
template: list[str],
335-
configuration: Configuration,
333+
compilation_info: iree_codegen.CompilationInfoAttr,
336334
) -> MLIRTransformation:
337335
M, N, K = problem_size.MNK
338336
B = problem_size.matmul_size.B
339337
modified = indent(
340338
self.get_transform_function_batch_mmt(
341-
problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", configuration
339+
problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", compilation_info
342340
),
343341
"// ",
344342
)
345343
modified += apply_configuration(
346344
template,
347-
configuration,
345+
compilation_info,
348346
)
349347

350348
embeddable = indent(
351349
self.get_transform_function_batch_mmt(
352-
problem_size, f"match_op", configuration
350+
problem_size, f"match_op", compilation_info
353351
),
354352
" ",
355353
)
@@ -362,7 +360,7 @@ def get_transform_function_batch_matmul(
362360
problem_size: ProblemSize,
363361
tile_dims: str,
364362
functionName: str,
365-
configuration: Configuration,
363+
compilation_info: iree_codegen.CompilationInfoAttr,
366364
) -> str:
367365
input0 = f"tensor<{problem_size.lhs_type}>"
368366
input1 = f"tensor<{problem_size.rhs_type}>"
@@ -377,10 +375,7 @@ def get_transform_function_batch_matmul(
377375
ins(%lhs, %rhs : {input0}, {input1})
378376
outs(%out : {output}) -> {output}
379377
}} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
380-
%config = transform.param.constant #iree_codegen.compilation_info<
381-
lowering_config = {configuration.lowering_config},
382-
translation_info = {configuration.translation_info}
383-
> -> !transform.any_param
378+
%config = transform.param.constant {compilation_info} -> !transform.any_param
384379
transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param
385380
}}
386381
"""
@@ -389,26 +384,26 @@ def apply_params(
389384
self,
390385
problem_size: ProblemSize,
391386
template: list[str],
392-
configuration: Configuration,
387+
compilation_info: iree_codegen.CompilationInfoAttr,
393388
) -> MLIRTransformation:
394389
M, N, K = problem_size.MNK
395390
modified = indent(
396391
self.get_transform_function_batch_matmul(
397392
problem_size,
398393
self.tile_dims,
399394
f"match_batch_matmul_{problem_size.matmul_size.B}x{M}x{N}x{K}",
400-
configuration,
395+
compilation_info,
401396
),
402397
"// ",
403398
)
404399
modified += apply_configuration(
405400
template,
406-
configuration,
401+
compilation_info,
407402
)
408403

409404
embeddable = indent(
410405
self.get_transform_function_batch_matmul(
411-
problem_size, self.tile_dims, f"match_op", configuration
406+
problem_size, self.tile_dims, f"match_op", compilation_info
412407
),
413408
" ",
414409
)

0 commit comments

Comments
 (0)