Skip to content

Commit 9f56535

Browse files
committed
[tuner] add padding_conv attribute along IGEMM supprot for conv
Signed-off-by: Bangtian Liu <[email protected]>
1 parent 715b6a6 commit 9f56535

File tree

7 files changed

+446
-4
lines changed

7 files changed

+446
-4
lines changed

amdsharktuner/amdsharktuner/common.py

Lines changed: 134 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import tempfile
1919

2020
from iree.compiler import ir # type: ignore
21-
from iree.compiler.dialects import iree_codegen, iree_gpu, transform # type: ignore
21+
from iree.compiler.dialects import iree_codegen, iree_gpu, linalg, transform # type: ignore
2222
import iree.compiler as ireec # type: ignore
2323
from iree.compiler._mlir_libs._mlir import ir # type: ignore
2424

@@ -190,6 +190,23 @@ class ContractionDimensions:
190190
batch: list[int] = field(default_factory=list)
191191

192192

193+
@dataclass
194+
class ConvToIgemmInfo:
195+
"""
196+
Stores information about convolution to IGEMM transformation.
197+
Used by get_padding_conv_sizes to calculate padding_conv attribute.
198+
199+
Corresponds to ConvToIgemmInfo struct in IREE:
200+
https://github.com/iree-org/iree/blob/d3440737cc56a4d1b20c72181d9a37f194bd3ce5/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp#L373-L379
201+
"""
202+
203+
conv_dims: linalg.ConvolutionDimensions
204+
is_batch_dim_last: bool = False
205+
is_spatial_dim_last: bool = False
206+
conv_to_igemm_dim_map: dict[int, int] = field(default_factory=dict)
207+
input_channel_dim_to_size: dict[int, int] = field(default_factory=dict)
208+
209+
193210
@dataclass
194211
class MatmulShapeType:
195212
m: int
@@ -233,6 +250,24 @@ class AttentionKnobs(KnobAssignment):
233250
pass
234251

235252

