Skip to content

Commit fdfeea5

Browse files
authored
[MLIR][OpenMP][Flang] Normalize clause arguments names (#99505)
Currently, there are some inconsistencies to how clause arguments are named in the OpenMP dialect. Additionally, the clause operand structures associated to them also diverge in certain cases. The purpose of this patch is to normalize argument names across all `OpenMP_Clause` tablegen definitions and clause operand structures. This has the benefit of providing more consistent representations for clauses in the dialect, but the main short-term advantage is that it enables the development of an OpenMP-specific tablegen backend to automatically generate the clause operand structures without breaking dependent code. The main re-naming decisions made in this patch are the following: - Variadic arguments (i.e. multiple values) have the "_vars" suffix. This and other similar suffixes are removed from array attribute arguments. - Individual required or optional value arguments do not have any suffix added to them (e.g. "val", "var", "expr", ...), except for `if` which would otherwise result in an invalid C++ variable name. - The associated clause's name is prepended to argument names that don't already contain it as part of its name. This avoids future collisions between arguments named the same way on different clauses and adding both clauses to the same operation. - Privatization and reduction related arguments that contain lists of symbols pointing to privatizer/reducer operations use the "_syms" suffix. This removes the inconsistencies between the names for "copyprivate_funcs", "[in]reductions", "privatizers", etc. - General improvements to names, replacement of camel case for snake case everywhere, etc. - Renaming of operation-associated operand structures to use the "Operands" suffix in place of "ClauseOps", to better differentiate between clause operand structures and operation operand structures. - Fields on clause operand structures are sorted according to the tablegen definition of the same clause. The assembly format for a few arguments is updated to better reflect the clause they are associated with: - `chunk_size` -> `dist_schedule_chunk_size` - `grain_size` -> `grainsize` - `simd` -> `par_level_simd`
1 parent a347bdb commit fdfeea5

File tree

14 files changed

+924
-948
lines changed

14 files changed

+924
-948
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

+46-48
Original file line numberDiff line numberDiff line change
@@ -187,13 +187,14 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
187187
// The types of lower bound, upper bound, and step are converted into the
188188
// type of the loop variable if necessary.
189189
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
190-
for (unsigned it = 0; it < (unsigned)result.loopLBVar.size(); it++) {
191-
result.loopLBVar[it] =
192-
firOpBuilder.createConvert(loc, loopVarType, result.loopLBVar[it]);
193-
result.loopUBVar[it] =
194-
firOpBuilder.createConvert(loc, loopVarType, result.loopUBVar[it]);
195-
result.loopStepVar[it] =
196-
firOpBuilder.createConvert(loc, loopVarType, result.loopStepVar[it]);
190+
for (unsigned it = 0; it < (unsigned)result.collapseLowerBounds.size();
191+
it++) {
192+
result.collapseLowerBounds[it] = firOpBuilder.createConvert(
193+
loc, loopVarType, result.collapseLowerBounds[it]);
194+
result.collapseUpperBounds[it] = firOpBuilder.createConvert(
195+
loc, loopVarType, result.collapseUpperBounds[it]);
196+
result.collapseSteps[it] =
197+
firOpBuilder.createConvert(loc, loopVarType, result.collapseSteps[it]);
197198
}
198199
}
199200

@@ -232,15 +233,15 @@ bool ClauseProcessor::processCollapse(
232233
std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
233234
assert(bounds && "Expected bounds for worksharing do loop");
234235
lower::StatementContext stmtCtx;
235-
result.loopLBVar.push_back(fir::getBase(
236+
result.collapseLowerBounds.push_back(fir::getBase(
236237
converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx)));
237-
result.loopUBVar.push_back(fir::getBase(
238+
result.collapseUpperBounds.push_back(fir::getBase(
238239
converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx)));
239240
if (bounds->step) {
240-
result.loopStepVar.push_back(fir::getBase(
241+
result.collapseSteps.push_back(fir::getBase(
241242
converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx)));
242243
} else { // If `step` is not present, assume it as `1`.
243-
result.loopStepVar.push_back(firOpBuilder.createIntegerConstant(
244+
result.collapseSteps.push_back(firOpBuilder.createIntegerConstant(
244245
currentLocation, firOpBuilder.getIntegerType(32), 1));
245246
}
246247
iv.push_back(bounds->name.thing.symbol);
@@ -291,8 +292,7 @@ bool ClauseProcessor::processDevice(lower::StatementContext &stmtCtx,
291292
}
292293
}
293294
const auto &deviceExpr = std::get<omp::SomeExpr>(clause->t);
294-
result.deviceVar =
295-
fir::getBase(converter.genExprValue(deviceExpr, stmtCtx));
295+
result.device = fir::getBase(converter.genExprValue(deviceExpr, stmtCtx));
296296
return true;
297297
}
298298
return false;
@@ -322,10 +322,10 @@ bool ClauseProcessor::processDistSchedule(
322322
lower::StatementContext &stmtCtx,
323323
mlir::omp::DistScheduleClauseOps &result) const {
324324
if (auto *clause = findUniqueClause<omp::clause::DistSchedule>()) {
325-
result.distScheduleStaticAttr = converter.getFirOpBuilder().getUnitAttr();
325+
result.distScheduleStatic = converter.getFirOpBuilder().getUnitAttr();
326326
const auto &chunkSize = std::get<std::optional<ExprTy>>(clause->t);
327327
if (chunkSize)
328-
result.distScheduleChunkSizeVar =
328+
result.distScheduleChunkSize =
329329
fir::getBase(converter.genExprValue(*chunkSize, stmtCtx));
330330
return true;
331331
}
@@ -335,7 +335,7 @@ bool ClauseProcessor::processDistSchedule(
335335
bool ClauseProcessor::processFilter(lower::StatementContext &stmtCtx,
336336
mlir::omp::FilterClauseOps &result) const {
337337
if (auto *clause = findUniqueClause<omp::clause::Filter>()) {
338-
result.filteredThreadIdVar =
338+
result.filteredThreadId =
339339
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
340340
return true;
341341
}
@@ -351,7 +351,7 @@ bool ClauseProcessor::processFinal(lower::StatementContext &stmtCtx,
351351

352352
mlir::Value finalVal =
353353
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
354-
result.finalVar = firOpBuilder.createConvert(
354+
result.final = firOpBuilder.createConvert(
355355
clauseLocation, firOpBuilder.getI1Type(), finalVal);
356356
return true;
357357
}
@@ -362,19 +362,19 @@ bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
362362
if (auto *clause = findUniqueClause<omp::clause::Hint>()) {
363363
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
364364
int64_t hintValue = *evaluate::ToInt64(clause->v);
365-
result.hintAttr = firOpBuilder.getI64IntegerAttr(hintValue);
365+
result.hint = firOpBuilder.getI64IntegerAttr(hintValue);
366366
return true;
367367
}
368368
return false;
369369
}
370370

371371
bool ClauseProcessor::processMergeable(
372372
mlir::omp::MergeableClauseOps &result) const {
373-
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeableAttr);
373+
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
374374
}
375375

376376
bool ClauseProcessor::processNowait(mlir::omp::NowaitClauseOps &result) const {
377-
return markClauseOccurrence<omp::clause::Nowait>(result.nowaitAttr);
377+
return markClauseOccurrence<omp::clause::Nowait>(result.nowait);
378378
}
379379

380380
bool ClauseProcessor::processNumTeams(
@@ -385,7 +385,7 @@ bool ClauseProcessor::processNumTeams(
385385
if (auto *clause = findUniqueClause<omp::clause::NumTeams>()) {
386386
// auto lowerBound = std::get<std::optional<ExprTy>>(clause->t);
387387
auto &upperBound = std::get<ExprTy>(clause->t);
388-
result.numTeamsUpperVar =
388+
result.numTeamsUpper =
389389
fir::getBase(converter.genExprValue(upperBound, stmtCtx));
390390
return true;
391391
}
@@ -397,7 +397,7 @@ bool ClauseProcessor::processNumThreads(
397397
mlir::omp::NumThreadsClauseOps &result) const {
398398
if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
399399
// OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
400-
result.numThreadsVar =
400+
result.numThreads =
401401
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
402402
return true;
403403
}
@@ -408,17 +408,17 @@ bool ClauseProcessor::processOrder(mlir::omp::OrderClauseOps &result) const {
408408
using Order = omp::clause::Order;
409409
if (auto *clause = findUniqueClause<Order>()) {
410410
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
411-
result.orderAttr = mlir::omp::ClauseOrderKindAttr::get(
411+
result.order = mlir::omp::ClauseOrderKindAttr::get(
412412
firOpBuilder.getContext(), mlir::omp::ClauseOrderKind::Concurrent);
413413
const auto &modifier =
414414
std::get<std::optional<Order::OrderModifier>>(clause->t);
415415
if (modifier && *modifier == Order::OrderModifier::Unconstrained) {
416-
result.orderModAttr = mlir::omp::OrderModifierAttr::get(
416+
result.orderMod = mlir::omp::OrderModifierAttr::get(
417417
firOpBuilder.getContext(), mlir::omp::OrderModifier::unconstrained);
418418
} else {
419419
// "If order-modifier is not unconstrained, the behavior is as if the
420420
// reproducible modifier is present."
421-
result.orderModAttr = mlir::omp::OrderModifierAttr::get(
421+
result.orderMod = mlir::omp::OrderModifierAttr::get(
422422
firOpBuilder.getContext(), mlir::omp::OrderModifier::reproducible);
423423
}
424424
return true;
@@ -433,7 +433,7 @@ bool ClauseProcessor::processOrdered(
433433
int64_t orderedClauseValue = 0l;
434434
if (clause->v.has_value())
435435
orderedClauseValue = *evaluate::ToInt64(*clause->v);
436-
result.orderedAttr = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
436+
result.ordered = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
437437
return true;
438438
}
439439
return false;
@@ -443,8 +443,7 @@ bool ClauseProcessor::processPriority(
443443
lower::StatementContext &stmtCtx,
444444
mlir::omp::PriorityClauseOps &result) const {
445445
if (auto *clause = findUniqueClause<omp::clause::Priority>()) {
446-
result.priorityVar =
447-
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
446+
result.priority = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
448447
return true;
449448
}
450449
return false;
@@ -454,7 +453,7 @@ bool ClauseProcessor::processProcBind(
454453
mlir::omp::ProcBindClauseOps &result) const {
455454
if (auto *clause = findUniqueClause<omp::clause::ProcBind>()) {
456455
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
457-
result.procBindKindAttr = genProcBindKindAttr(firOpBuilder, *clause);
456+
result.procBindKind = genProcBindKindAttr(firOpBuilder, *clause);
458457
return true;
459458
}
460459
return false;
@@ -465,7 +464,7 @@ bool ClauseProcessor::processSafelen(
465464
if (auto *clause = findUniqueClause<omp::clause::Safelen>()) {
466465
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
467466
const std::optional<std::int64_t> safelenVal = evaluate::ToInt64(clause->v);
468-
result.safelenAttr = firOpBuilder.getI64IntegerAttr(*safelenVal);
467+
result.safelen = firOpBuilder.getI64IntegerAttr(*safelenVal);
469468
return true;
470469
}
471470
return false;
@@ -498,19 +497,19 @@ bool ClauseProcessor::processSchedule(
498497
break;
499498
}
500499

501-
result.scheduleValAttr =
500+
result.scheduleKind =
502501
mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind);
503502

504-
mlir::omp::ScheduleModifier scheduleModifier = getScheduleModifier(*clause);
505-
if (scheduleModifier != mlir::omp::ScheduleModifier::none)
506-
result.scheduleModAttr =
507-
mlir::omp::ScheduleModifierAttr::get(context, scheduleModifier);
503+
mlir::omp::ScheduleModifier scheduleMod = getScheduleModifier(*clause);
504+
if (scheduleMod != mlir::omp::ScheduleModifier::none)
505+
result.scheduleMod =
506+
mlir::omp::ScheduleModifierAttr::get(context, scheduleMod);
508507

509508
if (getSimdModifier(*clause) != mlir::omp::ScheduleModifier::none)
510-
result.scheduleSimdAttr = firOpBuilder.getUnitAttr();
509+
result.scheduleSimd = firOpBuilder.getUnitAttr();
511510

512511
if (const auto &chunkExpr = std::get<omp::MaybeExpr>(clause->t))
513-
result.scheduleChunkVar =
512+
result.scheduleChunk =
514513
fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx));
515514

516515
return true;
@@ -523,7 +522,7 @@ bool ClauseProcessor::processSimdlen(
523522
if (auto *clause = findUniqueClause<omp::clause::Simdlen>()) {
524523
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
525524
const std::optional<std::int64_t> simdlenVal = evaluate::ToInt64(clause->v);
526-
result.simdlenAttr = firOpBuilder.getI64IntegerAttr(*simdlenVal);
525+
result.simdlen = firOpBuilder.getI64IntegerAttr(*simdlenVal);
527526
return true;
528527
}
529528
return false;
@@ -533,15 +532,15 @@ bool ClauseProcessor::processThreadLimit(
533532
lower::StatementContext &stmtCtx,
534533
mlir::omp::ThreadLimitClauseOps &result) const {
535534
if (auto *clause = findUniqueClause<omp::clause::ThreadLimit>()) {
536-
result.threadLimitVar =
535+
result.threadLimit =
537536
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
538537
return true;
539538
}
540539
return false;
541540
}
542541

543542
bool ClauseProcessor::processUntied(mlir::omp::UntiedClauseOps &result) const {
544-
return markClauseOccurrence<omp::clause::Untied>(result.untiedAttr);
543+
return markClauseOccurrence<omp::clause::Untied>(result.untied);
545544
}
546545

547546
//===----------------------------------------------------------------------===//
@@ -565,7 +564,7 @@ static void
565564
addAlignedClause(lower::AbstractConverter &converter,
566565
const omp::clause::Aligned &clause,
567566
llvm::SmallVectorImpl<mlir::Value> &alignedVars,
568-
llvm::SmallVectorImpl<mlir::Attribute> &alignmentAttrs) {
567+
llvm::SmallVectorImpl<mlir::Attribute> &alignments) {
569568
using Aligned = omp::clause::Aligned;
570569
lower::StatementContext stmtCtx;
571570
mlir::IntegerAttr alignmentValueAttr;
@@ -594,7 +593,7 @@ addAlignedClause(lower::AbstractConverter &converter,
594593
alignmentValueAttr = builder.getI64IntegerAttr(alignment);
595594
// All the list items in a aligned clause will have same alignment
596595
for (std::size_t i = 0; i < objects.size(); i++)
597-
alignmentAttrs.push_back(alignmentValueAttr);
596+
alignments.push_back(alignmentValueAttr);
598597
}
599598
}
600599

@@ -603,7 +602,7 @@ bool ClauseProcessor::processAligned(
603602
return findRepeatableClause<omp::clause::Aligned>(
604603
[&](const omp::clause::Aligned &clause, const parser::CharBlock &) {
605604
addAlignedClause(converter, clause, result.alignedVars,
606-
result.alignmentAttrs);
605+
result.alignments);
607606
});
608607
}
609608

@@ -798,7 +797,7 @@ bool ClauseProcessor::processCopyprivate(
798797
result.copyprivateVars.push_back(cpVar);
799798
mlir::func::FuncOp funcOp =
800799
createCopyFunc(currentLocation, converter, cpVar.getType(), attrs);
801-
result.copyprivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp));
800+
result.copyprivateSyms.push_back(mlir::SymbolRefAttr::get(funcOp));
802801
};
803802

804803
bool hasCopyPrivate = findRepeatableClause<clause::Copyprivate>(
@@ -832,7 +831,7 @@ bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const {
832831

833832
mlir::omp::ClauseTaskDependAttr dependTypeOperand =
834833
genDependKindAttr(firOpBuilder, kind);
835-
result.dependTypeAttrs.append(objects.size(), dependTypeOperand);
834+
result.dependKinds.append(objects.size(), dependTypeOperand);
836835

837836
for (const omp::Object &object : objects) {
838837
assert(object.ref() && "Expecting designator");
@@ -1037,10 +1036,9 @@ bool ClauseProcessor::processReduction(
10371036

10381037
// Copy local lists into the output.
10391038
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
1040-
llvm::copy(reduceVarByRef,
1041-
std::back_inserter(result.reductionVarsByRef));
1039+
llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));
10421040
llvm::copy(reductionDeclSymbols,
1043-
std::back_inserter(result.reductionDeclSymbols));
1041+
std::back_inserter(result.reductionSyms));
10441042

10451043
if (outReductionTypes) {
10461044
outReductionTypes->reserve(outReductionTypes->size() +

flang/lib/Lower/OpenMP/DataSharingProcessor.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,9 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
273273
mlir::Value cmpOp;
274274
llvm::SmallVector<mlir::Value> vs;
275275
vs.reserve(loopOp.getIVs().size());
276-
for (auto [iv, ub, step] : llvm::zip_equal(
277-
loopOp.getIVs(), loopOp.getUpperBound(), loopOp.getStep())) {
276+
for (auto [iv, ub, step] :
277+
llvm::zip_equal(loopOp.getIVs(), loopOp.getCollapseUpperBounds(),
278+
loopOp.getCollapseSteps())) {
278279
// v = iv + step
279280
// cmp = step < 0 ? v < ub : v > ub
280281
mlir::Value v = firOpBuilder.create<mlir::arith::AddIOp>(loc, iv, step);
@@ -593,7 +594,7 @@ void DataSharingProcessor::doPrivatize(const semantics::Symbol *sym,
593594
}();
594595

595596
if (clauseOps) {
596-
clauseOps->privatizers.push_back(mlir::SymbolRefAttr::get(privatizerOp));
597+
clauseOps->privateSyms.push_back(mlir::SymbolRefAttr::get(privatizerOp));
597598
clauseOps->privateVars.push_back(hsb.getAddr());
598599
}
599600

0 commit comments

Comments
 (0)