From 70017c3d0de2d81a30b107990c78c28f01dcecfb Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Tue, 3 Dec 2024 00:55:06 -0500 Subject: [PATCH] [tuner]: use lowering config binding (#629) This PR is relevant to the task in https://github.com/nod-ai/shark-ai/issues/453 : use IREE bindings for compilation info (incl., lowering_config and translation_info). Remove data class `ReorderWorkgroupsStrategy`, and use lowering_config binding. --------- Signed-off-by: Bangtian Liu --- tuner/tuner/candidate_gen.py | 133 ++++++++++++++-------- tuner/tuner/candidate_gen_test.py | 134 +++++++++++++++-------- tuner/tuner/common.py | 91 ++++++++++++--- tuner/tuner/common_test.py | 42 ++++++- tuner/tuner/dispatch_constraints.py | 42 ++++--- tuner/tuner/dispatch_constraints_test.py | 2 +- tuner/tuner/dispatch_parser.py | 68 ++++++++---- tuner/tuner/dispatch_parser_test.py | 69 ++++++++---- 8 files changed, 408 insertions(+), 173 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index f09e08888..c903ec85f 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -38,10 +38,16 @@ tune_logger = logging.getLogger("tune") - +# TODO: remove the argument 'workgroup_sizes' and 'reduction_sizes'. def apply_configuration( - template: list[str], configuration: Configuration, tile_sizes: list[int] + template: list[str], + configuration: Configuration, + workgroup_sizes: list[int], + reduction_sizes: list[int], ) -> str: + intrinsic = get_intrinsic(configuration) + subgroup_m_count = get_subgroup_m_count(configuration) + subgroup_n_count = get_subgroup_n_count(configuration) tune_logger.info(f"Applying: {configuration}") expr0 = re.compile( r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" @@ -49,14 +55,16 @@ def apply_configuration( expr1 = re.compile( r"LLVMGPUVectorDistribute workgroup_size = \[.+\] subgroup_size = ([0-9]+)," ) - expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]") - expr3 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") - expr4 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") - repl0 = f"" + expr2 = re.compile(r"workgroup = \[([0-9]+)(, ([0-9]+))+\]") + expr3 = re.compile(r"reduction = \[([0-9]+)(, ([0-9]+))+\]") + expr4 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") + expr5 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") + repl0 = f"" repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' - repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]' - repl3 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" - repl4 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' + repl2 = f"workgroup = {workgroup_sizes}" + repl3 = f"reduction = {reduction_sizes}" + repl4 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" + repl5 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' new_mlir = "" for line in template: @@ -64,12 +72,14 @@ def apply_configuration( line = re.sub(expr0, repl0, line) if "LLVMGPUVectorDistribute " in line: line = re.sub(expr1, repl1, line) - if "tile_sizes" in line: + if "workgroup" in line: line = re.sub(expr2, repl2, line) - if "gpu_pipeline_options =" in line: + if "reduction" in line: line = re.sub(expr3, repl3, line) - if "amdgpu-waves-per-eu" in line: + if "gpu_pipeline_options =" in line: line = re.sub(expr4, repl4, line) + if "amdgpu-waves-per-eu" in line: + line = re.sub(expr5, repl5, line) new_mlir += line return new_mlir @@ -115,7 +125,9 @@ class MmtTuner(DispatchTuner, MmtParser): def get_transform_function_mmt( self, problem_size: ProblemSize, functionName: str, configuration: Configuration ) -> str: - tile_sizes = ", ".join(map(str, get_mmt_tile_sizes(configuration))) + intrinsic = get_intrinsic(configuration) + subgroup_m_count = get_subgroup_m_count(configuration) + subgroup_n_count = get_subgroup_n_count(configuration) wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -127,12 +139,12 @@ def get_transform_function_mmt( transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = {configuration.lowering_config}>, translation_info = #iree_codegen.translation_info + intrinsic = {intrinsic}, + subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}> {extra_config}}}> > -> !transform.any_param transform.yield %matmul, %config : !transform.any_op, !transform.any_param @@ -153,7 +165,10 @@ def apply_params( "// ", ) modified += apply_configuration( - template, configuration, get_mmt_tile_sizes(configuration) + template, + configuration, + get_mmt_workgroup_sizes(configuration), + get_mmt_reduction_sizes(configuration), ) embeddable = indent( self.get_transform_function_mmt(problem_size, f"match_op", configuration), @@ -163,13 +178,6 @@ def apply_params( class ConvTuner(DispatchTuner, ConvParser): - # int64_t n = outputShape[0]; - # int64_t oh = outputShape[1]; - # int64_t ow = outputShape[2]; - # int64_t oc = outputShape[3]; - # int64_t fh = filterShape[0]; - # int64_t fw = filterShape[1]; - # int64_t ic = filterShape[2]; def get_transform_function_conv( self, problem_size: ProblemSize, functionName: str, configuration: Configuration ) -> str: @@ -185,7 +193,15 @@ def get_transform_function_conv( filter = f"tensor<{problem_size.rhs_type}>" output = f"tensor<{dynamic_batch_output_ty}>" - tile_sizes = ", ".join(map(str, self.get_conv_tile_sizes(configuration))) + workgroup_sizes = ", ".join( + map(str, self.get_conv_workgroup_sizes(configuration)) + ) + reduction_sizes = ", ".join( + map(str, self.get_conv_reduction_sizes(configuration)) + ) + intrinsic = get_intrinsic(configuration) + subgroup_m_count = get_subgroup_m_count(configuration) + subgroup_n_count = get_subgroup_n_count(configuration) wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -200,12 +216,12 @@ def get_transform_function_conv( outs(%out : {output}) -> {output} }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = {configuration.lowering_config}>, translation_info = #iree_codegen.translation_info + intrinsic = {intrinsic}, + subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}> {extra_config}}}> > -> !transform.any_param transform.yield %conv, %config : !transform.any_op, !transform.any_param @@ -228,7 +244,10 @@ def apply_params( "// ", ) modified += apply_configuration( - template, configuration, self.get_conv_tile_sizes(configuration) + template, + configuration, + self.get_conv_workgroup_sizes(configuration), + self.get_conv_reduction_sizes(configuration), ) embeddable = indent( self.get_transform_function_conv(problem_size, f"match_op", configuration), @@ -244,7 +263,15 @@ def get_transform_function_broadcast_rhs_mmt( functionName: str, configuration: Configuration, ) -> str: - tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) + workgroup_sizes = ", ".join( + map(str, get_batch_mmt_workgroup_sizes(configuration)) + ) + reduction_sizes = ", ".join( + map(str, get_batch_mmt_reduction_sizes(configuration)) + ) + intrinsic = get_intrinsic(configuration) + subgroup_m_count = get_subgroup_m_count(configuration) + subgroup_n_count = get_subgroup_n_count(configuration) wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -261,12 +288,12 @@ def get_transform_function_broadcast_rhs_mmt( transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = {configuration.lowering_config}>, translation_info = #iree_codegen.translation_info + intrinsic = {intrinsic}, + subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}> {extra_config}}}> > -> !transform.any_param transform.yield %generic, %config : !transform.any_op, !transform.any_param @@ -287,7 +314,10 @@ def apply_params_broadcast_rhs_mmt( "// ", ) modified += apply_configuration( - template, configuration, get_batch_mmt_tile_sizes(configuration) + template, + configuration, + get_batch_mmt_workgroup_sizes(configuration), + get_batch_mmt_reduction_sizes(configuration), ) embeddable = indent( @@ -315,7 +345,8 @@ def apply_params( apply_configuration( template, configuration, - get_contract_tile_sizes(configuration, self.tile_dims), + get_contract_workgroup_sizes(configuration, self.tile_dims), + get_contract_reduction_sizes(configuration, self.tile_dims), ), "", ) @@ -328,7 +359,9 @@ def get_transform_function_batch_mmt( functionName: str, configuration: Configuration, ) -> str: - tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) + intrinsic = get_intrinsic(configuration) + subgroup_m_count = get_subgroup_m_count(configuration) + subgroup_n_count = get_subgroup_n_count(configuration) wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -341,12 +374,12 @@ def get_transform_function_batch_mmt( transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = {configuration.lowering_config}>, translation_info = #iree_codegen.translation_info + intrinsic = {intrinsic}, + subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}> {extra_config}}}> > -> !transform.any_param transform.yield %generic, %config : !transform.any_op, !transform.any_param @@ -368,7 +401,10 @@ def apply_params( "// ", ) modified += apply_configuration( - template, configuration, get_batch_mmt_tile_sizes(configuration) + template, + configuration, + get_batch_mmt_workgroup_sizes(configuration), + get_batch_mmt_reduction_sizes(configuration), ) embeddable = indent( @@ -392,9 +428,9 @@ def get_transform_function_batch_matmul( input1 = f"tensor<{problem_size.rhs_type}>" output = f"tensor<{problem_size.res_type}>" - tile_sizes = ", ".join( - map(str, get_contract_tile_sizes(configuration, tile_dims)) - ) + intrinsic = get_intrinsic(configuration) + subgroup_m_count = get_subgroup_m_count(configuration) + subgroup_n_count = get_subgroup_n_count(configuration) wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -409,12 +445,12 @@ def get_transform_function_batch_matmul( outs(%out : {output}) -> {output} }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = {configuration.lowering_config}>, translation_info = #iree_codegen.translation_info + intrinsic = {intrinsic}, + subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}> {extra_config}}}> > -> !transform.any_param transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param @@ -440,7 +476,8 @@ def apply_params( modified += apply_configuration( template, configuration, - get_contract_tile_sizes(configuration, self.tile_dims), + get_contract_workgroup_sizes(configuration, self.tile_dims), + get_contract_reduction_sizes(configuration, self.tile_dims), ) embeddable = indent( @@ -548,7 +585,7 @@ def tune( tune_logger.debug(str(problem_size)) configs = [] for i, config in enumerate( - generate_solutions(tune_logger, problem_size, num_subgroups, mma_list) + generate_solutions(tuner_context, problem_size, num_subgroups, mma_list) ): if i >= limit: break diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 19b6e1fe7..11de8a900 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -39,7 +39,7 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", "", + "", "gpu_pipeline_options = #iree_gpu.pipeline_options", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', ] @@ -48,13 +48,18 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_kind=mma_attr, + workgroup=[8, 8, 0], + reduction=[0, 0, 8], + subgroup_m_count=16, + subgroup_n_count=16, + ) config = common.Configuration( subgroup_size=16, workgroup_size=[16, 16, 1], - intrinsic=mma_attr, - tile_sizes=[8, 8, 8], - subgroup_m_count=16, - subgroup_n_count=16, + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get( prefetch_shared_memory=True ), @@ -84,7 +89,8 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [16, 16, 1] subgroup_size = 16" in modified ) - assert "tile_sizes = [[8, 8, 8]]" in modified + assert "workgroup = [8, 8, 0]" in modified + assert "reduction = [0, 0, 8]" in modified assert ( "gpu_pipeline_options = #iree_gpu.pipeline_options" in modified @@ -96,7 +102,7 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", "", + "", 'gpu_pipeline_options = #iree_gpu.pipeline_options, {llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', ] @@ -104,13 +110,18 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_kind=mma_attr, + workgroup=[464, 320, 0], + reduction=[0, 0, 16], + subgroup_m_count=1, + subgroup_n_count=4, + ) config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=mma_attr, - tile_sizes=[464, 320, 16], - subgroup_m_count=1, - subgroup_n_count=4, + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get( reorder_workgroups_strategy=iree_gpu.ReorderWorkgroupsStrategyAttr.get( iree_gpu.ReorderWorkgroupsStrategy.Transpose @@ -144,7 +155,8 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" in modified ) - assert "tile_sizes = [[1, 1, 464, 320, 1, 1, 16]]" in modified + assert "workgroup = [1, 1, 464, 320, 1, 1, 0]" in modified + assert "reduction = [0, 0, 0, 0, 0, 0, 16]" in modified assert ( "gpu_pipeline_options = #iree_gpu.pipeline_options>" in modified @@ -156,7 +168,7 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 2, subgroup_n_count = 2>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -171,13 +183,18 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_kind=mma_attr, + workgroup=[480, 384, 0], + reduction=[0, 0, 32], + subgroup_m_count=1, + subgroup_n_count=4, + ) config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=mma_attr, - tile_sizes=[480, 384, 32], - subgroup_m_count=1, - subgroup_n_count=4, + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) @@ -197,7 +214,8 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" in new_mlir ) - assert "tile_sizes = [[1, 480, 384, 32]]" in new_mlir + assert "workgroup = [1, 480, 384, 0]" in new_mlir + assert "reduction = [0, 0, 0, 32]" in new_mlir assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in new_mlir @@ -205,7 +223,7 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -220,13 +238,18 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_kind=mma_attr, + workgroup=[416, 320, 0], + reduction=[0, 0, 128], + subgroup_m_count=2, + subgroup_n_count=2, + ) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=mma_attr, - tile_sizes=[416, 320, 128], - subgroup_m_count=2, - subgroup_n_count=2, + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) @@ -250,7 +273,8 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" in modified ) - assert "tile_sizes = [[1, 416, 320, 128]]" in modified + assert "workgroup = [1, 416, 320, 0]" in modified + assert "reduction = [0, 0, 0, 128]" in modified assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified @@ -258,7 +282,7 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -272,13 +296,18 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_kind=mma_attr, + workgroup=[128, 64, 0], + reduction=[0, 0, 128], + subgroup_m_count=2, + subgroup_n_count=2, + ) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=mma_attr, - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) @@ -300,7 +329,8 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" in modified ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert "workgroup = [1, 128, 64, 0]" in modified + assert "reduction = [0, 0, 0, 128]" in modified assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified @@ -308,7 +338,7 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -322,13 +352,18 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_kind=mma_attr, + workgroup=[128, 64, 0], + reduction=[0, 0, 128], + subgroup_m_count=2, + subgroup_n_count=2, + ) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=mma_attr, - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=4, ) @@ -352,7 +387,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" in modified ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert "workgroup = [1, 128, 64, 0]" in modified + assert "reduction = [0, 0, 0, 128]" in modified assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified assert embeddable @@ -372,7 +408,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: "%config = transform.param.constant #iree_codegen.compilation_info<" in embeddable ) - assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable + assert "workgroup = [128, 64, 0]" in embeddable + assert "reduction = [0, 0, 128]" in embeddable assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable @@ -381,7 +418,7 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -395,13 +432,18 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_kind=mma_attr, + workgroup=[128, 64, 0], + reduction=[0, 0, 128], + subgroup_m_count=2, + subgroup_n_count=2, + ) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=mma_attr, - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=4, ) @@ -428,7 +470,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" in modified ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert "workgroup = [1, 128, 64, 0]" in modified + assert "reduction = [0, 0, 0, 128]" in modified assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified assert embeddable @@ -449,7 +492,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: "%config = transform.param.constant #iree_codegen.compilation_info<" in embeddable ) - assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable + assert "workgroup = [128, 64, 0]" in embeddable + assert "reduction = [0, 0, 128]" in embeddable assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable @@ -457,7 +501,7 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: def test_detect_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: mlir_lines = [ r"%18 = tensor.empty() : tensor<2x1024x10240xi32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', ] assert candidate_gen.ContractionTuner("mk", "nk", "mnk").is_broadcast_rhs_mmt( diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 80c755aa7..702008f5e 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -9,6 +9,7 @@ from dataclasses import astuple, dataclass from enum import Enum from typing import Optional +from typing import Any from iree.compiler import ir # type: ignore @@ -22,6 +23,7 @@ def __init__(self, ctx: ir.Context): self.i8 = ir.IntegerType.get_signless(8, ctx) self.i16 = ir.IntegerType.get_signless(16, ctx) self.i32 = ir.IntegerType.get_signless(32, ctx) + self.i64 = ir.IntegerType.get_signless(64, ctx) self.f8E4M3FNUZ = ir.Float8E4M3FNUZType.get(ctx) self.f8E5M2FNUZ = ir.Float8E5M2FNUZType.get(ctx) @@ -30,6 +32,9 @@ def __init__(self, ctx: ir.Context): self.bf16 = ir.BF16Type.get(ctx) + def getI64(self, value: int) -> ir.IntegerAttr: + return ir.IntegerAttr.get(self.i64, value) + class TunerContext: def __init__(self, mlir_ctx: ir.Context, logger: logging.Logger): @@ -105,27 +110,87 @@ def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool: return list(filter(is_comptible, mma_intrinsics)) -class ReorderWorkgroupsStrategy(Enum): - NONE = 0 - SWIZZLE = 1 - TRANSPOSE = 2 - - def __str__(self) -> str: - return self.name.title() - - @dataclass class Configuration: subgroup_size: int workgroup_size: list[int] - intrinsic: iree_gpu.MMAAttr - tile_sizes: list[int] - subgroup_m_count: int - subgroup_n_count: int + lowering_config: iree_gpu.LoweringConfigAttr gpu_pipeline_options: iree_gpu.PipelineOptionsAttr waves_per_eu: int +def get_intrinsic(config: Configuration) -> Optional[iree_gpu.MMAAttr]: + if "mma_kind" in config.lowering_config.attributes: + return config.lowering_config.attributes["mma_kind"] + return None + + +def get_workgroup_tile_sizes(config: Configuration) -> list[int]: + if "workgroup" in config.lowering_config.attributes: + workgroup_attrs = config.lowering_config.attributes["workgroup"] + return [attr.value for attr in workgroup_attrs] + return [] + + +def get_reduction_tile_sizes(config: Configuration) -> list[int]: + if "reduction" in config.lowering_config.attributes: + reduction_attrs = config.lowering_config.attributes["reduction"] + return [attr.value for attr in reduction_attrs] + return [] + + +def get_subgroup_m_count(config: Configuration) -> Optional[int]: + if "subgroup_m_count" in config.lowering_config.attributes: + attr = config.lowering_config.attributes["subgroup_m_count"] + return attr.value + return None + + +def get_subgroup_n_count(config: Configuration) -> Optional[int]: + if "subgroup_n_count" in config.lowering_config.attributes: + attr = config.lowering_config.attributes["subgroup_n_count"] + return attr.value + return None + + +def get_lowering_config( + tuner_ctx: TunerContext, + **kwargs: Any, +) -> iree_gpu.LoweringConfigAttr: + lowering_config_dict: dict[str, Any] = {} + for key, value in kwargs.items(): + # A local variable to hold the transformed value. + promoted_value = value + match key: + case "workgroup" | "reduction": + if isinstance(value, list): + promoted_value = ir.ArrayAttr.get( + [tuner_ctx.type.getI64(x) for x in value] + ) + elif not isinstance(value, ir.ArrayAttr): + assert ( + False + ), f"Unsupported type for key '{key}': {type(value).__name__}" + case "subgroup_m_count" | "subgroup_n_count": + if isinstance(value, int): + promoted_value = tuner_ctx.type.getI64(value) + elif not isinstance(value, tuner_ctx.type.i64): + assert ( + False + ), f"Unsupported type for key '{key}': {type(value).__name__}" + case "mma_kind": + if not isinstance(value, iree_gpu.MMAAttr): + assert ( + False + ), f"Unsupported type for key '{key}': {type(value).__name__}" + case _: + assert False, f"Unhandled key in lowering configuration: {key}" + + lowering_config_dict[key] = promoted_value + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + return iree_gpu.LoweringConfigAttr.get(lowering_config_attrs) + + def get_pipeline_config(configuration: Configuration) -> str: extra_config = "" pipeline_options = configuration.gpu_pipeline_options diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index 73d3f04e3..f13aed3d7 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -73,16 +73,21 @@ def test_gpu_pipeline_options(tuner_ctx: common.TunerContext) -> None: ) -def test_get_pipeline_config(mlir_ctx: ir.Context) -> None: +def test_get_pipeline_config(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_kind=mma_attr, + workgroup=[4, 8, 0], + reduction=[0, 0, 16], + subgroup_m_count=1, + subgroup_n_count=1, + ) config = common.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - intrinsic=mma_attr, - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) @@ -186,3 +191,30 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: ) == [] ) + + +def test_get_lowering_config(tuner_ctx: common.TunerContext) -> None: + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + workgroup=[4, 8, 0], + reduction=[0, 0, 16], + subgroup_m_count=1, + subgroup_n_count=1, + ) + + assert ( + str(lowering_config) + == "#iree_gpu.lowering_config<{reduction = [0, 0, 16], subgroup_m_count = 1 : i64, subgroup_n_count = 1 : i64, workgroup = [4, 8, 0]}>" + ) + + config = common.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + lowering_config=lowering_config, + gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), + waves_per_eu=2, + ) + + assert common.get_intrinsic(config) is None + assert common.get_subgroup_m_count(config) == 1 + assert common.get_subgroup_n_count(config) == 1 diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index cdfb1bd50..f86523389 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -10,8 +10,10 @@ import z3 # type: ignore from typing import Iterator +from iree.compiler import ir # type: ignore from iree.compiler.dialects import iree_gpu # type: ignore +from iree.compiler.dialects import iree_codegen # type: ignore from .common import * @@ -172,13 +174,13 @@ def getMMAAttr( def generate_solutions( - logger: logging.Logger, + tuner_ctx: TunerContext, problem_size: ProblemSize, num_subgrups: int, mma_intrinsics: list[iree_gpu.MMAIntrinsic], ) -> Iterator[Configuration]: M, N, K = problem_size.MNK - logger.info(f"{M},{N},{K}") + tuner_ctx.logger.info(f"{M},{N},{K}") m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") subgroup_size = z3.Int("subgroup_size") intrinsic_mn = z3.Int("intrinsic_mn") @@ -216,26 +218,36 @@ def generate_solutions( mma_intrinsics, ) solver.add(z3.simplify(z3.And(constraints))) - logger.debug(f"Initial constraints: {solver}") + tuner_ctx.logger.debug(f"Initial constraints: {solver}") + i = 0 while solver.check() == z3.sat: model = solver.model() lookup = lambda var: model[var].as_long() - + mma_attr = getMMAAttr( + problem_size.res_type.element_type, + lookup(intrinsic_mn), + lookup(intrinsic_mn), + lookup(intrinsic_k), + problem_size.lhs_type.element_type, + problem_size.rhs_type.element_type, + ) + lowering_config = get_lowering_config( + tuner_ctx=tuner_ctx, + mma_kind=mma_attr, + workgroup=[lookup(m), lookup(n), 0], + reduction=[ + 0, + 0, + lookup(k), + ], + subgroup_m_count=lookup(sg_m_cnt), + subgroup_n_count=lookup(sg_n_cnt), + ) config = Configuration( lookup(subgroup_size), [lookup(wg_x), lookup(wg_y), lookup(wg_z)], - getMMAAttr( - problem_size.res_type.element_type, - lookup(intrinsic_mn), - lookup(intrinsic_mn), - lookup(intrinsic_k), - problem_size.lhs_type.element_type, - problem_size.rhs_type.element_type, - ), - [lookup(m), lookup(n), lookup(k)], - lookup(sg_m_cnt), - lookup(sg_n_cnt), + lowering_config, iree_gpu.PipelineOptionsAttr.get(), lookup(waves_per_eu), ) diff --git a/tuner/tuner/dispatch_constraints_test.py b/tuner/tuner/dispatch_constraints_test.py index 9de4beeee..842ea8509 100644 --- a/tuner/tuner/dispatch_constraints_test.py +++ b/tuner/tuner/dispatch_constraints_test.py @@ -39,7 +39,7 @@ def test_generate_solutions(tuner_ctx: common.TunerContext) -> None: matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) configs = dispatch_constraints.generate_solutions( - tuner_ctx.logger, + tuner_ctx, problem_size, 4, [ diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index c4b4b9ad5..ad63ba815 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -20,25 +20,49 @@ def parse_tensor_type(tensor_type: str) -> ShapedType: return ShapedType(shaped_ty.shape, shaped_ty.element_type) -def get_mmt_tile_sizes(configuration: Configuration): - return configuration.tile_sizes +def get_mmt_workgroup_sizes(configuration: Configuration): + return get_workgroup_tile_sizes(configuration) -def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: - m, n, k = configuration.tile_sizes - tile_size = [1] * len(tile_dims) +def get_mmt_reduction_sizes(configuration: Configuration): + return get_reduction_tile_sizes(configuration) + + +def get_contract_workgroup_sizes( + configuration: Configuration, tile_dims: str +) -> list[int]: + m, n, _k = get_workgroup_tile_sizes(configuration) + + workgroup_size = [1] * len(tile_dims) for idx, dim in enumerate(tile_dims): if dim == "m": - tile_size[idx] = m + workgroup_size[idx] = m if dim == "n": - tile_size[idx] = n + workgroup_size[idx] = n + if dim == "k": + workgroup_size[idx] = 0 + + return workgroup_size + + +def get_contract_reduction_sizes( + configuration: Configuration, tile_dims: str +) -> list[int]: + _m, _n, k = get_reduction_tile_sizes(configuration) + reduction_size = [0] * len(tile_dims) + for idx, dim in enumerate(tile_dims): if dim == "k": - tile_size[idx] = k - return tile_size + reduction_size[idx] = k + + return reduction_size + + +def get_batch_mmt_workgroup_sizes(configuration: Configuration) -> list[int]: + return [1] + get_workgroup_tile_sizes(configuration) -def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: - return [1] + configuration.tile_sizes +def get_batch_mmt_reduction_sizes(configuration: Configuration) -> list[int]: + return [0] + get_reduction_tile_sizes(configuration) class MlirRegex(Enum): @@ -140,18 +164,21 @@ class ConvParser(DispatchParser): def supports(self, op_name: str) -> bool: return "conv_2d_nhwc_hwcf" in op_name - def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]: - m, n, k = configuration.tile_sizes + def get_conv_workgroup_sizes(self, configuration: Configuration) -> list[int]: batch = 1 fh = 1 fw = 1 oh = 1 - oc = n - ow = m - ic = k - return [batch, oh, ow, oc, fh, fw, ic] + ow, oc, _ic = get_workgroup_tile_sizes(configuration) + + return [batch, oh, ow, oc, fh, fw, 0] + + def get_conv_reduction_sizes(self, configuration: Configuration) -> list[int]: + _ow, _oc, ic = get_reduction_tile_sizes(configuration) + + return [0, 0, 0, 0, 0, 0, ic] def get_shapes(self, template: list[str]) -> ProblemSize: for line in template: @@ -178,13 +205,6 @@ def get_shapes(self, template: list[str]) -> ProblemSize: res_shaped_type = parse_tensor_type(res_tensor_type) assert res_shaped_type.rank() == 4 - # int64_t n = outputShape[0]; - # int64_t oh = outputShape[1]; - # int64_t ow = outputShape[2]; - # int64_t oc = outputShape[3]; - # int64_t fh = filterShape[0]; - # int64_t fw = filterShape[1]; - # int64_t ic = filterShape[2]; dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) return ProblemSize( MatmulSize( diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index 529559f83..650540c63 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -42,39 +42,59 @@ def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None: def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_kind=mma_attr, + workgroup=[128, 320, 0], + reduction=[0, 0, 32], + subgroup_m_count=1, + subgroup_n_count=4, + ) config = dispatch_parser.Configuration( subgroup_size=0, workgroup_size=[], - intrinsic=mma_attr, - tile_sizes=[128, 320, 32], - subgroup_m_count=0, - subgroup_n_count=0, + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=0, ) - assert dispatch_parser.get_mmt_tile_sizes(config) == [128, 320, 32] + assert dispatch_parser.get_mmt_workgroup_sizes(config) == [128, 320, 0] + assert dispatch_parser.get_mmt_reduction_sizes(config) == [0, 0, 32] def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_kind=mma_attr, + workgroup=[464, 320, 0], + reduction=[0, 0, 16], + subgroup_m_count=1, + subgroup_n_count=4, + ) config = dispatch_parser.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=mma_attr, - tile_sizes=[464, 320, 16], - subgroup_m_count=1, - subgroup_n_count=4, + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=1, ) - assert dispatch_parser.ConvParser().get_conv_tile_sizes(config) == [ + assert dispatch_parser.ConvParser().get_conv_workgroup_sizes(config) == [ 1, 1, 464, 320, 1, 1, + 0, + ] + assert dispatch_parser.ConvParser().get_conv_reduction_sizes(config) == [ + 0, + 0, + 0, + 0, + 0, + 0, 16, ] @@ -82,24 +102,29 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_kind=mma_attr, + workgroup=[4, 8, 0], + reduction=[0, 0, 16], + subgroup_m_count=1, + subgroup_n_count=1, + ) config = dispatch_parser.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - intrinsic=mma_attr, - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) - assert dispatch_parser.get_contract_tile_sizes(config, "mnk") == [4, 8, 16] - assert dispatch_parser.get_contract_tile_sizes(config, "nmk") == [8, 4, 16] - assert dispatch_parser.get_contract_tile_sizes(config, "knm") == [16, 8, 4] - assert dispatch_parser.get_contract_tile_sizes(config, "kkk") == [ - 16, - 16, - 16, - ] + assert dispatch_parser.get_contract_workgroup_sizes(config, "mnk") == [4, 8, 0] + assert dispatch_parser.get_contract_reduction_sizes(config, "mnk") == [0, 0, 16] + assert dispatch_parser.get_contract_workgroup_sizes(config, "nmk") == [8, 4, 0] + assert dispatch_parser.get_contract_reduction_sizes(config, "nmk") == [0, 0, 16] + assert dispatch_parser.get_contract_workgroup_sizes(config, "knm") == [0, 8, 4] + assert dispatch_parser.get_contract_reduction_sizes(config, "knm") == [16, 0, 0] + assert dispatch_parser.get_contract_workgroup_sizes(config, "kkk") == [0, 0, 0] + assert dispatch_parser.get_contract_reduction_sizes(config, "kkk") == [16, 16, 16] def test_get_shapes_mmt(tuner_ctx: common.TunerContext) -> None: