Skip to content

Commit a472147

Browse files
[mlir][affine]make affine-loop-unroll a FunctionOpInterface pass. (#126475)
[mlir][affine]make affine-loop-unroll a FunctionOpInterface pass Make `affine-loop-unroll` a `FunctionOpInterface` pass.Now unroll can be done on gpu.func.
1 parent 9cc8442 commit a472147

File tree

5 files changed

+35
-15
lines changed

5 files changed

+35
-15
lines changed

mlir/include/mlir/Dialect/Affine/Passes.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef MLIR_DIALECT_AFFINE_PASSES_H
1515
#define MLIR_DIALECT_AFFINE_PASSES_H
1616

17+
#include "mlir/Interfaces/FunctionInterfaces.h"
1718
#include "mlir/Pass/Pass.h"
1819
#include <limits>
1920

@@ -93,7 +94,7 @@ std::unique_ptr<OperationPass<func::FuncOp>> createLoopTilingPass();
9394
/// factors supplied through other means. If -1 is passed as the unrollFactor
9495
/// and no callback is provided, anything passed from the command-line (if at
9596
/// all) or the default unroll factor is used (LoopUnroll:kDefaultUnrollFactor).
96-
std::unique_ptr<OperationPass<func::FuncOp>> createLoopUnrollPass(
97+
std::unique_ptr<InterfacePass<FunctionOpInterface>> createLoopUnrollPass(
9798
int unrollFactor = -1, bool unrollUpToFactor = false,
9899
bool unrollFull = false,
99100
const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr);

mlir/include/mlir/Dialect/Affine/Passes.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def AffineLoopTiling : Pass<"affine-loop-tile", "func::FuncOp"> {
199199
];
200200
}
201201

202-
def AffineLoopUnroll : Pass<"affine-loop-unroll", "func::FuncOp"> {
202+
def AffineLoopUnroll : InterfacePass<"affine-loop-unroll", "FunctionOpInterface"> {
203203
let summary = "Unroll affine loops";
204204
let constructor = "mlir::affine::createLoopUnrollPass()";
205205
let options = [

mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ static bool isInnermostAffineForOp(AffineForOp op) {
8282
}
8383

8484
/// Gathers loops that have no affine.for's nested within.
85-
static void gatherInnermostLoops(func::FuncOp f,
85+
static void gatherInnermostLoops(FunctionOpInterface f,
8686
SmallVectorImpl<AffineForOp> &loops) {
8787
f.walk([&](AffineForOp forOp) {
8888
if (isInnermostAffineForOp(forOp))
@@ -91,7 +91,7 @@ static void gatherInnermostLoops(func::FuncOp f,
9191
}
9292

9393
void LoopUnroll::runOnOperation() {
94-
func::FuncOp func = getOperation();
94+
FunctionOpInterface func = getOperation();
9595
if (func.isExternal())
9696
return;
9797

@@ -100,8 +100,8 @@ void LoopUnroll::runOnOperation() {
100100
SmallVector<AffineForOp, 4> loops;
101101

102102
// Gathers all loops with trip count <= minTripCount. Do a post order walk
103-
// so that loops are gathered from innermost to outermost (or else unrolling
104-
// an outer one may delete gathered inner ones).
103+
// so that loops are gathered from innermost to outermost (or else
104+
// unrolling an outer one may delete gathered inner ones).
105105
getOperation().walk([&](AffineForOp forOp) {
106106
std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
107107
if (tripCount && *tripCount <= unrollFullThreshold)
@@ -145,7 +145,8 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
145145
cleanUpUnroll);
146146
}
147147

148-
std::unique_ptr<OperationPass<func::FuncOp>> mlir::affine::createLoopUnrollPass(
148+
std::unique_ptr<InterfacePass<FunctionOpInterface>>
149+
mlir::affine::createLoopUnrollPass(
149150
int unrollFactor, bool unrollUpToFactor, bool unrollFull,
150151
const std::function<unsigned(AffineForOp)> &getUnrollFactor) {
151152
return std::make_unique<LoopUnroll>(

mlir/test/Dialect/Affine/unroll.mlir

+23-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-unroll="unroll-full" | FileCheck %s --check-prefix UNROLL-FULL
2-
// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-unroll="unroll-full unroll-full-threshold=2" | FileCheck %s --check-prefix SHORT
3-
// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-unroll="unroll-factor=4" | FileCheck %s --check-prefix UNROLL-BY-4
4-
// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-unroll="unroll-factor=1" | FileCheck %s --check-prefix UNROLL-BY-1
5-
// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-unroll="unroll-factor=5 cleanup-unroll=true" | FileCheck %s --check-prefix UNROLL-CLEANUP-LOOP
1+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-full=true}))" | FileCheck %s --check-prefix UNROLL-FULL
2+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-full=true unroll-full-threshold=2}))" | FileCheck %s --check-prefix SHORT
3+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-factor=4}))" | FileCheck %s --check-prefix UNROLL-BY-4
4+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-factor=1}))" | FileCheck %s --check-prefix UNROLL-BY-1
5+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-factor=5 cleanup-unroll=true}))" | FileCheck %s --check-prefix UNROLL-CLEANUP-LOOP
6+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(gpu.module(gpu.func(affine-loop-unroll{unroll-full=true})))" | FileCheck %s --check-prefix GPU-UNROLL-FULL
67

