Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typed throws] Infer thrown error type for multi-statement closures #70478

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions include/swift/Sema/Constraint.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#define SWIFT_SEMA_CONSTRAINT_H

#include "swift/AST/ASTNode.h"
#include "swift/AST/CatchNode.h"
#include "swift/AST/FunctionRefKind.h"
#include "swift/AST/Identifier.h"
#include "swift/AST/Type.h"
Expand Down Expand Up @@ -221,6 +222,10 @@ enum class ConstraintKind : char {
/// The first type is a tuple containing a single unlabeled element that is a
/// pack expansion. The second type is that pack expansion.
MaterializePackExpansion,
/// The first type is the thrown error type, and the second entry is a
/// CatchNode whose potential throw sites will be collected to determine
/// the thrown error type.
CaughtError,
};

/// Classification of the different kinds of constraints.
Expand Down Expand Up @@ -443,6 +448,14 @@ class Constraint final : public llvm::ilist_node<Constraint>,
DeclContext *UseDC;
} Overload;

struct {
/// The first type, which represents the thrown error type.
Type First;

/// The catch node, for which the potential throw sites
CatchNode Node;
} CaughtError;

struct {
/// The node itself.
ASTNode Element;
Expand Down Expand Up @@ -507,6 +520,11 @@ class Constraint final : public llvm::ilist_node<Constraint>,
ConstraintLocator *locator,
SmallPtrSetImpl<TypeVariableType *> &typeVars);

/// Construct a caught error constraint.
Constraint(Type type, CatchNode catchNode,
ConstraintLocator *locator,
SmallPtrSetImpl<TypeVariableType *> &typeVars);

/// Retrieve the type variables buffer, for internal mutation.
MutableArrayRef<TypeVariableType *> getTypeVariablesBuffer() {
return { getTrailingObjects<TypeVariableType *>(), NumTypeVariables };
Expand Down Expand Up @@ -602,6 +620,13 @@ class Constraint final : public llvm::ilist_node<Constraint>,
ConstraintLocator *locator,
bool isDiscarded = false);

/// Construct a caught error constraint.
static Constraint *createCaughtError(
ConstraintSystem &cs,
Type type, CatchNode catchNode,
ConstraintLocator *locator,
ArrayRef<TypeVariableType *> referencedVars);

/// Determine the kind of constraint.
ConstraintKind getKind() const { return Kind; }

Expand Down Expand Up @@ -691,6 +716,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
case ConstraintKind::PackElementOf:
case ConstraintKind::SameShape:
case ConstraintKind::MaterializePackExpansion:
case ConstraintKind::CaughtError:
return ConstraintClassification::Relational;

case ConstraintKind::ValueMember:
Expand Down Expand Up @@ -743,6 +769,9 @@ class Constraint final : public llvm::ilist_node<Constraint>,
case ConstraintKind::SyntacticElement:
llvm_unreachable("closure body element constraint has no type operands");

case ConstraintKind::CaughtError:
return CaughtError.First;

default:
return Types.First;
}
Expand All @@ -755,6 +784,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
case ConstraintKind::Conjunction:
case ConstraintKind::BindOverload:
case ConstraintKind::SyntacticElement:
case ConstraintKind::CaughtError:
llvm_unreachable("constraint has no second type");

case ConstraintKind::ValueMember:
Expand Down Expand Up @@ -878,6 +908,11 @@ class Constraint final : public llvm::ilist_node<Constraint>,
return SyntacticElement.IsDiscarded;
}

CatchNode getCatchNode() const {
assert(Kind == ConstraintKind::CaughtError);
return CaughtError.Node;
}

