Skip to content

Commit

Permalink
[tuner] Allow tuning of pipeline options (#791)
Browse files Browse the repository at this point in the history
This PR adds some command line arguments to tune pipeline options. The
new flags are `--prefetch-shared-memory-options`,
`--no-reduce-shared-memory-bank-conflicts-options`, and
`--waves-per-eu-options`. The flags take a comma separated list of
values (bool or int), representing the possible values that the
corresponding pipeline options can take in tuning.

This PR also adds `promote_operands = [0, 1]` to the TileAndFuse tuning
configurations. This could potentially be a tuned parameter in the
future, but for now it is typically good to promote the lhs and rhs
operands.

Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 authored Jan 8, 2025
1 parent 080b090 commit ec73934
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 58 deletions.
37 changes: 34 additions & 3 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def generate_configs_and_td_specs(
tuner_context: TunerContext,
limit: int = 4096, # Max candidates to be generated
num_subgroups: int = 4, # GPU spec, used to determine candidate generation constraints
allowed_waves_per_eu: list[int] = [2],
pipeline_options_search_space: PipelineOptionsSearchSpace = PipelineOptionsSearchSpace(),
codegen_pipeline: iree_codegen.DispatchLoweringPassPipeline = iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute,
) -> list[ir.Module]:
dispatch_tuner_registry = DispatchTunerRegistry(check_translation_info=False)
Expand Down Expand Up @@ -223,7 +225,13 @@ def generate_configs_and_td_specs(
mma_list = iree_codegen.query_mma_intrinsics(variant_op)
for i, config in enumerate(
generate_solutions(
tuner_context, problem_size, num_subgroups, mma_list, codegen_pipeline
tuner_context,
problem_size,
num_subgroups,
mma_list,
allowed_waves_per_eu,
pipeline_options_search_space,
codegen_pipeline,
)
):
if i >= limit:
Expand Down Expand Up @@ -348,6 +356,24 @@ def main():
type=int,
default=-1,
)
parser.add_argument(
"--prefetch-shared-memory-options",
type=lambda t: [s.strip().lower() == "true" for s in t.split(",")],
default=[True],
help="Comma-separated list of allowed values for the prefetch_shared_memory pipeline option. Possible values: [True, False]",
)
parser.add_argument(
"--no-reduce-shared-memory-bank-conflicts-options",
type=lambda t: [s.strip().lower() == "true" for s in t.split(",")],
default=[None],
help="Comma-separated list of allowed values for the no_reduce_shared_memory_bank_conflicts pipeline option. Possible values: [True, False]",
)
parser.add_argument(
"--waves-per-eu-options",
type=lambda t: [int(s) for s in t.split(",")],
default=[2],
help="Comma-separated list of allowed values for the waves_per_eu config option. Possible values: Any positive integer value",
)
parser.add_argument(
"--verbose", "-v", action="store_true", help="Enable verbose output to stdout"
)
Expand All @@ -363,15 +389,20 @@ def main():
console_handler.setFormatter(formatter)
tune_logger.addHandler(console_handler)

with ir.Context() as ctx:
tuner_ctx = TunerContext(ctx, tune_logger)
with TunerContext() as tuner_ctx:
mlir_text = strip_compilation_info(args.input)
mlir_module = parse_mlir(mlir_text, tuner_ctx)
pipeline_options_search_space = PipelineOptionsSearchSpace(
prefetch_shared_memory=args.prefetch_shared_memory_options,
no_reduce_shared_memory_bank_conflicts=args.no_reduce_shared_memory_bank_conflicts_options,
)
specs = generate_configs_and_td_specs(
mlir_module,
tuner_ctx,
args.limit,
args.num_subgroups,
args.waves_per_eu_options,
pipeline_options_search_space,
iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse,
)
for candidate_num, spec in enumerate(specs):
Expand Down
2 changes: 1 addition & 1 deletion tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def get_lowering_config(
# A local variable to hold the transformed value.
promoted_value = value
match key:
case "workgroup" | "reduction" | "subgroup":
case "workgroup" | "reduction" | "subgroup" | "promote_operands":
if isinstance(value, list):
promoted_value = ir.ArrayAttr.get(
[tuner_ctx.type.getI64(x) for x in value]
Expand Down
139 changes: 93 additions & 46 deletions tuner/tuner/dispatch_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def generate_vector_distribute_constraints(
workgroup_size: list[z3.ArithRef],
subgroup_m_count: z3.ArithRef,
subgroup_n_count: z3.ArithRef,
waves_per_eu: z3.ArithRef,
mma_intrinsics: list[iree_gpu.MMAIntrinsic],
):
M, N, K = (
Expand Down Expand Up @@ -142,9 +141,6 @@ def generate_vector_distribute_constraints(
else:
constraints += [subgroups >= 1, subgroups <= 10]

constraints += [waves_per_eu == 2]
# constraints += [z3.Or(waves_per_eu == 2, waves_per_eu == 3, waves_per_eu == 4)]

shared_memory = calculate_shared_memory_usage_in_bytes(problem_size, [m], [n], [k])
constraints += [shared_memory <= 65536]

Expand All @@ -162,7 +158,6 @@ def generate_tile_and_fuse_constraints(
workgroup_size: list[z3.ArithRef],
subgroup_m_count: z3.ArithRef,
subgroup_n_count: z3.ArithRef,
waves_per_eu: z3.ArithRef,
mma_intrinsics: list[iree_gpu.MMAIntrinsic],
):
M, N, K = problem_size.MNK
Expand Down Expand Up @@ -230,8 +225,6 @@ def generate_tile_and_fuse_constraints(
constraints += [subgroups >= 1, subgroups <= 10]
constraints += [wg_threads == subgroups * subgroup_size]

constraints += [waves_per_eu == 2]

shared_memory = calculate_shared_memory_usage_in_bytes(
problem_size, m_tiles, n_tiles, k_tiles
)
Expand Down Expand Up @@ -269,11 +262,89 @@ def getMMAAttr(
)


@dataclass
class PipelineOptionsSearchSpace:
prefetch_shared_memory: list[Optional[bool]] = field(default_factory=lambda: [None])
no_reduce_shared_memory_bank_conflicts: list[Optional[bool]] = field(
default_factory=lambda: [None]
)


def generate_allowed_pipeline_options(
pipeline_options_search_space: PipelineOptionsSearchSpace,
) -> list[iree_gpu.PipelineOptionsAttr]:
pipeline_options_list = []
for psm in pipeline_options_search_space.prefetch_shared_memory:
for (
nrbc
) in pipeline_options_search_space.no_reduce_shared_memory_bank_conflicts:
pipeline_options_list.append(
iree_gpu.PipelineOptionsAttr.get(
prefetch_shared_memory=psm,
no_reduce_shared_memory_bank_conflicts=nrbc,
)
)
return pipeline_options_list


def generate_compilation_infos(
tuner_ctx: TunerContext,
mma_attr: iree_gpu.MMAAttr,
workgroup_tile_sizes: list[int],
reduction_tile_sizes: list[int],
subgroup_tile_sizes: list[int],
workgroup_sizes: tuple[int, int, int],
subgroup_size: int,
subgroup_m_count: int,
subgroup_n_count: int,
codegen_pipeline: iree_codegen.DispatchLoweringPassPipeline,
pipeline_options_search_space: PipelineOptionsSearchSpace,
allowed_waves_per_eu: list[int],
) -> list[iree_codegen.CompilationInfoAttr]:
# Create the LoweringConfigAttr.
lowering_config_args = {
"tuner_ctx": tuner_ctx,
"mma_kind": mma_attr,
"workgroup": workgroup_tile_sizes,
"reduction": reduction_tile_sizes,
"subgroup_m_count": subgroup_m_count,
"subgroup_n_count": subgroup_n_count,
}
if codegen_pipeline == iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse:
lowering_config_args["subgroup"] = subgroup_tile_sizes
lowering_config_args["promote_operands"] = [0, 1]
lowering_config = get_lowering_config(**lowering_config_args)

# Create the TranslationInfoAttr
pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get(codegen_pipeline)
pipeline_options_list = generate_allowed_pipeline_options(
pipeline_options_search_space
)
wg_x, wg_y, wg_z = workgroup_sizes
compilation_infos = []
for pipeline_options in pipeline_options_list:
for waves_per_eu in allowed_waves_per_eu:
config_dict = get_translation_info_config(pipeline_options, waves_per_eu)
translation_info = iree_codegen.TranslationInfoAttr.get(
pipeline_attr,
None,
[wg_x, wg_y, wg_z],
subgroup_size,
config_dict,
)
compilation_infos.append(
iree_codegen.CompilationInfoAttr.get(lowering_config, translation_info)
)
return compilation_infos


def generate_solutions(
tuner_ctx: TunerContext,
problem_size: ProblemSize,
num_subgrups: int,
mma_intrinsics: list[iree_gpu.MMAIntrinsic],
allowed_waves_per_eu: list[int] = [2],
pipeline_options_search_space: PipelineOptionsSearchSpace = PipelineOptionsSearchSpace(),
codegen_pipeline: iree_codegen.DispatchLoweringPassPipeline = iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute,
) -> Iterator[iree_codegen.CompilationInfoAttr]:
M, N, K = problem_size.MNK
Expand All @@ -290,7 +361,6 @@ def generate_solutions(
wg_x, wg_y, wg_z = z3.Int("wg_x"), z3.Int("wg_y"), z3.Int("wg_z")
sg_m_cnt = z3.Int("sg_m_cnt")
sg_n_cnt = z3.Int("sg_n_cnt")
waves_per_eu = z3.Int("waves_per_eu")
all_vars = (
m_vars
+ n_vars
Expand All @@ -304,7 +374,6 @@ def generate_solutions(
wg_z,
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
]
)

Expand All @@ -320,7 +389,6 @@ def generate_solutions(
[wg_x, wg_y, wg_z],
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
mma_intrinsics,
)
constraints += [v == 0 for v in subgroup_m_vars + subgroup_n_vars]
Expand All @@ -334,7 +402,6 @@ def generate_solutions(
[wg_x, wg_y, wg_z],
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
mma_intrinsics,
)
solver.add(z3.simplify(z3.And(constraints)))
Expand Down Expand Up @@ -407,43 +474,23 @@ def set_cdim_tile_sizes(tile_sizes, contraction_dims, csizes):
[lookup(v) for v in k_vars],
)

# Create the LoweringConfigAttr.
lowering_config_args = {
"tuner_ctx": tuner_ctx,
"mma_kind": mma_attr,
"workgroup": workgroup_tile_sizes,
"reduction": reduction_tile_sizes,
"subgroup_m_count": lookup(sg_m_cnt),
"subgroup_n_count": lookup(sg_n_cnt),
}
if (
codegen_pipeline
== iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse
):
lowering_config_args["subgroup"] = subgroup_tile_sizes
lowering_config = get_lowering_config(**lowering_config_args)

# Create the TranslationInfoAttr
pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get(
codegen_pipeline
)
pipeline_options = iree_gpu.PipelineOptionsAttr.get()
config_dict = get_translation_info_config(
pipeline_options, lookup(waves_per_eu)
)
translation_info = iree_codegen.TranslationInfoAttr.get(
pipeline_attr,
None,
[lookup(wg_x), lookup(wg_y), lookup(wg_z)],
compilation_infos = generate_compilation_infos(
tuner_ctx,
mma_attr,
workgroup_tile_sizes,
reduction_tile_sizes,
subgroup_tile_sizes,
(lookup(wg_x), lookup(wg_y), lookup(wg_z)),
lookup(subgroup_size),
config_dict,
)

# Create the CompilationInfoAttr.
compilation_info = iree_codegen.CompilationInfoAttr.get(
lowering_config, translation_info
lookup(sg_m_cnt),
lookup(sg_n_cnt),
codegen_pipeline,
pipeline_options_search_space,
allowed_waves_per_eu,
)

solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars)))))
i += 1
yield compilation_info

for compilation_info in compilation_infos:
yield compilation_info
8 changes: 0 additions & 8 deletions tuner/tuner/dispatch_constraints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ def test_generate_tile_and_fuse_constraints_valid_input(
)
sg_m_cnt = z3.Int("sg_m_cnt")
sg_n_cnt = z3.Int("sg_n_cnt")
waves_per_eu = z3.Int("waves_per_eu")

constraints = dispatch_constraints.generate_tile_and_fuse_constraints(
problem_size,
Expand All @@ -177,7 +176,6 @@ def test_generate_tile_and_fuse_constraints_valid_input(
[wg_x, wg_y, wg_z],
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
[
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
Expand Down Expand Up @@ -240,7 +238,6 @@ def test_generate_tile_and_fuse_constraints_invalid_input(
)
sg_m_cnt = z3.Int("sg_m_cnt")
sg_n_cnt = z3.Int("sg_n_cnt")
waves_per_eu = z3.Int("waves_per_eu")

constraints = dispatch_constraints.generate_tile_and_fuse_constraints(
problem_size,
Expand All @@ -251,7 +248,6 @@ def test_generate_tile_and_fuse_constraints_invalid_input(
[wg_x, wg_y, wg_z],
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
[
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
Expand Down Expand Up @@ -300,7 +296,6 @@ def test_generate_vector_distribute_constraints_valid_input(
)
sg_m_cnt = z3.Int("sg_m_cnt")
sg_n_cnt = z3.Int("sg_n_cnt")
waves_per_eu = z3.Int("waves_per_eu")

constraints = dispatch_constraints.generate_vector_distribute_constraints(
problem_size,
Expand All @@ -311,7 +306,6 @@ def test_generate_vector_distribute_constraints_valid_input(
[wg_x, wg_y, wg_z],
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
[
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
Expand Down Expand Up @@ -359,7 +353,6 @@ def test_generate_vector_distribute_constraints_invalid_input(
)
sg_m_cnt = z3.Int("sg_m_cnt")
sg_n_cnt = z3.Int("sg_n_cnt")
waves_per_eu = z3.Int("waves_per_eu")

constraints = dispatch_constraints.generate_vector_distribute_constraints(
problem_size,
Expand All @@ -370,7 +363,6 @@ def test_generate_vector_distribute_constraints_invalid_input(
[wg_x, wg_y, wg_z],
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
[
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
Expand Down
Loading

0 comments on commit ec73934

Please sign in to comment.