diff --git a/Sources/AWSLambdaRuntimeCore/Lambda.swift b/Sources/AWSLambdaRuntimeCore/Lambda.swift index 3ba90e9c..4634fca0 100644 --- a/Sources/AWSLambdaRuntimeCore/Lambda.swift +++ b/Sources/AWSLambdaRuntimeCore/Lambda.swift @@ -37,25 +37,31 @@ public enum Lambda { ) async throws where Handler: StreamingLambdaHandler { var handler = handler - while !Task.isCancelled { - let (invocation, writer) = try await runtimeClient.nextInvocation() + do { + while !Task.isCancelled { + let (invocation, writer) = try await runtimeClient.nextInvocation() - do { - try await handler.handle( - invocation.event, - responseWriter: writer, - context: LambdaContext( - requestID: invocation.metadata.requestID, - traceID: invocation.metadata.traceID, - invokedFunctionARN: invocation.metadata.invokedFunctionARN, - deadline: DispatchWallTime(millisSinceEpoch: invocation.metadata.deadlineInMillisSinceEpoch), - logger: logger + do { + try await handler.handle( + invocation.event, + responseWriter: writer, + context: LambdaContext( + requestID: invocation.metadata.requestID, + traceID: invocation.metadata.traceID, + invokedFunctionARN: invocation.metadata.invokedFunctionARN, + deadline: DispatchWallTime( + millisSinceEpoch: invocation.metadata.deadlineInMillisSinceEpoch + ), + logger: logger + ) ) - ) - } catch { - try await writer.reportError(error) - continue + } catch { + try await writer.reportError(error) + continue + } } + } catch is CancellationError { + // don't allow cancellation error to propagate further } } diff --git a/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift b/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift index 228dc471..809a6a0e 100644 --- a/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift +++ b/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift @@ -145,22 +145,28 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol { } func nextInvocation() async throws -> (Invocation, Writer) { - switch self.lambdaState { - case .idle: - self.lambdaState = .waitingForNextInvocation - let handler = try await self.makeOrGetConnection() - let invocation = try await handler.nextInvocation() - guard case .waitingForNextInvocation = self.lambdaState else { + try await withTaskCancellationHandler { + switch self.lambdaState { + case .idle: + self.lambdaState = .waitingForNextInvocation + let handler = try await self.makeOrGetConnection() + let invocation = try await handler.nextInvocation() + guard case .waitingForNextInvocation = self.lambdaState else { + fatalError("Invalid state: \(self.lambdaState)") + } + self.lambdaState = .waitingForResponse(requestID: invocation.metadata.requestID) + return (invocation, Writer(runtimeClient: self)) + + case .waitingForNextInvocation, + .waitingForResponse, + .sendingResponse, + .sentResponse: fatalError("Invalid state: \(self.lambdaState)") } - self.lambdaState = .waitingForResponse(requestID: invocation.metadata.requestID) - return (invocation, Writer(runtimeClient: self)) - - case .waitingForNextInvocation, - .waitingForResponse, - .sendingResponse, - .sentResponse: - fatalError("Invalid state: \(self.lambdaState)") + } onCancel: { + Task { + await self.close() + } } } @@ -819,6 +825,12 @@ extension LambdaChannelHandler: ChannelInboundHandler { func channelInactive(context: ChannelHandlerContext) { // fail any pending responses with last error or assume peer disconnected + switch self.state { + case .connected(_, .waitingForNextInvocation(let continuation)): + continuation.resume(throwing: self.lastError ?? ChannelError.ioOnClosedChannel) + default: + break + } // we don't need to forward channelInactive to the delegate, as the delegate observes the // closeFuture diff --git a/Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift b/Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift index e779b931..5d430a0f 100644 --- a/Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift +++ b/Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift @@ -86,4 +86,56 @@ struct LambdaRuntimeClientTests { } } } + + @Test + func testCancellation() async throws { + struct HappyBehavior: LambdaServerBehavior { + let requestId = UUID().uuidString + let event = "hello" + + func getInvocation() -> GetInvocationResult { + .success((self.requestId, self.event)) + } + + func processResponse(requestId: String, response: String?) -> Result { + #expect(self.requestId == requestId) + #expect(self.event == response) + return .success(()) + } + + func processError(requestId: String, error: ErrorResponse) -> Result { + Issue.record("should not report error") + return .failure(.internalServerError) + } + + func processInitError(error: ErrorResponse) -> Result { + Issue.record("should not report init error") + return .failure(.internalServerError) + } + } + + try await withMockServer(behaviour: HappyBehavior()) { port in + try await LambdaRuntimeClient.withRuntimeClient( + configuration: .init(ip: "127.0.0.1", port: port), + eventLoop: NIOSingletons.posixEventLoopGroup.next(), + logger: self.logger + ) { runtimeClient in + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + while true { + let (_, writer) = try await runtimeClient.nextInvocation() + // Wrap this is a task so cancellation isn't propagated to the write calls + try await Task { + try await writer.write(ByteBuffer(string: "hello")) + try await writer.finish() + }.value + } + } + // wait a small amount to ensure we are waiting for continuation + try await Task.sleep(for: .milliseconds(100)) + group.cancelAll() + } + } + } + } }