diff --git a/include/swift/AST/Builtins.def b/include/swift/AST/Builtins.def index ff7327bf7de72..9daa736a75b85 100644 --- a/include/swift/AST/Builtins.def +++ b/include/swift/AST/Builtins.def @@ -524,7 +524,7 @@ BUILTIN_SIL_OPERATION(ApplyTranspose, "applyTranspose", Special) /// resumed. BUILTIN_SIL_OPERATION(WithUnsafeContinuation, "withUnsafeContinuation", Special) -/// withUnsafeThrowingContinuation : (Builtin.RawUnsafeContinuation -> ()) async throws -> sending T +/// withUnsafeThrowingContinuation : (Builtin.RawUnsafeContinuation -> ()) async throws(E) -> sending T /// /// Unsafely capture the current continuation and pass it to the given /// function value. Returns a value of type T or throws an error when diff --git a/lib/AST/Builtins.cpp b/lib/AST/Builtins.cpp index 53712e47a9b74..48f9ce7ffd514 100644 --- a/lib/AST/Builtins.cpp +++ b/lib/AST/Builtins.cpp @@ -2137,27 +2137,31 @@ static ValueDecl *getPolymorphicBinaryOperation(ASTContext &ctx, static ValueDecl *getWithUnsafeContinuation(ASTContext &ctx, Identifier id, bool throws) { - BuiltinFunctionBuilder builder(ctx); - - auto contTy = ctx.TheRawUnsafeContinuationType; - SmallVector params; - params.emplace_back(contTy); - - auto voidTy = ctx.TheEmptyTupleType; - auto extInfo = FunctionType::ExtInfoBuilder().withNoEscape().build(); - auto *fnTy = FunctionType::get(params, voidTy, extInfo); - - builder.addParameter(makeConcrete(fnTy)); + BuiltinFunctionBuilder builder(ctx, throws ? 2 : 1); auto resultTy = makeGenericParam(); builder.addConformanceRequirement(resultTy, KnownProtocolKind::Escapable); builder.setResult(resultTy); - builder.setAsync(); - if (throws) + + SmallVector params; + if (throws) { + auto errorTy = makeGenericParam(1); + builder.addConformanceRequirement(errorTy, KnownProtocolKind::Error); builder.setThrows(); + builder.setThrownError(errorTy); + builder.addParameter(makeMetatype(errorTy)); + } builder.setSendingResult(); + // Add the closure parameter. + auto voidTy = ctx.TheEmptyTupleType; + auto contTy = ctx.TheRawUnsafeContinuationType; + params.emplace_back(contTy); + auto extInfo = FunctionType::ExtInfoBuilder().withNoEscape().build(); + auto *fnTy = FunctionType::get(params, voidTy, extInfo); + builder.addParameter(makeConcrete(fnTy)); + return builder.build(id); } diff --git a/lib/SILGen/SILGenBuiltin.cpp b/lib/SILGen/SILGenBuiltin.cpp index dc6b002d9474a..c2ba8928833a3 100644 --- a/lib/SILGen/SILGenBuiltin.cpp +++ b/lib/SILGen/SILGenBuiltin.cpp @@ -1848,10 +1848,11 @@ static ManagedValue emitBuiltinWithUnsafeContinuation( throws); // Get the callee value. - auto substFnType = args[0].getType().castTo(); + unsigned calleeIndex = throws ? 1 : 0; + auto substFnType = args[calleeIndex].getType().castTo(); SILValue fnValue = (substFnType->isCalleeConsumed() - ? args[0].forward(SGF) - : args[0].getValue()); + ? args[calleeIndex].forward(SGF) + : args[calleeIndex].getValue()); // Call the provided function value. SGF.B.createApply(loc, fnValue, {}, {continuation}); @@ -1871,7 +1872,13 @@ static ManagedValue emitBuiltinWithUnsafeContinuation( Scope errorScope(SGF, loc); - auto errorTy = SGF.getASTContext().getErrorExistentialType(); + CanType errorTy; + if (subs.getReplacementTypes().size() > 1) { + errorTy = subs.getReplacementTypes()[1]->getCanonicalType(); + } else { + errorTy = SGF.getASTContext().getErrorExistentialType(); + } + auto errorVal = SGF.B.createTermResult( SILType::getPrimitiveObjectType(errorTy), OwnershipKind::Owned); diff --git a/stdlib/public/Concurrency/CheckedContinuation.swift b/stdlib/public/Concurrency/CheckedContinuation.swift index a4f0f1553d050..558ea9edb2db0 100644 --- a/stdlib/public/Concurrency/CheckedContinuation.swift +++ b/stdlib/public/Concurrency/CheckedContinuation.swift @@ -357,14 +357,31 @@ public func _unsafeInheritExecutor_withCheckedContinuation( @inlinable @available(SwiftStdlib 5.1, *) #if !$Embedded -@backDeployed(before: SwiftStdlib 6.0) +@backDeployed(before: SwiftStdlib 6.2) #endif -public func withCheckedThrowingContinuation( +public func withCheckedThrowingContinuation( + isolation: isolated (any Actor)? = #isolation, + function: String = #function, + _ body: (CheckedContinuation) -> Void +) async throws(E) -> sending T { + return try await Builtin.withUnsafeThrowingContinuation(E.self) { + let unsafeContinuation = unsafe UnsafeContinuation($0) + return body(unsafe CheckedContinuation(continuation: unsafeContinuation, + function: function)) + } +} + +// Superseded by the typed-throws version of this function. This function +// is retained for ABI purposes. +@available(SwiftStdlib 5.1, *) +@usableFromInline +@_silgen_name("$ss31withCheckedThrowingContinuation9isolation8function_xScA_pSgYi_SSyScCyxs5Error_pGXEtYaKlF") +internal func __abi_withCheckedThrowingContinuation( isolation: isolated (any Actor)? = #isolation, function: String = #function, _ body: (CheckedContinuation) -> Void ) async throws -> sending T { - return try await Builtin.withUnsafeThrowingContinuation { + return try await Builtin.withUnsafeThrowingContinuation(Error.self) { let unsafeContinuation = unsafe UnsafeContinuation($0) return body(unsafe CheckedContinuation(continuation: unsafeContinuation, function: function)) diff --git a/stdlib/public/Concurrency/PartialAsyncTask.swift b/stdlib/public/Concurrency/PartialAsyncTask.swift index 977ebe82f63ef..d60d1f8e68cfd 100644 --- a/stdlib/public/Concurrency/PartialAsyncTask.swift +++ b/stdlib/public/Concurrency/PartialAsyncTask.swift @@ -721,12 +721,12 @@ public func withUnsafeContinuation( @available(SwiftStdlib 5.1, *) @_alwaysEmitIntoClient @unsafe -public func withUnsafeThrowingContinuation( +public func withUnsafeThrowingContinuation( isolation: isolated (any Actor)? = #isolation, - _ fn: (UnsafeContinuation) -> Void -) async throws -> sending T { - return try await Builtin.withUnsafeThrowingContinuation { - unsafe fn(UnsafeContinuation($0)) + _ fn: (UnsafeContinuation) -> Void +) async throws(E) -> sending T { + return try await Builtin.withUnsafeThrowingContinuation(E.self) { + unsafe fn(UnsafeContinuation($0)) } } @@ -750,11 +750,11 @@ public func _unsafeInheritExecutor_withUnsafeContinuation( @available(SwiftStdlib 5.1, *) @_alwaysEmitIntoClient @_unsafeInheritExecutor -public func _unsafeInheritExecutor_withUnsafeThrowingContinuation( - _ fn: (UnsafeContinuation) -> Void -) async throws -> sending T { - return try await Builtin.withUnsafeThrowingContinuation { - unsafe fn(UnsafeContinuation($0)) +public func _unsafeInheritExecutor_withUnsafeThrowingContinuation( + _ fn: (UnsafeContinuation) -> Void +) async throws(E) -> sending T { + return try await Builtin.withUnsafeThrowingContinuation(E.self) { + unsafe fn(UnsafeContinuation($0)) } } diff --git a/test/Concurrency/with_continuation_typed_throws.swift b/test/Concurrency/with_continuation_typed_throws.swift new file mode 100644 index 0000000000000..e6437d3fad37d --- /dev/null +++ b/test/Concurrency/with_continuation_typed_throws.swift @@ -0,0 +1,17 @@ +// RUN: %target-swift-frontend -target %target-swift-5.1-abi-triple %s -typecheck /dev/null -verify + +// REQUIRES: concurrency + +enum MyError: Error { +case exploded +} + +func testTypedThrowsContinuations() async throws(MyError) { + let _: Int = try await withUnsafeThrowingContinuation { (continuation: UnsafeContinuation) in + continuation.resume(throwing: .exploded) + } + + let _: Int = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + continuation.resume(throwing: .exploded) + } +}