diff --git a/Package.swift b/Package.swift index bcc45444..9eac9771 100644 --- a/Package.swift +++ b/Package.swift @@ -19,6 +19,7 @@ let package = Package( dependencies: [ .package(url: "https://github.com/apple/swift-nio.git", from: "2.81.0"), .package(url: "https://github.com/apple/swift-log.git", from: "1.5.4"), + .package(url: "https://github.com/apple/swift-collections.git", from: "1.1.4"), ], targets: [ .target( @@ -31,10 +32,10 @@ let package = Package( .target( name: "AWSLambdaRuntimeCore", dependencies: [ + .product(name: "DequeModule", package: "swift-collections"), .product(name: "Logging", package: "swift-log"), .product(name: "NIOHTTP1", package: "swift-nio"), .product(name: "NIOCore", package: "swift-nio"), - .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), .product(name: "NIOPosix", package: "swift-nio"), ] ), diff --git a/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift b/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift index 64a9acb7..4bb113a7 100644 --- a/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift +++ b/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift @@ -13,9 +13,9 @@ //===----------------------------------------------------------------------===// #if DEBUG +import DequeModule import Dispatch import Logging -import NIOConcurrencyHelpers import NIOCore import NIOHTTP1 import NIOPosix @@ -47,24 +47,15 @@ extension Lambda { /// - note: This API is designed strictly for local testing and is behind a DEBUG flag static func withLocalServer( invocationEndpoint: String? = nil, - _ body: @escaping () async throws -> Void + _ body: sending @escaping () async throws -> Void ) async throws { + var logger = Logger(label: "LocalServer") + logger.logLevel = Lambda.env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info - // launch the local server and wait for it to be started before running the body - try await withThrowingTaskGroup(of: Void.self) { group in - // this call will return when the server calls continuation.resume() - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - group.addTask { - do { - try await LambdaHttpServer(invocationEndpoint: invocationEndpoint).start( - continuation: continuation - ) - } catch { - continuation.resume(throwing: error) - } - } - } - // now that server is started, run the Lambda function itself + try await LambdaHTTPServer.withLocalServer( + invocationEndpoint: invocationEndpoint, + logger: logger + ) { try await body() } } @@ -84,34 +75,46 @@ extension Lambda { /// 1. POST /invoke - the client posts the event to the lambda function /// /// This server passes the data received from /invoke POST request to the lambda function (GET /next) and then forwards the response back to the client. -private struct LambdaHttpServer { - private let logger: Logger - private let group: EventLoopGroup - private let host: String - private let port: Int +private struct LambdaHTTPServer { private let invocationEndpoint: String private let invocationPool = Pool() private let responsePool = Pool() - init(invocationEndpoint: String?) { - var logger = Logger(label: "LocalServer") - logger.logLevel = Lambda.env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info - self.logger = logger - self.group = MultiThreadedEventLoopGroup.singleton - self.host = "127.0.0.1" - self.port = 7000 + private init( + invocationEndpoint: String? + ) { self.invocationEndpoint = invocationEndpoint ?? "/invoke" } - func start(continuation: CheckedContinuation) async throws { - let channel = try await ServerBootstrap(group: self.group) + private enum TaskResult: Sendable { + case closureResult(Swift.Result) + case serverReturned(Swift.Result) + } + + struct UnsafeTransferBox: @unchecked Sendable { + let value: Value + + init(value: sending Value) { + self.value = value + } + } + + static func withLocalServer( + invocationEndpoint: String?, + host: String = "127.0.0.1", + port: Int = 7000, + eventLoopGroup: MultiThreadedEventLoopGroup = .singleton, + logger: Logger, + _ closure: sending @escaping () async throws -> Result + ) async throws -> Result { + let channel = try await ServerBootstrap(group: eventLoopGroup) .serverChannelOption(.backlog, value: 256) .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) .childChannelOption(.maxMessagesPerRead, value: 1) .bind( - host: self.host, - port: self.port + host: host, + port: port ) { channel in channel.eventLoop.makeCompletedFuture { @@ -129,8 +132,6 @@ private struct LambdaHttpServer { } } - // notify the caller that the server is started - continuation.resume() logger.info( "Server started and listening", metadata: [ @@ -139,30 +140,87 @@ private struct LambdaHttpServer { ] ) - // We are handling each incoming connection in a separate child task. It is important - // to use a discarding task group here which automatically discards finished child tasks. - // A normal task group retains all child tasks and their outputs in memory until they are - // consumed by iterating the group or by exiting the group. Since, we are never consuming - // the results of the group we need the group to automatically discard them; otherwise, this - // would result in a memory leak over time. - try await withThrowingDiscardingTaskGroup { group in - try await channel.executeThenClose { inbound in - for try await connectionChannel in inbound { - - group.addTask { - logger.trace("Handling a new connection") - await self.handleConnection(channel: connectionChannel) - logger.trace("Done handling the connection") + let server = LambdaHTTPServer(invocationEndpoint: invocationEndpoint) + + // Sadly the Swift compiler does not understand that the passed in closure will only be + // invoked once. Because of this we need an unsafe transfer box here. Buuuh! + let closureBox = UnsafeTransferBox(value: closure) + let result = await withTaskGroup(of: TaskResult.self, returning: Swift.Result.self) { + group in + group.addTask { + let c = closureBox.value + do { + let result = try await c() + return .closureResult(.success(result)) + } catch { + return .closureResult(.failure(error)) + } + } + + group.addTask { + do { + // We are handling each incoming connection in a separate child task. It is important + // to use a discarding task group here which automatically discards finished child tasks. + // A normal task group retains all child tasks and their outputs in memory until they are + // consumed by iterating the group or by exiting the group. Since, we are never consuming + // the results of the group we need the group to automatically discard them; otherwise, this + // would result in a memory leak over time. + try await withThrowingDiscardingTaskGroup { taskGroup in + try await channel.executeThenClose { inbound in + for try await connectionChannel in inbound { + + taskGroup.addTask { + logger.trace("Handling a new connection") + await server.handleConnection(channel: connectionChannel, logger: logger) + logger.trace("Done handling the connection") + } + } + } } + return .serverReturned(.success(())) + } catch { + return .serverReturned(.failure(error)) + } + } + + // Now that the local HTTP server and LambdaHandler tasks are started, wait for the + // first of the two that will terminate. + // When the first task terminates, cancel the group and collect the result of the + // second task. + + // collect and return the result of the LambdaHandler + let serverOrHandlerResult1 = await group.next()! + group.cancelAll() + + switch serverOrHandlerResult1 { + case .closureResult(let result): + return result + + case .serverReturned(let result): + logger.error( + "Server shutdown before closure completed", + metadata: [ + "error": "\(result.maybeError != nil ? "\(result.maybeError!)" : "none")" + ] + ) + switch await group.next()! { + case .closureResult(let result): + return result + + case .serverReturned: + fatalError("Only one task is a server, and only one can return `serverReturned`") } } } + logger.info("Server shutting down") + return try result.get() } /// This method handles individual TCP connections private func handleConnection( - channel: NIOAsyncChannel + channel: NIOAsyncChannel, + logger: Logger ) async { var requestHead: HTTPRequestHead! @@ -186,12 +244,14 @@ private struct LambdaHttpServer { // process the request let response = try await self.processRequest( head: requestHead, - body: requestBody + body: requestBody, + logger: logger ) // send the responses try await self.sendResponse( response: response, - outbound: outbound + outbound: outbound, + logger: logger ) requestHead = nil @@ -214,15 +274,19 @@ private struct LambdaHttpServer { /// - body: the HTTP request body /// - Throws: /// - Returns: the response to send back to the client or the Lambda function - private func processRequest(head: HTTPRequestHead, body: ByteBuffer?) async throws -> LocalServerResponse { + private func processRequest( + head: HTTPRequestHead, + body: ByteBuffer?, + logger: Logger + ) async throws -> LocalServerResponse { if let body { - self.logger.trace( + logger.trace( "Processing request", metadata: ["URI": "\(head.method) \(head.uri)", "Body": "\(String(buffer: body))"] ) } else { - self.logger.trace("Processing request", metadata: ["URI": "\(head.method) \(head.uri)"]) + logger.trace("Processing request", metadata: ["URI": "\(head.method) \(head.uri)"]) } switch (head.method, head.uri) { @@ -237,7 +301,9 @@ private struct LambdaHttpServer { } // we always accept the /invoke request and push them to the pool let requestId = "\(DispatchTime.now().uptimeNanoseconds)" - logger.trace("/invoke received invocation", metadata: ["requestId": "\(requestId)"]) + var logger = logger + logger[metadataKey: "requestID"] = "\(requestId)" + logger.trace("/invoke received invocation") await self.invocationPool.push(LocalServerInvocation(requestId: requestId, request: body)) // wait for the lambda function to process the request @@ -273,9 +339,9 @@ private struct LambdaHttpServer { case (.GET, let url) where url.hasSuffix(Consts.getNextInvocationURLSuffix): // pop the tasks from the queue - self.logger.trace("/next waiting for /invoke") + logger.trace("/next waiting for /invoke") for try await invocation in self.invocationPool { - self.logger.trace("/next retrieved invocation", metadata: ["requestId": "\(invocation.requestId)"]) + logger.trace("/next retrieved invocation", metadata: ["requestId": "\(invocation.requestId)"]) // this call also stores the invocation requestId into the response return invocation.makeResponse(status: .accepted) } @@ -322,12 +388,13 @@ private struct LambdaHttpServer { private func sendResponse( response: LocalServerResponse, - outbound: NIOAsyncChannelOutboundWriter + outbound: NIOAsyncChannelOutboundWriter, + logger: Logger ) async throws { var headers = HTTPHeaders(response.headers ?? []) headers.add(name: "Content-Length", value: "\(response.body?.readableBytes ?? 0)") - self.logger.trace("Writing response", metadata: ["requestId": "\(response.requestId ?? "")"]) + logger.trace("Writing response", metadata: ["requestId": "\(response.requestId ?? "")"]) try await outbound.write( HTTPServerResponsePart.head( HTTPResponseHead( @@ -350,44 +417,59 @@ private struct LambdaHttpServer { private final class Pool: AsyncSequence, AsyncIteratorProtocol, Sendable where T: Sendable { typealias Element = T - private let _buffer = Mutex>(.init()) - private let _continuation = Mutex?>(nil) - - /// retrieve the first element from the buffer - public func popFirst() async -> T? { - self._buffer.withLock { $0.popFirst() } + enum State: ~Copyable { + case buffer(Deque) + case continuation(CheckedContinuation?) } + private let lock = Mutex(.buffer([])) + /// enqueue an element, or give it back immediately to the iterator if it is waiting for an element public func push(_ invocation: T) async { // if the iterator is waiting for an element, give it to it // otherwise, enqueue the element - if let continuation = self._continuation.withLock({ $0 }) { - self._continuation.withLock { $0 = nil } - continuation.resume(returning: invocation) - } else { - self._buffer.withLock { $0.append(invocation) } + let maybeContinuation = self.lock.withLock { state -> CheckedContinuation? in + switch consume state { + case .continuation(let continuation): + state = .buffer([]) + return continuation + + case .buffer(var buffer): + buffer.append(invocation) + state = .buffer(buffer) + return nil + } } + + maybeContinuation?.resume(returning: invocation) } func next() async throws -> T? { - // exit the async for loop if the task is cancelled guard !Task.isCancelled else { return nil } - if let element = await self.popFirst() { - return element - } else { - // we can't return nil if there is nothing to dequeue otherwise the async for loop will stop - // wait for an element to be enqueued - return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - // store the continuation for later, when an element is enqueued - self._continuation.withLock { - $0 = continuation + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let nextAction = self.lock.withLock { state -> T? in + switch consume state { + case .buffer(var buffer): + if let first = buffer.popFirst() { + state = .buffer(buffer) + return first + } else { + state = .continuation(continuation) + return nil + } + + case .continuation: + fatalError("Concurrent invocations to next(). This is illegal.") } } + + guard let nextAction else { return } + + continuation.resume(returning: nextAction) } } @@ -432,3 +514,14 @@ private struct LambdaHttpServer { } } #endif + +extension Result { + var maybeError: Failure? { + switch self { + case .success: + return nil + case .failure(let error): + return error + } + } +}