diff --git a/include/swift/AST/CatchNode.h b/include/swift/AST/CatchNode.h index f564d8e03ac86..f2479604aeedd 100644 --- a/include/swift/AST/CatchNode.h +++ b/include/swift/AST/CatchNode.h @@ -19,7 +19,7 @@ #include "swift/AST/Decl.h" #include "swift/AST/Expr.h" #include "swift/AST/Stmt.h" - +#include "swift/AST/ASTNode.h" namespace swift { /// An AST node that represents a point where a thrown error can be caught and @@ -45,6 +45,8 @@ class CatchNode: public llvm::PointerUnion< /// needs to be inferred. Type getExplicitCaughtType(ASTContext &ctx) const; + explicit operator ASTNode() const; + friend llvm::hash_code hash_value(CatchNode catchNode) { using llvm::hash_value; return hash_value(catchNode.getOpaqueValue()); diff --git a/include/swift/Sema/Constraint.h b/include/swift/Sema/Constraint.h index bb5c5208dbf8d..efd44c0413b34 100644 --- a/include/swift/Sema/Constraint.h +++ b/include/swift/Sema/Constraint.h @@ -19,6 +19,7 @@ #define SWIFT_SEMA_CONSTRAINT_H #include "swift/AST/ASTNode.h" +#include "swift/AST/CatchNode.h" #include "swift/AST/FunctionRefKind.h" #include "swift/AST/Identifier.h" #include "swift/AST/Type.h" @@ -221,6 +222,10 @@ enum class ConstraintKind : char { /// The first type is a tuple containing a single unlabeled element that is a /// pack expansion. The second type is its pattern type. MaterializePackExpansion, + /// The first type is the thrown error type, and the second entry is a + /// CatchNode whose potential throw sites will be collected to determine + /// the thrown error type. + CaughtError, }; /// Classification of the different kinds of constraints. @@ -443,6 +448,14 @@ class Constraint final : public llvm::ilist_node, DeclContext *UseDC; } Overload; + struct { + /// The first type, which represents the thrown error type. + Type First; + + /// The catch node, for which the potential throw sites + CatchNode Node; + } CaughtError; + struct { /// The node itself. ASTNode Element; @@ -507,6 +520,11 @@ class Constraint final : public llvm::ilist_node, ConstraintLocator *locator, SmallPtrSetImpl &typeVars); + /// Construct a caught error constraint. + Constraint(Type type, CatchNode catchNode, + ConstraintLocator *locator, + SmallPtrSetImpl &typeVars); + /// Retrieve the type variables buffer, for internal mutation. MutableArrayRef getTypeVariablesBuffer() { return { getTrailingObjects(), NumTypeVariables }; @@ -602,6 +620,13 @@ class Constraint final : public llvm::ilist_node, ConstraintLocator *locator, bool isDiscarded = false); + /// Construct a caught error constraint. + static Constraint *createCaughtError( + ConstraintSystem &cs, + Type type, CatchNode catchNode, + ConstraintLocator *locator, + ArrayRef referencedVars); + /// Determine the kind of constraint. ConstraintKind getKind() const { return Kind; } @@ -691,6 +716,7 @@ class Constraint final : public llvm::ilist_node, case ConstraintKind::PackElementOf: case ConstraintKind::SameShape: case ConstraintKind::MaterializePackExpansion: + case ConstraintKind::CaughtError: return ConstraintClassification::Relational; case ConstraintKind::ValueMember: @@ -743,6 +769,9 @@ class Constraint final : public llvm::ilist_node, case ConstraintKind::SyntacticElement: llvm_unreachable("closure body element constraint has no type operands"); + case ConstraintKind::CaughtError: + return CaughtError.First; + default: return Types.First; } @@ -755,6 +784,7 @@ class Constraint final : public llvm::ilist_node, case ConstraintKind::Conjunction: case ConstraintKind::BindOverload: case ConstraintKind::SyntacticElement: + case ConstraintKind::CaughtError: llvm_unreachable("constraint has no second type"); case ConstraintKind::ValueMember: @@ -878,6 +908,11 @@ class Constraint final : public llvm::ilist_node, return SyntacticElement.IsDiscarded; } + CatchNode getCatchNode() const { + assert(Kind == ConstraintKind::CaughtError); + return CaughtError.Node; + } + /// For an applicable function constraint, retrieve the trailing closure /// matching rule. llvm::Optional getTrailingClosureMatching() const; diff --git a/include/swift/Sema/ConstraintLocator.h b/include/swift/Sema/ConstraintLocator.h index 1f1a91a84c1c0..eecb5ff43c3c0 100644 --- a/include/swift/Sema/ConstraintLocator.h +++ b/include/swift/Sema/ConstraintLocator.h @@ -337,6 +337,13 @@ class ConstraintLocator : public llvm::FoldingSetNode { return false; } + /// Determine whether this locator points directly to a given statement. + template bool directlyAtStmt() const { + if (auto *stmt = getAnchor().dyn_cast()) + return isa(stmt) && getPath().empty(); + return false; + } + /// Check whether the first element in the path of this locator (if any) /// is a given \c LocatorPathElt subclass. template diff --git a/include/swift/Sema/ConstraintSystem.h b/include/swift/Sema/ConstraintSystem.h index dc43b814d8c1d..f0fd27b7b870e 100644 --- a/include/swift/Sema/ConstraintSystem.h +++ b/include/swift/Sema/ConstraintSystem.h @@ -1446,6 +1446,15 @@ struct MatchCallArgumentResult { /// `x.y` is a potential throw site. struct PotentialThrowSite { enum Kind { + /// The caught type variable for this catch node, which is a stand-in + /// for the type that will be caught by the particular catch node. + /// + /// This is not strictly a potential throw site, but is convenient + /// to represent it as such so we don't need a separate mapping + /// from catch node -> caught type variable in the constraint solver's + /// state. + CaughtTypeVariable, + /// The application of a function or subscript. Application, @@ -3425,9 +3434,14 @@ class ConstraintSystem { PotentialThrowSite::Kind kind, Type type, ConstraintLocatorBuilder locator); - /// Determine the caught error type for the given catch node. + /// Retrieve the caught error type for the given catch node, which could be + /// a type variable if the caught error type is going to be inferred. Type getCaughtErrorType(CatchNode node); + /// Finalize the caught error type type once all of the potential throw + /// sites are known. + void finalizeCaughtErrorType(CatchNode node); + /// Retrieve the constraint locator for the given anchor and /// path, uniqued. ConstraintLocator * @@ -5206,6 +5220,12 @@ class ConstraintSystem { TypeMatchOptions flags, ConstraintLocatorBuilder locator); + /// Compute the caught error type for a given catch node. + SolutionKind + simplifyCaughtErrorConstraint(Type caughtError, CatchNode catchNode, + TypeMatchOptions flags, + ConstraintLocatorBuilder locator); + public: // FIXME: Public for use by static functions. /// Simplify a conversion constraint with a fix applied to it. SolutionKind simplifyFixConstraint(ConstraintFix *fix, Type type1, Type type2, @@ -6450,6 +6470,9 @@ class TypeVarRefCollector : public ASTWalker { /// Infer the referenced type variables from a given decl. void inferTypeVars(Decl *D); + /// Infer the referenced type variables from a type. + void inferTypeVars(Type type); + MacroWalking getMacroWalkingBehavior() const override { return MacroWalking::Arguments; } diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index a3fe9cff5fc85..85ef033781ce8 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -12088,6 +12088,18 @@ Type CatchNode::getExplicitCaughtType(ASTContext &ctx) const { ctx.evaluator, ExplicitCaughtTypeRequest{&ctx, *this}, Type()); } +CatchNode::operator ASTNode() const { + if (auto func = dyn_cast()) + return func; + if (auto closure = dyn_cast()) + return closure; + if (auto doCatch = dyn_cast()) + return doCatch; + if (auto anyTry = dyn_cast()) + return anyTry; + llvm_unreachable("Unhandled catch node"); +} + void swift::simple_display(llvm::raw_ostream &out, CatchNode catchNode) { out << "catch node"; } diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index 144d744a4e675..e0e8d6534fff3 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -1739,6 +1739,7 @@ void PotentialBindings::infer(Constraint *constraint) { case ConstraintKind::PackElementOf: case ConstraintKind::SameShape: case ConstraintKind::MaterializePackExpansion: + case ConstraintKind::CaughtError: // Constraints from which we can't do anything. break; diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index 14b23777fdb20..dba457c7bc23e 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -865,8 +865,12 @@ void TypeVarRefCollector::inferTypeVars(Decl *D) { if (!ty) return; + inferTypeVars(ty); +} + +void TypeVarRefCollector::inferTypeVars(Type type) { SmallPtrSet typeVars; - ty->getTypeVariables(typeVars); + type->getTypeVariables(typeVars); TypeVars.insert(typeVars.begin(), typeVars.end()); } @@ -915,7 +919,12 @@ TypeVarRefCollector::walkToStmtPre(Stmt *stmt) { if (isa(stmt) && DCDepth == 0 && !Locator->directlyAt()) { SmallPtrSet typeVars; - CS.getClosureType(CE)->getResult()->getTypeVariables(typeVars); + auto closureType = CS.getClosureType(CE); + closureType->getResult()->getTypeVariables(typeVars); + + if (auto thrownErrorType = closureType->getThrownError()) + thrownErrorType->getTypeVariables(typeVars); + TypeVars.insert(typeVars.begin(), typeVars.end()); } } @@ -2467,6 +2476,14 @@ namespace { if (closure->getThrowsLoc().isValid()) return Type(); + // If we are inferring thrown error types, create a type variable + // to capture the thrown error type. This will be resolved based on the + // throw sites that occur within the body of the closure. + if (CS.getASTContext().LangOpts.hasFeature( + Feature::FullTypedThrows)) { + return Type(CS.createTypeVariable(thrownErrorLocator, 0)); + } + // Thrown type inferred from context. if (auto contextualType = CS.getContextualType( closure, /*forConstraint=*/false)) { diff --git a/lib/Sema/CSSimplify.cpp b/lib/Sema/CSSimplify.cpp index 62731c6382fa8..50fe03d5653c6 100644 --- a/lib/Sema/CSSimplify.cpp +++ b/lib/Sema/CSSimplify.cpp @@ -2311,6 +2311,7 @@ ConstraintSystem::matchTupleTypes(TupleType *tuple1, TupleType *tuple2, case ConstraintKind::ExplicitGenericArguments: case ConstraintKind::SameShape: case ConstraintKind::MaterializePackExpansion: + case ConstraintKind::CaughtError: llvm_unreachable("Bad constraint kind in matchTupleTypes()"); } @@ -2672,6 +2673,7 @@ static bool matchFunctionRepresentations(FunctionType::ExtInfo einfo1, case ConstraintKind::ExplicitGenericArguments: case ConstraintKind::SameShape: case ConstraintKind::MaterializePackExpansion: + case ConstraintKind::CaughtError: return true; } @@ -3294,6 +3296,7 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2, case ConstraintKind::ExplicitGenericArguments: case ConstraintKind::SameShape: case ConstraintKind::MaterializePackExpansion: + case ConstraintKind::CaughtError: llvm_unreachable("Not a relational constraint"); } @@ -7084,6 +7087,7 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind, case ConstraintKind::ExplicitGenericArguments: case ConstraintKind::SameShape: case ConstraintKind::MaterializePackExpansion: + case ConstraintKind::CaughtError: llvm_unreachable("Not a relational constraint"); } } @@ -13012,9 +13016,10 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyApplicableFnConstraint( // Local function to form an unsolved result. auto formUnsolved = [&](bool activate = false) { if (flags.contains(TMF_GenerateConstraints)) { + auto fixedLocator = getConstraintLocator(locator); auto *application = Constraint::createApplicableFunction( *this, type1, type2, trailingClosureMatching, - getConstraintLocator(locator)); + fixedLocator); addUnsolvedConstraint(application); if (activate) @@ -13044,10 +13049,6 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyApplicableFnConstraint( // is not valid for operators though, where an inout parameter does not // have an explicit inout argument. if (type1.getPointer() == desugar2) { - // Note that this could throw. - recordPotentialThrowSite( - PotentialThrowSite::Application, Type(desugar2), outerLocator); - if (!isOperator || !hasInOut()) { recordMatchCallArgumentResult( getConstraintLocator( @@ -13098,10 +13099,6 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyApplicableFnConstraint( // For a function, bind the output and convert the argument to the input. if (auto func2 = dyn_cast(desugar2)) { - // Note that this could throw. - recordPotentialThrowSite( - PotentialThrowSite::Application, Type(desugar2), outerLocator); - ConstraintKind subKind = (isOperator ? ConstraintKind::OperatorArgumentConversion : ConstraintKind::ArgumentConversion); @@ -13837,6 +13834,20 @@ ConstraintSystem::simplifyMaterializePackExpansionConstraint( return SolutionKind::Error; } +ConstraintSystem::SolutionKind +ConstraintSystem::simplifyCaughtErrorConstraint( + Type type, + CatchNode catchNode, + TypeMatchOptions flags, + ConstraintLocatorBuilder locator) { + // Keep the constraint around until it simplifies beyond a type variable. + Type simplified = simplifyType(type); + if (simplified->is()) + return SolutionKind::Unsolved; + + return SolutionKind::Solved; +} + ConstraintSystem::SolutionKind ConstraintSystem::simplifyExplicitGenericArgumentsConstraint( Type type1, Type type2, TypeMatchOptions flags, @@ -15494,6 +15505,7 @@ ConstraintSystem::addConstraintImpl(ConstraintKind kind, Type first, case ConstraintKind::KeyPathApplication: case ConstraintKind::FallbackType: case ConstraintKind::SyntacticElement: + case ConstraintKind::CaughtError: llvm_unreachable("Use the correct addConstraint()"); } @@ -15703,6 +15715,14 @@ void ConstraintSystem::addConstraint(ConstraintKind kind, Type first, Type second, ConstraintLocatorBuilder locator, bool isFavored) { + // When adding a function-application constraint, make sure to introduce + // a potential throw site. + if (kind == ConstraintKind::ApplicableFunction) { + recordPotentialThrowSite( + PotentialThrowSite::Application, second, + getConstraintLocator(locator)); + } + switch (addConstraintImpl(kind, first, second, locator, isFavored)) { case SolutionKind::Error: // Add a failing constraint, if needed. @@ -16079,6 +16099,11 @@ ConstraintSystem::simplifyConstraint(const Constraint &constraint) { return simplifyMaterializePackExpansionConstraint( constraint.getFirstType(), constraint.getSecondType(), /*flags*/ llvm::None, constraint.getLocator()); + + case ConstraintKind::CaughtError: + return simplifyCaughtErrorConstraint( + constraint.getFirstType(), constraint.getCatchNode(), + /*flags*/ llvm::None, constraint.getLocator()); } llvm_unreachable("Unhandled ConstraintKind in switch."); diff --git a/lib/Sema/CSStep.cpp b/lib/Sema/CSStep.cpp index d6f7698a5b75e..47e355d6edb31 100644 --- a/lib/Sema/CSStep.cpp +++ b/lib/Sema/CSStep.cpp @@ -970,6 +970,18 @@ StepResult ConjunctionStep::resume(bool prevFailed) { if (HadFailure) return done(/*isSuccess=*/false); + auto locator = Conjunction->getLocator(); + if (auto syntacticElement = + locator->getLastElementAs()) { + // If we are at a do..catch, and we're inferring a throw error type for + // the do..catch, we're now in a position to finalize the thrown type. + if (auto doCatch = dyn_cast_or_null( + syntacticElement->getElement().dyn_cast())) { + if (CS.getCaughtErrorType(doCatch)->is()) + CS.finalizeCaughtErrorType(doCatch); + } + } + // If this was an isolated conjunction solver needs to do // the following: // @@ -1084,6 +1096,20 @@ void ConjunctionStep::restoreOuterState(const Score &solutionScore) const { void ConjunctionStep::SolverSnapshot::applySolution(const Solution &solution) { CS.applySolution(solution); + // If we are at a closure, and the closure type has a type variable for + // its thrown error type, we're now in a position to finalize that type. + auto locator = Conjunction->getLocator(); + if (locator->directlyAt()) { + auto closureTy = + CS.getClosureType(castToExpr(locator->getAnchor())); + if (auto thrownError = closureTy->getEffectiveThrownErrorTypeOrNever()) { + if (thrownError->is()) { + auto *closure = castToExpr(locator->getAnchor()); + CS.finalizeCaughtErrorType(closure); + } + } + } + if (!CS.shouldAttemptFixes()) return; @@ -1096,7 +1122,6 @@ void ConjunctionStep::SolverSnapshot::applySolution(const Solution &solution) { // has failed, let's bind all of unresolved type variables // in its interface type to holes to avoid extraneous // fixes produced by outer context. - auto locator = Conjunction->getLocator(); if (locator->directlyAt()) { auto closureTy = CS.getClosureType(castToExpr(locator->getAnchor())); diff --git a/lib/Sema/CSSyntacticElement.cpp b/lib/Sema/CSSyntacticElement.cpp index 864a26a14f294..f6c43b913ea44 100644 --- a/lib/Sema/CSSyntacticElement.cpp +++ b/lib/Sema/CSSyntacticElement.cpp @@ -324,6 +324,17 @@ static void createConjunction(ConstraintSystem &cs, DeclContext *dc, isIsolated = true; } + if (auto syntacticElement = + locator->getLastElementAs()) { + // If we are at a do..catch, and we're inferring a throw error type for + // the do..catch, we're now in a position to finalize the thrown type. + if (auto doCatch = dyn_cast_or_null( + syntacticElement->getElement().dyn_cast())) { + if (cs.getCaughtErrorType(doCatch)->is()) + isIsolated = true; + } + } + if (locator->isForSingleValueStmtConjunction()) { auto *SVE = castToExpr(locator->getAnchor()); referencedVars.push_back(cs.getType(SVE)->castTo()); @@ -367,6 +378,17 @@ static void createConjunction(ConstraintSystem &cs, DeclContext *dc, cs, element, context, elementLoc, isDiscarded)); } + // If the body of the closure is being used to infer the thrown error type + // of that closure, introduce a constraint to do so. + if (locator->directlyAt()) { + auto *closure = castToExpr(locator->getAnchor()); + if (auto thrownError = cs.getCaughtErrorType(closure)) + paramCollector.inferTypeVars(thrownError); + } + + for (auto *externalVar : paramCollector.getTypeVars()) + referencedVars.push_back(externalVar); + // It's possible that there are no viable elements in the body, // because e.g. whole body is an `#if` statement or it only has // declarations that are checked during solution application. @@ -374,9 +396,6 @@ static void createConjunction(ConstraintSystem &cs, DeclContext *dc, if (constraints.empty()) return; - for (auto *externalVar : paramCollector.getTypeVars()) - referencedVars.push_back(externalVar); - cs.addUnsolvedConstraint(Constraint::createConjunction( cs, constraints, isIsolated, locator, referencedVars)); } @@ -1028,12 +1047,18 @@ class SyntacticElementConstraintGenerator void visitDoCatchStmt(DoCatchStmt *doStmt) { SmallVector elements; + auto myLocator = locator; + if (cs.getASTContext().LangOpts.hasFeature(Feature::FullTypedThrows)) { + myLocator = cs.getConstraintLocator( + locator, LocatorPathElt::SyntacticElement(doStmt)); + } + // First, let's record a body of `do` statement. Note we need to add a // SyntaticElement locator path element here to avoid treating the inner // brace conjunction as being isolated if 'doLoc' is for an isolated // conjunction (as is the case with 'do' expressions). auto *doBodyLoc = cs.getConstraintLocator( - locator, LocatorPathElt::SyntacticElement(doStmt->getBody())); + myLocator, LocatorPathElt::SyntacticElement(doStmt->getBody())); elements.push_back(makeElement(doStmt->getBody(), doBodyLoc)); // After that has been type-checked, let's switch to @@ -1041,7 +1066,7 @@ class SyntacticElementConstraintGenerator for (auto *catchStmt : doStmt->getCatches()) elements.push_back(makeElement(catchStmt, locator)); - createConjunction(elements, locator); + createConjunction(elements, myLocator); } void visitCaseStmt(CaseStmt *caseStmt) { @@ -1089,10 +1114,11 @@ class SyntacticElementConstraintGenerator void visitBraceStmt(BraceStmt *braceStmt) { auto &ctx = cs.getASTContext(); + ClosureExpr *closure = nullptr; CaptureListExpr *captureList = nullptr; { if (locator->directlyAt()) { - auto *closure = castToExpr(locator->getAnchor()); + closure = castToExpr(locator->getAnchor()); captureList = getAsExpr(cs.getParentExpr(closure)); } } @@ -1126,6 +1152,11 @@ class SyntacticElementConstraintGenerator visitDecl(node.get()); } } + + if (closure) { + cs.finalizeCaughtErrorType(closure); + } + return; } diff --git a/lib/Sema/Constraint.cpp b/lib/Sema/Constraint.cpp index c51d0da1607ff..d0cda0d780400 100644 --- a/lib/Sema/Constraint.cpp +++ b/lib/Sema/Constraint.cpp @@ -120,6 +120,9 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second, case ConstraintKind::SyntacticElement: llvm_unreachable("Syntactic element constraint should use create()"); + + case ConstraintKind::CaughtError: + llvm_unreachable("Caught error constraint should use createCaughtError()"); } std::uninitialized_copy(typeVars.begin(), typeVars.end(), @@ -175,6 +178,7 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second, Type Third, case ConstraintKind::ExplicitGenericArguments: case ConstraintKind::SameShape: case ConstraintKind::MaterializePackExpansion: + case ConstraintKind::CaughtError: llvm_unreachable("Wrong constructor"); case ConstraintKind::KeyPath: @@ -282,6 +286,19 @@ Constraint::Constraint(ASTNode node, ContextualTypeInfo context, std::copy(typeVars.begin(), typeVars.end(), getTypeVariablesBuffer().begin()); } +Constraint::Constraint(Type type, CatchNode catchNode, + ConstraintLocator *locator, + SmallPtrSetImpl &typeVars) + : Kind(ConstraintKind::CaughtError), TheFix(nullptr), + HasRestriction(false), IsActive(false), IsDisabled(false), + IsDisabledForPerformance(false), RememberChoice(false), IsFavored(false), + IsIsolated(false), + NumTypeVariables(typeVars.size()), + CaughtError{type, catchNode}, + Locator(locator) { + std::copy(typeVars.begin(), typeVars.end(), getTypeVariablesBuffer().begin()); +} + ProtocolDecl *Constraint::getProtocol() const { assert((Kind == ConstraintKind::ConformsTo || Kind == ConstraintKind::LiteralConformsTo || @@ -364,6 +381,10 @@ Constraint *Constraint::clone(ConstraintSystem &cs) const { case ConstraintKind::SyntacticElement: return createSyntacticElement(cs, getSyntacticElement(), getLocator(), isDiscardedElement()); + + case ConstraintKind::CaughtError: + return createCaughtError(cs, getFirstType(), getCatchNode(), getLocator(), + getTypeVariables()); } llvm_unreachable("Unhandled ConstraintKind in switch."); @@ -592,6 +613,11 @@ void Constraint::print(llvm::raw_ostream &Out, SourceManager *sm, llvm_unreachable("conjunction handled above"); case ConstraintKind::SyntacticElement: llvm_unreachable("syntactic element handled above"); + case ConstraintKind::CaughtError: + Out << " caught error type for "; + simple_display(Out, getCatchNode()); + skipSecond = true; + break; } if (!skipSecond) @@ -771,6 +797,7 @@ gatherReferencedTypeVars(Constraint *constraint, break; case ConstraintKind::SyntacticElement: + case ConstraintKind::CaughtError: typeVars.insert(constraint->getTypeVariables().begin(), constraint->getTypeVariables().end()); break; @@ -1118,6 +1145,18 @@ Constraint *Constraint::createSyntacticElement(ConstraintSystem &cs, return new (mem) Constraint(node, context, isDiscarded, locator, typeVars); } +Constraint *Constraint::createCaughtError( + ConstraintSystem &cs, + Type type, CatchNode catchNode, + ConstraintLocator *locator, + ArrayRef referencedVars) { + SmallPtrSet typeVars; + typeVars.insert(referencedVars.begin(), referencedVars.end()); + unsigned size = totalSizeToAlloc(typeVars.size()); + void *mem = cs.getAllocator().Allocate(size, alignof(Constraint)); + return new (mem) Constraint(type, catchNode, locator, typeVars); +} + llvm::Optional Constraint::getTrailingClosureMatching() const { assert(Kind == ConstraintKind::ApplicableFunction); diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp index 234f7196450ae..8e9d9cc1cd1a2 100644 --- a/lib/Sema/ConstraintSystem.cpp +++ b/lib/Sema/ConstraintSystem.cpp @@ -425,17 +425,55 @@ Type ConstraintSystem::getCaughtErrorType(CatchNode catchNode) { return explicitCaughtType; } - // Retrieve the thrown error type of a closure. - // FIXME: This will need to change when we do inference of thrown error - // types in closures. + // Retrieve the thrown error type of a closure, which will be in the + // constraint system already. if (auto closure = catchNode.dyn_cast()) { return getClosureType(closure)->getEffectiveThrownErrorTypeOrNever(); } + // At this point, if we aren't doing full typed throws inference, the caught + // error type is always 'any Error'. if (!ctx.LangOpts.hasFeature(Feature::FullTypedThrows)) return ctx.getErrorExistentialType(); - // Handle inference of caught error types. + // Look for a record that we already created the caught error type of this + // catch node. + // FIXME: This is horribly inefficient linear scan, because we need a new + // data structure for potentialThrowSites. + for (const auto &potentialThrowSite : llvm::reverse(potentialThrowSites)) { + if (potentialThrowSite.first == catchNode && + potentialThrowSite.second.kind == + PotentialThrowSite::CaughtTypeVariable) { + return potentialThrowSite.second.type; + } + } + + // Create a type variable to represent the caught error. + // FIXME: How do we know when it's safe to simplify this constraint? + // FIXME: Can we reactivate based on whether we've visited the body? + auto caughtTypeVariableLocator = getConstraintLocator(ASTNode(catchNode)); + auto caughtTypeVariable = createTypeVariable(caughtTypeVariableLocator, 0); + auto constraint = Constraint::createCaughtError( + *this, Type(caughtTypeVariable), catchNode, + caughtTypeVariableLocator, { }); + addUnsolvedConstraint(constraint); + + // Record this type variable in the list of potential throw sites, + // so we can find it again later rather than creating a new one. + // This is effectively using the potential throw sites as a + // CatchNode -> TypeVariableType * map, without requiring a separate + // data structure. + potentialThrowSites.push_back( + {catchNode, + PotentialThrowSite{PotentialThrowSite::CaughtTypeVariable, + Type(caughtTypeVariable), + caughtTypeVariableLocator}}); + + return Type(caughtTypeVariable); +} + +void ConstraintSystem::finalizeCaughtErrorType(CatchNode catchNode) { + ASTContext &ctx = getASTContext(); // Collect all of the potential throw sites for this catch node. SmallVector throwSites; @@ -445,15 +483,32 @@ Type ConstraintSystem::getCaughtErrorType(CatchNode catchNode) { } } + /// The caught error type thus far, which starts at Never because Never + /// doesn't contribute to throwing anything. Type caughtErrorType = ctx.getNeverType(); + TypeVariableType *caughtTypeVariable = nullptr; for (const auto &throwSite : throwSites) { Type type = simplifyType(throwSite.type); - Type thrownErrorType; switch (throwSite.kind) { + case PotentialThrowSite::CaughtTypeVariable: + // Keep track of the caught type variable. + caughtTypeVariable = throwSite.type->castTo(); + + // Ignore the caught type variable; it contributes nothing. + continue; + case PotentialThrowSite::Application: { - auto fnType = type->castTo(); - thrownErrorType = fnType->getEffectiveThrownErrorTypeOrNever(); + auto fnType = type->getAs(); + if (!fnType) { + // This applicable function constraint either still involves type + // variables or wasn't actually a function. + continue; + } + + // Dig out the thrown error type from the function type. + thrownErrorType = + simplifyType(fnType->getEffectiveThrownErrorTypeOrNever()); break; } @@ -477,7 +532,21 @@ Type ConstraintSystem::getCaughtErrorType(CatchNode catchNode) { break; } - return caughtErrorType; + // If we didn't record a caught type variable, check whether this is a + // closure that infers its type variable. + if (!caughtTypeVariable) { + if (auto closure = catchNode.dyn_cast()) { + caughtTypeVariable = getClosureType(closure) + ->getEffectiveThrownErrorTypeOrNever()->getAs(); + } + } + + // If we have a type variable, bind it to the caught error type. + if (caughtTypeVariable) { + addConstraint( + ConstraintKind::Bind, caughtTypeVariable, caughtErrorType, + getConstraintLocator(ASTNode(catchNode))); + } } ConstraintLocator *ConstraintSystem::getConstraintLocator( @@ -3970,6 +4039,8 @@ void ConstraintSystem::resolveOverload(ConstraintLocator *locator, // If accessing this declaration could throw an error, record this as a // potential throw site. if (thrownErrorTypeOnAccess) { + // FIXME: We likely need to do this before overload resolution, + // otherwise we might establish the thrown error type too early. recordPotentialThrowSite( PotentialThrowSite::PropertyAccess, thrownErrorTypeOnAccess, locator); } diff --git a/lib/Sema/TypeCheckConstraints.cpp b/lib/Sema/TypeCheckConstraints.cpp index 3d8d9f5194e6f..2b377d2024c7c 100644 --- a/lib/Sema/TypeCheckConstraints.cpp +++ b/lib/Sema/TypeCheckConstraints.cpp @@ -1679,18 +1679,22 @@ void ConstraintSystem::print(raw_ostream &out) const { out.indent(indent) << "Potential throw sites:\n"; interleave(potentialThrowSites, [&](const auto &throwSite) { out.indent(indent + 2); + out << " -" << throwSite.second.type->getString(PO) << " is "; switch (throwSite.second.kind) { + case PotentialThrowSite::CaughtTypeVariable: + out << "caught type variable @ "; + break; case PotentialThrowSite::Application: - out << "- application @ "; + out << "application @ "; break; case PotentialThrowSite::ExplicitThrow: - out << " - explicit throw @ "; + out << "explicit throw @ "; break; case PotentialThrowSite::NonExhaustiveDoCatch: - out << " - non-exhaustive do..catch @ "; + out << "non-exhaustive do..catch @ "; break; case PotentialThrowSite::PropertyAccess: - out << " - property access @ "; + out << "property access @ "; break; } @@ -1699,7 +1703,6 @@ void ConstraintSystem::print(raw_ostream &out) const { out << "\n"; }); out << "\n"; - } } diff --git a/test/expr/closure/typed_throws_full.swift b/test/expr/closure/typed_throws_full.swift new file mode 100644 index 0000000000000..b8470efbe4687 --- /dev/null +++ b/test/expr/closure/typed_throws_full.swift @@ -0,0 +1,49 @@ +// RUN: %target-typecheck-verify-swift -enable-experimental-feature TypedThrows -enable-upcoming-feature FullTypedThrows + +enum MyError: Error { +case failed +case epicFailed +} + +func doSomething() throws(MyError) -> Int { 5 } + +func apply(body: () throws(E) -> T) throws(E) -> T { + return try body() +} + +func doNothing() { } + +func testSingleStatement() { + let c1 = { + throw MyError.failed + } + #if false + let _: () throws(MyError) -> Void = c1 + + let c2 = { + try doSomething() + } + let _: () throws(MyError) -> Int = c2 + + let c3 = { + return try doSomething() + } + let _: () throws(MyError) -> Int = c3 + #endif +} + +#if false +func testMultiStatement() { + let c1 = { + doNothing() + throw MyError.failed + } + let _: () throws(MyError) -> Void = c1 + + let c2 = { + doNothing() + return try doSomething() + } + let _: () throws(MyError) -> Int = c2 +} +#endif