Skip to content

Commit e7a9bcb

Browse files
committed
[TypeChecker] Type-check where clauses for for-in statements separately
Instead of using `one-way` constraints, just like in closure contexts for-in statements should type-check their `where` clauses separately. This also unifies and simplifies for-in preamble handling in the solver.
1 parent 25ad700 commit e7a9bcb

File tree

6 files changed

+68
-87
lines changed

6 files changed

+68
-87
lines changed

include/swift/Sema/SyntacticElementTarget.h

+15-23
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,9 @@ class SyntacticElementTarget {
161161
ForEachStmt *stmt;
162162
DeclContext *dc;
163163
Pattern *pattern;
164-
bool ignoreWhereClause;
165164
GenericEnvironment *packElementEnv;
166165
ForEachStmtInfo info;
167-
} forEachStmt;
166+
} forEachPreamble;
168167

169168
PatternBindingDecl *patternBinding;
170169
};
@@ -242,13 +241,11 @@ class SyntacticElementTarget {
242241
}
243242

244243
SyntacticElementTarget(ForEachStmt *stmt, DeclContext *dc,
245-
bool ignoreWhereClause,
246244
GenericEnvironment *packElementEnv)
247245
: kind(Kind::forEachPreamble) {
248-
forEachStmt.stmt = stmt;
249-
forEachStmt.dc = dc;
250-
forEachStmt.ignoreWhereClause = ignoreWhereClause;
251-
forEachStmt.packElementEnv = packElementEnv;
246+
forEachPreamble.stmt = stmt;
247+
forEachPreamble.dc = dc;
248+
forEachPreamble.packElementEnv = packElementEnv;
252249
}
253250

254251
/// Form a target for the initialization of a pattern from an expression.
@@ -267,10 +264,10 @@ class SyntacticElementTarget {
267264
static SyntacticElementTarget
268265
forReturn(ReturnStmt *returnStmt, Type contextTy, DeclContext *dc);
269266

270-
/// Form a target for the preamble of a for-in loop, excluding its body.
267+
/// Form a target for the preamble of a for-in loop, excluding its where
268+
/// clause and body.
271269
static SyntacticElementTarget
272270
forForEachPreamble(ForEachStmt *stmt, DeclContext *dc,
273-
bool ignoreWhereClause = false,
274271
GenericEnvironment *packElementEnv = nullptr);
275272

276273
/// Form a target for a property with an attached property wrapper that is
@@ -372,7 +369,7 @@ class SyntacticElementTarget {
372369
}
373370

374371
case Kind::forEachPreamble:
375-
return forEachStmt.dc;
372+
return forEachPreamble.dc;
376373
}
377374
llvm_unreachable("invalid decl context type");
378375
}
@@ -541,24 +538,19 @@ class SyntacticElementTarget {
541538
return expression.initialization.patternBindingIndex;
542539
}
543540

544-
bool ignoreForEachWhereClause() const {
545-
assert(isForEachPreamble());
546-
return forEachStmt.ignoreWhereClause;
547-
}
548-
549541
GenericEnvironment *getPackElementEnv() const {
550542
assert(isForEachPreamble());
551-
return forEachStmt.packElementEnv;
543+
return forEachPreamble.packElementEnv;
552544
}
553545

554546
const ForEachStmtInfo &getForEachStmtInfo() const {
555547
assert(isForEachPreamble());
556-
return forEachStmt.info;
548+
return forEachPreamble.info;
557549
}
558550

559551
ForEachStmtInfo &getForEachStmtInfo() {
560552
assert(isForEachPreamble());
561-
return forEachStmt.info;
553+
return forEachPreamble.info;
562554
}
563555

564556
/// Whether this context infers an opaque return type.
@@ -585,7 +577,7 @@ class SyntacticElementTarget {
585577
return getInitializationPattern();
586578

587579
if (kind == Kind::forEachPreamble)
588-
return forEachStmt.pattern;
580+
return forEachPreamble.pattern;
589581

590582
return nullptr;
591583
}
@@ -598,7 +590,7 @@ class SyntacticElementTarget {
598590
}
599591

600592
if (kind == Kind::forEachPreamble) {
601-
forEachStmt.pattern = pattern;
593+
forEachPreamble.pattern = pattern;
602594
return;
603595
}
604596