78
// UNROLL-FULL-DAG: [[$MAP0:#map[0-9]*]] = affine_map<(d0) -> (d0 + 1)>
89
// UNROLL-FULL-DAG: [[$MAP1:#map[0-9]*]] = affine_map<(d0) -> (d0 + 2)>
@@ -240,6 +241,23 @@ func.func @loop_nest_unroll_full() {
240241
return
241242
} // UNROLL-FULL }
242243

244+
gpu.module @unroll_full {
245+
// GPU-UNROLL-FULL-LABEL: func @gpu_loop_nest_simplest() {
246+
gpu.func @gpu_loop_nest_simplest() {
247+
// GPU-UNROLL-FULL: affine.for %arg0 = 0 to 100 step 2 {
248+
affine.for %i = 0 to 100 step 2 {
249+
// GPU-UNROLL-FULL: %c1_i32 = arith.constant 1 : i32
250+
// GPU-UNROLL-FULL-NEXT: %c1_i32_0 = arith.constant 1 : i32
251+
// GPU-UNROLL-FULL-NEXT: %c1_i32_1 = arith.constant 1 : i32
252+
// GPU-UNROLL-FULL-NEXT: %c1_i32_2 = arith.constant 1 : i32
253+
affine.for %j = 0 to 4 {
254+
%x = arith.constant 1 : i32
255+
}
256+
} // GPU-UNROLL-FULL: }
257+
gpu.return // GPU-UNROLL-FULL: return
258+
}
259+
}
260+
243261
// SHORT-LABEL: func @loop_nest_outer_unroll() {
244262
func.func @loop_nest_outer_unroll() {
245263
// SHORT: affine.for %arg0 = 0 to 4 {

mlir/test/Dialect/SCF/loop-unroll.mlir

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
// RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2 loop-depth=0' | FileCheck %s --check-prefix UNROLL-OUTER-BY-2
44
// RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2 loop-depth=1' | FileCheck %s --check-prefix UNROLL-INNER-BY-2
55
// RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2 annotate=true' | FileCheck %s --check-prefix UNROLL-BY-2-ANNOTATE
6-
// RUN: mlir-opt %s --affine-loop-unroll='unroll-factor=6 unroll-up-to-factor=true' | FileCheck %s --check-prefix UNROLL-UP-TO
7-
// RUN: mlir-opt %s --affine-loop-unroll='unroll-factor=5 cleanup-unroll=true' | FileCheck %s --check-prefix CLEANUP-UNROLL-BY-5
8-
// RUN: mlir-opt %s --affine-loop-unroll --split-input-file | FileCheck %s
6+
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-factor=6 unroll-up-to-factor=true}))" | FileCheck %s --check-prefix UNROLL-UP-TO
7+
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-factor=5 cleanup-unroll=true}))" | FileCheck %s --check-prefix CLEANUP-UNROLL-BY-5
8+
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll))" --split-input-file | FileCheck %s
99

1010
func.func @dynamic_loop_unroll(%arg0 : index, %arg1 : index, %arg2 : index,
1111
%arg3: memref<?xf32>) {

0 commit comments

Comments
 (0)