Skip to content

Commit ac90ffe

Browse files
committed
[tuner] add padding_conv attribute along IGEMM supprot for conv
Signed-off-by: Bangtian Liu <[email protected]>
1 parent 927a2ff commit ac90ffe

File tree

6 files changed

+403
-12
lines changed

6 files changed

+403
-12
lines changed

amdsharktuner/amdsharktuner/common.py

Lines changed: 143 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import tempfile
1919

2020
from iree.compiler import ir # type: ignore
21-
from iree.compiler.dialects import iree_codegen, iree_gpu, transform # type: ignore
21+
from iree.compiler.dialects import iree_codegen, iree_gpu, linalg, transform # type: ignore
2222
import iree.compiler as ireec # type: ignore
2323
from iree.compiler._mlir_libs._mlir import ir # type: ignore
2424

@@ -190,6 +190,23 @@ class ContractionDimensions:
190190
batch: list[int] = field(default_factory=list)
191191

192192

193+
@dataclass
194+
class ConvToIgemmInfo:
195+
"""
196+
Stores information about convolution to IGEMM transformation.
197+
This corresponds to the C++ ConvToIgemmInfo struct in IREE.
198+
199+
Note: In C++, conv_to_igemm_dim_map is DenseMap<int64_t, AffineExpr>,
200+
but in Python bindings it's dict[int, int] mapping conv dim to IGEMM position.
201+
"""
202+
203+
is_batch_dim_last: bool = False
204+
is_spatial_dim_last: bool = False
205+
conv_dims: Optional[linalg.ConvolutionDimensions] = None
206+
conv_to_igemm_dim_map: dict[int, int] = field(default_factory=dict)
207+
input_channel_dim_to_size: dict[int, int] = field(default_factory=dict)
208+
209+
193210
@dataclass
194211
class MatmulShapeType:
195212
m: int
@@ -233,6 +250,24 @@ class AttentionKnobs(KnobAssignment):
233250
pass
234251

235252

