Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tuner] Allow passing extra TD specs during tuning #925

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tuner/examples/simple/simple_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import argparse
from pathlib import Path
import shutil
from tuner import libtuner
from tuner.common import *

Expand Down Expand Up @@ -46,6 +47,12 @@ def main():
client_args.add_argument(
"simple_model_file", type=Path, help="Path to the model file to tune (.mlir)"
)
client_args.add_argument(
"--simple-best-spec-output-path",
type=Path,
help="Path to write the best tuned spec after",
default=None,
)
client_args.add_argument(
"--simple-num-dispatch-candidates",
type=int,
Expand Down Expand Up @@ -135,6 +142,10 @@ def main():
for id in top_candidates:
logging.info(f"{candidate_trackers[id].spec_path.resolve()}")
if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches:
if args.simple_best_spec_output_path:
top_spec_path = path_config.specs_dir / path_config.get_candidate_spec_filename(top_candidates[0])
shutil.copy(top_spec_path, args.simple_best_spec_output_path)
print(f"Saved top spec ({top_spec_path}) to {args.simple_best_spec_output_path}")
return

print("Compiling models with top candidates...")
Expand Down Expand Up @@ -167,6 +178,11 @@ def main():
for id in top_model_candidates:
logging.info(f"{candidate_trackers[id].spec_path.resolve()}")
print(f"Top model candidates: {top_model_candidates}")

if args.simple_best_spec_output_path:
top_spec_path = path_config.specs_dir / path_config.get_candidate_spec_filename(top_model_candidates[0])
shutil.copy(top_spec_path, args.simple_best_spec_output_path)
print(f"Saved top spec ({top_spec_path}) to {args.simple_best_spec_output_path}")

print("Check the detailed execution logs in:")
print(path_config.run_log.resolve())
Expand Down
13 changes: 9 additions & 4 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def get_td_spec(
self,
ir_module: ir.Module,
compilation_info: iree_codegen.CompilationInfoAttr,
args: argparse.Namespace,
) -> ir.Module:
"""Generate a transform dialect spec that applies the compilation info attr."""
pass
Expand All @@ -66,6 +67,7 @@ def get_td_spec(
self,
ir_module: ir.Module,
compilation_info: iree_codegen.CompilationInfoAttr,
args: argparse.Namespace,
) -> ir.Module:
contraction_op: ir.Operation = self.get_contraction_operation(ir_module)
lhs_type = ir.ShapedType(contraction_op.operands[0].type)
Expand All @@ -77,7 +79,7 @@ def get_td_spec(
# TODO(Max191): Get the function name from the func.func in the input module.
func_name = f"match_contraction_{M}x{N}x{K}_{lhs_type.element_type}x{rhs_type.element_type}x{acc_type.element_type}"
return build_td_spec(
ir_module.context, contraction_op, compilation_info, func_name
ir_module.context, contraction_op, compilation_info, func_name, args
)


Expand All @@ -86,6 +88,7 @@ def get_td_spec(
self,
ir_module: ir.Module,
compilation_info: iree_codegen.CompilationInfoAttr,
args: argparse.Namespace,
) -> ir.Module:
conv_op: ir.Operation = self.get_conv_operation(ir_module)
assert (
Expand All @@ -104,7 +107,7 @@ def get_td_spec(
conv_type = conv_op.name.split(".")[-1]
# TODO(Max191): Get the function name from the func.func in the input module.
func_name = f"match_{conv_type}_{N}x{H}x{W}x{C}x{P}x{Q}x{F}_{lhs_type.element_type}x{rhs_type.element_type}x{acc_type.element_type}"
return build_td_spec(ir_module.context, conv_op, compilation_info, func_name)
return build_td_spec(ir_module.context, conv_op, compilation_info, func_name, args)


@dataclass
Expand Down Expand Up @@ -150,6 +153,7 @@ def get_default_output_dir() -> str:
def generate_configs_and_td_specs(
input_module: ir.Module, # Path to the mlir file to be tuned
tuner_context: TunerContext,
args: argparse.Namespace,
limit: int = 4096, # Max candidates to be generated
num_subgroups: int = 4, # GPU spec, used to determine candidate generation constraints
allowed_waves_per_eu: list[int] = [2],
Expand All @@ -174,7 +178,7 @@ def generate_configs_and_td_specs(
tune_logger.debug(str(problem_size))

# Index 0 is reserved for default config, so it gets a placeholder spec.
config_specs: list[ir.Module] = [get_placeholder_spec(input_module.context)]
config_specs: list[ir.Module] = [get_placeholder_spec(input_module.context, args)]

# Get the MMA intrinisic intructions supported by the target.
variant_op_list = iree_codegen.get_executable_variant_ops(input_module)
Expand All @@ -195,7 +199,7 @@ def generate_configs_and_td_specs(
if i >= limit:
break
tune_logger.debug(f"Solution #{i+1}: {config}")
td_spec_module = dispatch_tuner.get_td_spec(input_module, config)
td_spec_module = dispatch_tuner.get_td_spec(input_module, config, args)
assert td_spec_module, "Failed to generate transform dialect spec"
config_specs.append(td_spec_module)

Expand Down Expand Up @@ -352,6 +356,7 @@ def main() -> None:
specs: list[ir.Module] = generate_configs_and_td_specs(
mlir_module,
tuner_ctx,
args,
args.limit,
args.num_subgroups,
args.waves_per_eu_options,
Expand Down
4 changes: 4 additions & 0 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ def parse_arguments(

# General options
general_args = parser.add_argument_group("General Options")
general_args.add_argument(
"--extra-spec-file", type=Path, help="Path to an additional spec file to append to each spec"
)
general_args.add_argument(
"--verbose", "-v", action="store_true", help="Enable verbose output to stdout"
)
Expand Down Expand Up @@ -685,6 +688,7 @@ def generate_candidate_specs(
config_specs: list[ir.Module] = candidate_gen.generate_configs_and_td_specs(
input_module=mlir_module,
tuner_context=tuning_client.tuner_context,
args=args,
limit=args.num_candidates,
num_subgroups=args.num_subgroups,
allowed_waves_per_eu=args.waves_per_eu_options,
Expand Down
88 changes: 78 additions & 10 deletions tuner/tuner/spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,84 @@
# Given an input dispatch, this code modifies the hyperparameters
# in the code and runs it.

import argparse
from pathlib import Path
from iree.compiler import ir # type: ignore
from iree.compiler.dialects import iree_codegen # type: ignore
from iree.compiler.dialects import iree_codegen, transform # type: ignore

from .common import *
from .dispatch_constraints import *
from .dispatch_parser import *
from .op_matchers import ROOT_OP_ATTR_NAME


def get_placeholder_spec(context: ir.Context) -> ir.Module:
spec_text = f"""
module attributes {{ transform.with_named_sequence }} {{
transform.named_sequence
@__kernel_config(%variant_op: !transform.any_op {{transform.readonly}}) -> !transform.any_op
attributes {{ iree_codegen.tuning_spec_entrypoint }} {{
transform.yield %variant_op : !transform.any_op
def get_matcher_calls(op: ir.Operation, matchers: list[ir.Attribute], actions: list[ir.Attribute]):
if not isinstance(op.opview, transform.ForeachMatchOp):
return ir.WalkResult.ADVANCE
matchers.extend(op.opview.attributes["matchers"])
actions.extend(op.opview.attributes["actions"])
return ir.WalkResult.INTERRUPT

def get_extra_spec_info(spec_file: Path, used_matcher_names: list[str], used_action_names: list[str]):
with open(spec_file) as f:
spec = f.read()
with ir.Context():
spec_module = ir.Module.parse(spec)
matchers: transform.NamedSequenceOp = []
matcher_calls: list[ir.FlatSymbolRefAttr] = []
matcher_call_actions: list[ir.FlatSymbolRefAttr] = []
for op in spec_module.body.operations:
if not isinstance(op, transform.NamedSequenceOp):
continue
if op.sym_name.value == "__kernel_config":
op.walk(
lambda op: get_matcher_calls(op, matcher_calls, matcher_call_actions),
ir.WalkOrder.POST_ORDER,
)
continue
matchers.append(op)
filtered_matcher_calls = []
filtered_matcher_call_actions = []
for call, action in zip(matcher_calls, matcher_call_actions):
if call.value in used_matcher_names:
continue
filtered_matcher_calls.append(call)
filtered_matcher_call_actions.append(action)
filtered_matchers = [matcher for matcher in matchers if matcher.sym_name.value not in used_matcher_names + used_action_names]
matcher_calls_str = "\n, ".join([f"{call} -> {action}" for call, action in zip(filtered_matcher_calls, filtered_matcher_call_actions)])
matchers_str = "\n".join([str(op) for op in filtered_matchers])
return matchers_str, matcher_calls_str

def get_placeholder_spec(context: ir.Context, args: argparse.Namespace) -> ir.Module:
extra_matchers, extra_matcher_fn_calls = get_extra_spec_info(
args.extra_spec_file,
used_matcher_names=[],
used_action_names=[],
)
if len(extra_matcher_fn_calls) == 0:
spec_text = f"""
module attributes {{ transform.with_named_sequence }} {{
transform.named_sequence
@__kernel_config(%variant_op: !transform.any_op {{transform.readonly}}) -> !transform.any_op
attributes {{ iree_codegen.tuning_spec_entrypoint }} {{
transform.yield %variant_op : !transform.any_op
}}
}}
}}
"""
"""
else:
spec_text = f"""
module attributes {{ transform.with_named_sequence }} {{
{extra_matchers}
transform.named_sequence
@__kernel_config(%variant_op: !transform.any_op {{transform.consumed}}) -> !transform.any_op
attributes {{ iree_codegen.tuning_spec_entrypoint }} {{
%res = transform.foreach_match in %variant_op
{extra_matcher_fn_calls}
: (!transform.any_op) -> (!transform.any_op)
transform.yield %res : !transform.any_op
}}
}}
"""
return ir.Module.parse(spec_text, context)


Expand All @@ -36,7 +95,13 @@ def build_td_spec(
op: ir.Operation,
compilation_info: iree_codegen.CompilationInfoAttr,
func_name: str,
args: argparse.Namespace,
) -> ir.Module:
extra_matchers, extra_matcher_fn_calls = get_extra_spec_info(
args.extra_spec_file,
used_matcher_names=[func_name],
used_action_names=["apply_op_config"],
)
bbargs = []
# The `root_op` attribute will prevent matching of ops without the attr in
# the resulting TD spec matcher if it is not removed, so we remove it here.
Expand Down Expand Up @@ -95,12 +160,15 @@ def build_td_spec(
transform.yield %cont, %config : !transform.any_op, !transform.any_param
}}

{extra_matchers}

// Entry Point
transform.named_sequence
@__kernel_config(%variant_op: !transform.any_op {{transform.consumed}}) -> !transform.any_op
attributes {{ iree_codegen.tuning_spec_entrypoint }} {{
%res = transform.foreach_match in %variant_op
@{func_name} -> @apply_op_config
, {extra_matcher_fn_calls}
: (!transform.any_op) -> !transform.any_op
transform.yield %res : !transform.any_op
}}
Expand Down
Loading