@@ -729,7 +721,7 @@ class SyntacticElementTarget {
729721
return nullptr;
730722

731723
case Kind::forEachPreamble:
732-
return forEachStmt.stmt;
724+
return forEachPreamble.stmt;
733725
}
734726
llvm_unreachable("invalid case label type");
735727
}
@@ -841,7 +833,7 @@ class SyntacticElementTarget {
841833

842834
// For-in preamble target doesn't cover the body.
843835
case Kind::forEachPreamble:
844-
auto *stmt = forEachStmt.stmt;
836+
auto *stmt = forEachPreamble.stmt;
845837
SourceLoc startLoc = stmt->getForLoc();
846838
SourceLoc endLoc = stmt->getParsedSequence()->getEndLoc();
847839

@@ -884,7 +876,7 @@ class SyntacticElementTarget {
884876
}
885877

886878
case Kind::forEachPreamble:
887-
return forEachStmt.stmt->getStartLoc();
879+
return forEachPreamble.stmt->getStartLoc();
888880
}
889881
llvm_unreachable("invalid target type");
890882
}

lib/Sema/CSApply.cpp

+13-23
Original file line numberDiff line numberDiff line change
@@ -9247,9 +9247,9 @@ applySolutionToInitialization(SyntacticElementTarget target, Expr *initializer,
92479247
}
92489248

