From f843c0041564a7388921689c0304d1ab3c0c9484 Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Thu, 14 Dec 2023 16:11:24 -0800 Subject: [PATCH] [Typed throws] Infer thrown error type for multi-statement closures --- include/swift/Sema/Constraint.h | 35 ++++++++++++++++ include/swift/Sema/ConstraintSystem.h | 14 +++++++ lib/Sema/CSBindings.cpp | 1 + lib/Sema/CSGen.cpp | 11 +++++ lib/Sema/CSSimplify.cpp | 21 ++++++++++ lib/Sema/CSSyntacticElement.cpp | 22 ++++++++-- lib/Sema/Constraint.cpp | 39 ++++++++++++++++++ lib/Sema/ConstraintSystem.cpp | 14 +++++++ lib/Sema/TypeCheckConstraints.cpp | 1 - test/expr/closure/typed_throws_full.swift | 49 +++++++++++++++++++++++ 10 files changed, 202 insertions(+), 5 deletions(-) create mode 100644 test/expr/closure/typed_throws_full.swift diff --git a/include/swift/Sema/Constraint.h b/include/swift/Sema/Constraint.h index 4c30b5e46a777..5ec0a09ce029c 100644 --- a/include/swift/Sema/Constraint.h +++ b/include/swift/Sema/Constraint.h @@ -19,6 +19,7 @@ #define SWIFT_SEMA_CONSTRAINT_H #include "swift/AST/ASTNode.h" +#include "swift/AST/CatchNode.h" #include "swift/AST/FunctionRefKind.h" #include "swift/AST/Identifier.h" #include "swift/AST/Type.h" @@ -221,6 +222,10 @@ enum class ConstraintKind : char { /// The first type is a tuple containing a single unlabeled element that is a /// pack expansion. The second type is that pack expansion. MaterializePackExpansion, + /// The first type is the thrown error type, and the second entry is a + /// CatchNode whose potential throw sites will be collected to determine + /// the thrown error type. + CaughtError, }; /// Classification of the different kinds of constraints. @@ -443,6 +448,14 @@ class Constraint final : public llvm::ilist_node, DeclContext *UseDC; } Overload; + struct { + /// The first type, which represents the thrown error type. + Type First; + + /// The catch node, for which the potential throw sites + CatchNode Node; + } CaughtError; + struct { /// The node itself. ASTNode Element; @@ -507,6 +520,11 @@ class Constraint final : public llvm::ilist_node, ConstraintLocator *locator, SmallPtrSetImpl &typeVars); + /// Construct a caught error constraint. + Constraint(Type type, CatchNode catchNode, + ConstraintLocator *locator, + SmallPtrSetImpl &typeVars); + /// Retrieve the type variables buffer, for internal mutation. MutableArrayRef getTypeVariablesBuffer() { return { getTrailingObjects(), NumTypeVariables }; @@ -602,6 +620,13 @@ class Constraint final : public llvm::ilist_node, ConstraintLocator *locator, bool isDiscarded = false); + /// Construct a caught error constraint. + static Constraint *createCaughtError( + ConstraintSystem &cs, + Type type, CatchNode catchNode, + ConstraintLocator *locator, + ArrayRef referencedVars); + /// Determine the kind of constraint. ConstraintKind getKind() const { return Kind; } @@ -691,6 +716,7 @@ class Constraint final : public llvm::ilist_node, case ConstraintKind::PackElementOf: case ConstraintKind::SameShape: case ConstraintKind::MaterializePackExpansion: + case ConstraintKind::CaughtError: return ConstraintClassification::Relational; case ConstraintKind::ValueMember: @@ -743,6 +769,9 @@ class Constraint final : public llvm::ilist_node, case ConstraintKind::SyntacticElement: llvm_unreachable("closure body element constraint has no type operands"); + case ConstraintKind::CaughtError: + return CaughtError.First; + default: return Types.First; } @@ -755,6 +784,7 @@ class Constraint final : public llvm::ilist_node, case ConstraintKind::Conjunction: case ConstraintKind::BindOverload: case ConstraintKind::SyntacticElement: + case ConstraintKind::CaughtError: llvm_unreachable("constraint has no second type"); case ConstraintKind::ValueMember: @@ -878,6 +908,11 @@ class Constraint final : public llvm::ilist_node, return SyntacticElement.IsDiscarded; } + CatchNode getCatchNode() const { + assert(Kind == ConstraintKind::CaughtError); + return CaughtError.Node; + } + /// For an applicable function constraint, retrieve the trailing closure /// matching rule. llvm::Optional getTrailingClosureMatching() const; diff --git a/include/swift/Sema/ConstraintSystem.h b/include/swift/Sema/ConstraintSystem.h index fd307553b7755..e72729182f367 100644 --- a/include/swift/Sema/ConstraintSystem.h +++ b/include/swift/Sema/ConstraintSystem.h @@ -3347,6 +3347,14 @@ class ConstraintSystem { /// Determine the caught error type for the given catch node. Type getCaughtErrorType(CatchNode node); + /// Infer the caught error type for this catch node, once we have all of + /// the potential throw sites. + Type inferCaughtErrorType(CatchNode node); + + /// Return the type variable that represents the inferred thrown error + /// type for this closure, or NULL if the thrown error type is not inferred. + TypeVariableType *getInferredThrownError(ClosureExpr *closure); + /// Retrieve the constraint locator for the given anchor and /// path, uniqued. ConstraintLocator * @@ -5108,6 +5116,12 @@ class ConstraintSystem { TypeMatchOptions flags, ConstraintLocatorBuilder locator); + /// Compute the caught error type for a given catch node. + SolutionKind + simplifyCaughtErrorConstraint(Type caughtError, CatchNode catchNode, + TypeMatchOptions flags, + ConstraintLocatorBuilder locator); + public: // FIXME: Public for use by static functions. /// Simplify a conversion constraint with a fix applied to it. SolutionKind simplifyFixConstraint(ConstraintFix *fix, Type type1, Type type2, diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index 3219a46a4e3dc..123aba722c01a 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -1730,6 +1730,7 @@ void PotentialBindings::infer(Constraint *constraint) { case ConstraintKind::PackElementOf: case ConstraintKind::SameShape: case ConstraintKind::MaterializePackExpansion: + case ConstraintKind::CaughtError: // Constraints from which we can't do anything. break; diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index 27158a2d8c4e0..1854f7a35f9e6 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -2471,6 +2471,17 @@ namespace { if (closure->getThrowsLoc().isValid()) return Type(); + // If we are inferring thrown error types, create a type variable + // to capture the thrown error type. This will be resolved based on the + // throw sites that occur within the body of the closure. + // FIXME: Single-expression closures don't yet work. + if (CS.getASTContext().LangOpts.hasFeature(Feature::FullTypedThrows) && + !CS.getAppliedResultBuilderTransform(closure) && + !closure->hasSingleExpressionBody()) { + return Type( + CS.createTypeVariable(thrownErrorLocator, TVO_CanBindToHole)); + } + // Thrown type inferred from context. if (auto contextualType = CS.getContextualType( closure, /*forConstraint=*/false)) { diff --git a/lib/Sema/CSSimplify.cpp b/lib/Sema/CSSimplify.cpp index 67378295e235a..9d290a5b54756 100644 --- a/lib/Sema/CSSimplify.cpp +++ b/lib/Sema/CSSimplify.cpp @@ -2304,6 +2304,7 @@ ConstraintSystem::matchTupleTypes(TupleType *tuple1, TupleType *tuple2, case ConstraintKind::ExplicitGenericArguments: case ConstraintKind::SameShape: case ConstraintKind::MaterializePackExpansion: + case ConstraintKind::CaughtError: llvm_unreachable("Bad constraint kind in matchTupleTypes()"); } @@ -2665,6 +2666,7 @@ static bool matchFunctionRepresentations(FunctionType::ExtInfo einfo1, case ConstraintKind::ExplicitGenericArguments: case ConstraintKind::SameShape: case ConstraintKind::MaterializePackExpansion: + case ConstraintKind::CaughtError: return true; } @@ -3186,6 +3188,7 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2, case ConstraintKind::ExplicitGenericArguments: case ConstraintKind::SameShape: case ConstraintKind::MaterializePackExpansion: + case ConstraintKind::CaughtError: llvm_unreachable("Not a relational constraint"); } @@ -6965,6 +6968,7 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind, case ConstraintKind::ExplicitGenericArguments: case ConstraintKind::SameShape: case ConstraintKind::MaterializePackExpansion: + case ConstraintKind::CaughtError: llvm_unreachable("Not a relational constraint"); } } @@ -13647,6 +13651,17 @@ ConstraintSystem::simplifyMaterializePackExpansionConstraint( return SolutionKind::Error; } +ConstraintSystem::SolutionKind +ConstraintSystem::simplifyCaughtErrorConstraint( + Type type, + CatchNode catchNode, + TypeMatchOptions flags, + ConstraintLocatorBuilder locator) { + Type caughtErrorType = inferCaughtErrorType(catchNode); + addConstraint(ConstraintKind::Bind, type, caughtErrorType, locator); + return SolutionKind::Solved; +} + ConstraintSystem::SolutionKind ConstraintSystem::simplifyExplicitGenericArgumentsConstraint( Type type1, Type type2, TypeMatchOptions flags, @@ -15293,6 +15308,7 @@ ConstraintSystem::addConstraintImpl(ConstraintKind kind, Type first, case ConstraintKind::KeyPathApplication: case ConstraintKind::FallbackType: case ConstraintKind::SyntacticElement: + case ConstraintKind::CaughtError: llvm_unreachable("Use the correct addConstraint()"); } @@ -15879,6 +15895,11 @@ ConstraintSystem::simplifyConstraint(const Constraint &constraint) { return simplifyMaterializePackExpansionConstraint( constraint.getFirstType(), constraint.getSecondType(), /*flags*/ llvm::None, constraint.getLocator()); + + case ConstraintKind::CaughtError: + return simplifyCaughtErrorConstraint( + constraint.getFirstType(), constraint.getCatchNode(), + /*flags*/ llvm::None, constraint.getLocator()); } llvm_unreachable("Unhandled ConstraintKind in switch."); diff --git a/lib/Sema/CSSyntacticElement.cpp b/lib/Sema/CSSyntacticElement.cpp index 4f0c8700039a2..05f7b07995738 100644 --- a/lib/Sema/CSSyntacticElement.cpp +++ b/lib/Sema/CSSyntacticElement.cpp @@ -361,6 +361,22 @@ static void createConjunction(ConstraintSystem &cs, DeclContext *dc, cs, element, context, elementLoc, isDiscarded)); } + for (auto *externalVar : paramCollector.getTypeVars()) + referencedVars.push_back(externalVar); + + // If the body of the closure is being used to infer the thrown error type + // of that closure, introduce a constraint to do so. + if (locator->directlyAt()) { + auto *closure = castToExpr(locator->getAnchor()); + if (auto thrownErrorTypeVar = cs.getInferredThrownError(closure)) { + referencedVars.push_back(thrownErrorTypeVar); + constraints.push_back( + Constraint::createCaughtError(cs, Type(thrownErrorTypeVar), closure, + locator, referencedVars)); + referencedVars.pop_back(); + } + } + // It's possible that there are no viable elements in the body, // because e.g. whole body is an `#if` statement or it only has // declarations that are checked during solution application. @@ -368,9 +384,6 @@ static void createConjunction(ConstraintSystem &cs, DeclContext *dc, if (constraints.empty()) return; - for (auto *externalVar : paramCollector.getTypeVars()) - referencedVars.push_back(externalVar); - cs.addUnsolvedConstraint(Constraint::createConjunction( cs, constraints, isIsolated, locator, referencedVars)); } @@ -1092,10 +1105,11 @@ class SyntacticElementConstraintGenerator void visitBraceStmt(BraceStmt *braceStmt) { auto &ctx = cs.getASTContext(); + ClosureExpr *closure = nullptr; CaptureListExpr *captureList = nullptr; { if (locator->directlyAt()) { - auto *closure = castToExpr(locator->getAnchor()); + closure = castToExpr(locator->getAnchor()); captureList = getAsExpr(cs.getParentExpr(closure)); } } diff --git a/lib/Sema/Constraint.cpp b/lib/Sema/Constraint.cpp index c51d0da1607ff..d0cda0d780400 100644 --- a/lib/Sema/Constraint.cpp +++ b/lib/Sema/Constraint.cpp @@ -120,6 +120,9 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second, case ConstraintKind::SyntacticElement: llvm_unreachable("Syntactic element constraint should use create()"); + + case ConstraintKind::CaughtError: + llvm_unreachable("Caught error constraint should use createCaughtError()"); } std::uninitialized_copy(typeVars.begin(), typeVars.end(), @@ -175,6 +178,7 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second, Type Third, case ConstraintKind::ExplicitGenericArguments: case ConstraintKind::SameShape: case ConstraintKind::MaterializePackExpansion: + case ConstraintKind::CaughtError: llvm_unreachable("Wrong constructor"); case ConstraintKind::KeyPath: @@ -282,6 +286,19 @@ Constraint::Constraint(ASTNode node, ContextualTypeInfo context, std::copy(typeVars.begin(), typeVars.end(), getTypeVariablesBuffer().begin()); } +Constraint::Constraint(Type type, CatchNode catchNode, + ConstraintLocator *locator, + SmallPtrSetImpl &typeVars) + : Kind(ConstraintKind::CaughtError), TheFix(nullptr), + HasRestriction(false), IsActive(false), IsDisabled(false), + IsDisabledForPerformance(false), RememberChoice(false), IsFavored(false), + IsIsolated(false), + NumTypeVariables(typeVars.size()), + CaughtError{type, catchNode}, + Locator(locator) { + std::copy(typeVars.begin(), typeVars.end(), getTypeVariablesBuffer().begin()); +} + ProtocolDecl *Constraint::getProtocol() const { assert((Kind == ConstraintKind::ConformsTo || Kind == ConstraintKind::LiteralConformsTo || @@ -364,6 +381,10 @@ Constraint *Constraint::clone(ConstraintSystem &cs) const { case ConstraintKind::SyntacticElement: return createSyntacticElement(cs, getSyntacticElement(), getLocator(), isDiscardedElement()); + + case ConstraintKind::CaughtError: + return createCaughtError(cs, getFirstType(), getCatchNode(), getLocator(), + getTypeVariables()); } llvm_unreachable("Unhandled ConstraintKind in switch."); @@ -592,6 +613,11 @@ void Constraint::print(llvm::raw_ostream &Out, SourceManager *sm, llvm_unreachable("conjunction handled above"); case ConstraintKind::SyntacticElement: llvm_unreachable("syntactic element handled above"); + case ConstraintKind::CaughtError: + Out << " caught error type for "; + simple_display(Out, getCatchNode()); + skipSecond = true; + break; } if (!skipSecond) @@ -771,6 +797,7 @@ gatherReferencedTypeVars(Constraint *constraint, break; case ConstraintKind::SyntacticElement: + case ConstraintKind::CaughtError: typeVars.insert(constraint->getTypeVariables().begin(), constraint->getTypeVariables().end()); break; @@ -1118,6 +1145,18 @@ Constraint *Constraint::createSyntacticElement(ConstraintSystem &cs, return new (mem) Constraint(node, context, isDiscarded, locator, typeVars); } +Constraint *Constraint::createCaughtError( + ConstraintSystem &cs, + Type type, CatchNode catchNode, + ConstraintLocator *locator, + ArrayRef referencedVars) { + SmallPtrSet typeVars; + typeVars.insert(referencedVars.begin(), referencedVars.end()); + unsigned size = totalSizeToAlloc(typeVars.size()); + void *mem = cs.getAllocator().Allocate(size, alignof(Constraint)); + return new (mem) Constraint(type, catchNode, locator, typeVars); +} + llvm::Optional Constraint::getTrailingClosureMatching() const { assert(Kind == ConstraintKind::ApplicableFunction); diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp index a33c8206b9522..1929f66ede4e6 100644 --- a/lib/Sema/ConstraintSystem.cpp +++ b/lib/Sema/ConstraintSystem.cpp @@ -433,6 +433,11 @@ Type ConstraintSystem::getCaughtErrorType(CatchNode catchNode) { } // Handle inference of caught error types. + return inferCaughtErrorType(catchNode); +} + +Type ConstraintSystem::inferCaughtErrorType(CatchNode catchNode) { + ASTContext &ctx = getASTContext(); // Collect all of the potential throw sites for this catch node. SmallVector throwSites; @@ -477,6 +482,15 @@ Type ConstraintSystem::getCaughtErrorType(CatchNode catchNode) { return caughtErrorType; } +TypeVariableType * +ConstraintSystem::getInferredThrownError(ClosureExpr *closure) { + auto closureType = getClosureType(closure); + if (Type thrownError = closureType->getThrownError()) + return thrownError->getAs(); + + return nullptr; +} + ConstraintLocator *ConstraintSystem::getConstraintLocator( ASTNode anchor, ArrayRef path) { auto summaryFlags = ConstraintLocator::getSummaryFlagsForPath(path); diff --git a/lib/Sema/TypeCheckConstraints.cpp b/lib/Sema/TypeCheckConstraints.cpp index 1531c5846ebc6..d5c1b6c2c1eb7 100644 --- a/lib/Sema/TypeCheckConstraints.cpp +++ b/lib/Sema/TypeCheckConstraints.cpp @@ -1680,7 +1680,6 @@ void ConstraintSystem::print(raw_ostream &out) const { out << "\n"; }); out << "\n"; - } } diff --git a/test/expr/closure/typed_throws_full.swift b/test/expr/closure/typed_throws_full.swift new file mode 100644 index 0000000000000..33a8b991c7c12 --- /dev/null +++ b/test/expr/closure/typed_throws_full.swift @@ -0,0 +1,49 @@ +// RUN: %target-typecheck-verify-swift -enable-experimental-feature TypedThrows -enable-upcoming-feature FullTypedThrows + +enum MyError: Error { +case failed +case epicFailed +} + +func doSomething() throws(MyError) -> Int { 5 } + +func apply(body: () throws(E) -> T) throws(E) -> T { + return try body() +} + +func doNothing() { } + +func testSingleStatement() { + let c1 = { + throw MyError.failed + } + let _: () throws(MyError) -> Void = c1 + + let c2 = { + try doSomething() + } + // FIXME: Single-expression closures aren't inferring thrown error types. + // expected-error@+1{{invalid conversion of thrown error type 'any Error' to 'MyError'}} + let _: () throws(MyError) -> Int = c2 + + let c3 = { + return try doSomething() + } + // FIXME: Single-expression closures aren't inferring thrown error types. + // expected-error@+1{{invalid conversion of thrown error type 'any Error' to 'MyError'}} + let _: () throws(MyError) -> Int = c3 +} + +func testMultiStatement() { + let c1 = { + doNothing() + throw MyError.failed + } + let _: () throws(MyError) -> Void = c1 + + let c2 = { + doNothing() + return try doSomething() + } + let _: () throws(MyError) -> Int = c2 +}