253+
def is_affine_expr_function_of_dim(expr: ir.AffineExpr, position: int) -> bool:
254+
"""
255+
Return True if the expression depends on the dimension at the given position.
256+
"""
257+
if ir.AffineDimExpr.isinstance(expr):
258+
dim_expr = ir.AffineDimExpr(expr)
259+
return dim_expr.position == position
260+
261+
# Check if it's a binary operation and recursively check both sides.
262+
if ir.AffineBinaryExpr.isinstance(expr):
263+
binary_expr = ir.AffineBinaryExpr(expr)
264+
return is_affine_expr_function_of_dim(
265+
binary_expr.lhs, position
266+
) or is_affine_expr_function_of_dim(binary_expr.rhs, position)
267+
268+
return False
269+
270+
236271
def get_map_result_dim_positions(map: ir.AffineMap) -> Optional[list[int]]:
237272
if not map.is_projected_permutation:
238273
return None
@@ -281,7 +316,7 @@ def get_lowering_config(
281316
# A local variable to hold the transformed value.
282317
promoted_value = value
283318
match key:
284-
case "workgroup" | "reduction" | "subgroup" | "promote_operands" | "padding":
319+
case "workgroup" | "reduction" | "subgroup" | "promote_operands" | "padding" | "padding_conv":
285320
if isinstance(value, Sequence):
286321
promoted_value = ir.ArrayAttr.get(
287322
[tuner_ctx.type.getI64(x) for x in value]
@@ -565,8 +600,103 @@ def get_dim_bounds(
565600
return result
566601

567602

568-
# Use padding logic from IREE side:
569-
# https://github.com/iree-org/iree/blob/8ae91ebb0e555e660b8a6898f6071476f7a1f20b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp#L691-L703
603+
# Implemented the logic from IREE side:
604+
# https://github.com/iree-org/iree/blob/8ae91ebb0e555e660b8a6898f6071476f7a1f20b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp#L382-L467
605+
def get_padding_conv_sizes(
606+
bounds: list[int],
607+
padding_sizes: list[int],
608+
workgroup_tile_sizes: list[int],
609+
reduction_tile_sizes: list[int],
610+
conv_to_igemm_info: ConvToIgemmInfo,
611+
) -> Optional[list[int]]:
612+
"""
613+
Computes padding_conv by mapping padding from IGEMM space to convolution space.
614+
615+
Args:
616+
bounds: Loop bounds for each dimension
617+
padding_sizes: Padding sizes in IGEMM dimension space (M, N, K)
618+
workgroup_tile_sizes: Workgroup tile sizes
619+
reduction_tile_sizes: Reduction tile sizes
620+
conv_to_igemm_info: Convolution to IGEMM transformation info
621+
622+
Returns:
623+
Padding sizes in convolution dimension space, or None if no padding
624+
is needed along original convolution dimensions.
625+
"""
626+
# Skip padding convolution for NCHW layout (spatial dimensions are last).
627+
if conv_to_igemm_info.is_spatial_dim_last:
628+
return None
629+
630+
conv_to_igemm_map = conv_to_igemm_info.conv_to_igemm_dim_map
631+
padded_igemm_dims = set()
632+
conv_dims = conv_to_igemm_info.conv_dims
633+
input_channel_dims = set(conv_dims.input_channel)
634+
635+
padding_conv_sizes = [0] * len(conv_to_igemm_map)
636+
637+
# For batch-last layout (e.g., CHWN), only pad the batch dimension to avoid
638+
# introducing pad op as the producer of collapse_shape op which may cause fusion problem.
639+
if conv_to_igemm_info.is_batch_dim_last:
640+
last_batch_dim = conv_dims.batch[-1]
641+
igemm_batch_pos = conv_to_igemm_map[last_batch_dim]
642+
643+
if (
644+
padding_sizes[igemm_batch_pos]
645+
and bounds[igemm_batch_pos] % padding_sizes[igemm_batch_pos] == 0
646+
):
647+
return None
648+
649+
padding_conv_sizes[last_batch_dim] = padding_sizes[igemm_batch_pos]
650+
return padding_conv_sizes
651+
652+
for conv_dim, igemm_pos in conv_to_igemm_map.items():
653+
if reduction_tile_sizes[igemm_pos] != 0:
654+
# Skip conv padding for reduction dims if already divisible by padding size.
655+
if (
656+
padding_sizes[igemm_pos]
657+
and bounds[igemm_pos] % padding_sizes[igemm_pos] == 0
658+
):
659+
padded_igemm_dims.add(igemm_pos)
660+
continue
661+
662+
# Only pad input channel dims. If we need to pad filter dims, then we
663+
# would rather just do padding on the IGEMM instead.
664+
if conv_dim in input_channel_dims:
665+
# Multiple input channel dims for a single IGEMMPos is not supported.
666+
if igemm_pos in padded_igemm_dims:
667+
return None
668+
669+
input_channel_size = conv_to_igemm_info.input_channel_dim_to_size.get(
670+
conv_dim, 0
671+
)
672+
is_input_channel_size_small = (
673+
padding_sizes[igemm_pos] // input_channel_size > 2
674+
)
675+
676+
# If the input channel dimension is much smaller than the padding size,
677+
# skip padding along that dimension while still padding the others.
678+
if is_input_channel_size_small:
679+
padding_conv_sizes[conv_dim] = 0
680+
else:
681+
padding_conv_sizes[conv_dim] = padding_sizes[igemm_pos]
682+
683+
padded_igemm_dims.add(igemm_pos)
684+
continue
685+
686+
# Multiple padded parallel dims mapping to the same IGEMM dim is not supported.
687+
if workgroup_tile_sizes[igemm_pos] != 0 and igemm_pos in padded_igemm_dims:
688+
return None
689+
690+
padding_conv_sizes[conv_dim] = padding_sizes[igemm_pos]
691+
padded_igemm_dims.add(igemm_pos)
692+
693+
# Ensure that all dimensions have been padded.
694+
if len(padded_igemm_dims) != len(padding_sizes):
695+
return None
696+
697+
return padding_conv_sizes
698+
699+
570700
def calculate_padded_dimensions(
571701
M: list[int],
572702
N: list[int],

amdsharktuner/amdsharktuner/constraint_generator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def generate_generic_contraction_solutions(
8282
allowed_waves_per_eu: list[int] = [2],
8383
pipeline_options_search_space: dispatch_constraints.PipelineOptionsSearchSpace = dispatch_constraints.PipelineOptionsSearchSpace(),
8484
igemm_details: Optional[iree_codegen.IGEMMGenericConvDetails] = None,
85+
conv_to_igemm_info: Optional[common.ConvToIgemmInfo] = None,
8586
) -> Iterator[list[common.TuningConfiguration]]:
8687
adjust_problem_size_for_pipeline(
8788
contraction_dims,
@@ -259,6 +260,7 @@ def set_cdim_tile_sizes(tile_sizes, contraction_dims, csizes):
259260

260261
promote_operands = [0, 1]
261262
padding = None
263+
padding_conv = None
262264
if padding_applied:
263265
# TODO: Remove promotion of operand 2 once codegen supports handling padded outputs without promotion.
264266
promote_operands = [0, 1, 2]
@@ -271,6 +273,18 @@ def set_cdim_tile_sizes(tile_sizes, contraction_dims, csizes):
271273
padding_tile_sizes[inner_k_dim] *= mma_intrinsic_k
272274

273275
padding = padding_tile_sizes
276+
277+
# Calculate padding_conv sizes for convolutions when using IGEMM.
278+
if conv_to_igemm_info and igemm_details:
279+
# Use IGEMM loop bounds directly from igemm_details.
280+
bounds = list(igemm_details.igemm_loop_bounds)
281+
padding_conv = common.get_padding_conv_sizes(
282+
bounds,
283+
padding_tile_sizes,
284+
workgroup_tile_sizes,
285+
reduction_tile_sizes,
286+
conv_to_igemm_info,
287+
)
274288
# Setting subgroup basis.
275289
# TODO(Bangtian): Sync changes from IREE PR: https://github.com/iree-org/iree/pull/22000.
276290
subgroup_basis_counts = [1] * num_loops
@@ -295,6 +309,7 @@ def set_cdim_tile_sizes(tile_sizes, contraction_dims, csizes):
295309
pipeline_options_search_space,
296310
allowed_waves_per_eu,
297311
padding=padding,
312+
padding_conv=padding_conv,
298313
)
299314

300315
solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars)))))
@@ -595,6 +610,8 @@ def generate_solutions(
595610
codegen_pipeline: iree_codegen.DispatchLoweringPassPipeline,
596611
**pipeline_constraint_options,
597612
) -> Iterator[list[common.TuningConfiguration]]:
613+
# TODO(Bangtian): Simplify the function signature to accept op_info directly instead of
614+
# unpacking all individual fields.
598615
return generate_generic_contraction_solutions(
599616
tuner_ctx=tuner_context,
600617
gpu_target_info=gpu_target_info,
@@ -607,6 +624,7 @@ def generate_solutions(
607624
indexing_maps=self.op_info.indexing_maps,
608625
codegen_pipeline=codegen_pipeline,
609626
igemm_details=self.op_info.igemm_details,
627+
conv_to_igemm_info=self.op_info.conv_to_igemm_info,
610628
**pipeline_constraint_options,
611629
)
612630

amdsharktuner/amdsharktuner/dispatch_constraints.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,7 @@ def generate_compilation_infos(
672672
pipeline_options_search_space: PipelineOptionsSearchSpace,
673673
allowed_waves_per_eu: list[int],
674674
padding: Optional[list[int]] = None,
675+
padding_conv: Optional[list[int]] = None,
675676
) -> list[iree_codegen.CompilationInfoAttr]:
676677
subgroup_basis = [subgroup_basis_counts, subgroup_basis_mapping]
677678
# Create the LoweringConfigAttr.
@@ -688,6 +689,9 @@ def generate_compilation_infos(
688689
if padding is not None:
689690
lowering_config_args["padding"] = padding
690691

692+
if padding_conv is not None:
693+
lowering_config_args["padding_conv"] = padding_conv
694+
691695
if codegen_pipeline == iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse:
692696
lowering_config_args["subgroup"] = subgroup_tile_sizes
693697

amdsharktuner/amdsharktuner/dispatch_parser.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,55 @@ def parse_mlir(mlir_text: str, ctx: common.TunerContext) -> ir.Module:
4040
return mlir_module
4141

4242

43+
def build_conv_to_igemm_info(
44+
convolution_dims: linalg.ConvolutionDimensions,
45+
input_type: ir.Type,
46+
input_map: ir.AffineMap,
47+
igemm_details,
48+
) -> common.ConvToIgemmInfo:
49+
"""
50+
Builds ConvToIgemmInfo from convolution dimensions and IGEMM details.
51+
52+
Corresponds to IREE:
53+
https://github.com/iree-org/iree/blob/d3440737cc56a4d1b20c72181d9a37f194bd3ce5/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp#L872-L909
54+
"""
55+
input_shape = input_type.shape
56+
conv_to_igemm_info = common.ConvToIgemmInfo(conv_dims=convolution_dims)
57+
58+
# Map input channel dimensions to their sizes in the input tensor.
59+
for dim in convolution_dims.input_channel:
60+
for idx, expr in enumerate(input_map.results):
61+
if common.is_affine_expr_function_of_dim(expr, dim):
62+
conv_to_igemm_info.input_channel_dim_to_size[dim] = input_shape[idx]
63+
64+
# Process output image dimensions to find input image positions.
65+
input_image_pos = []
66+
for dim in convolution_dims.output_image:
67+
for idx, expr in enumerate(input_map.results):
68+
if common.is_affine_expr_function_of_dim(expr, dim):
69+
input_image_pos.append(idx)
70+
71+
# Process batch dimensions to find batch positions.
72+
batch_pos = []
73+
for dim in convolution_dims.batch:
74+
for idx, expr in enumerate(input_map.results):
75+
if common.is_affine_expr_function_of_dim(expr, dim):
76+
batch_pos.append(idx)
77+
78+
input_image_pos = sorted(input_image_pos)
79+
batch_pos = sorted(batch_pos)
80+
81+
conv_to_igemm_info.is_batch_dim_last = (
82+
len(batch_pos) > 0 and batch_pos[-1] == len(input_shape) - 1
83+
)
84+
conv_to_igemm_info.is_spatial_dim_last = (
85+
len(input_image_pos) > 0 and input_image_pos[-1] == len(input_shape) - 1
86+
)
87+
88+
conv_to_igemm_info.conv_to_igemm_dim_map = dict(igemm_details.conv_to_igemm_dim_map)
89+
return conv_to_igemm_info
90+
91+
4392
@dataclass
4493
class OpInfo:
4594
root_op: ir.Operation
@@ -74,6 +123,8 @@ class ConvolutionOpInfo(OpInfo):
74123

75124
# IGEMM details for TileAndFuse pipeline (None if not available).
76125
igemm_details: Optional[iree_codegen.IGEMMGenericConvDetails] = None
126+
# Convolution to IGEMM transformation info (None if not available).
127+
conv_to_igemm_info: Optional[common.ConvToIgemmInfo] = None
77128

78129

79130
@dataclass
@@ -275,6 +326,13 @@ def __init__(self, root_op: ir.Operation, tuner_ctx: common.TunerContext):
275326
# for any convolution layout (nhwc_hwcf, nchw_fchw, etc.).
276327
igemm_details = iree_codegen.get_igemm_generic_conv_details(root_op)
277328

329+
# Build ConvToIgemmInfo using convolution_dims.
330+
conv_to_igemm_info = None
331+
if igemm_details:
332+
conv_to_igemm_info = build_conv_to_igemm_info(
333+
convolution_dims, lhs_type, indexing_maps[0], igemm_details
334+
)
335+
278336
self._op_info: ConvolutionOpInfo = ConvolutionOpInfo(
279337
root_op=root_op,
280338
indexing_maps=indexing_maps,
@@ -292,6 +350,7 @@ def __init__(self, root_op: ir.Operation, tuner_ctx: common.TunerContext):
292350
strides=strides,
293351
dilations=dilations,
294352
igemm_details=igemm_details,
353+
conv_to_igemm_info=conv_to_igemm_info,
295354
)
296355

297356
def has_valid_root_op(self) -> bool:

0 commit comments

Comments
 (0)