Skip to content

Commit 8b1f368

Browse files
committed
WIP full types throws inference
1 parent 03fce0e commit 8b1f368

File tree

8 files changed

+140
-133
lines changed

8 files changed

+140
-133
lines changed

include/swift/Sema/ConstraintLocator.h

+7
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,13 @@ class ConstraintLocator : public llvm::FoldingSetNode {
337337
return false;
338338
}
339339

340+
/// Determine whether this locator points directly to a given statement.
341+
template <typename E> bool directlyAtStmt() const {
342+
if (auto *stmt = getAnchor().dyn_cast<Stmt *>())
343+
return isa<E>(stmt) && getPath().empty();
344+
return false;
345+
}
346+
340347
/// Check whether the first element in the path of this locator (if any)
341348
/// is a given \c LocatorPathElt subclass.
342349
template <class T>

include/swift/Sema/ConstraintSystem.h

+8-9
Original file line numberDiff line numberDiff line change
@@ -3434,17 +3434,13 @@ class ConstraintSystem {
34343434
PotentialThrowSite::Kind kind, Type type,
34353435
ConstraintLocatorBuilder locator);
34363436

3437-
/// Determine the caught error type for the given catch node.
3437+
/// Retrieve the caught error type for the given catch node, which could be
3438+
/// a type variable if the caught error type is going to be inferred.
34383439
Type getCaughtErrorType(CatchNode node);
34393440

3440-
/// Infer the caught error type for this catch node, or introduce an
3441-
/// appropriate type variable to describe that caught error type if
3442-
/// it cannot be computed yet.
3443-
Type inferCaughtErrorType(CatchNode node);
3444-
3445-
/// Return the type variable that represents the inferred thrown error
3446-
/// type for this closure, or NULL if the thrown error type is not inferred.
3447-
TypeVariableType *getInferredThrownError(ClosureExpr *closure);
3441+
/// Finalize the caught error type type once all of the potential throw
3442+
/// sites are known.
3443+
void finalizeCaughtErrorType(CatchNode node);
34483444