253+
def is_affine_expr_function_of_dim(expr: ir.AffineExpr, position: int) -> bool:
254+
"""
255+
Return True if the expression depends on the dimension at the given position.
256+
"""
257+
if ir.AffineDimExpr.isinstance(expr):
258+
dim_expr = ir.AffineDimExpr(expr)
259+
return dim_expr.position == position
260+
261+
# Check if it's a binary operation and recursively check both sides.
262+
if ir.AffineBinaryExpr.isinstance(expr):
263+
binary_expr = ir.AffineBinaryExpr(expr)
264+
return is_affine_expr_function_of_dim(
265+
binary_expr.lhs, position
266+
) or is_affine_expr_function_of_dim(binary_expr.rhs, position)
267+
268+
return False
269+
270+
236271
def get_map_result_dim_positions(map: ir.AffineMap) -> Optional[list[int]]:
237272
if not map.is_projected_permutation:
238273
return None
@@ -281,7 +316,7 @@ def get_lowering_config(
281316
# A local variable to hold the transformed value.
282317
promoted_value = value
283318
match key:
284-
case "workgroup" | "reduction" | "subgroup" | "promote_operands" | "padding":
319+
case "workgroup" | "reduction" | "subgroup" | "promote_operands" | "padding" | "padding_conv":
285320
if isinstance(value, Sequence):
286321
promoted_value = ir.ArrayAttr.get(
287322
[tuner_ctx.type.getI64(x) for x in value]
@@ -562,8 +597,112 @@ def get_dim_bounds(
562597
return result, any_padding_applied
563598

564599

565-
# Use padding logic from IREE side:
566-
# https://github.com/iree-org/iree/blob/8ae91ebb0e555e660b8a6898f6071476f7a1f20b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp#L691-L703
600+
# Implemented padding logic from IREE side:
601+
# https://github.com/iree-org/iree/blob/8ae91ebb0e555e660b8a6898f6071476f7a1f20b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp#L382-L467
602+
def get_padding_conv_sizes(
603+
bounds: list[int],
604+
padding_sizes: list[int],
605+
workgroup_tile_sizes: list[int],
606+
reduction_tile_sizes: list[int],
607+
conv_to_igemm_info: ConvToIgemmInfo,
608+
) -> Optional[list[int]]:
609+
"""
610+
Calculate padding sizes for convolution dimensions when using IGEMM.
611+
This corresponds to C++ getPaddingConvSizes function in IREE.
612+
613+
Args:
614+
bounds: Loop bounds for each dimension
615+
padding_sizes: Padding sizes for IGEMM dimensions
616+
workgroup_tile_sizes: Workgroup tile sizes
617+
reduction_tile_sizes: Reduction tile sizes
618+
conv_to_igemm_info: Convolution to IGEMM transformation info (must not be None)
619+
620+
Returns:
621+
List of padding sizes for convolution dimensions, or None if padding should be skipped.
622+
Caller should convert to ArrayAttr if needed.
623+
"""
624+
# Skip padding convolution for NCHW layout (spatial dimension last).
625+
if conv_to_igemm_info.is_spatial_dim_last:
626+
return None
627+
628+
conv_to_igemm_map = conv_to_igemm_info.conv_to_igemm_dim_map
629+
padded_igemm_dims = set()
630+
conv_dims = conv_to_igemm_info.conv_dims
631+
632+
assert conv_dims is not None, "Expected conv_dims to be set in ConvToIgemmInfo"
633+
634+
input_channel_dims = set(conv_dims.input_channel)
635+
636+
padding_conv_sizes = [0] * len(conv_to_igemm_map)
637+
638+
# For batch-last layout (e.g., CHWN), only pad the batch dimension to avoid
639+
# introducing pad op as the producer of collapse_shape op which may cause fusion problem.
640+
if conv_to_igemm_info.is_batch_dim_last:
641+
last_batch_dim = conv_dims.batch[-1]
642+
# The map stores integer positions, use them directly.
643+
igemm_batch_pos = conv_to_igemm_map[last_batch_dim]
644+
645+
if (
646+
padding_sizes[igemm_batch_pos]
647+
and bounds[igemm_batch_pos] % padding_sizes[igemm_batch_pos] == 0
648+
):
649+
return None
650+
651+
padding_conv_sizes[last_batch_dim] = padding_sizes[igemm_batch_pos]
652+
return padding_conv_sizes
653+
654+
# Process each convolution dimension mapping.
655+
for conv_dim, igemm_pos in conv_to_igemm_map.items():
656+
# The map stores integer positions directly.
657+
658+
if reduction_tile_sizes[igemm_pos] != 0:
659+
# For reduction dimensions, avoid setting padding on the convolution
660+
# if the product of the corresponding conv sizes are already divisible by the padding size.
661+
if (
662+
padding_sizes[igemm_pos]
663+
and bounds[igemm_pos] % padding_sizes[igemm_pos] == 0
664+
):
665+
padded_igemm_dims.add(igemm_pos)
666+
continue
667+
668+
# Only pad input channel dims. If we need to pad filter dims, then we
669+
# would rather just do padding on the IGEMM instead.
670+
if conv_dim in input_channel_dims:
671+
# Multiple input channel dims for a single IGEMMPos is not supported.
672+
if igemm_pos in padded_igemm_dims:
673+
return None
674+
675+
input_channel_size = conv_to_igemm_info.input_channel_dim_to_size.get(
676+
conv_dim, 0
677+
)
678+
is_input_channel_size_small = (
679+
padding_sizes[igemm_pos] // input_channel_size > 2
680+
)
681+
682+
# If the input channel dimension is much smaller than the padding size,
683+
# skip padding along that dimension while still padding the others.
684+
if is_input_channel_size_small:
685+
padding_conv_sizes[conv_dim] = 0
686+
else:
687+
padding_conv_sizes[conv_dim] = padding_sizes[igemm_pos]
688+
689+
padded_igemm_dims.add(igemm_pos)
690+
continue
691+
692+
# Multiple padded parallel dims mapping to the same IGEMM dim is not supported.
693+
if workgroup_tile_sizes[igemm_pos] != 0 and igemm_pos in padded_igemm_dims:
694+
return None
695+
696+
padding_conv_sizes[conv_dim] = padding_sizes[igemm_pos]
697+
padded_igemm_dims.add(igemm_pos)
698+
699+
# Ensure that all dimensions have been padded.
700+
if len(padded_igemm_dims) != len(padding_sizes):
701+
return None
702+
703+
return padding_conv_sizes
704+
705+
567706
def calculate_padded_dimensions(
568707
M: list[int],
569708
N: list[int],

amdsharktuner/amdsharktuner/constraint_generator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def generate_generic_contraction_solutions(
8282
allowed_waves_per_eu: list[int] = [2],
8383
pipeline_options_search_space: dispatch_constraints.PipelineOptionsSearchSpace = dispatch_constraints.PipelineOptionsSearchSpace(),
8484
igemm_details: Optional[iree_codegen.IGEMMGenericConvDetails] = None,
85+
conv_to_igemm_info: Optional[common.ConvToIgemmInfo] = None,
8586
) -> Iterator[list[common.TuningConfiguration]]:
8687
adjust_problem_size_for_pipeline(
8788
contraction_dims,
@@ -258,6 +259,7 @@ def set_cdim_tile_sizes(tile_sizes, contraction_dims, csizes):
258259

259260
promote_operands = [0, 1]
260261
padding = None
262+
padding_conv = None
261263
if padding_applied:
262264
# TODO: Remove promotion of operand 2 once codegen supports handling padded outputs without promotion.
263265
promote_operands = [0, 1, 2]
@@ -270,6 +272,18 @@ def set_cdim_tile_sizes(tile_sizes, contraction_dims, csizes):
270272
padding_tile_sizes[inner_k_dim] *= mma_intrinsic_k
271273

272274
padding = padding_tile_sizes
275+
276+
# Calculate padding_conv sizes for convolutions when using IGEMM.
277+
if conv_to_igemm_info and igemm_details:
278+
# Use IGEMM loop bounds directly from igemm_details.
279+
bounds = list(igemm_details.igemm_loop_bounds)
280+
padding_conv = common.get_padding_conv_sizes(
281+
bounds,
282+
padding_tile_sizes,
283+
workgroup_tile_sizes,
284+
reduction_tile_sizes,
285+
conv_to_igemm_info,
286+
)
273287
# Setting subgroup basis.
274288
# TODO(Bangtian): Sync changes from IREE PR: https://github.com/iree-org/iree/pull/22000.
275289
subgroup_basis_counts = [1] * num_loops
@@ -294,6 +308,7 @@ def set_cdim_tile_sizes(tile_sizes, contraction_dims, csizes):
294308
pipeline_options_search_space,
295309
allowed_waves_per_eu,
296310
padding=padding,
311+
padding_conv=padding_conv,
297312
)
298313

299314
solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars)))))
@@ -594,6 +609,8 @@ def generate_solutions(
594609
codegen_pipeline: iree_codegen.DispatchLoweringPassPipeline,
595610
**pipeline_constraint_options,
596611
) -> Iterator[list[common.TuningConfiguration]]:
612+
# TODO(Bangtian): Simplify the function signature to accept op_info directly instead of
613+
# unpacking all individual fields.
597614
return generate_generic_contraction_solutions(
598615
tuner_ctx=tuner_context,
599616
gpu_target_info=gpu_target_info,
@@ -606,6 +623,7 @@ def generate_solutions(
606623
indexing_maps=self.op_info.indexing_maps,
607624
codegen_pipeline=codegen_pipeline,
608625
igemm_details=self.op_info.igemm_details,
626+
conv_to_igemm_info=self.op_info.conv_to_igemm_info,
609627
**pipeline_constraint_options,
610628
)
611629

amdsharktuner/amdsharktuner/dispatch_constraints.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,7 @@ def generate_compilation_infos(
672672
pipeline_options_search_space: PipelineOptionsSearchSpace,
673673
allowed_waves_per_eu: list[int],
674674
padding: Optional[list[int]] = None,
675+
padding_conv: Optional[list[int]] = None,
675676
) -> list[iree_codegen.CompilationInfoAttr]:
676677
subgroup_basis = [subgroup_basis_counts, subgroup_basis_mapping]
677678
# Create the LoweringConfigAttr.
@@ -688,6 +689,9 @@ def generate_compilation_infos(
688689
if padding is not None:
689690
lowering_config_args["padding"] = padding
690691

692+
if padding_conv is not None:
693+
lowering_config_args["padding_conv"] = padding_conv
694+
691695
if codegen_pipeline == iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse:
692696
lowering_config_args["subgroup"] = subgroup_tile_sizes
693697

amdsharktuner/amdsharktuner/dispatch_parser.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ class ConvolutionOpInfo(OpInfo):
7474

7575
# IGEMM details for TileAndFuse pipeline (None if not available).
7676
igemm_details: Optional[iree_codegen.IGEMMGenericConvDetails] = None
77+
# Convolution to IGEMM transformation info (None if not available).
78+
conv_to_igemm_info: Optional[common.ConvToIgemmInfo] = None
7779

7880

7981
@dataclass
@@ -275,6 +277,62 @@ def __init__(self, root_op: ir.Operation, tuner_ctx: common.TunerContext):
275277
# for any convolution layout (nhwc_hwcf, nchw_fchw, etc.).
276278
igemm_details = iree_codegen.get_igemm_generic_conv_details(root_op)
277279

280+
# Build ConvToIgemmInfo using convolution_dims (similar to C++ implementation).
281+
conv_to_igemm_info = None
282+
if igemm_details:
283+
conv_to_igemm_info = common.ConvToIgemmInfo()
284+
285+
# Get input operand information (first operand).
286+
input_type = lhs_type
287+
input_shape = input_type.shape
288+
input_map = indexing_maps[0]
289+
290+
# Store the convolution dimensions.
291+
conv_to_igemm_info.conv_dims = convolution_dims
292+
293+
# Process input channel dimensions.
294+
# Note: For convolutions with strides, expressions may be complex (e.g., d0*2 + d1).
295+
# We use isFunctionOfDim to check if the expression depends on a specific dimension.
296+
for dim in convolution_dims.input_channel:
297+
for idx, expr in enumerate(input_map.results):
298+
if common.is_affine_expr_function_of_dim(expr, dim):
299+
conv_to_igemm_info.input_channel_dim_to_size[dim] = input_shape[
300+
idx
301+
]
302+
303+
# Process output image dimensions to find input image positions.
304+
input_image_pos = []
305+
for dim in convolution_dims.output_image:
306+
for idx, expr in enumerate(input_map.results):
307+
if common.is_affine_expr_function_of_dim(expr, dim):
308+
input_image_pos.append(idx)
309+
310+
# Process batch dimensions to find batch positions.
311+
batch_pos = []
312+
for dim in convolution_dims.batch:
313+
for idx, expr in enumerate(input_map.results):
314+
if common.is_affine_expr_function_of_dim(expr, dim):
315+
batch_pos.append(idx)
316+
317+
# Sort positions.
318+
input_image_pos = sorted(input_image_pos)
319+
batch_pos = sorted(batch_pos)
320+
321+
# Determine if batch dimension is last.
322+
conv_to_igemm_info.is_batch_dim_last = (
323+
len(batch_pos) > 0 and batch_pos[-1] == len(input_shape) - 1
324+
)
325+
326+
# Determine if spatial dimension is last.
327+
conv_to_igemm_info.is_spatial_dim_last = (
328+
len(input_image_pos) > 0 and input_image_pos[-1] == len(input_shape) - 1
329+
)
330+
331+
# Store conv to IGEMM dimension mapping from IGEMM details.
332+
conv_to_igemm_info.conv_to_igemm_dim_map = dict(
333+
igemm_details.conv_to_igemm_dim_map
334+
)
335+
278336
self._op_info: ConvolutionOpInfo = ConvolutionOpInfo(
279337
root_op=root_op,
280338
indexing_maps=indexing_maps,
@@ -292,6 +350,7 @@ def __init__(self, root_op: ir.Operation, tuner_ctx: common.TunerContext):
292350
strides=strides,
293351
dilations=dilations,
294352
igemm_details=igemm_details,
353+
conv_to_igemm_info=conv_to_igemm_info,
295354
)
296355

297356
def has_valid_root_op(self) -> bool:

0 commit comments

Comments
 (0)