Skip to content

Commit b87e655

Browse files
committed
[flang][hlfir] Lower forall to HLFIR
Lower Forall to the previously added hlfir.forall, hlfir.forall_mask. hlfir.forall_index, and hlfir.region_assign operations. The HLFIR assignment code lowering is moved into genDataAssignment for more readability and so that user defined assignment (still a TODO), will be able to share most of the logic. Differential Revision: https://reviews.llvm.org/D149878
1 parent 9d7eb60 commit b87e655

File tree

2 files changed

+440
-67
lines changed

2 files changed

+440
-67
lines changed

flang/lib/Lower/Bridge.cpp

Lines changed: 255 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1937,7 +1937,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
19371937
}
19381938

19391939
void genFIR(const Fortran::parser::EndForallStmt &) {
1940-
cleanupExplicitSpace();
1940+
if (!lowerToHighLevelFIR())
1941+
cleanupExplicitSpace();
19411942
}
19421943

19431944
template <typename A>
@@ -1956,11 +1957,24 @@ class FirConverter : public Fortran::lower::AbstractConverter {
19561957

19571958
/// Generate FIR for a FORALL statement.
19581959
void genFIR(const Fortran::parser::ForallStmt &stmt) {
1960+
const auto &concurrentHeader =
1961+
std::get<
1962+
Fortran::common::Indirection<Fortran::parser::ConcurrentHeader>>(
1963+
stmt.t)
1964+
.value();
1965+
if (lowerToHighLevelFIR()) {
1966+
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
1967+
localSymbols.pushScope();
1968+
genForallNest(concurrentHeader);
1969+
genFIR(std::get<Fortran::parser::UnlabeledStatement<
1970+
Fortran::parser::ForallAssignmentStmt>>(stmt.t)
1971+
.statement);
1972+
localSymbols.popScope();
1973+
builder->restoreInsertionPoint(insertPt);
1974+
return;
1975+
}
19591976
prepareExplicitSpace(stmt);
1960-
genFIR(std::get<
1961-
Fortran::common::Indirection<Fortran::parser::ConcurrentHeader>>(
1962-
stmt.t)
1963-
.value());
1977+
genFIR(concurrentHeader);
19641978
genFIR(std::get<Fortran::parser::UnlabeledStatement<
19651979
Fortran::parser::ForallAssignmentStmt>>(stmt.t)
19661980
.statement);
@@ -1969,7 +1983,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
19691983

19701984
/// Generate FIR for a FORALL construct.
19711985
void genFIR(const Fortran::parser::ForallConstruct &forall) {
1972-
prepareExplicitSpace(forall);
1986+
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
1987+
if (lowerToHighLevelFIR())
1988+
localSymbols.pushScope();
1989+
else
1990+
prepareExplicitSpace(forall);
19731991
genNestedStatement(
19741992
std::get<
19751993
Fortran::parser::Statement<Fortran::parser::ForallConstructStmt>>(
@@ -1987,14 +2005,101 @@ class FirConverter : public Fortran::lower::AbstractConverter {
19872005
genNestedStatement(
19882006
std::get<Fortran::parser::Statement<Fortran::parser::EndForallStmt>>(
19892007
forall.t));
2008+
if (lowerToHighLevelFIR()) {
2009+
localSymbols.popScope();
2010+
builder->restoreInsertionPoint(insertPt);
2011+
}
19902012
}
19912013

19922014
/// Lower the concurrent header specification.
19932015
void genFIR(const Fortran::parser::ForallConstructStmt &stmt) {
1994-
genFIR(std::get<
1995-
Fortran::common::Indirection<Fortran::parser::ConcurrentHeader>>(
1996-
stmt.t)
1997-
.value());
2016+
const auto &concurrentHeader =
2017+
std::get<
2018+
Fortran::common::Indirection<Fortran::parser::ConcurrentHeader>>(
2019+
stmt.t)
2020+
.value();
2021+
if (lowerToHighLevelFIR())
2022+
genForallNest(concurrentHeader);
2023+
else
2024+
genFIR(concurrentHeader);
2025+
}
2026+
2027+
/// Generate hlfir.forall and hlfir.forall_mask nest given a Forall
2028+
/// concurrent header
2029+
void genForallNest(const Fortran::parser::ConcurrentHeader &header) {
2030+
mlir::Location loc = getCurrentLocation();
2031+
const bool isOutterForall = !isInsideHlfirForallOrWhere();
2032+
hlfir::ForallOp outerForall;
2033+
auto evaluateControl = [&](const auto &parserExpr, mlir::Region &region,
2034+
bool isMask = false) {
2035+
if (region.empty())
2036+
builder->createBlock(&region);
2037+
Fortran::lower::StatementContext localStmtCtx;
2038+
const Fortran::semantics::SomeExpr *anlalyzedExpr =
2039+
Fortran::semantics::GetExpr(parserExpr);
2040+
assert(anlalyzedExpr && "expression semantics failed");
2041+
// Generate the controls of outer forall outside of the hlfir.forall
2042+
// region. They do not depend on any previous forall indices (C1123) and
2043+
// no assignment has been made yet that could modify their value. This
2044+
// will simplify hlfir.forall analysis because the SSA integer value
2045+
// yielded will obviously not depend on any variable modified by the
2046+
// forall when produced outside of it.
2047+
// This is not done for the mask because it may (and in usual code, does)
2048+
// depend on the forall indices that have just been defined as
2049+
// hlfir.forall block arguments.
2050+
mlir::OpBuilder::InsertPoint innerInsertionPoint;
2051+
if (outerForall && !isMask) {
2052+
innerInsertionPoint = builder->saveInsertionPoint();
2053+
builder->setInsertionPoint(outerForall);
2054+
}
2055+
mlir::Value exprVal =
2056+
fir::getBase(genExprValue(*anlalyzedExpr, localStmtCtx, &loc));
2057+
localStmtCtx.finalizeAndPop();
2058+
if (isMask)
2059+
exprVal = builder->createConvert(loc, builder->getI1Type(), exprVal);
2060+
if (innerInsertionPoint.isSet())
2061+
builder->restoreInsertionPoint(innerInsertionPoint);
2062+
builder->create<hlfir::YieldOp>(loc, exprVal);
2063+
};
2064+
for (const Fortran::parser::ConcurrentControl &control :
2065+
std::get<std::list<Fortran::parser::ConcurrentControl>>(header.t)) {
2066+
auto forallOp = builder->create<hlfir::ForallOp>(loc);
2067+
if (isOutterForall && !outerForall)
2068+
outerForall = forallOp;
2069+
evaluateControl(std::get<1>(control.t), forallOp.getLbRegion());
2070+
evaluateControl(std::get<2>(control.t), forallOp.getUbRegion());
2071+
if (const auto &optionalStep =
2072+
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
2073+
control.t))
2074+
evaluateControl(*optionalStep, forallOp.getStepRegion());
2075+
// Create block argument and map it to a symbol via an hlfir.forall_index
2076+
// op (symbols must be mapped to in memory values).
2077+
const Fortran::semantics::Symbol *controlVar =
2078+
std::get<Fortran::parser::Name>(control.t).symbol;
2079+
assert(controlVar && "symbol analysis failed");
2080+
mlir::Type controlVarType = genType(*controlVar);
2081+
mlir::Block *forallBody = builder->createBlock(&forallOp.getBody(), {},
2082+
{controlVarType}, {loc});
2083+
auto forallIndex = builder->create<hlfir::ForallIndexOp>(
2084+
loc, fir::ReferenceType::get(controlVarType),
2085+
forallBody->getArguments()[0],
2086+
builder->getStringAttr(controlVar->name().ToString()));
2087+
localSymbols.addVariableDefinition(*controlVar, forallIndex,
2088+
/*force=*/true);
2089+
auto end = builder->create<fir::FirEndOp>(loc);
2090+
builder->setInsertionPoint(end);
2091+
}
2092+
2093+
if (const auto &maskExpr =
2094+
std::get<std::optional<Fortran::parser::ScalarLogicalExpr>>(
2095+
header.t)) {
2096+
// Create hlfir.forall_mask and set insertion point in its body.
2097+
auto forallMaskOp = builder->create<hlfir::ForallMaskOp>(loc);
2098+
evaluateControl(*maskExpr, forallMaskOp.getMaskRegion(), /*isMask=*/true);
2099+
builder->createBlock(&forallMaskOp.getBody());
2100+
auto end = builder->create<fir::FirEndOp>(loc);
2101+
builder->setInsertionPoint(end);
2102+
}
19982103
}
19992104

20002105
void genFIR(const Fortran::parser::CompilerDirective &) {
@@ -2991,13 +3096,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
29913096
/// DestroyOp in case the returned value has hlfir::ExprType.
29923097
mlir::Value
29933098
genImplicitLogicalConvert(const Fortran::evaluate::Assignment &assign,
2994-
hlfir::Entity lhs, hlfir::Entity rhs) {
3099+
hlfir::Entity rhs,
3100+
Fortran::lower::StatementContext &stmtCtx) {
29953101
mlir::Type fromTy = rhs.getFortranElementType();
2996-
mlir::Type toTy = lhs.getFortranElementType();
2997-
if (fromTy == toTy)
3102+
if (!fromTy.isa<mlir::IntegerType, fir::LogicalType>())
29983103
return nullptr;
29993104

3000-
if (!fromTy.isa<mlir::IntegerType, fir::LogicalType>())
3105+
mlir::Type toTy = hlfir::getFortranElementType(genType(assign.lhs));
3106+
if (fromTy == toTy)
30013107
return nullptr;
30023108
if (!toTy.isa<mlir::IntegerType, fir::LogicalType>())
30033109
return nullptr;
@@ -3015,76 +3121,147 @@ class FirConverter : public Fortran::lower::AbstractConverter {
30153121
auto val = hlfir::loadTrivialScalar(loc, builder, elementPtr);
30163122
return hlfir::EntityWithAttributes{builder.createConvert(loc, toTy, val)};
30173123
};
3018-
return hlfir::genElementalOp(loc, builder, toTy, shape, /*typeParams=*/{},
3019-
genKernel);
3124+
mlir::Value convertedRhs = hlfir::genElementalOp(
3125+
loc, builder, toTy, shape, /*typeParams=*/{}, genKernel);
3126+
fir::FirOpBuilder *bldr = &builder;
3127+
stmtCtx.attachCleanup([loc, bldr, convertedRhs]() {
3128+
bldr->create<hlfir::DestroyOp>(loc, convertedRhs);
3129+
});
3130+
return convertedRhs;
3131+
}
3132+
3133+
static void
3134+
genCleanUpInRegionIfAny(mlir::Location loc, fir::FirOpBuilder &builder,
3135+
mlir::Region &region,
3136+
Fortran::lower::StatementContext &context) {
3137+
if (!context.hasCode())
3138+
return;
3139+
mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
3140+
if (region.empty())
3141+
builder.createBlock(&region);
3142+
else
3143+
builder.setInsertionPointToEnd(&region.front());
3144+
context.finalizeAndPop();
3145+
hlfir::YieldOp::ensureTerminator(region, builder, loc);
3146+
builder.restoreInsertionPoint(insertPt);
3147+
}
3148+
3149+
void genDataAssignment(
3150+
const Fortran::evaluate::Assignment &assign,
3151+
const Fortran::evaluate::ProcedureRef *userDefinedAssignment) {
3152+
mlir::Location loc = getCurrentLocation();
3153+
fir::FirOpBuilder &builder = getFirOpBuilder();
3154+
// Gather some information about the assignment that will impact how it is
3155+
// lowered.
3156+
const bool isWholeAllocatableAssignment =
3157+
!userDefinedAssignment &&
3158+
Fortran::lower::isWholeAllocatable(assign.lhs);
3159+
std::optional<Fortran::evaluate::DynamicType> lhsType =
3160+
assign.lhs.GetType();
3161+
const bool keepLhsLengthInAllocatableAssignment =
3162+
isWholeAllocatableAssignment && lhsType.has_value() &&
3163+
lhsType->category() == Fortran::common::TypeCategory::Character &&
3164+
!lhsType->HasDeferredTypeParameter();
3165+
const bool lhsHasVectorSubscripts =
3166+
Fortran::evaluate::HasVectorSubscript(assign.lhs);
3167+
3168+
// Helper to generate the code evaluating the right-hand side.
3169+
auto evaluateRhs = [&](Fortran::lower::StatementContext &stmtCtx) {
3170+
hlfir::Entity rhs = Fortran::lower::convertExprToHLFIR(
3171+
loc, *this, assign.rhs, localSymbols, stmtCtx);
3172+
// Load trivial scalar RHS to allow the loads to be hoisted outside of
3173+
// loops early if possible. This also dereferences pointer and
3174+
// allocatable RHS: the target is being assigned from.
3175+
rhs = hlfir::loadTrivialScalar(loc, builder, rhs);
3176+
// In intrinsic assignments, Logical<->Integer assignments are allowed as
3177+
// an extension, but there is no explicit Convert expression for the RHS.
3178+
// Recognize the type mismatch here and insert explicit scalar convert or
3179+
// ElementalOp for array assignment.
3180+
if (!userDefinedAssignment)
3181+
if (mlir::Value conversion =
3182+
genImplicitLogicalConvert(assign, rhs, stmtCtx))
3183+
rhs = hlfir::Entity{conversion};
3184+
return rhs;
3185+
};
3186+
3187+
// Helper to generate the code evaluating the left-hand side.
3188+
auto evaluateLhs = [&](Fortran::lower::StatementContext &stmtCtx) {
3189+
hlfir::Entity lhs = Fortran::lower::convertExprToHLFIR(
3190+
loc, *this, assign.lhs, localSymbols, stmtCtx);
3191+
// Dereference pointer LHS: the target is being assigned to.
3192+
// Same for allocatables outside of whole allocatable assignments.
3193+
if (!isWholeAllocatableAssignment)
3194+
lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs);
3195+
return lhs;
3196+
};
3197+
3198+
if (!isInsideHlfirForallOrWhere() && !lhsHasVectorSubscripts &&
3199+
!userDefinedAssignment) {
3200+
Fortran::lower::StatementContext localStmtCtx;
3201+
hlfir::Entity rhs = evaluateRhs(localStmtCtx);
3202+
hlfir::Entity lhs = evaluateLhs(localStmtCtx);
3203+
builder.create<hlfir::AssignOp>(loc, rhs, lhs,
3204+
isWholeAllocatableAssignment,
3205+
keepLhsLengthInAllocatableAssignment);
3206+
return;
3207+
}
3208+
// Assignments inside Forall, Where, or assignments to a vector subscripted
3209+
// left-hand side requires using an hlfir.region_assign in HLFIR. The
3210+
// right-hand side and left-hand side must be evaluated inside the
3211+
// hlfir.region_assign regions.
3212+
auto regionAssignOp = builder.create<hlfir::RegionAssignOp>(loc);
3213+
3214+
// Lower RHS in its own region.
3215+
builder.createBlock(&regionAssignOp.getRhsRegion());
3216+
Fortran::lower::StatementContext rhsContext;
3217+
hlfir::Entity rhs = evaluateRhs(rhsContext);
3218+
auto rhsYieldOp = builder.create<hlfir::YieldOp>(loc, rhs);
3219+
genCleanUpInRegionIfAny(loc, builder, rhsYieldOp.getCleanup(), rhsContext);
3220+
// Lower LHS in its own region.
3221+
builder.createBlock(&regionAssignOp.getLhsRegion());
3222+
Fortran::lower::StatementContext lhsContext;
3223+
if (!lhsHasVectorSubscripts) {
3224+
hlfir::Entity lhs = evaluateLhs(lhsContext);
3225+
auto lhsYieldOp = builder.create<hlfir::YieldOp>(loc, lhs);
3226+
genCleanUpInRegionIfAny(loc, builder, lhsYieldOp.getCleanup(),
3227+
lhsContext);
3228+
} else {
3229+
TODO(loc, "assignment to vector subscripted entity");
3230+
}
3231+
3232+
// Add "realloc" flag to hlfir.region_assign.
3233+
if (isWholeAllocatableAssignment)
3234+
TODO(loc, "assignment to a whole allocatable inside FORALL");
3235+
// Generate the hlfir.region_assign userDefinedAssignment region.
3236+
if (userDefinedAssignment)
3237+
TODO(loc, "HLFIR user defined assignment");
3238+
3239+
builder.setInsertionPointAfter(regionAssignOp);
30203240
}
30213241

30223242
/// Shared for both assignments and pointer assignments.
30233243
void genAssignment(const Fortran::evaluate::Assignment &assign) {
30243244
mlir::Location loc = toLocation();
30253245
if (lowerToHighLevelFIR()) {
3026-
if (explicitIterationSpace() || !implicitIterSpace.empty())
3027-
TODO(loc, "HLFIR assignment inside FORALL or WHERE");
3028-
auto &builder = getFirOpBuilder();
3246+
if (!implicitIterSpace.empty())
3247+
TODO(loc, "HLFIR assignment inside WHERE");
30293248
std::visit(
30303249
Fortran::common::visitors{
3031-
// [1] Plain old assignment.
30323250
[&](const Fortran::evaluate::Assignment::Intrinsic &) {
3033-
if (Fortran::evaluate::HasVectorSubscript(assign.lhs))
3034-
TODO(loc, "assignment to vector subscripted entity");
3035-
Fortran::lower::StatementContext stmtCtx;
3036-
hlfir::Entity rhs = Fortran::lower::convertExprToHLFIR(
3037-
loc, *this, assign.rhs, localSymbols, stmtCtx);
3038-
// Load trivial scalar LHS to allow the loads to be hoisted
3039-
// outside of loops early if possible. This also dereferences
3040-
// pointer and allocatable RHS: the target is being assigned
3041-
// from.
3042-
rhs = hlfir::loadTrivialScalar(loc, builder, rhs);
3043-
hlfir::Entity lhs = Fortran::lower::convertExprToHLFIR(
3044-
loc, *this, assign.lhs, localSymbols, stmtCtx);
3045-
bool isWholeAllocatableAssignment = false;
3046-
bool keepLhsLengthInAllocatableAssignment = false;
3047-
if (Fortran::lower::isWholeAllocatable(assign.lhs)) {
3048-
isWholeAllocatableAssignment = true;
3049-
if (std::optional<Fortran::evaluate::DynamicType> lhsType =
3050-
assign.lhs.GetType())
3051-
keepLhsLengthInAllocatableAssignment =
3052-
lhsType->category() ==
3053-
Fortran::common::TypeCategory::Character &&
3054-
!lhsType->HasDeferredTypeParameter();
3055-
} else {
3056-
// Dereference pointer LHS: the target is being assigned to.
3057-
lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs);
3058-
}
3059-
3060-
// Logical<->Integer assignments are allowed as an extension,
3061-
// but there is no explicit Convert expression for the RHS.
3062-
// Recognize the type mismatch here and insert explicit
3063-
// scalar convert or ElementalOp for array assignment.
3064-
mlir::Value logicalConvert =
3065-
genImplicitLogicalConvert(assign, lhs, rhs);
3066-
if (logicalConvert)
3067-
rhs = hlfir::EntityWithAttributes{logicalConvert};
3068-
3069-
builder.create<hlfir::AssignOp>(
3070-
loc, rhs, lhs, isWholeAllocatableAssignment,
3071-
keepLhsLengthInAllocatableAssignment);
3072-
3073-
// Mark the end of life range of the ElementalOp's result.
3074-
if (logicalConvert &&
3075-
logicalConvert.getType().isa<hlfir::ExprType>())
3076-
builder.create<hlfir::DestroyOp>(loc, rhs);
3251+
genDataAssignment(assign, /*userDefinedAssignment=*/nullptr);
30773252
},
3078-
// [2] User defined assignment. If the context is a scalar
3079-
// expression then call the procedure.
30803253
[&](const Fortran::evaluate::ProcedureRef &procRef) {
3081-
TODO(loc, "HLFIR user defined assignment");
3254+
genDataAssignment(assign, /*userDefinedAssignment=*/&procRef);
30823255
},
30833256
[&](const Fortran::evaluate::Assignment::BoundsSpec &lbExprs) {
3257+
if (isInsideHlfirForallOrWhere())
3258+
TODO(loc, "pointer assignment inside FORALL");
30843259
genPointerAssignment(loc, assign, lbExprs);
30853260
},
30863261
[&](const Fortran::evaluate::Assignment::BoundsRemapping
30873262
&boundExprs) {
3263+
if (isInsideHlfirForallOrWhere())
3264+
TODO(loc, "pointer assignment inside FORALL");
30883265
genPointerAssignment(loc, assign, boundExprs);
30893266
},
30903267
},
@@ -3275,6 +3452,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
32753452
Fortran::lower::createArrayMergeStores(*this, explicitIterSpace);
32763453
}
32773454

3455+
bool isInsideHlfirForallOrWhere() const {
3456+
mlir::Block *block = builder->getInsertionBlock();
3457+
mlir::Operation *op = block ? block->getParentOp() : nullptr;
3458+
while (op) {
3459+
if (mlir::isa<hlfir::ForallOp, hlfir::WhereOp>(op))
3460+
return true;
3461+
op = op->getParentOp();
3462+
}
3463+
return false;
3464+
}
3465+
32783466
void genFIR(const Fortran::parser::WhereConstruct &c) {
32793467
implicitIterSpace.growStack();
32803468
genNestedStatement(

0 commit comments

Comments
 (0)