Skip to content

Commit 54c88fc

Browse files
committed
[flang][hlfir] Lower WHERE to HLFIR
Lower WHERE to the newly added hlfir.where and hlfir.elsewhere operations. Differential Revision: https://reviews.llvm.org/D149950
1 parent b87e655 commit 54c88fc

File tree

2 files changed

+264
-14
lines changed

2 files changed

+264
-14
lines changed

flang/lib/Lower/Bridge.cpp

Lines changed: 94 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3154,7 +3154,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
31543154
// Gather some information about the assignment that will impact how it is
31553155
// lowered.
31563156
const bool isWholeAllocatableAssignment =
3157-
!userDefinedAssignment &&
3157+
!userDefinedAssignment && !isInsideHlfirWhere() &&
31583158
Fortran::lower::isWholeAllocatable(assign.lhs);
31593159
std::optional<Fortran::evaluate::DynamicType> lhsType =
31603160
assign.lhs.GetType();
@@ -3243,8 +3243,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
32433243
void genAssignment(const Fortran::evaluate::Assignment &assign) {
32443244
mlir::Location loc = toLocation();
32453245
if (lowerToHighLevelFIR()) {
3246-
if (!implicitIterSpace.empty())
3247-
TODO(loc, "HLFIR assignment inside WHERE");
32483246
std::visit(
32493247
Fortran::common::visitors{
32503248
[&](const Fortran::evaluate::Assignment::Intrinsic &) {
@@ -3452,23 +3450,47 @@ class FirConverter : public Fortran::lower::AbstractConverter {
34523450
Fortran::lower::createArrayMergeStores(*this, explicitIterSpace);
34533451
}
34543452

3455-
bool isInsideHlfirForallOrWhere() const {
3453+
// Is the insertion point of the builder directly or indirectly set
3454+
// inside any operation of type "Op"?
3455+
template <typename... Op>
3456+
bool isInsideOp() const {
34563457
mlir::Block *block = builder->getInsertionBlock();
34573458
mlir::Operation *op = block ? block->getParentOp() : nullptr;
34583459
while (op) {
3459-
if (mlir::isa<hlfir::ForallOp, hlfir::WhereOp>(op))
3460+
if (mlir::isa<Op...>(op))
34603461
return true;
34613462
op = op->getParentOp();
34623463
}
34633464
return false;
34643465
}
3466+
bool isInsideHlfirForallOrWhere() const {
3467+
return isInsideOp<hlfir::ForallOp, hlfir::WhereOp>();
3468+
}
3469+
bool isInsideHlfirWhere() const { return isInsideOp<hlfir::WhereOp>(); }
34653470

34663471
void genFIR(const Fortran::parser::WhereConstruct &c) {
3467-
implicitIterSpace.growStack();
3472+
mlir::Location loc = getCurrentLocation();
3473+
hlfir::WhereOp whereOp;
3474+
3475+
if (!lowerToHighLevelFIR()) {
3476+
implicitIterSpace.growStack();
3477+
} else {
3478+
whereOp = builder->create<hlfir::WhereOp>(loc);
3479+
builder->createBlock(&whereOp.getMaskRegion());
3480+
}
3481+
3482+
// Lower the where mask. For HLFIR, this is done in the hlfir.where mask
3483+
// region.
34683484
genNestedStatement(
34693485
std::get<
34703486
Fortran::parser::Statement<Fortran::parser::WhereConstructStmt>>(
34713487
c.t));
3488+
3489+
// Lower WHERE body. For HLFIR, this is done in the hlfir.where body
3490+
// region.
3491+
if (whereOp)
3492+
builder->createBlock(&whereOp.getBody());
3493+
34723494
for (const auto &body :
34733495
std::get<std::list<Fortran::parser::WhereBodyConstruct>>(c.t))
34743496
genFIR(body);
@@ -3484,6 +3506,13 @@ class FirConverter : public Fortran::lower::AbstractConverter {
34843506
genNestedStatement(
34853507
std::get<Fortran::parser::Statement<Fortran::parser::EndWhereStmt>>(
34863508
c.t));
3509+
3510+
if (whereOp) {
3511+
// For HLFIR, create fir.end terminator in the last hlfir.elsewhere, or
3512+
// in the hlfir.where if it had no elsewhere.
3513+
builder->create<fir::FirEndOp>(loc);
3514+
builder->setInsertionPointAfter(whereOp);
3515+
}
34873516
}
34883517
void genFIR(const Fortran::parser::WhereBodyConstruct &body) {
34893518
std::visit(
@@ -3499,24 +3528,61 @@ class FirConverter : public Fortran::lower::AbstractConverter {
34993528
},
35003529
body.u);
35013530
}
3531+
3532+
/// Lower a Where or Elsewhere mask into an hlfir mask region.
3533+
void lowerWhereMaskToHlfir(mlir::Location loc,
3534+
const Fortran::semantics::SomeExpr *maskExpr) {
3535+
assert(maskExpr && "mask semantic analysis failed");
3536+
Fortran::lower::StatementContext maskContext;
3537+
hlfir::Entity mask = Fortran::lower::convertExprToHLFIR(
3538+
loc, *this, *maskExpr, localSymbols, maskContext);
3539+
mask = hlfir::loadTrivialScalar(loc, *builder, mask);
3540+
auto yieldOp = builder->create<hlfir::YieldOp>(loc, mask);
3541+
genCleanUpInRegionIfAny(loc, *builder, yieldOp.getCleanup(), maskContext);
3542+
}
35023543
void genFIR(const Fortran::parser::WhereConstructStmt &stmt) {
3503-
implicitIterSpace.append(Fortran::semantics::GetExpr(
3504-
std::get<Fortran::parser::LogicalExpr>(stmt.t)));
3544+
const Fortran::semantics::SomeExpr *maskExpr = Fortran::semantics::GetExpr(
3545+
std::get<Fortran::parser::LogicalExpr>(stmt.t));
3546+
if (lowerToHighLevelFIR())
3547+
lowerWhereMaskToHlfir(getCurrentLocation(), maskExpr);
3548+
else
3549+
implicitIterSpace.append(maskExpr);
35053550
}
35063551
void genFIR(const Fortran::parser::WhereConstruct::MaskedElsewhere &ew) {
3552+
mlir::Location loc = getCurrentLocation();
3553+
hlfir::ElseWhereOp elsewhereOp;
3554+
if (lowerToHighLevelFIR()) {
3555+
elsewhereOp = builder->create<hlfir::ElseWhereOp>(loc);
3556+
// Lower mask in the mask region.
3557+
builder->createBlock(&elsewhereOp.getMaskRegion());
3558+
}
35073559
genNestedStatement(
35083560
std::get<
35093561
Fortran::parser::Statement<Fortran::parser::MaskedElsewhereStmt>>(
35103562
ew.t));
3563+
3564+
// For HLFIR, lower the body in the hlfir.elsewhere body region.
3565+
if (elsewhereOp)
3566+
builder->createBlock(&elsewhereOp.getBody());
3567+
35113568
for (const auto &body :
35123569
std::get<std::list<Fortran::parser::WhereBodyConstruct>>(ew.t))
35133570
genFIR(body);
35143571
}
35153572
void genFIR(const Fortran::parser::MaskedElsewhereStmt &stmt) {
3516-
implicitIterSpace.append(Fortran::semantics::GetExpr(
3517-
std::get<Fortran::parser::LogicalExpr>(stmt.t)));
3573+
const auto *maskExpr = Fortran::semantics::GetExpr(
3574+
std::get<Fortran::parser::LogicalExpr>(stmt.t));
3575+
if (lowerToHighLevelFIR())
3576+
lowerWhereMaskToHlfir(getCurrentLocation(), maskExpr);
3577+
else
3578+
implicitIterSpace.append(maskExpr);
35183579
}
35193580
void genFIR(const Fortran::parser::WhereConstruct::Elsewhere &ew) {
3581+
if (lowerToHighLevelFIR()) {
3582+
auto elsewhereOp =
3583+
builder->create<hlfir::ElseWhereOp>(getCurrentLocation());
3584+
builder->createBlock(&elsewhereOp.getBody());
3585+
}
35203586
genNestedStatement(
35213587
std::get<Fortran::parser::Statement<Fortran::parser::ElsewhereStmt>>(
35223588
ew.t));
@@ -3525,18 +3591,32 @@ class FirConverter : public Fortran::lower::AbstractConverter {
35253591
genFIR(body);
35263592
}
35273593
void genFIR(const Fortran::parser::ElsewhereStmt &stmt) {
3528-
implicitIterSpace.append(nullptr);
3594+
if (!lowerToHighLevelFIR())
3595+
implicitIterSpace.append(nullptr);
35293596
}
35303597
void genFIR(const Fortran::parser::EndWhereStmt &) {
3531-
implicitIterSpace.shrinkStack();
3598+
if (!lowerToHighLevelFIR())
3599+
implicitIterSpace.shrinkStack();
35323600
}
35333601

35343602
void genFIR(const Fortran::parser::WhereStmt &stmt) {
35353603
Fortran::lower::StatementContext stmtCtx;
35363604
const auto &assign = std::get<Fortran::parser::AssignmentStmt>(stmt.t);
3605+
const auto *mask = Fortran::semantics::GetExpr(
3606+
std::get<Fortran::parser::LogicalExpr>(stmt.t));
3607+
if (lowerToHighLevelFIR()) {
3608+
mlir::Location loc = getCurrentLocation();
3609+
auto whereOp = builder->create<hlfir::WhereOp>(loc);
3610+
builder->createBlock(&whereOp.getMaskRegion());
3611+
lowerWhereMaskToHlfir(loc, mask);
3612+
builder->createBlock(&whereOp.getBody());
3613+
genAssignment(*assign.typedAssignment->v);
3614+
builder->create<fir::FirEndOp>(loc);
3615+
builder->setInsertionPointAfter(whereOp);
3616+
return;
3617+
}
35373618
implicitIterSpace.growStack();
3538-
implicitIterSpace.append(Fortran::semantics::GetExpr(
3539-
std::get<Fortran::parser::LogicalExpr>(stmt.t)));
3619+
implicitIterSpace.append(mask);
35403620
genAssignment(*assign.typedAssignment->v);
35413621
implicitIterSpace.shrinkStack();
35423622
}

flang/test/Lower/HLFIR/where.f90

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
! Test lowering of WHERE construct and statements to HLFIR.
2+
! RUN: bbc --hlfir -emit-fir -o - %s | FileCheck %s
3+
4+
module where_defs
5+
logical :: mask(10)
6+
real :: x(10), y(10)
7+
real, allocatable :: a(:), b(:)
8+
interface
9+
function return_temporary_mask()
10+
logical, allocatable :: return_temporary_mask(:)
11+
end function
12+
function return_temporary_array()
13+
real, allocatable :: return_temporary_array(:)
14+
end function
15+
end interface
16+
end module
17+
18+
subroutine simple_where()
19+
use where_defs, only: mask, x, y
20+
where (mask) x = y
21+
end subroutine
22+
! CHECK-LABEL: func.func @_QPsimple_where() {
23+
! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare {{.*}}Emask
24+
! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Ex
25+
! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ey
26+
! CHECK: hlfir.where {
27+
! CHECK: hlfir.yield %[[VAL_3]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>>
28+
! CHECK: } do {
29+
! CHECK: hlfir.region_assign {
30+
! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
31+
! CHECK: } to {
32+
! CHECK: hlfir.yield %[[VAL_7]]#0 : !fir.ref<!fir.array<10xf32>>
33+
! CHECK: }
34+
! CHECK: }
35+
! CHECK: return
36+
! CHECK:}
37+
38+
subroutine where_construct()
39+
use where_defs
40+
where (mask)
41+
x = y
42+
a = b
43+
end where
44+
end subroutine
45+
! CHECK-LABEL: func.func @_QPwhere_construct() {
46+
! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMwhere_defsEa"}
47+
! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMwhere_defsEb"}
48+
! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Emask
49+
! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex
50+
! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey
51+
! CHECK: hlfir.where {
52+
! CHECK: hlfir.yield %[[VAL_7]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>>
53+
! CHECK: } do {
54+
! CHECK: hlfir.region_assign {
55+
! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
56+
! CHECK: } to {
57+
! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
58+
! CHECK: }
59+
! CHECK: hlfir.region_assign {
60+
! CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_3]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
61+
! CHECK: hlfir.yield %[[VAL_16]] : !fir.box<!fir.heap<!fir.array<?xf32>>>
62+
! CHECK: } to {
63+
! CHECK: %[[VAL_17:.*]] = fir.load %[[VAL_1]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
64+
! CHECK: hlfir.yield %[[VAL_17]] : !fir.box<!fir.heap<!fir.array<?xf32>>>
65+
! CHECK: }
66+
! CHECK: }
67+
! CHECK: return
68+
! CHECK:}
69+
70+
subroutine where_cleanup()
71+
use where_defs, only: x, return_temporary_mask, return_temporary_array
72+
where (return_temporary_mask()) x = return_temporary_array()
73+
end subroutine
74+
! CHECK-LABEL: func.func @_QPwhere_cleanup() {
75+
! CHECK: %[[VAL_0:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = ".result"}
76+
! CHECK: %[[VAL_1:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>> {bindc_name = ".result"}
77+
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare {{.*}}Ex
78+
! CHECK: hlfir.where {
79+
! CHECK: %[[VAL_6:.*]] = fir.call @_QPreturn_temporary_mask() fastmath<contract> : () -> !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>
80+
! CHECK: fir.save_result %[[VAL_6]] to %[[VAL_1]] : !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>
81+
! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_1]] {uniq_name = ".tmp.func_result"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>)
82+
! CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>
83+
! CHECK: hlfir.yield %[[VAL_8]] : !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>> cleanup {
84+
! CHECK: fir.freemem
85+
! CHECK: }
86+
! CHECK: } do {
87+
! CHECK: hlfir.region_assign {
88+
! CHECK: %[[VAL_14:.*]] = fir.call @_QPreturn_temporary_array() fastmath<contract> : () -> !fir.box<!fir.heap<!fir.array<?xf32>>>
89+
! CHECK: fir.save_result %[[VAL_14]] to %[[VAL_0]] : !fir.box<!fir.heap<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
90+
! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = ".tmp.func_result"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
91+
! CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_15]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
92+
! CHECK: hlfir.yield %[[VAL_16]] : !fir.box<!fir.heap<!fir.array<?xf32>>> cleanup {
93+
! CHECK: fir.freemem
94+
! CHECK: }
95+
! CHECK: } to {
96+
! CHECK: hlfir.yield %[[VAL_5]]#0 : !fir.ref<!fir.array<10xf32>>
97+
! CHECK: }
98+
! CHECK: }
99+
100+
subroutine simple_elsewhere()
101+
use where_defs
102+
where (mask)
103+
x = y
104+
elsewhere
105+
y = x
106+
end where
107+
end subroutine
108+
! CHECK-LABEL: func.func @_QPsimple_elsewhere() {
109+
! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Emask
110+
! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex
111+
! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey
112+
! CHECK: hlfir.where {
113+
! CHECK: hlfir.yield %[[VAL_7]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>>
114+
! CHECK: } do {
115+
! CHECK: hlfir.region_assign {
116+
! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
117+
! CHECK: } to {
118+
! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
119+
! CHECK: }
120+
! CHECK: hlfir.elsewhere do {
121+
! CHECK: hlfir.region_assign {
122+
! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
123+
! CHECK: } to {
124+
! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
125+
! CHECK: }
126+
! CHECK: }
127+
! CHECK: }
128+
129+
subroutine elsewhere_2(mask2)
130+
use where_defs, only : mask, x, y
131+
logical :: mask2(:)
132+
where (mask)
133+
x = y
134+
elsewhere(mask2)
135+
y = x
136+
elsewhere
137+
x = foo()
138+
end where
139+
end subroutine
140+
! CHECK-LABEL: func.func @_QPelsewhere_2(
141+
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare {{.*}}Emask
142+
! CHECK: %[[VAL_6:.*]]:2 = hlfir.declare {{.*}}Emask2
143+
! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex
144+
! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey
145+
! CHECK: hlfir.where {
146+
! CHECK: hlfir.yield %[[VAL_5]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>>
147+
! CHECK: } do {
148+
! CHECK: hlfir.region_assign {
149+
! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
150+
! CHECK: } to {
151+
! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
152+
! CHECK: }
153+
! CHECK: hlfir.elsewhere mask {
154+
! CHECK: hlfir.yield %[[VAL_6]]#0 : !fir.box<!fir.array<?x!fir.logical<4>>>
155+
! CHECK: } do {
156+
! CHECK: hlfir.region_assign {
157+
! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
158+
! CHECK: } to {
159+
! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
160+
! CHECK: }
161+
! CHECK: hlfir.elsewhere do {
162+
! CHECK: hlfir.region_assign {
163+
! CHECK: %[[VAL_16:.*]] = fir.call @_QPfoo() fastmath<contract> : () -> f32
164+
! CHECK: hlfir.yield %[[VAL_16]] : f32
165+
! CHECK: } to {
166+
! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
167+
! CHECK: }
168+
! CHECK: }
169+
! CHECK: }
170+
! CHECK: }

0 commit comments

Comments
 (0)