Skip to content

Commit 46ecd7b

Browse files
authored
[MLIR][OpenMP] Create LoopRelatedClause (#99506)
This patch introduces a new OpenMP clause definition not defined by the spec. Its main purpose is to define the `loop_inclusive` (previously "inclusive", renamed according to the parent of this PR in the stack) argument of `omp.loop_nest` in such a way that a followup implementation of a tablegen backend to automatically generate clause and operation operand structures directly from `OpenMP_Op` and `OpenMP_Clause` definitions can properly generate the `LoopNestOperands` structure. `collapse` clause arguments are also moved into this new definition, as they represent information on the loop nests being collapsed rather than the `collapse` clause itself.
1 parent 389679d commit 46ecd7b

File tree

9 files changed

+72
-81
lines changed

9 files changed

+72
-81
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

+13-14
Original file line numberDiff line numberDiff line change
@@ -181,20 +181,19 @@ static void addUseDeviceClause(
181181

182182
static void convertLoopBounds(lower::AbstractConverter &converter,
183183
mlir::Location loc,
184-
mlir::omp::CollapseClauseOps &result,
184+
mlir::omp::LoopRelatedOps &result,
185185
std::size_t loopVarTypeSize) {
186186
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
187187
// The types of lower bound, upper bound, and step are converted into the
188188
// type of the loop variable if necessary.
189189
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
190-
for (unsigned it = 0; it < (unsigned)result.collapseLowerBounds.size();
191-
it++) {
192-
result.collapseLowerBounds[it] = firOpBuilder.createConvert(
193-
loc, loopVarType, result.collapseLowerBounds[it]);
194-
result.collapseUpperBounds[it] = firOpBuilder.createConvert(
195-
loc, loopVarType, result.collapseUpperBounds[it]);
196-
result.collapseSteps[it] =
197-
firOpBuilder.createConvert(loc, loopVarType, result.collapseSteps[it]);
190+
for (unsigned it = 0; it < (unsigned)result.loopLowerBounds.size(); it++) {
191+
result.loopLowerBounds[it] = firOpBuilder.createConvert(
192+
loc, loopVarType, result.loopLowerBounds[it]);
193+
result.loopUpperBounds[it] = firOpBuilder.createConvert(
194+
loc, loopVarType, result.loopUpperBounds[it]);
195+
result.loopSteps[it] =
196+
firOpBuilder.createConvert(loc, loopVarType, result.loopSteps[it]);
198197
}
199198
}
200199

@@ -204,7 +203,7 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
204203

205204
bool ClauseProcessor::processCollapse(
206205
mlir::Location currentLocation, lower::pft::Evaluation &eval,
207-
mlir::omp::CollapseClauseOps &result,
206+
mlir::omp::LoopRelatedOps &result,
208207
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const {
209208
bool found = false;
210209
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -233,15 +232,15 @@ bool ClauseProcessor::processCollapse(
233232
std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
234233
assert(bounds && "Expected bounds for worksharing do loop");
235234
lower::StatementContext stmtCtx;
236-
result.collapseLowerBounds.push_back(fir::getBase(
235+
result.loopLowerBounds.push_back(fir::getBase(
237236
converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx)));
238-
result.collapseUpperBounds.push_back(fir::getBase(
237+
result.loopUpperBounds.push_back(fir::getBase(
239238
converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx)));
240239
if (bounds->step) {
241-
result.collapseSteps.push_back(fir::getBase(
240+
result.loopSteps.push_back(fir::getBase(
242241
converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx)));
243242
} else { // If `step` is not present, assume it as `1`.
244-
result.collapseSteps.push_back(firOpBuilder.createIntegerConstant(
243+
result.loopSteps.push_back(firOpBuilder.createIntegerConstant(
245244
currentLocation, firOpBuilder.getIntegerType(32), 1));
246245
}
247246
iv.push_back(bounds->name.thing.symbol);

flang/lib/Lower/OpenMP/ClauseProcessor.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class ClauseProcessor {
5555
// 'Unique' clauses: They can appear at most once in the clause list.
5656
bool
5757
processCollapse(mlir::Location currentLocation, lower::pft::Evaluation &eval,
58-
mlir::omp::CollapseClauseOps &result,
58+
mlir::omp::LoopRelatedOps &result,
5959
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const;
6060
bool processDefault() const;
6161
bool processDevice(lower::StatementContext &stmtCtx,

flang/lib/Lower/OpenMP/DataSharingProcessor.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,8 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
274274
llvm::SmallVector<mlir::Value> vs;
275275
vs.reserve(loopOp.getIVs().size());
276276
for (auto [iv, ub, step] :
277-
llvm::zip_equal(loopOp.getIVs(), loopOp.getCollapseUpperBounds(),
278-
loopOp.getCollapseSteps())) {
277+
llvm::zip_equal(loopOp.getIVs(), loopOp.getLoopUpperBounds(),
278+
loopOp.getLoopSteps())) {
279279
// v = iv + step
280280
// cmp = step < 0 ? v < ub : v > ub
281281
mlir::Value v = firOpBuilder.create<mlir::arith::AddIOp>(loc, iv, step);

mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h

+2-6
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,6 @@ struct CancelDirectiveNameClauseOps {
4343
ClauseCancellationConstructTypeAttr cancelDirective;
4444
};
4545

46-
struct CollapseClauseOps {
47-
llvm::SmallVector<Value> collapseLowerBounds, collapseUpperBounds,
48-
collapseSteps;
49-
};
50-
5146
struct CopyprivateClauseOps {
5247
llvm::SmallVector<Value> copyprivateVars;
5348
llvm::SmallVector<Attribute> copyprivateSyms;
@@ -125,6 +120,7 @@ struct LinearClauseOps {
125120
};
126121

127122
struct LoopRelatedOps {
123+
llvm::SmallVector<Value> loopLowerBounds, loopUpperBounds, loopSteps;
128124
UnitAttr loopInclusive;
129125
};
130126

@@ -261,7 +257,7 @@ using DistributeOperands =
261257
detail::Clauses<AllocateClauseOps, DistScheduleClauseOps, OrderClauseOps,
262258
PrivateClauseOps>;
263259

264-
using LoopNestOperands = detail::Clauses<CollapseClauseOps, LoopRelatedOps>;
260+
using LoopNestOperands = detail::Clauses<LoopRelatedOps>;
265261

266262
using MaskedOperands = detail::Clauses<FilterClauseOps>;
267263

mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td

+32-31
Original file line numberDiff line numberDiff line change
@@ -107,37 +107,6 @@ class OpenMP_CancelDirectiveNameClauseSkip<
107107

108108
def OpenMP_CancelDirectiveNameClause : OpenMP_CancelDirectiveNameClauseSkip<>;
109109

110-
//===----------------------------------------------------------------------===//
111-
// V5.2: [4.4.3] `collapse` clause
112-
//===----------------------------------------------------------------------===//
113-
114-
class OpenMP_CollapseClauseSkip<
115-
bit traits = false, bit arguments = false, bit assemblyFormat = false,
116-
bit description = false, bit extraClassDeclaration = false
117-
> : OpenMP_Clause</*isRequired=*/false, traits, arguments, assemblyFormat,
118-
description, extraClassDeclaration> {
119-
let traits = [
120-
AllTypesMatch<
121-
["collapse_lower_bounds", "collapse_upper_bounds", "collapse_steps"]>
122-
];
123-
124-
let arguments = (ins
125-
Variadic<IntLikeType>:$collapse_lower_bounds,
126-
Variadic<IntLikeType>:$collapse_upper_bounds,
127-
Variadic<IntLikeType>:$collapse_steps
128-
);
129-
130-
let extraClassDeclaration = [{
131-
/// Returns the number of loops in the loop nest.
132-
unsigned getNumLoops() { return getCollapseLowerBounds().size(); }
133-
}];
134-
135-
// Description and formatting integrated in the `omp.loop_nest` operation,
136-
// which is the only one currently accepting this clause.
137-
}
138-
139-
def OpenMP_CollapseClause : OpenMP_CollapseClauseSkip<>;
140-
141110
//===----------------------------------------------------------------------===//
142111
// V5.2: [5.7.2] `copyprivate` clause
143112
//===----------------------------------------------------------------------===//
@@ -564,6 +533,38 @@ class OpenMP_LinearClauseSkip<
564533

565534
def OpenMP_LinearClause : OpenMP_LinearClauseSkip<>;
566535

536+
//===----------------------------------------------------------------------===//
537+
// Not in the spec: Clause-like structure to hold loop related information.
538+
//===----------------------------------------------------------------------===//
539+
540+
class OpenMP_LoopRelatedClauseSkip<
541+
bit traits = false, bit arguments = false, bit assemblyFormat = false,
542+
bit description = false, bit extraClassDeclaration = false
543+
> : OpenMP_Clause</*isRequired=*/false, traits, arguments, assemblyFormat,
544+
description, extraClassDeclaration> {
545+
let traits = [
546+
AllTypesMatch<
547+
["loop_lower_bounds", "loop_upper_bounds", "loop_steps"]>
548+
];
549+
550+
let arguments = (ins
551+
Variadic<IntLikeType>:$loop_lower_bounds,
552+
Variadic<IntLikeType>:$loop_upper_bounds,
553+
Variadic<IntLikeType>:$loop_steps,
554+
UnitAttr:$loop_inclusive
555+
);
556+
557+
let extraClassDeclaration = [{
558+
/// Returns the number of loops in the loop nest.
559+
unsigned getNumLoops() { return getLoopLowerBounds().size(); }
560+
}];
561+
562+
// Description and formatting integrated in the `omp.loop_nest` operation,
563+
// which is the only one currently accepting this clause.
564+
}
565+
566+
def OpenMP_LoopRelatedClause : OpenMP_LoopRelatedClauseSkip<>;
567+
567568
//===----------------------------------------------------------------------===//
568569
// V5.2: [5.8.3] `map` clause
569570
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

+4-6
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def SingleOp : OpenMP_Op<"single", traits = [
297297
def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
298298
RecursiveMemoryEffects, SameVariadicOperandSize
299299
], clauses = [
300-
OpenMP_CollapseClause
300+
OpenMP_LoopRelatedClause
301301
], singleRegion = true> {
302302
let summary = "rectangular loop nest";
303303
let description = [{
@@ -306,14 +306,14 @@ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
306306
lower and upper bounds, as well as a step variable, must be defined.
307307

308308
The lower and upper bounds specify a half-open range: the range includes the
309-
lower bound but does not include the upper bound. If the `inclusive`
309+
lower bound but does not include the upper bound. If the `loop_inclusive`
310310
attribute is specified then the upper bound is also included.
311311

312312
The body region can contain any number of blocks. The region is terminated
313313
by an `omp.yield` instruction without operands. The induction variables,
314314
represented as entry block arguments to the loop nest operation's single
315-
region, match the types of the `collapse_lower_bounds`,
316-
`collapse_upper_bounds` and `collapse_steps` arguments.
315+
region, match the types of the `loop_lower_bounds`, `loop_upper_bounds` and
316+
`loop_steps` arguments.
317317

318318
```mlir
319319
omp.loop_nest (%i1, %i2) : i32 = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) {
@@ -335,8 +335,6 @@ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
335335
non-perfectly nested loops.
336336
}];
337337

338-
let arguments = !con(clausesArgs, (ins UnitAttr:$inclusive));
339-
340338
let builders = [
341339
OpBuilder<(ins CArg<"const LoopNestOperands &">:$clauses)>
342340
];

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -2047,7 +2047,7 @@ ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
20472047

20482048
// Parse "inclusive" flag.
20492049
if (succeeded(parser.parseOptionalKeyword("inclusive")))
2050-
result.addAttribute("inclusive",
2050+
result.addAttribute("loop_inclusive",
20512051
UnitAttr::get(parser.getBuilder().getContext()));
20522052

20532053
// Parse step values.
@@ -2075,28 +2075,28 @@ void LoopNestOp::print(OpAsmPrinter &p) {
20752075
Region &region = getRegion();
20762076
auto args = region.getArguments();
20772077
p << " (" << args << ") : " << args[0].getType() << " = ("
2078-
<< getCollapseLowerBounds() << ") to (" << getCollapseUpperBounds() << ") ";
2079-
if (getInclusive())
2078+
<< getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
2079+
if (getLoopInclusive())
20802080
p << "inclusive ";
2081-
p << "step (" << getCollapseSteps() << ") ";
2081+
p << "step (" << getLoopSteps() << ") ";
20822082
p.printRegion(region, /*printEntryBlockArgs=*/false);
20832083
}
20842084

20852085
void LoopNestOp::build(OpBuilder &builder, OperationState &state,
20862086
const LoopNestOperands &clauses) {
2087-
LoopNestOp::build(builder, state, clauses.collapseLowerBounds,
2088-
clauses.collapseUpperBounds, clauses.collapseSteps,
2087+
LoopNestOp::build(builder, state, clauses.loopLowerBounds,
2088+
clauses.loopUpperBounds, clauses.loopSteps,
20892089
clauses.loopInclusive);
20902090
}
20912091

20922092
LogicalResult LoopNestOp::verify() {
2093-
if (getCollapseLowerBounds().empty())
2093+
if (getLoopLowerBounds().empty())
20942094
return emitOpError() << "must represent at least one loop";
20952095

2096-
if (getCollapseLowerBounds().size() != getIVs().size())
2096+
if (getLoopLowerBounds().size() != getIVs().size())
20972097
return emitOpError() << "number of range arguments and IVs do not match";
20982098

2099-
for (auto [lb, iv] : llvm::zip_equal(getCollapseLowerBounds(), getIVs())) {
2099+
for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
21002100
if (lb.getType() != iv.getType())
21012101
return emitOpError()
21022102
<< "range argument type does not match corresponding IV type";

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

+8-11
Original file line numberDiff line numberDiff line change
@@ -1109,8 +1109,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
11091109
wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
11101110

11111111
// Find the loop configuration.
1112-
llvm::Value *step =
1113-
moduleTranslation.lookupValue(loopOp.getCollapseSteps()[0]);
1112+
llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]);
11141113
llvm::Type *ivType = step->getType();
11151114
llvm::Value *chunk = nullptr;
11161115
if (wsloopOp.getScheduleChunk()) {
@@ -1179,11 +1178,10 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
11791178
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
11801179
for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
11811180
llvm::Value *lowerBound =
1182-
moduleTranslation.lookupValue(loopOp.getCollapseLowerBounds()[i]);
1181+
moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
11831182
llvm::Value *upperBound =
1184-
moduleTranslation.lookupValue(loopOp.getCollapseUpperBounds()[i]);
1185-
llvm::Value *step =
1186-
moduleTranslation.lookupValue(loopOp.getCollapseSteps()[i]);
1183+
moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
1184+
llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
11871185

11881186
// Make sure loop trip count are emitted in the preheader of the outermost
11891187
// loop at the latest so that they are all available for the new collapsed
@@ -1196,7 +1194,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
11961194
}
11971195
loopInfos.push_back(ompBuilder->createCanonicalLoop(
11981196
loc, bodyGen, lowerBound, upperBound, step,
1199-
/*IsSigned=*/true, loopOp.getInclusive(), computeIP));
1197+
/*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP));
12001198

12011199
if (failed(bodyGenStatus))
12021200
return failure();
@@ -1644,11 +1642,10 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
16441642
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
16451643
for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
16461644
llvm::Value *lowerBound =
1647-
moduleTranslation.lookupValue(loopOp.getCollapseLowerBounds()[i]);
1645+
moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
16481646
llvm::Value *upperBound =
1649-
moduleTranslation.lookupValue(loopOp.getCollapseUpperBounds()[i]);
1650-
llvm::Value *step =
1651-
moduleTranslation.lookupValue(loopOp.getCollapseSteps()[i]);
1647+
moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
1648+
llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
16521649

16531650
// Make sure loop trip count are emitted in the preheader of the outermost
16541651
// loop at the latest so that they are all available for the new collapsed

mlir/test/Dialect/OpenMP/ops.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ func.func @omp_loop_nest(%lb : index, %ub : index, %step : index) -> () {
184184
"omp.loop_nest" (%lb, %ub, %step) ({
185185
^bb0(%iv: index):
186186
omp.yield
187-
}) {inclusive} : (index, index, index) -> ()
187+
}) {loop_inclusive} : (index, index, index) -> ()
188188
omp.terminator
189189
}
190190

0 commit comments

Comments
 (0)