Skip to content

Commit f461373

Browse files
committed
Clean up tracking of throws sites
1 parent e469a6f commit f461373

File tree

8 files changed

+147
-29
lines changed

8 files changed

+147
-29
lines changed

include/swift/AST/CatchNode.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include "swift/AST/Decl.h"
2020
#include "swift/AST/Expr.h"
2121
#include "swift/AST/Stmt.h"
22-
22+
#include "swift/AST/ASTNode.h"
2323
namespace swift {
2424

2525
/// An AST node that represents a point where a thrown error can be caught and
@@ -45,6 +45,8 @@ class CatchNode: public llvm::PointerUnion<
4545
/// needs to be inferred.
4646
Type getExplicitCaughtType(ASTContext &ctx) const;
4747

48+
explicit operator ASTNode() const;
49+
4850
friend llvm::hash_code hash_value(CatchNode catchNode) {
4951
using llvm::hash_value;
5052
return hash_value(catchNode.getOpaqueValue());

include/swift/Sema/ConstraintSystem.h

+12-2
Original file line numberDiff line numberDiff line change
@@ -1446,6 +1446,15 @@ struct MatchCallArgumentResult {
14461446
/// `x.y` is a potential throw site.
14471447
struct PotentialThrowSite {
14481448
enum Kind {
1449+
/// The caught type variable for this catch node, which is a stand-in
1450+
/// for the type that will be caught by the particular catch node.
1451+
///
1452+
/// This is not strictly a potential throw site, but is convenient
1453+
/// to represent it as such so we don't need a separate mapping
1454+
/// from catch node -> caught type variable in the constraint solver's
1455+
/// state.
1456+
CaughtTypeVariable,
1457+
14491458
/// The application of a function or subscript.
14501459
Application,
14511460

@@ -3428,8 +3437,9 @@ class ConstraintSystem {
34283437
/// Determine the caught error type for the given catch node.
34293438
Type getCaughtErrorType(CatchNode node);
34303439

3431-
/// Infer the caught error type for this catch node, once we have all of
3432-
/// the potential throw sites.
3440+
/// Infer the caught error type for this catch node, or introduce an
3441+
/// appropriate type variable to describe that caught error type if
3442+
/// it cannot be computed yet.
34333443
Type inferCaughtErrorType(CatchNode node);
34343444

34353445
/// Return the type variable that represents the inferred thrown error

lib/AST/Decl.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -12088,6 +12088,18 @@ Type CatchNode::getExplicitCaughtType(ASTContext &ctx) const {
1208812088
ctx.evaluator, ExplicitCaughtTypeRequest{&ctx, *this}, Type());
1208912089
}
1209012090

12091+
CatchNode::operator ASTNode() const {
12092+
if (auto func = dyn_cast<AbstractFunctionDecl *>())
12093+
return func;
12094+
if (auto closure = dyn_cast<ClosureExpr *>())
12095+
return closure;
12096+
if (auto doCatch = dyn_cast<DoCatchStmt *>())
12097+
return doCatch;
12098+
if (auto anyTry = dyn_cast<AnyTryExpr *>())
12099+
return anyTry;
12100+
llvm_unreachable("Unhandled catch node");
12101+
}
12102+
1209112103
void swift::simple_display(llvm::raw_ostream &out, CatchNode catchNode) {
1209212104
out << "catch node";
1209312105
}

lib/Sema/CSSimplify.cpp

+13-9
Original file line numberDiff line numberDiff line change
@@ -13016,9 +13016,10 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyApplicableFnConstraint(
1301613016
// Local function to form an unsolved result.
1301713017
auto formUnsolved = [&](bool activate = false) {
1301813018
if (flags.contains(TMF_GenerateConstraints)) {
13019+
auto fixedLocator = getConstraintLocator(locator);
1301913020
auto *application = Constraint::createApplicableFunction(
1302013021
*this, type1, type2, trailingClosureMatching,
13021-
getConstraintLocator(locator));
13022+
fixedLocator);
1302213023

1302313024
addUnsolvedConstraint(application);
1302413025
if (activate)
@@ -13048,10 +13049,6 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyApplicableFnConstraint(
1304813049
// is not valid for operators though, where an inout parameter does not
1304913050
// have an explicit inout argument.
1305013051
if (type1.getPointer() == desugar2) {
13051-
// Note that this could throw.
13052-
recordPotentialThrowSite(
13053-
PotentialThrowSite::Application, Type(desugar2), outerLocator);
13054-
1305513052
if (!isOperator || !hasInOut()) {
1305613053
recordMatchCallArgumentResult(
1305713054
getConstraintLocator(
@@ -13102,10 +13099,6 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyApplicableFnConstraint(
1310213099

1310313100
// For a function, bind the output and convert the argument to the input.
1310413101
if (auto func2 = dyn_cast<FunctionType>(desugar2)) {
13105-
// Note that this could throw.
13106-
recordPotentialThrowSite(
13107-
PotentialThrowSite::Application, Type(desugar2), outerLocator);
13108-
1310913102
ConstraintKind subKind = (isOperator
1311013103
? ConstraintKind::OperatorArgumentConversion
1311113104
: ConstraintKind::ArgumentConversion);
@@ -13848,6 +13841,9 @@ ConstraintSystem::simplifyCaughtErrorConstraint(
1384813841
TypeMatchOptions flags,
1384913842
ConstraintLocatorBuilder locator) {
1385013843
Type caughtErrorType = inferCaughtErrorType(catchNode);
13844+
if (caughtErrorType->isEqual(type))
13845+
return SolutionKind::Unsolved;
13846+
1385113847
addConstraint(ConstraintKind::Bind, type, caughtErrorType, locator);
1385213848
return SolutionKind::Solved;
1385313849
}
@@ -15719,6 +15715,14 @@ void ConstraintSystem::addConstraint(ConstraintKind kind, Type first,
1571915715
Type second,
1572015716
ConstraintLocatorBuilder locator,
1572115717
bool isFavored) {
15718+
// When adding a function-application constraint, make sure to introduce
15719+
// a potential throw site.
15720+
if (kind == ConstraintKind::ApplicableFunction) {
15721+
recordPotentialThrowSite(
15722+
PotentialThrowSite::Application, second,
15723+
getConstraintLocator(locator));
15724+
}
15725+
1572215726
switch (addConstraintImpl(kind, first, second, locator, isFavored)) {
1572315727
case SolutionKind::Error:
1572415728
// Add a failing constraint, if needed.

lib/Sema/CSSyntacticElement.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -374,12 +374,8 @@ static void createConjunction(ConstraintSystem &cs, DeclContext *dc,
374374
// of that closure, introduce a constraint to do so.
375375
if (locator->directlyAt<ClosureExpr>()) {
376376
auto *closure = castToExpr<ClosureExpr>(locator->getAnchor());
377-
if (auto thrownErrorTypeVar = cs.getInferredThrownError(closure)) {
377+
if (auto thrownErrorTypeVar = cs.getInferredThrownError(closure))
378378
referencedVars.push_back(thrownErrorTypeVar);
379-
constraints.push_back(
380-
Constraint::createCaughtError(cs, Type(thrownErrorTypeVar), closure,
381-
locator, referencedVars));
382-
}
383379
}
384380

385381
// It's possible that there are no viable elements in the body,
@@ -1139,6 +1135,11 @@ class SyntacticElementConstraintGenerator
11391135
visitDecl(node.get<Decl *>());
11401136
}
11411137
}
1138+
1139+
if (closure) {
1140+
auto closureType = cs.getClosureType(closure);
1141+
}
1142+
11421143
return;
11431144
}
11441145

lib/Sema/ConstraintSystem.cpp

+93-4
Original file line numberDiff line numberDiff line change
@@ -450,15 +450,56 @@ Type ConstraintSystem::inferCaughtErrorType(CatchNode catchNode) {
450450
}
451451
}
452452

453+
// The type variable that already describes the caught error type for
454+
// this particular catch node, if it exists. These are lazily created
455+
// when we are unable to determine the caught error type completely.
456+
TypeVariableType *caughtTypeVariable = nullptr;
457+
ConstraintLocator *caughtTypeVariableLocator = nullptr;
458+
459+
/// The complete set of type variables referenced by the throw sites.
460+
SmallPtrSet<TypeVariableType *, 4> referencedTypeVariables;
461+
462+
/// The caught error type thus far, which starts at Never because Never
463+
/// doesn't contribute to throwing anything.
453464
Type caughtErrorType = ctx.getNeverType();
454465
for (const auto &throwSite : throwSites) {
455466
Type type = simplifyType(throwSite.type);
456-
457467
Type thrownErrorType;
458468
switch (throwSite.kind) {
469+
case PotentialThrowSite::CaughtTypeVariable:
470+
// The type variable hasn't been bound yet, so record that we
471+
// know it exists and keep looking.
472+
if (auto typeVar = type->getAs<TypeVariableType>()) {
473+
caughtTypeVariable = typeVar;
474+
caughtTypeVariableLocator = throwSite.locator;
475+
continue;
476+
}
477+
478+
// The caught type variable has been bound to a fixed type. Return
479+
// that.
480+
return type;
481+
459482
case PotentialThrowSite::Application: {
460-
auto fnType = type->castTo<AnyFunctionType>();
461-
thrownErrorType = fnType->getEffectiveThrownErrorTypeOrNever();
483+
auto fnType = type->getAs<AnyFunctionType>();
484+
if (!fnType) {
485+
// This applicable function constraint either still involves type
486+
// variables or wasn't actually a function.
487+
//
488+
// If it still involves type variables, track those so we know
489+
// when to look at the potential throw sites again. Otherwise,
490+
// it's something like a metatype or a callable-as-function type,
491+
// in which case there will be another Application throw site that
492+
// handles the throw site.
493+
if (type->hasTypeVariable()) {
494+
type->getTypeVariables(referencedTypeVariables);
495+
}
496+
497+
continue;
498+
}
499+
500+
// Dig out the thrown error type from the function type.
501+
thrownErrorType =
502+
simplifyType(fnType->getEffectiveThrownErrorTypeOrNever());
462503
break;
463504
}
464505

@@ -469,6 +510,13 @@ Type ConstraintSystem::inferCaughtErrorType(CatchNode catchNode) {
469510
break;
470511
}
471512

513+
// If the result has type variables, there's nothing useful we can
514+
// do. Just record them and continue on.
515+
if (thrownErrorType->hasTypeVariable()) {
516+
thrownErrorType->getTypeVariables(referencedTypeVariables);
517+
continue;
518+
}
519+
472520
// Perform the errorUnion() of the caught error type so far with the
473521
// thrown error type of this potential throw site.
474522
caughtErrorType = TypeChecker::errorUnion(
@@ -482,7 +530,46 @@ Type ConstraintSystem::inferCaughtErrorType(CatchNode catchNode) {
482530
break;
483531
}
484532

485-
return caughtErrorType;
533+
// If the caught error type is 'any Error', or if there are no
534+
// unresolved type variables, then we now have a fixed type for the
535+
// caught error type.
536+
if (caughtErrorType->isErrorExistentialType() ||
537+
referencedTypeVariables.empty()) {
538+
return caughtErrorType;
539+
}
540+
541+
// We don't have sufficient information to provide a concrete caught
542+
// error type, so we will use a type variable instead. If none exists
543+
// yet, create one now.
544+
if (!caughtTypeVariable) {
545+
// Create the type variable.
546+
caughtTypeVariableLocator = getConstraintLocator(ASTNode(catchNode));
547+
caughtTypeVariable = createTypeVariable(caughtTypeVariableLocator, 0);
548+
549+
// Create a constraint stating that this is the caught error type
550+
// for the given catch node. This constraint will get reactivated
551+
// whenever any of the type variables it depends on change, causing
552+
// us to re-evaluate the potential throw sites.
553+
SmallVector<TypeVariableType *, 2> allTypeVars(
554+
referencedTypeVariables.begin(), referencedTypeVariables.end());
555+
auto constraint = Constraint::createCaughtError(
556+
*this, Type(caughtTypeVariable), catchNode,
557+
caughtTypeVariableLocator, allTypeVars);
558+
addUnsolvedConstraint(constraint);
559+
560+
// Record this type variable in the list of potential throw sites,
561+
// so we can find it again later rather than creating a new one.
562+
// This is effectively using the potential throw sites as a
563+
// CatchNode -> TypeVariableType * map, without requiring a separate
564+
// data structure.
565+
potentialThrowSites.push_back(
566+
{catchNode,
567+
PotentialThrowSite{PotentialThrowSite::CaughtTypeVariable,
568+
Type(caughtTypeVariable),
569+
caughtTypeVariableLocator}});
570+
}
571+
572+
return Type(caughtTypeVariable);
486573
}
487574

488575
TypeVariableType *
@@ -3984,6 +4071,8 @@ void ConstraintSystem::resolveOverload(ConstraintLocator *locator,
39844071
// If accessing this declaration could throw an error, record this as a
39854072
// potential throw site.
39864073
if (thrownErrorTypeOnAccess) {
4074+
// FIXME: We likely need to do this before overload resolution,
4075+
// otherwise we might establish the thrown error type too early.
39874076
recordPotentialThrowSite(
39884077
PotentialThrowSite::PropertyAccess, thrownErrorTypeOnAccess, locator);
39894078
}

lib/Sema/TypeCheckConstraints.cpp

+8-4
Original file line numberDiff line numberDiff line change
@@ -1679,18 +1679,22 @@ void ConstraintSystem::print(raw_ostream &out) const {
16791679
out.indent(indent) << "Potential throw sites:\n";
16801680
interleave(potentialThrowSites, [&](const auto &throwSite) {
16811681
out.indent(indent + 2);
1682+
out << " -" << throwSite.second.type->getString(PO) << " is ";
16821683
switch (throwSite.second.kind) {
1684+
case PotentialThrowSite::CaughtTypeVariable:
1685+
out << "caught type variable @ ";
1686+
break;
16831687
case PotentialThrowSite::Application:
1684-
out << "- application @ ";
1688+
out << "application @ ";
16851689
break;
16861690
case PotentialThrowSite::ExplicitThrow:
1687-
out << " - explicit throw @ ";
1691+
out << "explicit throw @ ";
16881692
break;
16891693
case PotentialThrowSite::NonExhaustiveDoCatch:
1690-
out << " - non-exhaustive do..catch @ ";
1694+
out << "non-exhaustive do..catch @ ";
16911695
break;
16921696
case PotentialThrowSite::PropertyAccess:
1693-
out << " - property access @ ";
1697+
out << "property access @ ";
16941698
break;
16951699
}
16961700

test/expr/closure/typed_throws_full.swift

-4
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,11 @@ func testSingleStatement() {
2222
let c2 = {
2323
try doSomething()
2424
}
25-
// FIXME: Single-expression closures aren't inferring thrown error types.
26-
// expected-error@+1{{invalid conversion of thrown error type 'any Error' to 'MyError'}}
2725
let _: () throws(MyError) -> Int = c2
2826

2927
let c3 = {
3028
return try doSomething()
3129
}
32-
// FIXME: Single-expression closures aren't inferring thrown error types.
33-
// expected-error@+1{{invalid conversion of thrown error type 'any Error' to 'MyError'}}
3430
let _: () throws(MyError) -> Int = c3
3531
}
3632

0 commit comments

Comments
 (0)