diff --git a/tuner/examples/simple/simple_tuner.py b/tuner/examples/simple/simple_tuner.py index 979e069fd..76a4d6230 100644 --- a/tuner/examples/simple/simple_tuner.py +++ b/tuner/examples/simple/simple_tuner.py @@ -6,6 +6,7 @@ import argparse from pathlib import Path +import shutil from tuner import libtuner from tuner.common import * @@ -46,6 +47,12 @@ def main(): client_args.add_argument( "simple_model_file", type=Path, help="Path to the model file to tune (.mlir)" ) + client_args.add_argument( + "--simple-best-spec-output-path", + type=Path, + help="Path to write the best tuned spec after", + default=None, + ) client_args.add_argument( "--simple-num-dispatch-candidates", type=int, @@ -135,6 +142,10 @@ def main(): for id in top_candidates: logging.info(f"{candidate_trackers[id].spec_path.resolve()}") if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches: + if args.simple_best_spec_output_path: + top_spec_path = path_config.specs_dir / path_config.get_candidate_spec_filename(top_candidates[0]) + shutil.copy(top_spec_path, args.simple_best_spec_output_path) + print(f"Saved top spec ({top_spec_path}) to {args.simple_best_spec_output_path}") return print("Compiling models with top candidates...") @@ -167,6 +178,11 @@ def main(): for id in top_model_candidates: logging.info(f"{candidate_trackers[id].spec_path.resolve()}") print(f"Top model candidates: {top_model_candidates}") + + if args.simple_best_spec_output_path: + top_spec_path = path_config.specs_dir / path_config.get_candidate_spec_filename(top_model_candidates[0]) + shutil.copy(top_spec_path, args.simple_best_spec_output_path) + print(f"Saved top spec ({top_spec_path}) to {args.simple_best_spec_output_path}") print("Check the detailed execution logs in:") print(path_config.run_log.resolve()) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index fd6fcdcc4..0d09374dd 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -41,6 +41,7 @@ def get_td_spec( self, ir_module: ir.Module, compilation_info: iree_codegen.CompilationInfoAttr, + args: argparse.Namespace, ) -> ir.Module: """Generate a transform dialect spec that applies the compilation info attr.""" pass @@ -66,6 +67,7 @@ def get_td_spec( self, ir_module: ir.Module, compilation_info: iree_codegen.CompilationInfoAttr, + args: argparse.Namespace, ) -> ir.Module: contraction_op: ir.Operation = self.get_contraction_operation(ir_module) lhs_type = ir.ShapedType(contraction_op.operands[0].type) @@ -77,7 +79,7 @@ def get_td_spec( # TODO(Max191): Get the function name from the func.func in the input module. func_name = f"match_contraction_{M}x{N}x{K}_{lhs_type.element_type}x{rhs_type.element_type}x{acc_type.element_type}" return build_td_spec( - ir_module.context, contraction_op, compilation_info, func_name + ir_module.context, contraction_op, compilation_info, func_name, args ) @@ -86,6 +88,7 @@ def get_td_spec( self, ir_module: ir.Module, compilation_info: iree_codegen.CompilationInfoAttr, + args: argparse.Namespace, ) -> ir.Module: conv_op: ir.Operation = self.get_conv_operation(ir_module) assert ( @@ -104,7 +107,7 @@ def get_td_spec( conv_type = conv_op.name.split(".")[-1] # TODO(Max191): Get the function name from the func.func in the input module. func_name = f"match_{conv_type}_{N}x{H}x{W}x{C}x{P}x{Q}x{F}_{lhs_type.element_type}x{rhs_type.element_type}x{acc_type.element_type}" - return build_td_spec(ir_module.context, conv_op, compilation_info, func_name) + return build_td_spec(ir_module.context, conv_op, compilation_info, func_name, args) @dataclass @@ -150,6 +153,7 @@ def get_default_output_dir() -> str: def generate_configs_and_td_specs( input_module: ir.Module, # Path to the mlir file to be tuned tuner_context: TunerContext, + args: argparse.Namespace, limit: int = 4096, # Max candidates to be generated num_subgroups: int = 4, # GPU spec, used to determine candidate generation constraints allowed_waves_per_eu: list[int] = [2], @@ -174,7 +178,7 @@ def generate_configs_and_td_specs( tune_logger.debug(str(problem_size)) # Index 0 is reserved for default config, so it gets a placeholder spec. - config_specs: list[ir.Module] = [get_placeholder_spec(input_module.context)] + config_specs: list[ir.Module] = [get_placeholder_spec(input_module.context, args)] # Get the MMA intrinisic intructions supported by the target. variant_op_list = iree_codegen.get_executable_variant_ops(input_module) @@ -195,7 +199,7 @@ def generate_configs_and_td_specs( if i >= limit: break tune_logger.debug(f"Solution #{i+1}: {config}") - td_spec_module = dispatch_tuner.get_td_spec(input_module, config) + td_spec_module = dispatch_tuner.get_td_spec(input_module, config, args) assert td_spec_module, "Failed to generate transform dialect spec" config_specs.append(td_spec_module) @@ -352,6 +356,7 @@ def main() -> None: specs: list[ir.Module] = generate_configs_and_td_specs( mlir_module, tuner_ctx, + args, args.limit, args.num_subgroups, args.waves_per_eu_options, diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index f3a73bb3d..1a1a75e0c 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -257,6 +257,9 @@ def parse_arguments( # General options general_args = parser.add_argument_group("General Options") + general_args.add_argument( + "--extra-spec-file", type=Path, help="Path to an additional spec file to append to each spec" + ) general_args.add_argument( "--verbose", "-v", action="store_true", help="Enable verbose output to stdout" ) @@ -685,6 +688,7 @@ def generate_candidate_specs( config_specs: list[ir.Module] = candidate_gen.generate_configs_and_td_specs( input_module=mlir_module, tuner_context=tuning_client.tuner_context, + args=args, limit=args.num_candidates, num_subgroups=args.num_subgroups, allowed_waves_per_eu=args.waves_per_eu_options, diff --git a/tuner/tuner/spec_builder.py b/tuner/tuner/spec_builder.py index 6005e89ae..de08545f4 100644 --- a/tuner/tuner/spec_builder.py +++ b/tuner/tuner/spec_builder.py @@ -7,8 +7,10 @@ # Given an input dispatch, this code modifies the hyperparameters # in the code and runs it. +import argparse +from pathlib import Path from iree.compiler import ir # type: ignore -from iree.compiler.dialects import iree_codegen # type: ignore +from iree.compiler.dialects import iree_codegen, transform # type: ignore from .common import * from .dispatch_constraints import * @@ -16,16 +18,73 @@ from .op_matchers import ROOT_OP_ATTR_NAME -def get_placeholder_spec(context: ir.Context) -> ir.Module: - spec_text = f""" - module attributes {{ transform.with_named_sequence }} {{ - transform.named_sequence - @__kernel_config(%variant_op: !transform.any_op {{transform.readonly}}) -> !transform.any_op - attributes {{ iree_codegen.tuning_spec_entrypoint }} {{ - transform.yield %variant_op : !transform.any_op +def get_matcher_calls(op: ir.Operation, matchers: list[ir.Attribute], actions: list[ir.Attribute]): + if not isinstance(op.opview, transform.ForeachMatchOp): + return ir.WalkResult.ADVANCE + matchers.extend(op.opview.attributes["matchers"]) + actions.extend(op.opview.attributes["actions"]) + return ir.WalkResult.INTERRUPT + +def get_extra_spec_info(spec_file: Path, used_matcher_names: list[str], used_action_names: list[str]): + with open(spec_file) as f: + spec = f.read() + with ir.Context(): + spec_module = ir.Module.parse(spec) + matchers: transform.NamedSequenceOp = [] + matcher_calls: list[ir.FlatSymbolRefAttr] = [] + matcher_call_actions: list[ir.FlatSymbolRefAttr] = [] + for op in spec_module.body.operations: + if not isinstance(op, transform.NamedSequenceOp): + continue + if op.sym_name.value == "__kernel_config": + op.walk( + lambda op: get_matcher_calls(op, matcher_calls, matcher_call_actions), + ir.WalkOrder.POST_ORDER, + ) + continue + matchers.append(op) + filtered_matcher_calls = [] + filtered_matcher_call_actions = [] + for call, action in zip(matcher_calls, matcher_call_actions): + if call.value in used_matcher_names: + continue + filtered_matcher_calls.append(call) + filtered_matcher_call_actions.append(action) + filtered_matchers = [matcher for matcher in matchers if matcher.sym_name.value not in used_matcher_names + used_action_names] + matcher_calls_str = "\n, ".join([f"{call} -> {action}" for call, action in zip(filtered_matcher_calls, filtered_matcher_call_actions)]) + matchers_str = "\n".join([str(op) for op in filtered_matchers]) + return matchers_str, matcher_calls_str + +def get_placeholder_spec(context: ir.Context, args: argparse.Namespace) -> ir.Module: + extra_matchers, extra_matcher_fn_calls = get_extra_spec_info( + args.extra_spec_file, + used_matcher_names=[], + used_action_names=[], + ) + if len(extra_matcher_fn_calls) == 0: + spec_text = f""" + module attributes {{ transform.with_named_sequence }} {{ + transform.named_sequence + @__kernel_config(%variant_op: !transform.any_op {{transform.readonly}}) -> !transform.any_op + attributes {{ iree_codegen.tuning_spec_entrypoint }} {{ + transform.yield %variant_op : !transform.any_op + }} }} - }} - """ + """ + else: + spec_text = f""" + module attributes {{ transform.with_named_sequence }} {{ + {extra_matchers} + transform.named_sequence + @__kernel_config(%variant_op: !transform.any_op {{transform.consumed}}) -> !transform.any_op + attributes {{ iree_codegen.tuning_spec_entrypoint }} {{ + %res = transform.foreach_match in %variant_op + {extra_matcher_fn_calls} + : (!transform.any_op) -> (!transform.any_op) + transform.yield %res : !transform.any_op + }} + }} + """ return ir.Module.parse(spec_text, context) @@ -36,7 +95,13 @@ def build_td_spec( op: ir.Operation, compilation_info: iree_codegen.CompilationInfoAttr, func_name: str, + args: argparse.Namespace, ) -> ir.Module: + extra_matchers, extra_matcher_fn_calls = get_extra_spec_info( + args.extra_spec_file, + used_matcher_names=[func_name], + used_action_names=["apply_op_config"], + ) bbargs = [] # The `root_op` attribute will prevent matching of ops without the attr in # the resulting TD spec matcher if it is not removed, so we remove it here. @@ -95,12 +160,15 @@ def build_td_spec( transform.yield %cont, %config : !transform.any_op, !transform.any_param }} + {extra_matchers} + // Entry Point transform.named_sequence @__kernel_config(%variant_op: !transform.any_op {{transform.consumed}}) -> !transform.any_op attributes {{ iree_codegen.tuning_spec_entrypoint }} {{ %res = transform.foreach_match in %variant_op @{func_name} -> @apply_op_config + , {extra_matcher_fn_calls} : (!transform.any_op) -> !transform.any_op transform.yield %res : !transform.any_op }}