|
18 | 18 | import tempfile |
19 | 19 |
|
20 | 20 | from iree.compiler import ir # type: ignore |
21 | | -from iree.compiler.dialects import iree_codegen, iree_gpu, transform # type: ignore |
| 21 | +from iree.compiler.dialects import iree_codegen, iree_gpu, linalg, transform # type: ignore |
22 | 22 | import iree.compiler as ireec # type: ignore |
23 | 23 | from iree.compiler._mlir_libs._mlir import ir # type: ignore |
24 | 24 |
|
@@ -190,6 +190,23 @@ class ContractionDimensions: |
190 | 190 | batch: list[int] = field(default_factory=list) |
191 | 191 |
|
192 | 192 |
|
| 193 | +@dataclass |
| 194 | +class ConvToIgemmInfo: |
| 195 | + """ |
| 196 | + Stores information about convolution to IGEMM transformation. |
| 197 | + Used by get_padding_conv_sizes to calculate padding_conv attribute. |
| 198 | +
|
| 199 | + Corresponds to ConvToIgemmInfo struct in IREE: |
| 200 | + https://github.com/iree-org/iree/blob/d3440737cc56a4d1b20c72181d9a37f194bd3ce5/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp#L373-L379 |
| 201 | + """ |
| 202 | + |
| 203 | + conv_dims: linalg.ConvolutionDimensions |
| 204 | + is_batch_dim_last: bool = False |
| 205 | + is_spatial_dim_last: bool = False |
| 206 | + conv_to_igemm_dim_map: dict[int, int] = field(default_factory=dict) |
| 207 | + input_channel_dim_to_size: dict[int, int] = field(default_factory=dict) |
| 208 | + |
| 209 | + |
193 | 210 | @dataclass |
194 | 211 | class MatmulShapeType: |
195 | 212 | m: int |
@@ -233,6 +250,24 @@ class AttentionKnobs(KnobAssignment): |
233 | 250 | pass |
234 | 251 |
|
235 | 252 |
|
| 253 | +def is_affine_expr_function_of_dim(expr: ir.AffineExpr, position: int) -> bool: |
| 254 | + """ |
| 255 | + Return True if the expression depends on the dimension at the given position. |
| 256 | + """ |
| 257 | + if ir.AffineDimExpr.isinstance(expr): |
| 258 | + dim_expr = ir.AffineDimExpr(expr) |
| 259 | + return dim_expr.position == position |
| 260 | + |
| 261 | + # Check if it's a binary operation and recursively check both sides. |
| 262 | + if ir.AffineBinaryExpr.isinstance(expr): |
| 263 | + binary_expr = ir.AffineBinaryExpr(expr) |
| 264 | + return is_affine_expr_function_of_dim( |
| 265 | + binary_expr.lhs, position |
| 266 | + ) or is_affine_expr_function_of_dim(binary_expr.rhs, position) |
| 267 | + |
| 268 | + return False |
| 269 | + |
| 270 | + |
236 | 271 | def get_map_result_dim_positions(map: ir.AffineMap) -> Optional[list[int]]: |
237 | 272 | if not map.is_projected_permutation: |
238 | 273 | return None |
@@ -281,7 +316,7 @@ def get_lowering_config( |
281 | 316 | # A local variable to hold the transformed value. |
282 | 317 | promoted_value = value |
283 | 318 | match key: |
284 | | - case "workgroup" | "reduction" | "subgroup" | "promote_operands" | "padding": |
| 319 | + case "workgroup" | "reduction" | "subgroup" | "promote_operands" | "padding" | "padding_conv": |
285 | 320 | if isinstance(value, Sequence): |
286 | 321 | promoted_value = ir.ArrayAttr.get( |
287 | 322 | [tuner_ctx.type.getI64(x) for x in value] |
@@ -565,8 +600,103 @@ def get_dim_bounds( |
565 | 600 | return result |
566 | 601 |
|
567 | 602 |
|
568 | | -# Use padding logic from IREE side: |
569 | | -# https://github.com/iree-org/iree/blob/8ae91ebb0e555e660b8a6898f6071476f7a1f20b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp#L691-L703 |
| 603 | +# Implemented the logic from IREE side: |
| 604 | +# https://github.com/iree-org/iree/blob/8ae91ebb0e555e660b8a6898f6071476f7a1f20b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp#L382-L467 |
| 605 | +def get_padding_conv_sizes( |
| 606 | + bounds: list[int], |
| 607 | + padding_sizes: list[int], |
| 608 | + workgroup_tile_sizes: list[int], |
| 609 | + reduction_tile_sizes: list[int], |
| 610 | + conv_to_igemm_info: ConvToIgemmInfo, |
| 611 | +) -> Optional[list[int]]: |
| 612 | + """ |
| 613 | + Computes padding_conv by mapping padding from IGEMM space to convolution space. |
| 614 | +
|
| 615 | + Args: |
| 616 | + bounds: Loop bounds for each dimension |
| 617 | + padding_sizes: Padding sizes in IGEMM dimension space (M, N, K) |
| 618 | + workgroup_tile_sizes: Workgroup tile sizes |
| 619 | + reduction_tile_sizes: Reduction tile sizes |
| 620 | + conv_to_igemm_info: Convolution to IGEMM transformation info |
| 621 | +
|
| 622 | + Returns: |
| 623 | + Padding sizes in convolution dimension space, or None if no padding |
| 624 | + is needed along original convolution dimensions. |
| 625 | + """ |
| 626 | + # Skip padding convolution for NCHW layout (spatial dimensions are last). |
| 627 | + if conv_to_igemm_info.is_spatial_dim_last: |
| 628 | + return None |
| 629 | + |
| 630 | + conv_to_igemm_map = conv_to_igemm_info.conv_to_igemm_dim_map |
| 631 | + padded_igemm_dims = set() |
| 632 | + conv_dims = conv_to_igemm_info.conv_dims |
| 633 | + input_channel_dims = set(conv_dims.input_channel) |
| 634 | + |
| 635 | + padding_conv_sizes = [0] * len(conv_to_igemm_map) |
| 636 | + |
| 637 | + # For batch-last layout (e.g., CHWN), only pad the batch dimension to avoid |
| 638 | + # introducing pad op as the producer of collapse_shape op which may cause fusion problem. |
| 639 | + if conv_to_igemm_info.is_batch_dim_last: |
| 640 | + last_batch_dim = conv_dims.batch[-1] |
| 641 | + igemm_batch_pos = conv_to_igemm_map[last_batch_dim] |
| 642 | + |
| 643 | + if ( |
| 644 | + padding_sizes[igemm_batch_pos] |
| 645 | + and bounds[igemm_batch_pos] % padding_sizes[igemm_batch_pos] == 0 |
| 646 | + ): |
| 647 | + return None |
| 648 | + |
| 649 | + padding_conv_sizes[last_batch_dim] = padding_sizes[igemm_batch_pos] |
| 650 | + return padding_conv_sizes |
| 651 | + |
| 652 | + for conv_dim, igemm_pos in conv_to_igemm_map.items(): |
| 653 | + if reduction_tile_sizes[igemm_pos] != 0: |
| 654 | + # Skip conv padding for reduction dims if already divisible by padding size. |
| 655 | + if ( |
| 656 | + padding_sizes[igemm_pos] |
| 657 | + and bounds[igemm_pos] % padding_sizes[igemm_pos] == 0 |
| 658 | + ): |
| 659 | + padded_igemm_dims.add(igemm_pos) |
| 660 | + continue |
| 661 | + |
| 662 | + # Only pad input channel dims. If we need to pad filter dims, then we |
| 663 | + # would rather just do padding on the IGEMM instead. |
| 664 | + if conv_dim in input_channel_dims: |
| 665 | + # Multiple input channel dims for a single IGEMMPos is not supported. |
| 666 | + if igemm_pos in padded_igemm_dims: |
| 667 | + return None |
| 668 | + |
| 669 | + input_channel_size = conv_to_igemm_info.input_channel_dim_to_size.get( |
| 670 | + conv_dim, 0 |
| 671 | + ) |
| 672 | + is_input_channel_size_small = ( |
| 673 | + padding_sizes[igemm_pos] // input_channel_size > 2 |
| 674 | + ) |
| 675 | + |
| 676 | + # If the input channel dimension is much smaller than the padding size, |
| 677 | + # skip padding along that dimension while still padding the others. |
| 678 | + if is_input_channel_size_small: |
| 679 | + padding_conv_sizes[conv_dim] = 0 |
| 680 | + else: |
| 681 | + padding_conv_sizes[conv_dim] = padding_sizes[igemm_pos] |
| 682 | + |
| 683 | + padded_igemm_dims.add(igemm_pos) |
| 684 | + continue |
| 685 | + |
| 686 | + # Multiple padded parallel dims mapping to the same IGEMM dim is not supported. |
| 687 | + if workgroup_tile_sizes[igemm_pos] != 0 and igemm_pos in padded_igemm_dims: |
| 688 | + return None |
| 689 | + |
| 690 | + padding_conv_sizes[conv_dim] = padding_sizes[igemm_pos] |
| 691 | + padded_igemm_dims.add(igemm_pos) |
| 692 | + |
| 693 | + # Ensure that all dimensions have been padded. |
| 694 | + if len(padded_igemm_dims) != len(padding_sizes): |
| 695 | + return None |
| 696 | + |
| 697 | + return padding_conv_sizes |
| 698 | + |
| 699 | + |
570 | 700 | def calculate_padded_dimensions( |
571 | 701 | M: list[int], |
572 | 702 | N: list[int], |
|
0 commit comments