Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Full typed throws #71704

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/swift/AST/CatchNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "swift/AST/Decl.h"
#include "swift/AST/Expr.h"
#include "swift/AST/Stmt.h"

#include "swift/AST/ASTNode.h"
namespace swift {

/// An AST node that represents a point where a thrown error can be caught and
Expand All @@ -45,6 +45,8 @@ class CatchNode: public llvm::PointerUnion<
/// needs to be inferred.
Type getExplicitCaughtType(ASTContext &ctx) const;

explicit operator ASTNode() const;

friend llvm::hash_code hash_value(CatchNode catchNode) {
using llvm::hash_value;
return hash_value(catchNode.getOpaqueValue());
Expand Down
35 changes: 35 additions & 0 deletions include/swift/Sema/Constraint.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#define SWIFT_SEMA_CONSTRAINT_H

#include "swift/AST/ASTNode.h"
#include "swift/AST/CatchNode.h"
#include "swift/AST/FunctionRefKind.h"
#include "swift/AST/Identifier.h"
#include "swift/AST/Type.h"
Expand Down Expand Up @@ -221,6 +222,10 @@ enum class ConstraintKind : char {
/// The first type is a tuple containing a single unlabeled element that is a
/// pack expansion. The second type is its pattern type.
MaterializePackExpansion,
/// The first type is the thrown error type, and the second entry is a
/// CatchNode whose potential throw sites will be collected to determine
/// the thrown error type.
CaughtError,
};

/// Classification of the different kinds of constraints.
Expand Down Expand Up @@ -443,6 +448,14 @@ class Constraint final : public llvm::ilist_node<Constraint>,
DeclContext *UseDC;
} Overload;

struct {
/// The first type, which represents the thrown error type.
Type First;

/// The catch node, for which the potential throw sites
CatchNode Node;
} CaughtError;

struct {
/// The node itself.
ASTNode Element;
Expand Down Expand Up @@ -507,6 +520,11 @@ class Constraint final : public llvm::ilist_node<Constraint>,
ConstraintLocator *locator,
SmallPtrSetImpl<TypeVariableType *> &typeVars);

/// Construct a caught error constraint.
Constraint(Type type, CatchNode catchNode,
ConstraintLocator *locator,
SmallPtrSetImpl<TypeVariableType *> &typeVars);

/// Retrieve the type variables buffer, for internal mutation.
MutableArrayRef<TypeVariableType *> getTypeVariablesBuffer() {
return { getTrailingObjects<TypeVariableType *>(), NumTypeVariables };
Expand Down Expand Up @@ -602,6 +620,13 @@ class Constraint final : public llvm::ilist_node<Constraint>,
ConstraintLocator *locator,
bool isDiscarded = false);

/// Construct a caught error constraint.
static Constraint *createCaughtError(
ConstraintSystem &cs,
Type type, CatchNode catchNode,
ConstraintLocator *locator,
ArrayRef<TypeVariableType *> referencedVars);

/// Determine the kind of constraint.
ConstraintKind getKind() const { return Kind; }

Expand Down Expand Up @@ -691,6 +716,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
case ConstraintKind::PackElementOf:
case ConstraintKind::SameShape:
case ConstraintKind::MaterializePackExpansion:
case ConstraintKind::CaughtError:
return ConstraintClassification::Relational;

case ConstraintKind::ValueMember:
Expand Down Expand Up @@ -743,6 +769,9 @@ class Constraint final : public llvm::ilist_node<Constraint>,
case ConstraintKind::SyntacticElement:
llvm_unreachable("closure body element constraint has no type operands");

case ConstraintKind::CaughtError:
return CaughtError.First;

default:
return Types.First;
}
Expand All @@ -755,6 +784,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
case ConstraintKind::Conjunction:
case ConstraintKind::BindOverload:
case ConstraintKind::SyntacticElement:
case ConstraintKind::CaughtError:
llvm_unreachable("constraint has no second type");

