@@ -3154,7 +3154,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
3154
3154
// Gather some information about the assignment that will impact how it is
3155
3155
// lowered.
3156
3156
const bool isWholeAllocatableAssignment =
3157
- !userDefinedAssignment &&
3157
+ !userDefinedAssignment && ! isInsideHlfirWhere () &&
3158
3158
Fortran::lower::isWholeAllocatable (assign.lhs );
3159
3159
std::optional<Fortran::evaluate::DynamicType> lhsType =
3160
3160
assign.lhs .GetType ();
@@ -3243,8 +3243,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
3243
3243
void genAssignment (const Fortran::evaluate::Assignment &assign) {
3244
3244
mlir::Location loc = toLocation ();
3245
3245
if (lowerToHighLevelFIR ()) {
3246
- if (!implicitIterSpace.empty ())
3247
- TODO (loc, " HLFIR assignment inside WHERE" );
3248
3246
std::visit (
3249
3247
Fortran::common::visitors{
3250
3248
[&](const Fortran::evaluate::Assignment::Intrinsic &) {
@@ -3452,23 +3450,47 @@ class FirConverter : public Fortran::lower::AbstractConverter {
3452
3450
Fortran::lower::createArrayMergeStores (*this , explicitIterSpace);
3453
3451
}
3454
3452
3455
- bool isInsideHlfirForallOrWhere () const {
3453
+ // Is the insertion point of the builder directly or indirectly set
3454
+ // inside any operation of type "Op"?
3455
+ template <typename ... Op>
3456
+ bool isInsideOp () const {
3456
3457
mlir::Block *block = builder->getInsertionBlock ();
3457
3458
mlir::Operation *op = block ? block->getParentOp () : nullptr ;
3458
3459
while (op) {
3459
- if (mlir::isa<hlfir::ForallOp, hlfir::WhereOp >(op))
3460
+ if (mlir::isa<Op... >(op))
3460
3461
return true ;
3461
3462
op = op->getParentOp ();
3462
3463
}
3463
3464
return false ;
3464
3465
}
3466
+ bool isInsideHlfirForallOrWhere () const {
3467
+ return isInsideOp<hlfir::ForallOp, hlfir::WhereOp>();
3468
+ }
3469
+ bool isInsideHlfirWhere () const { return isInsideOp<hlfir::WhereOp>(); }
3465
3470
3466
3471
void genFIR (const Fortran::parser::WhereConstruct &c) {
3467
- implicitIterSpace.growStack ();
3472
+ mlir::Location loc = getCurrentLocation ();
3473
+ hlfir::WhereOp whereOp;
3474
+
3475
+ if (!lowerToHighLevelFIR ()) {
3476
+ implicitIterSpace.growStack ();
3477
+ } else {
3478
+ whereOp = builder->create <hlfir::WhereOp>(loc);
3479
+ builder->createBlock (&whereOp.getMaskRegion ());
3480
+ }
3481
+
3482
+ // Lower the where mask. For HLFIR, this is done in the hlfir.where mask
3483
+ // region.
3468
3484
genNestedStatement (
3469
3485
std::get<
3470
3486
Fortran::parser::Statement<Fortran::parser::WhereConstructStmt>>(
3471
3487
c.t ));
3488
+
3489
+ // Lower WHERE body. For HLFIR, this is done in the hlfir.where body
3490
+ // region.
3491
+ if (whereOp)
3492
+ builder->createBlock (&whereOp.getBody ());
3493
+
3472
3494
for (const auto &body :
3473
3495
std::get<std::list<Fortran::parser::WhereBodyConstruct>>(c.t ))
3474
3496
genFIR (body);
@@ -3484,6 +3506,13 @@ class FirConverter : public Fortran::lower::AbstractConverter {
3484
3506
genNestedStatement (
3485
3507
std::get<Fortran::parser::Statement<Fortran::parser::EndWhereStmt>>(
3486
3508
c.t ));
3509
+
3510
+ if (whereOp) {
3511
+ // For HLFIR, create fir.end terminator in the last hlfir.elsewhere, or
3512
+ // in the hlfir.where if it had no elsewhere.
3513
+ builder->create <fir::FirEndOp>(loc);
3514
+ builder->setInsertionPointAfter (whereOp);
3515
+ }
3487
3516
}
3488
3517
void genFIR (const Fortran::parser::WhereBodyConstruct &body) {
3489
3518
std::visit (
@@ -3499,24 +3528,61 @@ class FirConverter : public Fortran::lower::AbstractConverter {
3499
3528
},
3500
3529
body.u );
3501
3530
}
3531
+
3532
+ // / Lower a Where or Elsewhere mask into an hlfir mask region.
3533
+ void lowerWhereMaskToHlfir (mlir::Location loc,
3534
+ const Fortran::semantics::SomeExpr *maskExpr) {
3535
+ assert (maskExpr && " mask semantic analysis failed" );
3536
+ Fortran::lower::StatementContext maskContext;
3537
+ hlfir::Entity mask = Fortran::lower::convertExprToHLFIR (
3538
+ loc, *this , *maskExpr, localSymbols, maskContext);
3539
+ mask = hlfir::loadTrivialScalar (loc, *builder, mask);
3540
+ auto yieldOp = builder->create <hlfir::YieldOp>(loc, mask);
3541
+ genCleanUpInRegionIfAny (loc, *builder, yieldOp.getCleanup (), maskContext);
3542
+ }
3502
3543
void genFIR (const Fortran::parser::WhereConstructStmt &stmt) {
3503
- implicitIterSpace.append (Fortran::semantics::GetExpr (
3504
- std::get<Fortran::parser::LogicalExpr>(stmt.t )));
3544
+ const Fortran::semantics::SomeExpr *maskExpr = Fortran::semantics::GetExpr (
3545
+ std::get<Fortran::parser::LogicalExpr>(stmt.t ));
3546
+ if (lowerToHighLevelFIR ())
3547
+ lowerWhereMaskToHlfir (getCurrentLocation (), maskExpr);
3548
+ else
3549
+ implicitIterSpace.append (maskExpr);
3505
3550
}
3506
3551
void genFIR (const Fortran::parser::WhereConstruct::MaskedElsewhere &ew) {
3552
+ mlir::Location loc = getCurrentLocation ();
3553
+ hlfir::ElseWhereOp elsewhereOp;
3554
+ if (lowerToHighLevelFIR ()) {
3555
+ elsewhereOp = builder->create <hlfir::ElseWhereOp>(loc);
3556
+ // Lower mask in the mask region.
3557
+ builder->createBlock (&elsewhereOp.getMaskRegion ());
3558
+ }
3507
3559
genNestedStatement (
3508
3560
std::get<
3509
3561
Fortran::parser::Statement<Fortran::parser::MaskedElsewhereStmt>>(
3510
3562
ew.t ));
3563
+
3564
+ // For HLFIR, lower the body in the hlfir.elsewhere body region.
3565
+ if (elsewhereOp)
3566
+ builder->createBlock (&elsewhereOp.getBody ());
3567
+
3511
3568
for (const auto &body :
3512
3569
std::get<std::list<Fortran::parser::WhereBodyConstruct>>(ew.t ))
3513
3570
genFIR (body);
3514
3571
}
3515
3572
void genFIR (const Fortran::parser::MaskedElsewhereStmt &stmt) {
3516
- implicitIterSpace.append (Fortran::semantics::GetExpr (
3517
- std::get<Fortran::parser::LogicalExpr>(stmt.t )));
3573
+ const auto *maskExpr = Fortran::semantics::GetExpr (
3574
+ std::get<Fortran::parser::LogicalExpr>(stmt.t ));
3575
+ if (lowerToHighLevelFIR ())
3576
+ lowerWhereMaskToHlfir (getCurrentLocation (), maskExpr);
3577
+ else
3578
+ implicitIterSpace.append (maskExpr);
3518
3579
}
3519
3580
void genFIR (const Fortran::parser::WhereConstruct::Elsewhere &ew) {
3581
+ if (lowerToHighLevelFIR ()) {
3582
+ auto elsewhereOp =
3583
+ builder->create <hlfir::ElseWhereOp>(getCurrentLocation ());
3584
+ builder->createBlock (&elsewhereOp.getBody ());
3585
+ }
3520
3586
genNestedStatement (
3521
3587
std::get<Fortran::parser::Statement<Fortran::parser::ElsewhereStmt>>(
3522
3588
ew.t ));
@@ -3525,18 +3591,32 @@ class FirConverter : public Fortran::lower::AbstractConverter {
3525
3591
genFIR (body);
3526
3592
}
3527
3593
void genFIR (const Fortran::parser::ElsewhereStmt &stmt) {
3528
- implicitIterSpace.append (nullptr );
3594
+ if (!lowerToHighLevelFIR ())
3595
+ implicitIterSpace.append (nullptr );
3529
3596
}
3530
3597
void genFIR (const Fortran::parser::EndWhereStmt &) {
3531
- implicitIterSpace.shrinkStack ();
3598
+ if (!lowerToHighLevelFIR ())
3599
+ implicitIterSpace.shrinkStack ();
3532
3600
}
3533
3601
3534
3602
void genFIR (const Fortran::parser::WhereStmt &stmt) {
3535
3603
Fortran::lower::StatementContext stmtCtx;
3536
3604
const auto &assign = std::get<Fortran::parser::AssignmentStmt>(stmt.t );
3605
+ const auto *mask = Fortran::semantics::GetExpr (
3606
+ std::get<Fortran::parser::LogicalExpr>(stmt.t ));
3607
+ if (lowerToHighLevelFIR ()) {
3608
+ mlir::Location loc = getCurrentLocation ();
3609
+ auto whereOp = builder->create <hlfir::WhereOp>(loc);
3610
+ builder->createBlock (&whereOp.getMaskRegion ());
3611
+ lowerWhereMaskToHlfir (loc, mask);
3612
+ builder->createBlock (&whereOp.getBody ());
3613
+ genAssignment (*assign.typedAssignment ->v );
3614
+ builder->create <fir::FirEndOp>(loc);
3615
+ builder->setInsertionPointAfter (whereOp);
3616
+ return ;
3617
+ }
3537
3618
implicitIterSpace.growStack ();
3538
- implicitIterSpace.append (Fortran::semantics::GetExpr (
3539
- std::get<Fortran::parser::LogicalExpr>(stmt.t )));
3619
+ implicitIterSpace.append (mask);
3540
3620
genAssignment (*assign.typedAssignment ->v );
3541
3621
implicitIterSpace.shrinkStack ();
3542
3622
}
0 commit comments