/// For an applicable function constraint, retrieve the trailing closure
/// matching rule.
llvm::Optional<TrailingClosureMatching> getTrailingClosureMatching() const;
Expand Down
14 changes: 14 additions & 0 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -3347,6 +3347,14 @@ class ConstraintSystem {
/// Determine the caught error type for the given catch node.
Type getCaughtErrorType(CatchNode node);

/// Infer the caught error type for this catch node, once we have all of
/// the potential throw sites.
Type inferCaughtErrorType(CatchNode node);

/// Return the type variable that represents the inferred thrown error
/// type for this closure, or NULL if the thrown error type is not inferred.
TypeVariableType *getInferredThrownError(ClosureExpr *closure);

/// Retrieve the constraint locator for the given anchor and
/// path, uniqued.
ConstraintLocator *
Expand Down Expand Up @@ -5108,6 +5116,12 @@ class ConstraintSystem {
TypeMatchOptions flags,
ConstraintLocatorBuilder locator);

/// Compute the caught error type for a given catch node.
SolutionKind
simplifyCaughtErrorConstraint(Type caughtError, CatchNode catchNode,
TypeMatchOptions flags,
ConstraintLocatorBuilder locator);

public: // FIXME: Public for use by static functions.
/// Simplify a conversion constraint with a fix applied to it.
SolutionKind simplifyFixConstraint(ConstraintFix *fix, Type type1, Type type2,
Expand Down
1 change: 1 addition & 0 deletions lib/Sema/CSBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1730,6 +1730,7 @@ void PotentialBindings::infer(Constraint *constraint) {
case ConstraintKind::PackElementOf:
case ConstraintKind::SameShape:
case ConstraintKind::MaterializePackExpansion:
case ConstraintKind::CaughtError:
// Constraints from which we can't do anything.
break;

Expand Down
11 changes: 11 additions & 0 deletions lib/Sema/CSGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2471,6 +2471,17 @@ namespace {
if (closure->getThrowsLoc().isValid())
return Type();

// If we are inferring thrown error types, create a type variable
// to capture the thrown error type. This will be resolved based on the
// throw sites that occur within the body of the closure.
// FIXME: Single-expression closures don't yet work.
if (CS.getASTContext().LangOpts.hasFeature(Feature::FullTypedThrows) &&
!CS.getAppliedResultBuilderTransform(closure) &&
!closure->hasSingleExpressionBody()) {
return Type(
CS.createTypeVariable(thrownErrorLocator, TVO_CanBindToHole));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is allowed to bind to a hole we should handle it in TypeVariableBinding::fixForHole by producing an appropriate fix otherwise we'd end up crashing during solution verification when the type cannot be inferred.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect I don't need it to bind to a hole at all, thank you for noticing this.

}

// Thrown type inferred from context.
if (auto contextualType = CS.getContextualType(
closure, /*forConstraint=*/false)) {
Expand Down
21 changes: 21 additions & 0 deletions lib/Sema/CSSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2304,6 +2304,7 @@ ConstraintSystem::matchTupleTypes(TupleType *tuple1, TupleType *tuple2,
case ConstraintKind::ExplicitGenericArguments:
case ConstraintKind::SameShape:
case ConstraintKind::MaterializePackExpansion:
case ConstraintKind::CaughtError:
llvm_unreachable("Bad constraint kind in matchTupleTypes()");
}

Expand Down Expand Up @@ -2665,6 +2666,7 @@ static bool matchFunctionRepresentations(FunctionType::ExtInfo einfo1,
case ConstraintKind::ExplicitGenericArguments:
case ConstraintKind::SameShape:
case ConstraintKind::MaterializePackExpansion:
case ConstraintKind::CaughtError:
return true;
}

Expand Down Expand Up @@ -3186,6 +3188,7 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
case ConstraintKind::ExplicitGenericArguments:
case ConstraintKind::SameShape:
case ConstraintKind::MaterializePackExpansion:
case ConstraintKind::CaughtError:
llvm_unreachable("Not a relational constraint");
}

Expand Down Expand Up @@ -6965,6 +6968,7 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind,
case ConstraintKind::ExplicitGenericArguments:
case ConstraintKind::SameShape:
case ConstraintKind::MaterializePackExpansion:
case ConstraintKind::CaughtError:
llvm_unreachable("Not a relational constraint");
}
}
Expand Down Expand Up @@ -13647,6 +13651,17 @@ ConstraintSystem::simplifyMaterializePackExpansionConstraint(
return SolutionKind::Error;
}

ConstraintSystem::SolutionKind
ConstraintSystem::simplifyCaughtErrorConstraint(
Type type,
CatchNode catchNode,
TypeMatchOptions flags,
ConstraintLocatorBuilder locator) {
Type caughtErrorType = inferCaughtErrorType(catchNode);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's only used in multi-statement closures at the moment it's guaranteed to have all the throw sites resolved by the time it comes to this constraint (since it's the last in the body). The situation is more interesting for single-expression closures because they support bi-directional inference and the constraint needs to be re-activated every time something in one of its throw sites gets resolved.

I think what needs to happen to support single-expression closures is:

  • CaughtError constraint needs to reference all of the relevant (relevance should be determined by the PotentialThrowSite) type variables from all of its potential throw sites; This way it would never run into the risk of being disconnected from its context;
    • Referencing these variables also triggers re-activation which is the most important.
  • Simplification logic needs to go over all of the referenced variables and check whether they are fixed before attempting to call inferCaughtErrorType and re-introduce the constraint otherwise;
  • CaughtError constraint for single-expression closures could be introduced @ https://github.com/apple/swift/blob/main/lib/Sema/CSSyntacticElement.cpp#L1132 when all of the potential throw sites in the body have been identified.

addConstraint(ConstraintKind::Bind, type, caughtErrorType, locator);
return SolutionKind::Solved;
}

ConstraintSystem::SolutionKind
ConstraintSystem::simplifyExplicitGenericArgumentsConstraint(
Type type1, Type type2, TypeMatchOptions flags,
Expand Down Expand Up @@ -15293,6 +15308,7 @@ ConstraintSystem::addConstraintImpl(ConstraintKind kind, Type first,
case ConstraintKind::KeyPathApplication:
case ConstraintKind::FallbackType:
case ConstraintKind::SyntacticElement:
case ConstraintKind::CaughtError:
llvm_unreachable("Use the correct addConstraint()");
}

Expand Down Expand Up @@ -15879,6 +15895,11 @@ ConstraintSystem::simplifyConstraint(const Constraint &constraint) {
return simplifyMaterializePackExpansionConstraint(
constraint.getFirstType(), constraint.getSecondType(),
/*flags*/ llvm::None, constraint.getLocator());

case ConstraintKind::CaughtError:
return simplifyCaughtErrorConstraint(
constraint.getFirstType(), constraint.getCatchNode(),
/*flags*/ llvm::None, constraint.getLocator());
}

llvm_unreachable("Unhandled ConstraintKind in switch.");
Expand Down
22 changes: 18 additions & 4 deletions lib/Sema/CSSyntacticElement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,16 +361,29 @@ static void createConjunction(ConstraintSystem &cs, DeclContext *dc,
cs, element, context, elementLoc, isDiscarded));
}

for (auto *externalVar : paramCollector.getTypeVars())
referencedVars.push_back(externalVar);

// If the body of the closure is being used to infer the thrown error type
// of that closure, introduce a constraint to do so.
if (locator->directlyAt<ClosureExpr>()) {
auto *closure = castToExpr<ClosureExpr>(locator->getAnchor());
if (auto thrownErrorTypeVar = cs.getInferredThrownError(closure)) {
referencedVars.push_back(thrownErrorTypeVar);
constraints.push_back(
Constraint::createCaughtError(cs, Type(thrownErrorTypeVar), closure,
locator, referencedVars));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of points here:

  • This constraint needs a custom locator, maybe use ClosureThrownError just like the type variable but it cannot be directly on the closure.

  • I don't think referencedVars needs to be passed here because the constraint itself doesn't actually reference all of these variables and there is no risk of it being disconnected from the conjunction.

referencedVars.pop_back();
}
}

// It's possible that there are no viable elements in the body,
// because e.g. whole body is an `#if` statement or it only has
// declarations that are checked during solution application.
// In such cases, let's avoid creating a conjunction.
if (constraints.empty())
return;

for (auto *externalVar : paramCollector.getTypeVars())
referencedVars.push_back(externalVar);

cs.addUnsolvedConstraint(Constraint::createConjunction(
cs, constraints, isIsolated, locator, referencedVars));
}
Expand Down Expand Up @@ -1092,10 +1105,11 @@ class SyntacticElementConstraintGenerator
void visitBraceStmt(BraceStmt *braceStmt) {
auto &ctx = cs.getASTContext();

ClosureExpr *closure = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Doesn't seem to be necessary since it's not used anywhere else?

CaptureListExpr *captureList = nullptr;
{
if (locator->directlyAt<ClosureExpr>()) {
auto *closure = castToExpr<ClosureExpr>(locator->getAnchor());
closure = castToExpr<ClosureExpr>(locator->getAnchor());
captureList = getAsExpr<CaptureListExpr>(cs.getParentExpr(closure));
}
}
Expand Down
39 changes: 39 additions & 0 deletions lib/Sema/Constraint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second,

case ConstraintKind::SyntacticElement:
llvm_unreachable("Syntactic element constraint should use create()");

case ConstraintKind::CaughtError:
llvm_unreachable("Caught error constraint should use createCaughtError()");
}

std::uninitialized_copy(typeVars.begin(), typeVars.end(),
Expand Down Expand Up @@ -175,6 +178,7 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second, Type Third,
case ConstraintKind::ExplicitGenericArguments:
case ConstraintKind::SameShape:
case ConstraintKind::MaterializePackExpansion:
case ConstraintKind::CaughtError:
llvm_unreachable("Wrong constructor");

case ConstraintKind::KeyPath:
Expand Down Expand Up @@ -282,6 +286,19 @@ Constraint::Constraint(ASTNode node, ContextualTypeInfo context,
std::copy(typeVars.begin(), typeVars.end(), getTypeVariablesBuffer().begin());
}

Constraint::Constraint(Type type, CatchNode catchNode,
ConstraintLocator *locator,
SmallPtrSetImpl<TypeVariableType *> &typeVars)
: Kind(ConstraintKind::CaughtError), TheFix(nullptr),
HasRestriction(false), IsActive(false), IsDisabled(false),
IsDisabledForPerformance(false), RememberChoice(false), IsFavored(false),
IsIsolated(false),
NumTypeVariables(typeVars.size()),
CaughtError{type, catchNode},
Locator(locator) {
std::copy(typeVars.begin(), typeVars.end(), getTypeVariablesBuffer().begin());
}

ProtocolDecl *Constraint::getProtocol() const {
assert((Kind == ConstraintKind::ConformsTo ||
Kind == ConstraintKind::LiteralConformsTo ||
Expand Down Expand Up @@ -364,6 +381,10 @@ Constraint *Constraint::clone(ConstraintSystem &cs) const {
case ConstraintKind::SyntacticElement:
return createSyntacticElement(cs, getSyntacticElement(), getLocator(),
isDiscardedElement());

case ConstraintKind::CaughtError:
return createCaughtError(cs, getFirstType(), getCatchNode(), getLocator(),
getTypeVariables());
}

llvm_unreachable("Unhandled ConstraintKind in switch.");
Expand Down Expand Up @@ -592,6 +613,11 @@ void Constraint::print(llvm::raw_ostream &Out, SourceManager *sm,
llvm_unreachable("conjunction handled above");
case ConstraintKind::SyntacticElement:
llvm_unreachable("syntactic element handled above");
case ConstraintKind::CaughtError:
Out << " caught error type for ";
simple_display(Out, getCatchNode());
skipSecond = true;
break;
}

if (!skipSecond)
Expand Down Expand Up @@ -771,6 +797,7 @@ gatherReferencedTypeVars(Constraint *constraint,
break;

case ConstraintKind::SyntacticElement:
case ConstraintKind::CaughtError:
typeVars.insert(constraint->getTypeVariables().begin(),
constraint->getTypeVariables().end());
break;
Expand Down Expand Up @@ -1118,6 +1145,18 @@ Constraint *Constraint::createSyntacticElement(ConstraintSystem &cs,
return new (mem) Constraint(node, context, isDiscarded, locator, typeVars);
}

Constraint *Constraint::createCaughtError(
ConstraintSystem &cs,
Type type, CatchNode catchNode,
ConstraintLocator *locator,
ArrayRef<TypeVariableType *> referencedVars) {
SmallPtrSet<TypeVariableType *, 4> typeVars;
typeVars.insert(referencedVars.begin(), referencedVars.end());
unsigned size = totalSizeToAlloc<TypeVariableType *>(typeVars.size());
void *mem = cs.getAllocator().Allocate(size, alignof(Constraint));
return new (mem) Constraint(type, catchNode, locator, typeVars);
}

llvm::Optional<TrailingClosureMatching>
Constraint::getTrailingClosureMatching() const {
assert(Kind == ConstraintKind::ApplicableFunction);
Expand Down
14 changes: 14 additions & 0 deletions lib/Sema/ConstraintSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,11 @@ Type ConstraintSystem::getCaughtErrorType(CatchNode catchNode) {
}

// Handle inference of caught error types.
return inferCaughtErrorType(catchNode);
}

Type ConstraintSystem::inferCaughtErrorType(CatchNode catchNode) {
ASTContext &ctx = getASTContext();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: appears to be unused.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, thanks!


// Collect all of the potential throw sites for this catch node.
SmallVector<PotentialThrowSite, 2> throwSites;
Expand Down Expand Up @@ -477,6 +482,15 @@ Type ConstraintSystem::getCaughtErrorType(CatchNode catchNode) {
return caughtErrorType;
}

TypeVariableType *
ConstraintSystem::getInferredThrownError(ClosureExpr *closure) {
auto closureType = getClosureType(closure);
if (Type thrownError = closureType->getThrownError())
return thrownError->getAs<TypeVariableType>();

return nullptr;
}

ConstraintLocator *ConstraintSystem::getConstraintLocator(
ASTNode anchor, ArrayRef<ConstraintLocator::PathElement> path) {
auto summaryFlags = ConstraintLocator::getSummaryFlagsForPath(path);
Expand Down
1 change: 0 additions & 1 deletion lib/Sema/TypeCheckConstraints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1680,7 +1680,6 @@ void ConstraintSystem::print(raw_ostream &out) const {
out << "\n";
});
out << "\n";

}
}

Expand Down
Loading