case ConstraintKind::ValueMember:
Expand Down Expand Up @@ -878,6 +908,11 @@ class Constraint final : public llvm::ilist_node<Constraint>,
return SyntacticElement.IsDiscarded;
}

CatchNode getCatchNode() const {
assert(Kind == ConstraintKind::CaughtError);
return CaughtError.Node;
}

/// For an applicable function constraint, retrieve the trailing closure
/// matching rule.
llvm::Optional<TrailingClosureMatching> getTrailingClosureMatching() const;
Expand Down
7 changes: 7 additions & 0 deletions include/swift/Sema/ConstraintLocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,13 @@ class ConstraintLocator : public llvm::FoldingSetNode {
return false;
}

/// Determine whether this locator points directly to a given statement.
template <typename E> bool directlyAtStmt() const {
if (auto *stmt = getAnchor().dyn_cast<Stmt *>())
return isa<E>(stmt) && getPath().empty();
return false;
}

/// Check whether the first element in the path of this locator (if any)
/// is a given \c LocatorPathElt subclass.
template <class T>
Expand Down
25 changes: 24 additions & 1 deletion include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -1446,6 +1446,15 @@ struct MatchCallArgumentResult {
/// `x.y` is a potential throw site.
struct PotentialThrowSite {
enum Kind {
/// The caught type variable for this catch node, which is a stand-in
/// for the type that will be caught by the particular catch node.
///
/// This is not strictly a potential throw site, but is convenient
/// to represent it as such so we don't need a separate mapping
/// from catch node -> caught type variable in the constraint solver's
/// state.
CaughtTypeVariable,

/// The application of a function or subscript.
Application,

Expand Down Expand Up @@ -3425,9 +3434,14 @@ class ConstraintSystem {
PotentialThrowSite::Kind kind, Type type,
ConstraintLocatorBuilder locator);

/// Determine the caught error type for the given catch node.
/// Retrieve the caught error type for the given catch node, which could be
/// a type variable if the caught error type is going to be inferred.
Type getCaughtErrorType(CatchNode node);

/// Finalize the caught error type type once all of the potential throw
/// sites are known.
void finalizeCaughtErrorType(CatchNode node);

/// Retrieve the constraint locator for the given anchor and
/// path, uniqued.
ConstraintLocator *
Expand Down Expand Up @@ -5206,6 +5220,12 @@ class ConstraintSystem {
TypeMatchOptions flags,
ConstraintLocatorBuilder locator);

/// Compute the caught error type for a given catch node.
SolutionKind
simplifyCaughtErrorConstraint(Type caughtError, CatchNode catchNode,
TypeMatchOptions flags,
ConstraintLocatorBuilder locator);

public: // FIXME: Public for use by static functions.
/// Simplify a conversion constraint with a fix applied to it.
SolutionKind simplifyFixConstraint(ConstraintFix *fix, Type type1, Type type2,
Expand Down Expand Up @@ -6450,6 +6470,9 @@ class TypeVarRefCollector : public ASTWalker {
/// Infer the referenced type variables from a given decl.
void inferTypeVars(Decl *D);

/// Infer the referenced type variables from a type.
void inferTypeVars(Type type);

MacroWalking getMacroWalkingBehavior() const override {
return MacroWalking::Arguments;
}
Expand Down
12 changes: 12 additions & 0 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12088,6 +12088,18 @@ Type CatchNode::getExplicitCaughtType(ASTContext &ctx) const {
ctx.evaluator, ExplicitCaughtTypeRequest{&ctx, *this}, Type());
}

CatchNode::operator ASTNode() const {
if (auto func = dyn_cast<AbstractFunctionDecl *>())
return func;
if (auto closure = dyn_cast<ClosureExpr *>())
return closure;
if (auto doCatch = dyn_cast<DoCatchStmt *>())
return doCatch;
if (auto anyTry = dyn_cast<AnyTryExpr *>())
return anyTry;
llvm_unreachable("Unhandled catch node");
}

void swift::simple_display(llvm::raw_ostream &out, CatchNode catchNode) {
out << "catch node";
}
Expand Down
1 change: 1 addition & 0 deletions lib/Sema/CSBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1739,6 +1739,7 @@ void PotentialBindings::infer(Constraint *constraint) {
case ConstraintKind::PackElementOf:
case ConstraintKind::SameShape:
case ConstraintKind::MaterializePackExpansion:
case ConstraintKind::CaughtError:
// Constraints from which we can't do anything.
break;

Expand Down
21 changes: 19 additions & 2 deletions lib/Sema/CSGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -865,8 +865,12 @@ void TypeVarRefCollector::inferTypeVars(Decl *D) {
if (!ty)
return;

inferTypeVars(ty);
}

void TypeVarRefCollector::inferTypeVars(Type type) {
SmallPtrSet<TypeVariableType *, 4> typeVars;
ty->getTypeVariables(typeVars);
type->getTypeVariables(typeVars);
TypeVars.insert(typeVars.begin(), typeVars.end());
}

Expand Down Expand Up @@ -915,7 +919,12 @@ TypeVarRefCollector::walkToStmtPre(Stmt *stmt) {
if (isa<ReturnStmt>(stmt) && DCDepth == 0 &&
!Locator->directlyAt<ClosureExpr>()) {
SmallPtrSet<TypeVariableType *, 4> typeVars;
CS.getClosureType(CE)->getResult()->getTypeVariables(typeVars);
auto closureType = CS.getClosureType(CE);
closureType->getResult()->getTypeVariables(typeVars);

if (auto thrownErrorType = closureType->getThrownError())
thrownErrorType->getTypeVariables(typeVars);

TypeVars.insert(typeVars.begin(), typeVars.end());
}
}
Expand Down Expand Up @@ -2467,6 +2476,14 @@ namespace {
if (closure->getThrowsLoc().isValid())
return Type();

// If we are inferring thrown error types, create a type variable
// to capture the thrown error type. This will be resolved based on the
// throw sites that occur within the body of the closure.
if (CS.getASTContext().LangOpts.hasFeature(
Feature::FullTypedThrows)) {
return Type(CS.createTypeVariable(thrownErrorLocator, 0));
}

// Thrown type inferred from context.
if (auto contextualType = CS.getContextualType(
closure, /*forConstraint=*/false)) {
Expand Down
43 changes: 34 additions & 9 deletions lib/Sema/CSSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2311,6 +2311,7 @@ ConstraintSystem::matchTupleTypes(TupleType *tuple1, TupleType *tuple2,
case ConstraintKind::ExplicitGenericArguments:
case ConstraintKind::SameShape:
case ConstraintKind::MaterializePackExpansion:
case ConstraintKind::CaughtError:
llvm_unreachable("Bad constraint kind in matchTupleTypes()");
}

Expand Down Expand Up @@ -2672,6 +2673,7 @@ static bool matchFunctionRepresentations(FunctionType::ExtInfo einfo1,
case ConstraintKind::ExplicitGenericArguments:
case ConstraintKind::SameShape:
case ConstraintKind::MaterializePackExpansion:
case ConstraintKind::CaughtError:
return true;
}

Expand Down Expand Up @@ -3294,6 +3296,7 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
case ConstraintKind::ExplicitGenericArguments:
case ConstraintKind::SameShape:
case ConstraintKind::MaterializePackExpansion:
case ConstraintKind::CaughtError:
llvm_unreachable("Not a relational constraint");
}

Expand Down Expand Up @@ -7084,6 +7087,7 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind,
case ConstraintKind::ExplicitGenericArguments:
case ConstraintKind::SameShape:
case ConstraintKind::MaterializePackExpansion:
case ConstraintKind::CaughtError:
llvm_unreachable("Not a relational constraint");
}
}
Expand Down Expand Up @@ -13012,9 +13016,10 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyApplicableFnConstraint(
// Local function to form an unsolved result.
auto formUnsolved = [&](bool activate = false) {
if (flags.contains(TMF_GenerateConstraints)) {
auto fixedLocator = getConstraintLocator(locator);
auto *application = Constraint::createApplicableFunction(
*this, type1, type2, trailingClosureMatching,
getConstraintLocator(locator));
fixedLocator);

addUnsolvedConstraint(application);
if (activate)
Expand Down Expand Up @@ -13044,10 +13049,6 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyApplicableFnConstraint(
// is not valid for operators though, where an inout parameter does not
// have an explicit inout argument.
if (type1.getPointer() == desugar2) {
// Note that this could throw.
recordPotentialThrowSite(
PotentialThrowSite::Application, Type(desugar2), outerLocator);

if (!isOperator || !hasInOut()) {
recordMatchCallArgumentResult(
getConstraintLocator(
Expand Down Expand Up @@ -13098,10 +13099,6 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyApplicableFnConstraint(

// For a function, bind the output and convert the argument to the input.
if (auto func2 = dyn_cast<FunctionType>(desugar2)) {
// Note that this could throw.
recordPotentialThrowSite(
PotentialThrowSite::Application, Type(desugar2), outerLocator);

ConstraintKind subKind = (isOperator
? ConstraintKind::OperatorArgumentConversion
: ConstraintKind::ArgumentConversion);
Expand Down Expand Up @@ -13837,6 +13834,20 @@ ConstraintSystem::simplifyMaterializePackExpansionConstraint(
return SolutionKind::Error;
}

ConstraintSystem::SolutionKind
ConstraintSystem::simplifyCaughtErrorConstraint(
Type type,
CatchNode catchNode,
TypeMatchOptions flags,
ConstraintLocatorBuilder locator) {
// Keep the constraint around until it simplifies beyond a type variable.
Type simplified = simplifyType(type);
if (simplified->is<TypeVariableType>())
return SolutionKind::Unsolved;

return SolutionKind::Solved;
}

ConstraintSystem::SolutionKind
ConstraintSystem::simplifyExplicitGenericArgumentsConstraint(
Type type1, Type type2, TypeMatchOptions flags,
Expand Down Expand Up @@ -15494,6 +15505,7 @@ ConstraintSystem::addConstraintImpl(ConstraintKind kind, Type first,
case ConstraintKind::KeyPathApplication:
case ConstraintKind::FallbackType:
case ConstraintKind::SyntacticElement:
case ConstraintKind::CaughtError:
llvm_unreachable("Use the correct addConstraint()");
}

Expand Down Expand Up @@ -15703,6 +15715,14 @@ void ConstraintSystem::addConstraint(ConstraintKind kind, Type first,
Type second,
ConstraintLocatorBuilder locator,
bool isFavored) {
// When adding a function-application constraint, make sure to introduce
// a potential throw site.
if (kind == ConstraintKind::ApplicableFunction) {
recordPotentialThrowSite(
PotentialThrowSite::Application, second,
getConstraintLocator(locator));
}

switch (addConstraintImpl(kind, first, second, locator, isFavored)) {
case SolutionKind::Error:
// Add a failing constraint, if needed.
Expand Down Expand Up @@ -16079,6 +16099,11 @@ ConstraintSystem::simplifyConstraint(const Constraint &constraint) {
return simplifyMaterializePackExpansionConstraint(
constraint.getFirstType(), constraint.getSecondType(),
/*flags*/ llvm::None, constraint.getLocator());

case ConstraintKind::CaughtError:
return simplifyCaughtErrorConstraint(
constraint.getFirstType(), constraint.getCatchNode(),
/*flags*/ llvm::None, constraint.getLocator());
}

llvm_unreachable("Unhandled ConstraintKind in switch.");
Expand Down
Loading