From 119833b8637556aba160a946688a2090ce468be5 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Thu, 12 Dec 2024 14:49:44 -0600 Subject: [PATCH] address comments, use root_op attribute in matching Signed-off-by: Max Dawkins --- tuner/examples/test/conv_benchmark.mlir | 2 +- tuner/examples/test/mmt_benchmark.mlir | 2 +- tuner/tuner/candidate_gen.py | 8 +-- tuner/tuner/candidate_gen_test.py | 5 +- tuner/tuner/dispatch_parser.py | 25 +++------ tuner/tuner/dispatch_parser_test.py | 3 +- tuner/tuner/libtuner.py | 2 +- tuner/tuner/op_matchers.py | 72 ++++++++++++++----------- tuner/tuner/spec_builder.py | 2 +- 9 files changed, 60 insertions(+), 61 deletions(-) diff --git a/tuner/examples/test/conv_benchmark.mlir b/tuner/examples/test/conv_benchmark.mlir index 99eb0f69b..945d67fbb 100644 --- a/tuner/examples/test/conv_benchmark.mlir +++ b/tuner/examples/test/conv_benchmark.mlir @@ -18,7 +18,7 @@ module { %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 1280, 1280], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<3x3x1280x1280xi8> %5 = tensor.empty() : tensor<2x32x32x1280xi32> %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2x32x32x1280xi32>) -> tensor<2x32x32x1280xi32> - %7 = linalg.conv_2d_nhwc_hwcf {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout, promote_operands = [0, 1], reduction = [0, 0, 0, 0, 1, 1, 64], subgroup_m_count = 1 : i64, subgroup_n_count = 4 : i64, workgroup = [1, 1, 32, 256, 0, 0, 0]}>} ins(%3, %4 : tensor<2x34x34x1280xi8>, tensor<3x3x1280x1280xi8>) outs(%6 : tensor<2x32x32x1280xi32>) -> tensor<2x32x32x1280xi32> + %7 = linalg.conv_2d_nhwc_hwcf {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout, promote_operands = [0, 1], reduction = [0, 0, 0, 0, 1, 1, 64], subgroup_m_count = 1 : i64, subgroup_n_count = 4 : i64, workgroup = [1, 1, 32, 256, 0, 0, 0]}>, root_op} ins(%3, %4 : tensor<2x34x34x1280xi8>, tensor<3x3x1280x1280xi8>) outs(%6 : tensor<2x32x32x1280xi32>) -> tensor<2x32x32x1280xi32> flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 32, 32, 1280], strides = [1, 1, 1, 1] : tensor<2x32x32x1280xi32> -> !flow.dispatch.tensor> return } diff --git a/tuner/examples/test/mmt_benchmark.mlir b/tuner/examples/test/mmt_benchmark.mlir index 68edeb9d7..d95e46015 100644 --- a/tuner/examples/test/mmt_benchmark.mlir +++ b/tuner/examples/test/mmt_benchmark.mlir @@ -18,7 +18,7 @@ module { %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<2048x2048xf16> %5 = tensor.empty() : tensor<2048x2048xf32> %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2048x2048xf32>) -> tensor<2048x2048xf32> - %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%3, %4 : tensor<2048x2048xf16>, tensor<2048x2048xf16>) outs(%6 : tensor<2048x2048xf32>) attrs = {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout, promote_operands = [0, 1], reduction = [0, 0, 64], subgroup_m_count = 2 : i64, subgroup_n_count = 2 : i64, workgroup = [64, 128, 0]}>} { + %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%3, %4 : tensor<2048x2048xf16>, tensor<2048x2048xf16>) outs(%6 : tensor<2048x2048xf32>) attrs = {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout, promote_operands = [0, 1], reduction = [0, 0, 64], subgroup_m_count = 2 : i64, subgroup_n_count = 2 : i64, workgroup = [64, 128, 0]}>, root_op} { ^bb0(%in: f16, %in_0: f16, %out: f32): %8 = arith.extf %in : f16 to f32 %9 = arith.extf %in_0 : f16 to f32 diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 00a617d8e..07ee175d8 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -113,7 +113,7 @@ def get_td_spec( ir_module: ir.Module, compilation_info: iree_codegen.CompilationInfoAttr, ) -> ir.Module: - """Generate a transform dialect spec module for the funcOp.""" + """Generate a transform dialect spec that applies the compilation info attr.""" pass @@ -162,7 +162,8 @@ def get_td_spec( M = acc_type.get_dim_size(0) N = acc_type.get_dim_size(1) K = lhs_type.get_dim_size(1) - func_name = f"match_mmt_{M}x{N}x{K}_{lhs_type.element_type}x{rhs_type.element_type}x{acc_type.element_type}" + # TODO(Max191): Get the function name from the func.func in the input module. + func_name = f"match_contraction_{M}x{N}x{K}_{lhs_type.element_type}x{rhs_type.element_type}x{acc_type.element_type}" return build_td_spec( ir_module.context, contraction_op, compilation_info, func_name ) @@ -195,6 +196,7 @@ def get_td_spec( Q = rhs_type.get_dim_size(1) F = rhs_type.get_dim_size(3) conv_type = conv_op.name.split(".")[-1] + # TODO(Max191): Get the function name from the func.func in the input module. func_name = f"match_{conv_type}_{N}x{H}x{W}x{C}x{P}x{Q}x{F}_{lhs_type.element_type}x{rhs_type.element_type}x{acc_type.element_type}" return build_td_spec( ir_module.context, conv_op, compilation_info, func_name @@ -558,7 +560,7 @@ def get_default_output_dir() -> str: return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M") -# TODO(Max191): Remove in favor of using tune_with_td. +# TODO(https://github.com/nod-ai/shark-ai/issues/453): Remove in favor of using tune_with_td. def tune( input: str, # Path to the mlir file to be tuned output: str = "", # Path to the output directory, auto creates one if not given diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 6424f7e70..3a28cd791 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -52,6 +52,7 @@ def test_get_td_spec_contraction(tuner_ctx: common.TunerContext) -> None: affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} + {root_op} ins(%arg0, %arg1 : tensor<2048x2048xf16>, tensor<2048x2048xf16>) outs(%1 : tensor<2048x2048xf32>) { ^bb0(%in: f16, %in_0: f16, %out: f32): @@ -101,7 +102,6 @@ def test_get_td_spec_contraction(tuner_ctx: common.TunerContext) -> None: matcher_sequence = None entry_point = None for op in named_sequence_ops: - print(op.opview.sym_name) if str(op.opview.sym_name) == "\"apply_op_config\"": apply_config_sequence = op elif str(op.opview.sym_name) == "\"__kernel_config\"": @@ -139,7 +139,7 @@ def test_get_td_spec_convolution(tuner_ctx: common.TunerContext) -> None: %cst = arith.constant 0 : i32 %0 = tensor.empty() : tensor<2x32x32x2048xi32> %1 = linalg.fill ins(%cst : i32) outs(%0 : tensor<2x32x32x2048xi32>) -> tensor<2x32x32x2048xi32> - %2 = linalg.conv_2d_nhwc_hwcf + %2 = linalg.conv_2d_nhwc_hwcf {root_op} ins(%arg0, %arg1 : tensor<2x34x34x2048xi8>, tensor<3x3x2048x2048xi8>) outs(%1 : tensor<2x32x32x2048xi32>) -> tensor<2x32x32x2048xi32> return %2 : tensor<2x32x32x2048xi32> @@ -182,7 +182,6 @@ def test_get_td_spec_convolution(tuner_ctx: common.TunerContext) -> None: matcher_sequence = None entry_point = None for op in named_sequence_ops: - print(op.opview.sym_name) if str(op.opview.sym_name) == "\"apply_op_config\"": apply_config_sequence = op elif str(op.opview.sym_name) == "\"__kernel_config\"": diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index b749b7e26..1ea77fc4c 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -78,19 +78,6 @@ def parse_mlir(mlir_text: str, ctx: TunerContext) -> ir.Module: return mlir_module -def find_root_op( - ir_module: ir.Module, - matcher: NamedOpMatcher, -) -> Optional[ir.Operation]: - func_ops: list[ir.Operation] = get_named_ops(ir_module, "func.func") - if len(func_ops) != 1: - return None - matched_ops = matcher.get_matched_ops(func_ops[0].operation) - if len(matched_ops) != 1: - return None - return matched_ops[0] - - class DispatchParser(metaclass=ABCMeta): @abstractmethod def supports(self, op_name: str) -> bool: @@ -118,14 +105,14 @@ def get_contraction_operation( self, ir_module: ir.Module, ) -> Optional[ir.Operation]: - return find_root_op(ir_module, ContractionOpInterfaceMatcher()) + return match_root_op(ir_module, ContractionOpInterfaceMatcher()) # TODO(Max191): Pass the ir_module directly instead of the template str. def get_shapes(self, template: list[str]) -> ProblemSize: matcher = ContractionOpInterfaceMatcher() with ir.Context() as ctx: - ir_module = ir.Module.parse("".join(template), ctx) - contraction_op = find_root_op(ir_module, matcher) + ir_module = ir.Module.parse("\n".join(template), ctx) + contraction_op = match_root_op(ir_module, matcher) if contraction_op is None: assert False, f"contraction op not found" cdims = matcher.contraction_dimensions @@ -167,13 +154,13 @@ def get_conv_operation( self, ir_module: ir.Module, ) -> Optional[ir.Operation]: - return find_root_op(ir_module, NamedOpMatcher(self.supported_ops)) + return match_root_op(ir_module, NamedOpMatcher(self.supported_ops)) # TODO(Max191): Pass the ir_module directly instead of the template str. def get_shapes(self, template: list[str]) -> ProblemSize: with ir.Context() as ctx: - ir_module = ir.Module.parse("".join(template), ctx) - conv_op = find_root_op(ir_module, NamedOpMatcher(self.supported_ops)) + ir_module = ir.Module.parse("\n".join(template), ctx) + conv_op = match_root_op(ir_module, NamedOpMatcher(self.supported_ops)) if conv_op is None: assert False, f"convolution op not found" lhs_type = ir.RankedTensorType(conv_op.operands[0].type) diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index 6adb3b83c..fba0efa11 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -55,6 +55,7 @@ def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None: affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} + {root_op} ins(%arg0, %arg1 : tensor<4x4xf16>, tensor<4x4xf16>) outs(%1 : tensor<4x4xf32>) { ^bb0(%in: f16, %in_0: f16, %out: f32): @@ -82,7 +83,7 @@ def test_get_conv_operation(tuner_ctx: common.TunerContext) -> None: %cst = arith.constant 0 : i32 %0 = tensor.empty() : tensor<2x32x32x16xi32> %1 = linalg.fill ins(%cst : i32) outs(%0 : tensor<2x32x32x16xi32>) -> tensor<2x32x32x16xi32> - %2 = linalg.conv_2d_nhwc_hwcf + %2 = linalg.conv_2d_nhwc_hwcf {root_op} ins(%arg0, %arg1 : tensor<2x34x34x16xi8>, tensor<3x3x16x16xi8>) outs(%1 : tensor<2x32x32x16xi32>) -> tensor<2x32x32x16xi32> return %2 : tensor<2x32x32x16xi32> diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index ac48c5930..6bece17f4 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -878,7 +878,7 @@ def generate_candidate_specs( candidate_trackers.append(new_candidate) except Exception as e: logging.error("An error occurred during candidates generation: %s", str(e)) - # Capture and log debug messages from candidate_gen.py + # Capture and log debug messages from candidate_gen.py. tune_logger = logging.getLogger("tune_with_td") for handler in logging.getLogger().handlers: if isinstance(handler, logging.FileHandler): diff --git a/tuner/tuner/op_matchers.py b/tuner/tuner/op_matchers.py index 8108019b1..0f46093c9 100644 --- a/tuner/tuner/op_matchers.py +++ b/tuner/tuner/op_matchers.py @@ -6,12 +6,19 @@ # This code implements matcher functions for MLIR modules using python bindings. -from abc import abstractmethod +from abc import ABCMeta, abstractmethod from .common import * from iree.compiler import ir # type: ignore +class OpMatcher(metaclass=ABCMeta): + @abstractmethod + def match(self, op: ir.Operation) -> bool: + """Check if the op passes the matching criteria.""" + pass + + def walk_collect_ops( op: ir.Operation, ops: list[ir.Operation], @@ -22,15 +29,6 @@ def walk_collect_ops( return ir.WalkResult.ADVANCE -def get_ops(op: ir.Operation, fn): - ops: list[ir.Operation] = [] - op.opview.walk( - lambda op: walk_collect_ops(op, ops, fn), - ir.WalkOrder.POST_ORDER, - ) - return ops - - def get_ops_from_module(module: ir.Module, fn): ops: list[ir.Operation] = [] for op in module.body.operations: @@ -41,37 +39,32 @@ def get_ops_from_module(module: ir.Module, fn): return ops -def get_named_ops(module: ir.Module, name: str): - return get_ops_from_module(module, lambda op: op.name == name) +def is_root_op(op: ir.Operation) -> bool: + for attr in op.opview.attributes: + if attr.name == "root_op": + return True + return False -def get_map_result_dim_positions(map: ir.AffineMap): - exprs = [] - if not map.is_projected_permutation: +def match_root_op( + ir_module: ir.Module, + matcher: OpMatcher, +) -> Optional[ir.Operation]: + root_ops: list[ir.Operation] = get_ops_from_module(ir_module, is_root_op) + if len(root_ops) != 1: return None - for expr in map.results: - dim_str = str(expr) - if len(dim_str) < 1: - return None - if dim_str[0] != "d": - return None - if not dim_str[1:].isdigit(): - return None - dim_position = int(dim_str[1:]) - exprs.append(dim_position) - return exprs + if not matcher.match(root_ops[0].operation): + return None + return root_ops[0] -class NamedOpMatcher: +class NamedOpMatcher(OpMatcher): def __init__(self, op_names: list[str]): self.op_names = op_names def match(self, op: ir.Operation) -> bool: return op.name in self.op_names - def get_matched_ops(self, op: ir.Operation): - return get_ops(op, lambda nestedOp: self.match(nestedOp)) - # TODO(Max191): Add logic to match the body of the generic op. class GenericOpMatcher(NamedOpMatcher): @@ -111,6 +104,23 @@ def match(self, op: ir.Operation) -> bool: return True +def get_map_result_dim_positions(map: ir.AffineMap): + exprs = [] + if not map.is_projected_permutation: + return None + for expr in map.results: + dim_str = str(expr) + if len(dim_str) < 1: + return None + if dim_str[0] != "d": + return None + if not dim_str[1:].isdigit(): + return None + dim_position = int(dim_str[1:]) + exprs.append(dim_position) + return exprs + + class ContractionOpInterfaceMatcher(GenericOpMatcher): def __init__(self): super().__init__() @@ -152,7 +162,7 @@ def match_indexing_maps(self, maps: list[ir.AffineMap]) -> bool: k_dims.append(d) continue return False - + self.contraction_dimensions = ContractionDimensions( batch=batch_dims, m=m_dims, diff --git a/tuner/tuner/spec_builder.py b/tuner/tuner/spec_builder.py index bc59b33c5..a27bd072f 100644 --- a/tuner/tuner/spec_builder.py +++ b/tuner/tuner/spec_builder.py @@ -8,7 +8,7 @@ # in the code and runs it. from iree.compiler import ir # type: ignore -from iree.compiler.dialects import iree_codegen # type: ignore +from iree.compiler.dialects import iree_codegen # type: ignore from .common import * from .dispatch_constraints import *