92499249
static std::optional<SequenceIterationInfo>
9250-
applySolutionToForEachStmt(ForEachStmt *stmt, SequenceIterationInfo info,
9251-
DeclContext *dc,
9252-
SyntacticElementTargetRewriter &rewriter) {
9250+
applySolutionToForEachStmtPreamble(ForEachStmt *stmt,
9251+
SequenceIterationInfo info, DeclContext *dc,
9252+
SyntacticElementTargetRewriter &rewriter) {
92539253
auto &solution = rewriter.getSolution();
92549254
auto &cs = solution.getConstraintSystem();
92559255
auto &ctx = cs.getASTContext();
@@ -9374,23 +9374,12 @@ applySolutionToForEachStmt(ForEachStmt *stmt, SequenceIterationInfo info,
93749374
"Couldn't find sequence conformance");
93759375
stmt->setSequenceConformance(type, sequenceConformance);
93769376

9377-
// Apply the solution to the filtering condition, if there is one.
9378-
if (auto *whereExpr = stmt->getWhere()) {
9379-
auto whereTarget = *cs.getTargetFor(whereExpr);
9380-
9381-
auto rewrittenTarget = rewriter.rewriteTarget(whereTarget);
9382-
if (!rewrittenTarget)
9383-
return std::nullopt;
9384-
9385-
stmt->setWhere(rewrittenTarget->getAsExpr());
9386-
}
9387-
93889377
return info;
93899378
}
93909379

93919380
static std::optional<PackIterationInfo>
9392-
applySolutionToForEachStmt(ForEachStmt *stmt, PackIterationInfo info,
9393-
SyntacticElementTargetRewriter &rewriter) {
9381+
applySolutionToForEachStmtPreamble(ForEachStmt *stmt, PackIterationInfo info,
9382+
SyntacticElementTargetRewriter &rewriter) {
93949383
auto &solution = rewriter.getSolution();
93959384
auto &cs = solution.getConstraintSystem();
93969385
auto *sequenceExpr = stmt->getParsedSequence();
@@ -9412,16 +9401,16 @@ applySolutionToForEachStmt(ForEachStmt *stmt, PackIterationInfo info,
94129401
///
94139402
/// \returns the resulting initialization expression.
94149403
static std::optional<SyntacticElementTarget>
9415-
applySolutionToForEachStmt(SyntacticElementTarget target,
9416-
SyntacticElementTargetRewriter &rewriter) {
9404+
applySolutionToForEachStmtPreamble(SyntacticElementTarget target,
9405+
SyntacticElementTargetRewriter &rewriter) {
94179406
auto resultTarget = target;
94189407
auto &forEachStmtInfo = resultTarget.getForEachStmtInfo();
94199408
auto *stmt = target.getAsForEachStmt();
94209409

94219410
Type rewrittenPatternType;
94229411

94239412
if (auto *info = forEachStmtInfo.dyn_cast<SequenceIterationInfo>()) {
9424-
auto resultInfo = applySolutionToForEachStmt(
9413+
auto resultInfo = applySolutionToForEachStmtPreamble(
94259414
stmt, *info, target.getDeclContext(), rewriter);
94269415
if (!resultInfo) {
94279416
return std::nullopt;
@@ -9430,7 +9419,7 @@ applySolutionToForEachStmt(SyntacticElementTarget target,
94309419
forEachStmtInfo = *resultInfo;
94319420
rewrittenPatternType = resultInfo->initType;
94329421
} else {
9433-
auto resultInfo = applySolutionToForEachStmt(
9422+
auto resultInfo = applySolutionToForEachStmtPreamble(
94349423
stmt, forEachStmtInfo.get<PackIterationInfo>(), rewriter);
94359424
if (!resultInfo) {
94369425
return std::nullopt;
@@ -9664,11 +9653,12 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
96649653

96659654
return std::nullopt;
96669655
} else if (auto *forEach = target.getAsForEachStmt()) {
9667-
auto forEachResultTarget = applySolutionToForEachStmt(target, *this);
9668-
if (!forEachResultTarget)
9656+
auto forEachPreambleResultTarget =
9657+
applySolutionToForEachStmtPreamble(target, *this);
9658+
if (!forEachPreambleResultTarget)
96699659
return std::nullopt;
96709660

9671-
result = *forEachResultTarget;
9661+
result = *forEachPreambleResultTarget;
96729662
} else {
96739663
auto fn = *target.getAsFunction();
96749664
if (rewriteFunction(fn))

lib/Sema/CSGen.cpp

+3-22
Original file line numberDiff line numberDiff line change
@@ -3660,8 +3660,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc,
36603660
/// expression that conforms to `Swift.Sequence`.
36613661
static std::optional<SequenceIterationInfo>
36623662
generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc,
3663-
ForEachStmt *stmt, Pattern *typeCheckedPattern,
3664-
bool ignoreForEachWhereClause) {
3663+
ForEachStmt *stmt, Pattern *typeCheckedPattern) {
36653664
ASTContext &ctx = cs.getASTContext();
36663665
bool isAsync = stmt->getAwaitLoc().isValid();
36673666
auto *sequenceExpr = stmt->getParsedSequence();
@@ -3835,24 +3834,6 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc,
38353834
elementLocator);
38363835
}
38373836

3838-
// Generate constraints for the "where" expression, if there is one.
3839-
auto *whereExpr = stmt->getWhere();
3840-
if (whereExpr && !ignoreForEachWhereClause) {
3841-
Type boolType = dc->getASTContext().getBoolType();
3842-
if (!boolType)
3843-
return std::nullopt;
3844-
3845-
SyntacticElementTarget whereTarget(whereExpr, dc, CTP_Condition, boolType,
3846-
/*isDiscarded=*/false);
3847-
if (cs.generateConstraints(whereTarget, FreeTypeVariableBinding::Disallow))
3848-
return std::nullopt;
3849-
3850-
cs.setTargetFor(whereExpr, whereTarget);
3851-
3852-
ContextualTypeInfo contextInfo(boolType, CTP_Condition);
3853-
cs.setContextualInfo(whereExpr, contextInfo);
3854-
}
3855-
38563837
// Populate all of the information for a for-each loop.
38573838
sequenceIterationInfo.elementType = elementType;
38583839
sequenceIterationInfo.initType = initType;
@@ -3904,8 +3885,8 @@ generateForEachPreambleConstraints(ConstraintSystem &cs,
39043885

39053886
target.getForEachStmtInfo() = *packIterationInfo;
39063887
} else {
3907-
auto sequenceIterationInfo = generateForEachStmtConstraints(
3908-
cs, dc, stmt, pattern, target.ignoreForEachWhereClause());
3888+
auto sequenceIterationInfo =
3889+
generateForEachStmtConstraints(cs, dc, stmt, pattern);
39093890
if (!sequenceIterationInfo) {
39103891
return std::nullopt;
39113892
}

lib/Sema/CSSyntacticElement.cpp

+14-9
Original file line numberDiff line numberDiff line change
@@ -686,13 +686,8 @@ class SyntacticElementConstraintGenerator
686686
///
687687
/// - From sequence to pattern, when pattern has no type information.
688688
void visitForEachPattern(Pattern *pattern, ForEachStmt *forEachStmt) {
689-
// The `where` clause should be ignored because \c visitForEachStmt
690-
// records it as a separate conjunction element to allow for a more
691-
// granular control over what contextual information is brought into
692-
// the scope during pattern + sequence and `where` clause solving.
693689
auto target = SyntacticElementTarget::forForEachPreamble(
694-
forEachStmt, context.getAsDeclContext(),
695-
/*ignoreWhereClause=*/true);
690+
forEachStmt, context.getAsDeclContext());
696691

697692
if (cs.generateConstraints(target)) {
698693
hadError = true;
@@ -1898,10 +1893,20 @@ class SyntacticElementSolutionApplication
18981893
ASTNode visitForEachStmt(ForEachStmt *forEachStmt) {
18991894
ConstraintSystem &cs = solution.getConstraintSystem();
19001895

1901-
auto forEachTarget = rewriter.rewriteTarget(*cs.getTargetFor(forEachStmt));
1902-
1903-
if (!forEachTarget)
1896+
// Apply solution to the preamble first.
1897+
if (!rewriter.rewriteTarget(*cs.getTargetFor(forEachStmt))) {
19041898
hadError = true;
1899+
}
1900+
1901+
// Then apply the solution to the filtering condition, if there is one.
1902+
if (auto *whereExpr = forEachStmt->getWhere()) {
1903+
auto whereTarget = *cs.getTargetFor(whereExpr);
1904+
if (auto rewrittenWhereTarget = rewriter.rewriteTarget(whereTarget)) {
1905+
forEachStmt->setWhere(rewrittenWhereTarget->getAsExpr());
1906+
} else {
1907+
hadError = true;
1908+
}
1909+
}
19051910

19061911
auto body = visit(forEachStmt->getBody()).get<Stmt *>();
19071912
forEachStmt->setBody(cast<BraceStmt>(body));

lib/Sema/SyntacticElementTarget.cpp

+6-8
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,8 @@ SyntacticElementTarget::forReturn(ReturnStmt *returnStmt, Type contextTy,
192192

193193
SyntacticElementTarget
194194
SyntacticElementTarget::forForEachPreamble(ForEachStmt *stmt, DeclContext *dc,
195-
bool ignoreWhereClause,
196195
GenericEnvironment *packElementEnv) {
197-
SyntacticElementTarget target(stmt, dc, ignoreWhereClause, packElementEnv);
196+
SyntacticElementTarget target(stmt, dc, packElementEnv);
198197
return target;
199198
}
200199

@@ -234,8 +233,8 @@ ContextualPattern SyntacticElementTarget::getContextualPattern() const {
234233
}
235234

236235
if (isForEachPreamble()) {
237-
return ContextualPattern::forRawPattern(forEachStmt.pattern,
238-
forEachStmt.dc);
236+
return ContextualPattern::forRawPattern(forEachPreamble.pattern,
237+
forEachPreamble.dc);
239238
}
240239

241240
auto ctp = getExprContextualTypePurpose();
@@ -400,7 +399,7 @@ SyntacticElementTarget::walk(ASTWalker &walker) const {
400399
break;
401400
}
402401
case Kind::forEachPreamble: {
403-
// We need to skip the where clause if requested, and we currently do not
402+
// We need to skip the where clause, and we currently do not
404403
// type-check a for loop's BraceStmt as part of the SyntacticElementTarget,
405404
// so we need to skip it here.
406405
// TODO: We ought to be able to fold BraceStmt checking into the constraint
@@ -421,8 +420,7 @@ SyntacticElementTarget::walk(ASTWalker &walker) const {
421420
}
422421

423422
PreWalkResult<Expr *> walkToExprPre(Expr *E) override {
424-
// Ignore where clause if needed.
425-
if (Target.ignoreForEachWhereClause() && E == ForStmt->getWhere())
423+
if (E == ForStmt->getWhere())
426424
return Action::SkipNode(E);
427425

428426
E = E->walk(Walker);
@@ -458,7 +456,7 @@ SyntacticElementTarget::walk(ASTWalker &walker) const {
458456
ForEachWalker forEachWalker(walker, *this);
459457

460458
if (auto *newStmt = getAsForEachStmt()->walk(forEachWalker)) {
461-
result.forEachStmt.stmt = cast<ForEachStmt>(newStmt);
459+
result.forEachPreamble.stmt = cast<ForEachStmt>(newStmt);
462460
} else {
463461
return std::nullopt;
464462
}

lib/Sema/TypeCheckConstraints.cpp

+17-2
Original file line numberDiff line numberDiff line change
@@ -890,11 +890,26 @@ bool TypeChecker::typeCheckForEachPreamble(DeclContext *dc, ForEachStmt *stmt,
890890
return true;
891891
};
892892

893-
auto target = SyntacticElementTarget::forForEachPreamble(
894-
stmt, dc, /*ignoreWhereClause=*/false, packElementEnv);
893+
auto target =
894+
SyntacticElementTarget::forForEachPreamble(stmt, dc, packElementEnv);
895895
if (!typeCheckTarget(target))
896896
return failed();
897897

898+
if (auto *where = stmt->getWhere()) {
899+
auto boolType = dc->getASTContext().getBoolType();
900+
if (!boolType)
901+
return failed();
902+
903+
SyntacticElementTarget whereClause(stmt->getWhere(), dc,
904+
{boolType, CTP_Condition},
905+
/*isDiscarded=*/false);
906+
auto result = typeCheckTarget(whereClause);
907+
if (!result)
908+
return true;
909+
910+
stmt->setWhere(result->getAsExpr());
911+
}
912+
898913
// Check to see if the sequence expr is throwing (in async context),
899914
// if so require the stmt to have a `try`.
900915
if (diagnoseUnhandledThrowsInAsyncContext(dc, stmt))

0 commit comments

Comments
 (0)