34493445
/// Retrieve the constraint locator for the given anchor and
34503446
/// path, uniqued.
@@ -6474,6 +6470,9 @@ class TypeVarRefCollector : public ASTWalker {
64746470
/// Infer the referenced type variables from a given decl.
64756471
void inferTypeVars(Decl *D);
64766472

6473+
/// Infer the referenced type variables from a type.
6474+
void inferTypeVars(Type type);
6475+
64776476
MacroWalking getMacroWalkingBehavior() const override {
64786477
return MacroWalking::Arguments;
64796478
}

lib/Sema/CSGen.cpp

+10-4
Original file line numberDiff line numberDiff line change
@@ -865,8 +865,12 @@ void TypeVarRefCollector::inferTypeVars(Decl *D) {
865865
if (!ty)
866866
return;
867867

868+
inferTypeVars(ty);
869+
}
870+
871+
void TypeVarRefCollector::inferTypeVars(Type type) {
868872
SmallPtrSet<TypeVariableType *, 4> typeVars;
869-
ty->getTypeVariables(typeVars);
873+
type->getTypeVariables(typeVars);
870874
TypeVars.insert(typeVars.begin(), typeVars.end());
871875
}
872876

@@ -915,10 +919,12 @@ TypeVarRefCollector::walkToStmtPre(Stmt *stmt) {
915919
if (isa<ReturnStmt>(stmt) && DCDepth == 0 &&
916920
!Locator->directlyAt<ClosureExpr>()) {
917921
SmallPtrSet<TypeVariableType *, 4> typeVars;
918-
CS.getClosureType(CE)->getResult()->getTypeVariables(typeVars);
922+
auto closureType = CS.getClosureType(CE);
923+
closureType->getResult()->getTypeVariables(typeVars);
924+
925+
if (auto thrownErrorType = closureType->getThrownError())
926+
thrownErrorType->getTypeVariables(typeVars);
919927

920-
FIXME if we're doing full typed throws, also look at the thrown
921-
error type here?
922928
TypeVars.insert(typeVars.begin(), typeVars.end());
923929
}
924930
}

lib/Sema/CSSimplify.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -13840,11 +13840,11 @@ ConstraintSystem::simplifyCaughtErrorConstraint(
1384013840
CatchNode catchNode,
1384113841
TypeMatchOptions flags,
1384213842
ConstraintLocatorBuilder locator) {
13843-
Type caughtErrorType = inferCaughtErrorType(catchNode);
13844-
if (caughtErrorType->isEqual(type))
13843+
// Keep the constraint around until it simplifies beyond a type variable.
13844+
Type simplified = simplifyType(type);
13845+
if (simplified->is<TypeVariableType>())
1384513846
return SolutionKind::Unsolved;
1384613847

13847-
addConstraint(ConstraintKind::Bind, type, caughtErrorType, locator);
1384813848
return SolutionKind::Solved;
1384913849
}
1385013850

lib/Sema/CSStep.cpp

+26-1
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,18 @@ StepResult ConjunctionStep::resume(bool prevFailed) {
970970
if (HadFailure)
971971
return done(/*isSuccess=*/false);
972972

973+
auto locator = Conjunction->getLocator();
974+
if (auto syntacticElement =
975+
locator->getLastElementAs<LocatorPathElt::SyntacticElement>()) {
976+
// If we are at a do..catch, and we're inferring a throw error type for
977+
// the do..catch, we're now in a position to finalize the thrown type.
978+
if (auto doCatch = dyn_cast_or_null<DoCatchStmt>(
979+
syntacticElement->getElement().dyn_cast<Stmt *>())) {
980+
if (CS.getCaughtErrorType(doCatch)->is<TypeVariableType>())
981+
CS.finalizeCaughtErrorType(doCatch);
982+
}
983+
}
984+
973985
// If this was an isolated conjunction solver needs to do
974986
// the following:
975987
//
@@ -1084,6 +1096,20 @@ void ConjunctionStep::restoreOuterState(const Score &solutionScore) const {
10841096
void ConjunctionStep::SolverSnapshot::applySolution(const Solution &solution) {
10851097
CS.applySolution(solution);
10861098

1099+
// If we are at a closure, and the closure type has a type variable for
1100+
// its thrown error type, we're now in a position to finalize that type.
1101+
auto locator = Conjunction->getLocator();
1102+
if (locator->directlyAt<ClosureExpr>()) {
1103+
auto closureTy =
1104+
CS.getClosureType(castToExpr<ClosureExpr>(locator->getAnchor()));
1105+
if (auto thrownError = closureTy->getEffectiveThrownErrorTypeOrNever()) {
1106+
if (thrownError->is<TypeVariableType>()) {
1107+
auto *closure = castToExpr<ClosureExpr>(locator->getAnchor());
1108+
CS.finalizeCaughtErrorType(closure);
1109+
}
1110+
}
1111+
}
1112+
10871113
if (!CS.shouldAttemptFixes())
10881114
return;
10891115

@@ -1096,7 +1122,6 @@ void ConjunctionStep::SolverSnapshot::applySolution(const Solution &solution) {
10961122
// has failed, let's bind all of unresolved type variables
10971123
// in its interface type to holes to avoid extraneous
10981124
// fixes produced by outer context.
1099-
auto locator = Conjunction->getLocator();
11001125
if (locator->directlyAt<ClosureExpr>()) {
11011126
auto closureTy =
11021127
CS.getClosureType(castToExpr<ClosureExpr>(locator->getAnchor()));

lib/Sema/CSSyntacticElement.cpp

+26-28
Original file line numberDiff line numberDiff line change
@@ -287,23 +287,6 @@ static bool isViableElement(ASTNode element,
287287
using ElementInfo = std::tuple<ASTNode, ContextualTypeInfo,
288288
/*isDiscarded=*/bool, ConstraintLocator *>;
289289

290-
static TypeVariableType *assignClosureThrownErrorType(
291-
ConstraintSystem &cs, ClosureExpr *closure) {
292-
// FIXME: Remove this once the inference is working in general.
293-
if (!cs.getASTContext().LangOpts.hasFeature(Feature::FullTypedThrows))
294-
return nullptr;
295-
296-
auto closureType = cs.getClosureType(closure);
297-
auto thrownType = closureType->getEffectiveThrownErrorTypeOrNever();
298-
auto computedThrownType = cs.inferCaughtErrorType(closure);
299-
cs.addConstraint(
300-
ConstraintKind::Conversion, computedThrownType, thrownType,
301-
cs.getConstraintLocator(closure,
302-
ConstraintLocator::ClosureThrownError));
303-
304-
return computedThrownType->getAs<TypeVariableType>();
305-
}
306-
307290
static void createConjunction(ConstraintSystem &cs, DeclContext *dc,
308291
ArrayRef<ElementInfo> elements,
309292
ConstraintLocator *locator, bool isIsolated,
@@ -341,6 +324,17 @@ static void createConjunction(ConstraintSystem &cs, DeclContext *dc,
341324
isIsolated = true;
342325
}
343326

327+
if (auto syntacticElement =
328+
locator->getLastElementAs<LocatorPathElt::SyntacticElement>()) {
329+
// If we are at a do..catch, and we're inferring a throw error type for
330+
// the do..catch, we're now in a position to finalize the thrown type.
331+
if (auto doCatch = dyn_cast_or_null<DoCatchStmt>(
332+
syntacticElement->getElement().dyn_cast<Stmt *>())) {
333+
if (cs.getCaughtErrorType(doCatch)->is<TypeVariableType>())
334+
isIsolated = true;
335+
}
336+
}
337+
344338
if (locator->isForSingleValueStmtConjunction()) {
345339
auto *SVE = castToExpr<SingleValueStmtExpr>(locator->getAnchor());
346340
referencedVars.push_back(cs.getType(SVE)->castTo<TypeVariableType>());
@@ -384,19 +378,17 @@ static void createConjunction(ConstraintSystem &cs, DeclContext *dc,
384378
cs, element, context, elementLoc, isDiscarded));
385379
}
386380

387-
for (auto *externalVar : paramCollector.getTypeVars())
388-
referencedVars.push_back(externalVar);
389-
390-
#if false
391381
// If the body of the closure is being used to infer the thrown error type
392382
// of that closure, introduce a constraint to do so.
393383
if (locator->directlyAt<ClosureExpr>()) {
394384
auto *closure = castToExpr<ClosureExpr>(locator->getAnchor());
395-
if (auto thrownErrorTypeVar = assignClosureThrownErrorType(cs, closure))
396-
referencedVars.push_back(thrownErrorTypeVar);
385+
if (auto thrownError = cs.getCaughtErrorType(closure))
386+
paramCollector.inferTypeVars(thrownError);
397387
}
398-
#endif
399-
388+
389+
for (auto *externalVar : paramCollector.getTypeVars())
390+
referencedVars.push_back(externalVar);
391+
400392
// It's possible that there are no viable elements in the body,
401393
// because e.g. whole body is an `#if` statement or it only has
402394
// declarations that are checked during solution application.
@@ -1055,20 +1047,26 @@ class SyntacticElementConstraintGenerator
10551047
void visitDoCatchStmt(DoCatchStmt *doStmt) {
10561048
SmallVector<ElementInfo, 4> elements;
10571049

1050+
auto myLocator = locator;
1051+
if (cs.getASTContext().LangOpts.hasFeature(Feature::FullTypedThrows)) {
1052+
myLocator = cs.getConstraintLocator(
1053+
locator, LocatorPathElt::SyntacticElement(doStmt));
1054+
}
1055+
10581056
// First, let's record a body of `do` statement. Note we need to add a
10591057
// SyntaticElement locator path element here to avoid treating the inner
10601058
// brace conjunction as being isolated if 'doLoc' is for an isolated
10611059
// conjunction (as is the case with 'do' expressions).
10621060
auto *doBodyLoc = cs.getConstraintLocator(
1063-
locator, LocatorPathElt::SyntacticElement(doStmt->getBody()));
1061+
myLocator, LocatorPathElt::SyntacticElement(doStmt->getBody()));
10641062
elements.push_back(makeElement(doStmt->getBody(), doBodyLoc));
10651063

10661064
// After that has been type-checked, let's switch to
10671065
// individual `catch` statements.
10681066
for (auto *catchStmt : doStmt->getCatches())
10691067
elements.push_back(makeElement(catchStmt, locator));
10701068

1071-
createConjunction(elements, locator);
1069+
createConjunction(elements, myLocator);
10721070
}
10731071

10741072
void visitCaseStmt(CaseStmt *caseStmt) {
@@ -1156,7 +1154,7 @@ class SyntacticElementConstraintGenerator
11561154
}
11571155

11581156
if (closure) {
1159-
assignClosureThrownErrorType(cs, closure);
1157+
cs.finalizeCaughtErrorType(closure);
11601158
}
11611159

11621160
return;

0 commit comments

Comments
 (0)