41
41
42
42
def apply_configuration (
43
43
template : list [str ],
44
- configuration : Configuration ,
44
+ compilation_info : iree_codegen . CompilationInfoAttr ,
45
45
) -> str :
46
- lowering_config = configuration .lowering_config
46
+ lowering_config = compilation_info .lowering_config
47
47
intrinsic = lowering_config .mma_kind
48
48
(
49
49
subgroup_m_count ,
50
50
subgroup_n_count ,
51
51
) = lowering_config .subgroup_count_mn
52
52
workgroup_sizes = lowering_config .workgroup_tile_sizes
53
53
reduction_sizes = lowering_config .reduction_tile_sizes
54
- gpu_pipeline_options = configuration .translation_info .configuration [
54
+ gpu_pipeline_options = compilation_info .translation_info .configuration [
55
55
GPU_PIPELINE_OPTIONS_KEY
56
56
]
57
- waves_per_eu = configuration .translation_info .configuration [LLVM_FUNC_ATTRS_KEY ][
57
+ waves_per_eu = compilation_info .translation_info .configuration [LLVM_FUNC_ATTRS_KEY ][
58
58
WAVES_PER_EU_KEY
59
59
]
60
- tune_logger .info (f"Applying: { configuration } " )
60
+ tune_logger .info (f"Applying: { compilation_info } " )
61
61
expr0 = re .compile (
62
62
r"<intrinsic = #iree_gpu\.mma_layout<(.+)>, subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>"
63
63
)
@@ -69,7 +69,7 @@ def apply_configuration(
69
69
expr4 = re .compile (r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>" )
70
70
expr5 = re .compile (r"\"amdgpu-waves-per-eu\" = \"([0-9])\"" )
71
71
repl0 = f"<intrinsic = { intrinsic } , subgroup_m_count = { subgroup_m_count } , subgroup_n_count = { subgroup_n_count } >"
72
- repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{ ", " .join (map (str , configuration .translation_info .workgroup_size ))} ] subgroup_size = { configuration .translation_info .subgroup_size } ,'
72
+ repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{ ", " .join (map (str , compilation_info .translation_info .workgroup_size ))} ] subgroup_size = { compilation_info .translation_info .subgroup_size } ,'
73
73
repl2 = f"workgroup = { workgroup_sizes } "
74
74
repl3 = f"reduction = { reduction_sizes } "
75
75
repl4 = f"gpu_pipeline_options = { gpu_pipeline_options } "
@@ -101,7 +101,7 @@ def apply_params(
101
101
self ,
102
102
problem_size : ProblemSize ,
103
103
template : list [str ],
104
- configuration : Configuration ,
104
+ compilation_info : iree_codegen . CompilationInfoAttr ,
105
105
) -> MLIRTransformation :
106
106
"""Apply parameter transformations to the operation."""
107
107
pass
@@ -132,7 +132,10 @@ def find_handler(self, op_name: str) -> DispatchTuner:
132
132
133
133
class MmtTuner (DispatchTuner , MmtParser ):
134
134
def get_transform_function_mmt (
135
- self , problem_size : ProblemSize , functionName : str , configuration : Configuration
135
+ self ,
136
+ problem_size : ProblemSize ,
137
+ functionName : str ,
138
+ compilation_info : iree_codegen .CompilationInfoAttr ,
136
139
) -> str :
137
140
return f"""
138
141
transform.named_sequence @{ functionName } (%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{
@@ -141,10 +144,7 @@ def get_transform_function_mmt(
141
144
%rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value
142
145
transform.iree.match.cast_compatible_type %lhs = tensor<{ problem_size .lhs_type } > : !transform.any_value
143
146
transform.iree.match.cast_compatible_type %rhs = tensor<{ problem_size .rhs_type } > : !transform.any_value
144
- %config = transform.param.constant #iree_codegen.compilation_info<
145
- lowering_config = { configuration .lowering_config } ,
146
- translation_info = { configuration .translation_info }
147
- > -> !transform.any_param
147
+ %config = transform.param.constant { compilation_info } -> !transform.any_param
148
148
transform.yield %matmul, %config : !transform.any_op, !transform.any_param
149
149
}}
150
150
"""
@@ -153,29 +153,34 @@ def apply_params(
153
153
self ,
154
154
problem_size : ProblemSize ,
155
155
template : list [str ],
156
- configuration : Configuration ,
156
+ compilation_info : iree_codegen . CompilationInfoAttr ,
157
157
) -> MLIRTransformation :
158
158
M , N , K = problem_size .MNK
159
159
modified = indent (
160
160
self .get_transform_function_mmt (
161
- problem_size , f"match_mmt_{ M } x{ N } x{ K } " , configuration
161
+ problem_size , f"match_mmt_{ M } x{ N } x{ K } " , compilation_info
162
162
),
163
163
"// " ,
164
164
)
165
165
modified += apply_configuration (
166
166
template ,
167
- configuration ,
167
+ compilation_info ,
168
168
)
169
169
embeddable = indent (
170
- self .get_transform_function_mmt (problem_size , f"match_op" , configuration ),
170
+ self .get_transform_function_mmt (
171
+ problem_size , f"match_op" , compilation_info
172
+ ),
171
173
" " ,
172
174
)
173
175
return MLIRTransformation (template , modified , embeddable )
174
176
175
177
176
178
class ConvTuner (DispatchTuner , ConvParser ):
177
179
def get_transform_function_conv (
178
- self , problem_size : ProblemSize , functionName : str , configuration : Configuration
180
+ self ,
181
+ problem_size : ProblemSize ,
182
+ functionName : str ,
183
+ compilation_info : iree_codegen .CompilationInfoAttr ,
179
184
) -> str :
180
185
dynamic_batch_input_ty = problem_size .lhs_type
181
186
dynamic_batch_input_ty .shape = dynamic_batch_input_ty .shape .copy ()
@@ -198,10 +203,7 @@ def get_transform_function_conv(
198
203
ins(%lhs, %rhs : { input } , { filter } )
199
204
outs(%out : { output } ) -> { output }
200
205
}} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
201
- %config = transform.param.constant #iree_codegen.compilation_info<
202
- lowering_config = { configuration .lowering_config } ,
203
- translation_info = { configuration .translation_info }
204
- > -> !transform.any_param
206
+ %config = transform.param.constant { compilation_info } -> !transform.any_param
205
207
transform.yield %conv, %config : !transform.any_op, !transform.any_param
206
208
}}
207
209
"""
@@ -210,23 +212,25 @@ def apply_params(
210
212
self ,
211
213
problem_size : ProblemSize ,
212
214
template : list [str ],
213
- configuration : Configuration ,
215
+ compilation_info : iree_codegen . CompilationInfoAttr ,
214
216
) -> MLIRTransformation :
215
217
conv_dims = ConvDimInfo .from_problem_size (problem_size )
216
218
modified = indent (
217
219
self .get_transform_function_conv (
218
220
problem_size ,
219
221
f"match_conv_2d_nhwc_hwcf_Bx{ conv_dims .oh } x{ conv_dims .ow } x{ conv_dims .oc } x{ conv_dims .fh } x{ conv_dims .fw } x{ conv_dims .ic } " ,
220
- configuration ,
222
+ compilation_info ,
221
223
),
222
224
"// " ,
223
225
)
224
226
modified += apply_configuration (
225
227
template ,
226
- configuration ,
228
+ compilation_info ,
227
229
)
228
230
embeddable = indent (
229
- self .get_transform_function_conv (problem_size , f"match_op" , configuration ),
231
+ self .get_transform_function_conv (
232
+ problem_size , f"match_op" , compilation_info
233
+ ),
230
234
" " ,
231
235
)
232
236
return MLIRTransformation (template , modified , embeddable )
@@ -237,7 +241,7 @@ def get_transform_function_broadcast_rhs_mmt(
237
241
self ,
238
242
problem_size : ProblemSize ,
239
243
functionName : str ,
240
- configuration : Configuration ,
244
+ compilation_info : iree_codegen . CompilationInfoAttr ,
241
245
) -> str :
242
246
lhs_dynamic_batch = problem_size .lhs_type
243
247
lhs_dynamic_batch .shape = lhs_dynamic_batch .shape .copy ()
@@ -250,10 +254,7 @@ def get_transform_function_broadcast_rhs_mmt(
250
254
%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value
251
255
transform.iree.match.cast_compatible_type %lhs = tensor<{ lhs_dynamic_batch } > : !transform.any_value
252
256
transform.iree.match.cast_compatible_type %rhs = tensor<{ problem_size .rhs_type } > : !transform.any_value
253
- %config = transform.param.constant #iree_codegen.compilation_info<
254
- lowering_config = { configuration .lowering_config } ,
255
- translation_info = { configuration .translation_info }
256
- > -> !transform.any_param
257
+ %config = transform.param.constant { compilation_info } -> !transform.any_param
257
258
transform.yield %generic, %config : !transform.any_op, !transform.any_param
258
259
}}
259
260
"""
@@ -262,23 +263,23 @@ def apply_params_broadcast_rhs_mmt(
262
263
self ,
263
264
problem_size : ProblemSize ,
264
265
template : list [str ],
265
- configuration : Configuration ,
266
+ compilation_info : iree_codegen . CompilationInfoAttr ,
266
267
) -> MLIRTransformation :
267
268
M , N , K = problem_size .MNK
268
269
modified = indent (
269
270
self .get_transform_function_broadcast_rhs_mmt (
270
- problem_size , f"match_broadcast_rhs_mmt_Bx{ M } x{ N } x{ K } " , configuration
271
+ problem_size , f"match_broadcast_rhs_mmt_Bx{ M } x{ N } x{ K } " , compilation_info
271
272
),
272
273
"// " ,
273
274
)
274
275
modified += apply_configuration (
275
276
template ,
276
- configuration ,
277
+ compilation_info ,
277
278
)
278
279
279
280
embeddable = indent (
280
281
self .get_transform_function_broadcast_rhs_mmt (
281
- problem_size , f"match_op" , configuration
282
+ problem_size , f"match_op" , compilation_info
282
283
),
283
284
" " ,
284
285
)
@@ -288,19 +289,19 @@ def apply_params(
288
289
self ,
289
290
problem_size : ProblemSize ,
290
291
template : list [str ],
291
- configuration : Configuration ,
292
+ compilation_info : iree_codegen . CompilationInfoAttr ,
292
293
) -> MLIRTransformation :
293
294
if self .is_broadcast_rhs_mmt (template ):
294
295
return self .apply_params_broadcast_rhs_mmt (
295
- problem_size , template , configuration
296
+ problem_size , template , compilation_info
296
297
)
297
298
298
299
# TODO: Generate transform function.
299
300
return MLIRTransformation (
300
301
template ,
301
302
apply_configuration (
302
303
template ,
303
- configuration ,
304
+ compilation_info ,
304
305
),
305
306
"" ,
306
307
)
@@ -311,7 +312,7 @@ def get_transform_function_batch_mmt(
311
312
self ,
312
313
problem_size : ProblemSize ,
313
314
functionName : str ,
314
- configuration : Configuration ,
315
+ compilation_info : iree_codegen . CompilationInfoAttr ,
315
316
) -> str :
316
317
return f"""
317
318
transform.named_sequence @{ functionName } (%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{
@@ -320,10 +321,7 @@ def get_transform_function_batch_mmt(
320
321
%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value
321
322
transform.iree.match.cast_compatible_type %lhs = tensor<{ problem_size .lhs_type } > : !transform.any_value
322
323
transform.iree.match.cast_compatible_type %rhs = tensor<{ problem_size .rhs_type } > : !transform.any_value
323
- %config = transform.param.constant #iree_codegen.compilation_info<
324
- lowering_config = { configuration .lowering_config } ,
325
- translation_info = { configuration .translation_info }
326
- > -> !transform.any_param
324
+ %config = transform.param.constant { compilation_info } -> !transform.any_param
327
325
transform.yield %generic, %config : !transform.any_op, !transform.any_param
328
326
}}
329
327
"""
@@ -332,24 +330,24 @@ def apply_params(
332
330
self ,
333
331
problem_size : ProblemSize ,
334
332
template : list [str ],
335
- configuration : Configuration ,
333
+ compilation_info : iree_codegen . CompilationInfoAttr ,
336
334
) -> MLIRTransformation :
337
335
M , N , K = problem_size .MNK
338
336
B = problem_size .matmul_size .B
339
337
modified = indent (
340
338
self .get_transform_function_batch_mmt (
341
- problem_size , f"match_batch_mmt_{ B } x{ M } x{ N } x{ K } " , configuration
339
+ problem_size , f"match_batch_mmt_{ B } x{ M } x{ N } x{ K } " , compilation_info
342
340
),
343
341
"// " ,
344
342
)
345
343
modified += apply_configuration (
346
344
template ,
347
- configuration ,
345
+ compilation_info ,
348
346
)
349
347
350
348
embeddable = indent (
351
349
self .get_transform_function_batch_mmt (
352
- problem_size , f"match_op" , configuration
350
+ problem_size , f"match_op" , compilation_info
353
351
),
354
352
" " ,
355
353
)
@@ -362,7 +360,7 @@ def get_transform_function_batch_matmul(
362
360
problem_size : ProblemSize ,
363
361
tile_dims : str ,
364
362
functionName : str ,
365
- configuration : Configuration ,
363
+ compilation_info : iree_codegen . CompilationInfoAttr ,
366
364
) -> str :
367
365
input0 = f"tensor<{ problem_size .lhs_type } >"
368
366
input1 = f"tensor<{ problem_size .rhs_type } >"
@@ -377,10 +375,7 @@ def get_transform_function_batch_matmul(
377
375
ins(%lhs, %rhs : { input0 } , { input1 } )
378
376
outs(%out : { output } ) -> { output }
379
377
}} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
380
- %config = transform.param.constant #iree_codegen.compilation_info<
381
- lowering_config = { configuration .lowering_config } ,
382
- translation_info = { configuration .translation_info }
383
- > -> !transform.any_param
378
+ %config = transform.param.constant { compilation_info } -> !transform.any_param
384
379
transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param
385
380
}}
386
381
"""
@@ -389,26 +384,26 @@ def apply_params(
389
384
self ,
390
385
problem_size : ProblemSize ,
391
386
template : list [str ],
392
- configuration : Configuration ,
387
+ compilation_info : iree_codegen . CompilationInfoAttr ,
393
388
) -> MLIRTransformation :
394
389
M , N , K = problem_size .MNK
395
390
modified = indent (
396
391
self .get_transform_function_batch_matmul (
397
392
problem_size ,
398
393
self .tile_dims ,
399
394
f"match_batch_matmul_{ problem_size .matmul_size .B } x{ M } x{ N } x{ K } " ,
400
- configuration ,
395
+ compilation_info ,
401
396
),
402
397
"// " ,
403
398
)
404
399
modified += apply_configuration (
405
400
template ,
406
- configuration ,
401
+ compilation_info ,
407
402
)
408
403
409
404
embeddable = indent (
410
405
self .get_transform_function_batch_matmul (
411
- problem_size , self .tile_dims , f"match_op" , configuration
406
+ problem_size , self .tile_dims , f"match_op" , compilation_info
412
407
),
413
408
" " ,
414
409
)
0 commit comments