From d279afff48c56ab5e0e6e5ebe2717ccf0b26ee50 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Wed, 11 Dec 2024 23:51:09 -0500 Subject: [PATCH] [tuner]: use compilation_info binding (#678) This PR is relevant to the task in https://github.com/nod-ai/shark-ai/issues/453 : use IREE bindings for compilation info (incl., lowering_config and translation_info). Retire data class `configuration` and use the `compilation_info` from IREE python binding. Signed-off-by: Bangtian Liu --- tuner/tuner/candidate_gen.py | 105 +++++++++++++--------------- tuner/tuner/candidate_gen_test.py | 51 ++++++-------- tuner/tuner/common.py | 6 -- tuner/tuner/common_test.py | 23 +++--- tuner/tuner/dispatch_constraints.py | 8 ++- tuner/tuner/dispatch_parser.py | 4 +- tuner/tuner/dispatch_parser_test.py | 87 +++++++++++++++++------ tuner/tuner/libtuner.py | 2 +- 8 files changed, 157 insertions(+), 129 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index a3252130e..ed150bfec 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -41,9 +41,9 @@ def apply_configuration( template: list[str], - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> str: - lowering_config = configuration.lowering_config + lowering_config = compilation_info.lowering_config intrinsic = lowering_config.mma_kind ( subgroup_m_count, @@ -51,13 +51,13 @@ def apply_configuration( ) = lowering_config.subgroup_count_mn workgroup_sizes = lowering_config.workgroup_tile_sizes reduction_sizes = lowering_config.reduction_tile_sizes - gpu_pipeline_options = configuration.translation_info.configuration[ + gpu_pipeline_options = compilation_info.translation_info.configuration[ GPU_PIPELINE_OPTIONS_KEY ] - waves_per_eu = configuration.translation_info.configuration[LLVM_FUNC_ATTRS_KEY][ + waves_per_eu = compilation_info.translation_info.configuration[LLVM_FUNC_ATTRS_KEY][ WAVES_PER_EU_KEY ] - tune_logger.info(f"Applying: {configuration}") + tune_logger.info(f"Applying: {compilation_info}") expr0 = re.compile( r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" ) @@ -69,7 +69,7 @@ def apply_configuration( expr4 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") expr5 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") repl0 = f"" - repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.translation_info.workgroup_size))}] subgroup_size = {configuration.translation_info.subgroup_size},' + repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, compilation_info.translation_info.workgroup_size))}] subgroup_size = {compilation_info.translation_info.subgroup_size},' repl2 = f"workgroup = {workgroup_sizes}" repl3 = f"reduction = {reduction_sizes}" repl4 = f"gpu_pipeline_options = {gpu_pipeline_options}" @@ -101,7 +101,7 @@ def apply_params( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> MLIRTransformation: """Apply parameter transformations to the operation.""" pass @@ -132,7 +132,10 @@ def find_handler(self, op_name: str) -> DispatchTuner: class MmtTuner(DispatchTuner, MmtParser): def get_transform_function_mmt( - self, problem_size: ProblemSize, functionName: str, configuration: Configuration + self, + problem_size: ProblemSize, + functionName: str, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> str: return f""" transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ @@ -141,10 +144,7 @@ def get_transform_function_mmt( %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = {configuration.lowering_config}, - translation_info = {configuration.translation_info} - > -> !transform.any_param + %config = transform.param.constant {compilation_info} -> !transform.any_param transform.yield %matmul, %config : !transform.any_op, !transform.any_param }} """ @@ -153,21 +153,23 @@ def apply_params( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> MLIRTransformation: M, N, K = problem_size.MNK modified = indent( self.get_transform_function_mmt( - problem_size, f"match_mmt_{M}x{N}x{K}", configuration + problem_size, f"match_mmt_{M}x{N}x{K}", compilation_info ), "// ", ) modified += apply_configuration( template, - configuration, + compilation_info, ) embeddable = indent( - self.get_transform_function_mmt(problem_size, f"match_op", configuration), + self.get_transform_function_mmt( + problem_size, f"match_op", compilation_info + ), " ", ) return MLIRTransformation(template, modified, embeddable) @@ -175,7 +177,10 @@ def apply_params( class ConvTuner(DispatchTuner, ConvParser): def get_transform_function_conv( - self, problem_size: ProblemSize, functionName: str, configuration: Configuration + self, + problem_size: ProblemSize, + functionName: str, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> str: dynamic_batch_input_ty = problem_size.lhs_type dynamic_batch_input_ty.shape = dynamic_batch_input_ty.shape.copy() @@ -198,10 +203,7 @@ def get_transform_function_conv( ins(%lhs, %rhs : {input}, {filter}) outs(%out : {output}) -> {output} }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = {configuration.lowering_config}, - translation_info = {configuration.translation_info} - > -> !transform.any_param + %config = transform.param.constant {compilation_info} -> !transform.any_param transform.yield %conv, %config : !transform.any_op, !transform.any_param }} """ @@ -210,23 +212,25 @@ def apply_params( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> MLIRTransformation: conv_dims = ConvDimInfo.from_problem_size(problem_size) modified = indent( self.get_transform_function_conv( problem_size, f"match_conv_2d_nhwc_hwcf_Bx{conv_dims.oh}x{conv_dims.ow}x{conv_dims.oc}x{conv_dims.fh}x{conv_dims.fw}x{conv_dims.ic}", - configuration, + compilation_info, ), "// ", ) modified += apply_configuration( template, - configuration, + compilation_info, ) embeddable = indent( - self.get_transform_function_conv(problem_size, f"match_op", configuration), + self.get_transform_function_conv( + problem_size, f"match_op", compilation_info + ), " ", ) return MLIRTransformation(template, modified, embeddable) @@ -237,7 +241,7 @@ def get_transform_function_broadcast_rhs_mmt( self, problem_size: ProblemSize, functionName: str, - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> str: lhs_dynamic_batch = problem_size.lhs_type lhs_dynamic_batch.shape = lhs_dynamic_batch.shape.copy() @@ -250,10 +254,7 @@ def get_transform_function_broadcast_rhs_mmt( %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value -%config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = {configuration.lowering_config}, - translation_info = {configuration.translation_info} - > -> !transform.any_param +%config = transform.param.constant {compilation_info} -> !transform.any_param transform.yield %generic, %config : !transform.any_op, !transform.any_param }} """ @@ -262,23 +263,23 @@ def apply_params_broadcast_rhs_mmt( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> MLIRTransformation: M, N, K = problem_size.MNK modified = indent( self.get_transform_function_broadcast_rhs_mmt( - problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", configuration + problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", compilation_info ), "// ", ) modified += apply_configuration( template, - configuration, + compilation_info, ) embeddable = indent( self.get_transform_function_broadcast_rhs_mmt( - problem_size, f"match_op", configuration + problem_size, f"match_op", compilation_info ), " ", ) @@ -288,11 +289,11 @@ def apply_params( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> MLIRTransformation: if self.is_broadcast_rhs_mmt(template): return self.apply_params_broadcast_rhs_mmt( - problem_size, template, configuration + problem_size, template, compilation_info ) # TODO: Generate transform function. @@ -300,7 +301,7 @@ def apply_params( template, apply_configuration( template, - configuration, + compilation_info, ), "", ) @@ -311,7 +312,7 @@ def get_transform_function_batch_mmt( self, problem_size: ProblemSize, functionName: str, - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> str: return f""" transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ @@ -320,10 +321,7 @@ def get_transform_function_batch_mmt( %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value -%config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = {configuration.lowering_config}, - translation_info = {configuration.translation_info} - > -> !transform.any_param +%config = transform.param.constant {compilation_info} -> !transform.any_param transform.yield %generic, %config : !transform.any_op, !transform.any_param }} """ @@ -332,24 +330,24 @@ def apply_params( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> MLIRTransformation: M, N, K = problem_size.MNK B = problem_size.matmul_size.B modified = indent( self.get_transform_function_batch_mmt( - problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", configuration + problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", compilation_info ), "// ", ) modified += apply_configuration( template, - configuration, + compilation_info, ) embeddable = indent( self.get_transform_function_batch_mmt( - problem_size, f"match_op", configuration + problem_size, f"match_op", compilation_info ), " ", ) @@ -362,7 +360,7 @@ def get_transform_function_batch_matmul( problem_size: ProblemSize, tile_dims: str, functionName: str, - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> str: input0 = f"tensor<{problem_size.lhs_type}>" input1 = f"tensor<{problem_size.rhs_type}>" @@ -377,10 +375,7 @@ def get_transform_function_batch_matmul( ins(%lhs, %rhs : {input0}, {input1}) outs(%out : {output}) -> {output} }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = {configuration.lowering_config}, - translation_info = {configuration.translation_info} - > -> !transform.any_param + %config = transform.param.constant {compilation_info} -> !transform.any_param transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param }} """ @@ -389,7 +384,7 @@ def apply_params( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> MLIRTransformation: M, N, K = problem_size.MNK modified = indent( @@ -397,18 +392,18 @@ def apply_params( problem_size, self.tile_dims, f"match_batch_matmul_{problem_size.matmul_size.B}x{M}x{N}x{K}", - configuration, + compilation_info, ), "// ", ) modified += apply_configuration( template, - configuration, + compilation_info, ) embeddable = indent( self.get_transform_function_batch_matmul( - problem_size, self.tile_dims, f"match_op", configuration + problem_size, self.tile_dims, f"match_op", compilation_info ), " ", ) diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 7f104bcd9..0428ab7d2 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -65,9 +65,8 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [16, 16, 1], 16, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) problem_size = common.ProblemSize( @@ -77,7 +76,9 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: common.ShapedType([M, N], tuner_ctx.type.f32), common.DispatchKind.mmt, ) - tf_mlir = candidate_gen.MmtTuner().apply_params(problem_size, mlir_template, config) + tf_mlir = candidate_gen.MmtTuner().apply_params( + problem_size, mlir_template, compilation_info + ) modified = tf_mlir.modified embeddable = tf_mlir.embeddable @@ -134,9 +135,8 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [256, 1, 1], 64, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) problem_size = common.ProblemSize( @@ -147,7 +147,7 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.conv, ) tf_mlir = candidate_gen.ConvTuner().apply_params( - problem_size, mlir_template, config + problem_size, mlir_template, compilation_info ) modified = tf_mlir.modified @@ -208,13 +208,12 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [256, 1, 1], 64, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) tf_mlir = candidate_gen.ContractionTuner("mk", "nk", tile_dims).apply_params( - problem_size, mlir_template, config + problem_size, mlir_template, compilation_info ) new_mlir = tf_mlir.modified @@ -268,13 +267,12 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [128, 2, 1], 64, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) tf_mlir = candidate_gen.BatchMatmulTuner("mk", "nk", tile_dims).apply_params( - problem_size, mlir_template, config + problem_size, mlir_template, compilation_info ) modified = tf_mlir.modified @@ -331,13 +329,12 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [128, 2, 1], 64, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) tf_mlir = candidate_gen.BatchMmtTuner().apply_params( - problem_size, mlir_template, config + problem_size, mlir_template, compilation_info ) modified = tf_mlir.modified @@ -392,13 +389,12 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [128, 2, 1], 64, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) tf_mlir = candidate_gen.BatchMmtTuner().apply_params( - problem_size, mlir_template, config + problem_size, mlir_template, compilation_info ) modified = tf_mlir.modified @@ -477,14 +473,13 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [128, 2, 1], 64, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) tf_mlir = candidate_gen.ContractionTuner( "mk", "nk", "mnk" - ).apply_params_broadcast_rhs_mmt(problem_size, mlir_template, config) + ).apply_params_broadcast_rhs_mmt(problem_size, mlir_template, compilation_info) modified = tf_mlir.modified embeddable = tf_mlir.embeddable diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index c683c5bdc..5c79bd8dd 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -111,12 +111,6 @@ def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool: return list(filter(is_comptible, mma_intrinsics)) -@dataclass -class Configuration: - translation_info: iree_codegen.TranslationInfoAttr - lowering_config: iree_gpu.LoweringConfigAttr - - # The key name for GPUPipelineOptionsAttr in the translation info config dictionary. GPU_PIPELINE_OPTIONS_KEY = "gpu_pipeline_options" # The key name for llvm_func_attrs attribute in the translation info config dictionary. diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index af1e1bf9a..6157bb355 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -93,12 +93,11 @@ def test_get_pipeline_config(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [16, 16, 1], 32, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) config1_str: str = str( - config.translation_info.configuration[common.LLVM_FUNC_ATTRS_KEY] + compilation_info.translation_info.configuration[common.LLVM_FUNC_ATTRS_KEY] ) assert config1_str == '{"amdgpu-waves-per-eu" = "2"}' @@ -107,11 +106,10 @@ def test_get_pipeline_config(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [16, 16, 1], 32, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) - config2_str: str = str(config.translation_info.configuration) + config2_str: str = str(compilation_info.translation_info.configuration) assert ( config2_str == '{gpu_pipeline_options = #iree_gpu.pipeline_options, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}}' @@ -225,10 +223,9 @@ def test_get_lowering_config(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [16, 16, 1], 32, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) - assert config.lowering_config.mma_kind is None - assert config.lowering_config.subgroup_count_mn == (1, 1) + assert compilation_info.lowering_config.mma_kind is None + assert compilation_info.lowering_config.subgroup_count_mn == (1, 1) diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index 8ba202310..797c83534 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -178,7 +178,7 @@ def generate_solutions( problem_size: ProblemSize, num_subgrups: int, mma_intrinsics: list[iree_gpu.MMAIntrinsic], -) -> Iterator[Configuration]: +) -> Iterator[iree_codegen.CompilationInfoAttr]: M, N, K = problem_size.MNK tuner_ctx.logger.info(f"{M},{N},{K}") m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") @@ -258,7 +258,9 @@ def generate_solutions( lookup(subgroup_size), config_dict, ) - config = Configuration(translation_info, lowering_config) + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info + ) solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars))))) i += 1 - yield config + yield compilation_info diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index cc63c89a3..fe95c52a6 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -21,7 +21,7 @@ def parse_tensor_type(tensor_type: str) -> ShapedType: def get_contract_workgroup_sizes( - configuration: Configuration, tile_dims: str + configuration: iree_codegen.CompilationInfoAttr, tile_dims: str ) -> list[int]: m, n, _k = configuration.lowering_config.workgroup_tile_sizes @@ -38,7 +38,7 @@ def get_contract_workgroup_sizes( def get_contract_reduction_sizes( - configuration: Configuration, tile_dims: str + configuration: iree_codegen.CompilationInfoAttr, tile_dims: str ) -> list[int]: _m, _n, k = configuration.lowering_config.reduction_tile_sizes reduction_size = [0] * len(tile_dims) diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index a63576808..9f4afbb19 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -59,11 +59,10 @@ def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [], 0, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) - lowering_config = config.lowering_config + lowering_config = compilation_info.lowering_config assert lowering_config.workgroup_tile_sizes == [128, 320, 0] assert lowering_config.reduction_tile_sizes == [0, 0, 32] @@ -87,12 +86,27 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [256, 1, 1], 64, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) - assert config.lowering_config.workgroup_tile_sizes == [1, 1, 464, 320, 1, 1, 0] - assert config.lowering_config.reduction_tile_sizes == [0, 0, 0, 0, 0, 0, 16] + assert compilation_info.lowering_config.workgroup_tile_sizes == [ + 1, + 1, + 464, + 320, + 1, + 1, + 0, + ] + assert compilation_info.lowering_config.reduction_tile_sizes == [ + 0, + 0, + 0, + 0, + 0, + 0, + 16, + ] def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: @@ -114,18 +128,49 @@ def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [16, 16, 1], 32, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, - ) - assert dispatch_parser.get_contract_workgroup_sizes(config, "mnk") == [4, 8, 0] - assert dispatch_parser.get_contract_reduction_sizes(config, "mnk") == [0, 0, 16] - assert dispatch_parser.get_contract_workgroup_sizes(config, "nmk") == [8, 4, 0] - assert dispatch_parser.get_contract_reduction_sizes(config, "nmk") == [0, 0, 16] - assert dispatch_parser.get_contract_workgroup_sizes(config, "knm") == [0, 8, 4] - assert dispatch_parser.get_contract_reduction_sizes(config, "knm") == [16, 0, 0] - assert dispatch_parser.get_contract_workgroup_sizes(config, "kkk") == [0, 0, 0] - assert dispatch_parser.get_contract_reduction_sizes(config, "kkk") == [16, 16, 16] + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info + ) + assert dispatch_parser.get_contract_workgroup_sizes(compilation_info, "mnk") == [ + 4, + 8, + 0, + ] + assert dispatch_parser.get_contract_reduction_sizes(compilation_info, "mnk") == [ + 0, + 0, + 16, + ] + assert dispatch_parser.get_contract_workgroup_sizes(compilation_info, "nmk") == [ + 8, + 4, + 0, + ] + assert dispatch_parser.get_contract_reduction_sizes(compilation_info, "nmk") == [ + 0, + 0, + 16, + ] + assert dispatch_parser.get_contract_workgroup_sizes(compilation_info, "knm") == [ + 0, + 8, + 4, + ] + assert dispatch_parser.get_contract_reduction_sizes(compilation_info, "knm") == [ + 16, + 0, + 0, + ] + assert dispatch_parser.get_contract_workgroup_sizes(compilation_info, "kkk") == [ + 0, + 0, + 0, + ] + assert dispatch_parser.get_contract_reduction_sizes(compilation_info, "kkk") == [ + 16, + 16, + 16, + ] def test_get_shapes_mmt(tuner_ctx: common.TunerContext) -> None: diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index 3aa932dc4..3c195520c 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -64,7 +64,7 @@ class CandidateTracker: candidate_id: int dispatch_mlir_path: Optional[Path] = None dispatch_config_path: Optional[Path] = None - configuration: Optional[candidate_gen.Configuration] = None + configuration: Optional[candidate_gen.iree_codegen.CompilationInfoAttr] = None compilation_successful: Optional[bool] = None compiled_dispatch_path: Optional[Path] = None compiled_dispatch_hash: Optional[str] = None