Skip to content

Fixes for Local Lambda Server #486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 156 additions & 79 deletions Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//

#if DEBUG
import DequeModule
import Dispatch
import Logging
import NIOConcurrencyHelpers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this import now that we don't use CircularBuffer anymore ?

Expand Down Expand Up @@ -47,24 +48,15 @@ extension Lambda {
/// - note: This API is designed strictly for local testing and is behind a DEBUG flag
static func withLocalServer(
invocationEndpoint: String? = nil,
_ body: @escaping () async throws -> Void
_ body: sending @escaping () async throws -> Void
) async throws {
var logger = Logger(label: "LocalServer")
logger.logLevel = Lambda.env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info

// launch the local server and wait for it to be started before running the body
try await withThrowingTaskGroup(of: Void.self) { group in
// this call will return when the server calls continuation.resume()
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
group.addTask {
do {
try await LambdaHttpServer(invocationEndpoint: invocationEndpoint).start(
continuation: continuation
)
} catch {
continuation.resume(throwing: error)
}
}
}
// now that server is started, run the Lambda function itself
try await LambdaHTTPServer.withLocalServer(
invocationEndpoint: invocationEndpoint,
logger: logger
) {
try await body()
}
}
Expand All @@ -84,34 +76,46 @@ extension Lambda {
/// 1. POST /invoke - the client posts the event to the lambda function
///
/// This server passes the data received from /invoke POST request to the lambda function (GET /next) and then forwards the response back to the client.
private struct LambdaHttpServer {
private let logger: Logger
private let group: EventLoopGroup
private let host: String
private let port: Int
private struct LambdaHTTPServer {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we update the name of this struct ? This is not the HTTP Server anymore (all the NIO bootstrap and channel creation is done in the withLocalServer static function)
This struct now just provides the handleConnection() function

But I can't think about a descriptive name :-)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets keep it for now. We can always change it later. It's internal.

private let invocationEndpoint: String

private let invocationPool = Pool<LocalServerInvocation>()
private let responsePool = Pool<LocalServerResponse>()

init(invocationEndpoint: String?) {
var logger = Logger(label: "LocalServer")
logger.logLevel = Lambda.env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info
self.logger = logger
self.group = MultiThreadedEventLoopGroup.singleton
self.host = "127.0.0.1"
self.port = 7000
private init(
invocationEndpoint: String?
) {
self.invocationEndpoint = invocationEndpoint ?? "/invoke"
}

func start(continuation: CheckedContinuation<Void, any Error>) async throws {
let channel = try await ServerBootstrap(group: self.group)
private enum TaskResult<Result: Sendable>: Sendable {
case closureResult(Swift.Result<Result, any Error>)
case serverReturned(Swift.Result<Void, any Error>)
}

struct UnsafeTransferBox<Value>: @unchecked Sendable {
let value: Value

init(value: sending Value) {
self.value = value
}
}

static func withLocalServer<Result: Sendable>(
invocationEndpoint: String?,
host: String = "127.0.0.1",
port: Int = 7000,
eventLoopGroup: MultiThreadedEventLoopGroup = .singleton,
logger: Logger,
_ closure: sending @escaping () async throws -> Result
) async throws -> Result {
let channel = try await ServerBootstrap(group: eventLoopGroup)
.serverChannelOption(.backlog, value: 256)
.serverChannelOption(.socketOption(.so_reuseaddr), value: 1)
.childChannelOption(.maxMessagesPerRead, value: 1)
.bind(
host: self.host,
port: self.port
host: host,
port: port
) { channel in
channel.eventLoop.makeCompletedFuture {

Expand All @@ -129,8 +133,6 @@ private struct LambdaHttpServer {
}
}

// notify the caller that the server is started
continuation.resume()
logger.info(
"Server started and listening",
metadata: [
Expand All @@ -139,30 +141,77 @@ private struct LambdaHttpServer {
]
)

// We are handling each incoming connection in a separate child task. It is important
// to use a discarding task group here which automatically discards finished child tasks.
// A normal task group retains all child tasks and their outputs in memory until they are
// consumed by iterating the group or by exiting the group. Since, we are never consuming
// the results of the group we need the group to automatically discard them; otherwise, this
// would result in a memory leak over time.
try await withThrowingDiscardingTaskGroup { group in
try await channel.executeThenClose { inbound in
for try await connectionChannel in inbound {

group.addTask {
logger.trace("Handling a new connection")
await self.handleConnection(channel: connectionChannel)
logger.trace("Done handling the connection")
let server = LambdaHTTPServer(invocationEndpoint: invocationEndpoint)

// Sadly the Swift compiler does not understand that the passed in closure will only be
// invoked once. Because of this we need an unsafe transfer box here. Buuuh!
let closureBox = UnsafeTransferBox(value: closure)
let result = await withTaskGroup(of: TaskResult<Result>.self, returning: Swift.Result<Result, any Error>.self) { group in
group.addTask {
let c = closureBox.value
do {
let result = try await c()
return .closureResult(.success(result))
} catch {
return .closureResult(.failure(error))
}
}

group.addTask {
do {
// We are handling each incoming connection in a separate child task. It is important
// to use a discarding task group here which automatically discards finished child tasks.
// A normal task group retains all child tasks and their outputs in memory until they are
// consumed by iterating the group or by exiting the group. Since, we are never consuming
// the results of the group we need the group to automatically discard them; otherwise, this
// would result in a memory leak over time.
try await withThrowingDiscardingTaskGroup { taskGroup in
try await channel.executeThenClose { inbound in
for try await connectionChannel in inbound {

taskGroup.addTask {
logger.trace("Handling a new connection")
await server.handleConnection(channel: connectionChannel, logger: logger)
logger.trace("Done handling the connection")
}
}
}
}
return .serverReturned(.success(()))
} catch {
return .serverReturned(.failure(error))
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 2 lines of comments to explain the logic below wouldn't hurt readability :-)

// Now that the local HTTP server and LambdaHandler tasks are started, wait for the first of the two that will terminate.
// When the first task terminates, cancel the group and collect the result of the second task

let task1 = await group.next()!
group.cancelAll()
let task2 = await group.next()!

switch task1 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// collect and return the result of the LambdaHandler 

case .closureResult(let result):
return result

case .serverReturned:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we log.error() the fact that the server terminated before the Handler? It might be an error in the server implementation or an otherwise important information to show to the user

switch task2 {
case .closureResult(let result):
return result

case .serverReturned:
fatalError()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a message to the fatalError : "Only one task is a server, and only one can return serverReturned"

Also what about renaming task1 and task2 with serverOrHandlerResult1 and serverOrHandlerResult2

}
}
}

logger.info("Server shutting down")
return try result.get()
}



/// This method handles individual TCP connections
private func handleConnection(
channel: NIOAsyncChannel<HTTPServerRequestPart, HTTPServerResponsePart>
channel: NIOAsyncChannel<HTTPServerRequestPart, HTTPServerResponsePart>,
logger: Logger
) async {

var requestHead: HTTPRequestHead!
Expand All @@ -186,12 +235,14 @@ private struct LambdaHttpServer {
// process the request
let response = try await self.processRequest(
head: requestHead,
body: requestBody
body: requestBody,
logger: logger
)
// send the responses
try await self.sendResponse(
response: response,
outbound: outbound
outbound: outbound,
logger: logger
)

requestHead = nil
Expand All @@ -214,15 +265,15 @@ private struct LambdaHttpServer {
/// - body: the HTTP request body
/// - Throws:
/// - Returns: the response to send back to the client or the Lambda function
private func processRequest(head: HTTPRequestHead, body: ByteBuffer?) async throws -> LocalServerResponse {
private func processRequest(head: HTTPRequestHead, body: ByteBuffer?, logger: Logger) async throws -> LocalServerResponse {

if let body {
self.logger.trace(
logger.trace(
"Processing request",
metadata: ["URI": "\(head.method) \(head.uri)", "Body": "\(String(buffer: body))"]
)
} else {
self.logger.trace("Processing request", metadata: ["URI": "\(head.method) \(head.uri)"])
logger.trace("Processing request", metadata: ["URI": "\(head.method) \(head.uri)"])
}

switch (head.method, head.uri) {
Expand All @@ -237,7 +288,9 @@ private struct LambdaHttpServer {
}
// we always accept the /invoke request and push them to the pool
let requestId = "\(DispatchTime.now().uptimeNanoseconds)"
logger.trace("/invoke received invocation", metadata: ["requestId": "\(requestId)"])
var logger = logger
logger[metadataKey: "requestID"] = "\(requestId)"
logger.trace("/invoke received invocation")
await self.invocationPool.push(LocalServerInvocation(requestId: requestId, request: body))

// wait for the lambda function to process the request
Expand Down Expand Up @@ -273,9 +326,9 @@ private struct LambdaHttpServer {
case (.GET, let url) where url.hasSuffix(Consts.getNextInvocationURLSuffix):

// pop the tasks from the queue
self.logger.trace("/next waiting for /invoke")
logger.trace("/next waiting for /invoke")
for try await invocation in self.invocationPool {
self.logger.trace("/next retrieved invocation", metadata: ["requestId": "\(invocation.requestId)"])
logger.trace("/next retrieved invocation", metadata: ["requestId": "\(invocation.requestId)"])
// this call also stores the invocation requestId into the response
return invocation.makeResponse(status: .accepted)
}
Expand Down Expand Up @@ -322,12 +375,13 @@ private struct LambdaHttpServer {

private func sendResponse(
response: LocalServerResponse,
outbound: NIOAsyncChannelOutboundWriter<HTTPServerResponsePart>
outbound: NIOAsyncChannelOutboundWriter<HTTPServerResponsePart>,
logger: Logger
) async throws {
var headers = HTTPHeaders(response.headers ?? [])
headers.add(name: "Content-Length", value: "\(response.body?.readableBytes ?? 0)")

self.logger.trace("Writing response", metadata: ["requestId": "\(response.requestId ?? "")"])
logger.trace("Writing response", metadata: ["requestId": "\(response.requestId ?? "")"])
try await outbound.write(
HTTPServerResponsePart.head(
HTTPResponseHead(
Expand All @@ -350,44 +404,67 @@ private struct LambdaHttpServer {
private final class Pool<T>: AsyncSequence, AsyncIteratorProtocol, Sendable where T: Sendable {
typealias Element = T

private let _buffer = Mutex<CircularBuffer<T>>(.init())
private let _continuation = Mutex<CheckedContinuation<T, any Error>?>(nil)
struct State {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need a struct here ? Can't we simplify and just define the enum ?

Mutex's T must be Copyable, therefore passing an enum is accepted.

        enum Test {
            case A
            case B
        }

        private let m = Mutex<Test>(.A)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reduced it down to an enum but marked the enum as ~Copyable to ensure we don't allocate when we hold the lock.

enum State {
case buffer(Deque<T>)
case continuation(CheckedContinuation<T, any Error>?)
}

/// retrieve the first element from the buffer
public func popFirst() async -> T? {
self._buffer.withLock { $0.popFirst() }
var state: State

init() {
self.state = .buffer([])
}
}

private let lock = Mutex<State>(.init())

/// enqueue an element, or give it back immediately to the iterator if it is waiting for an element
public func push(_ invocation: T) async {
// if the iterator is waiting for an element, give it to it
// otherwise, enqueue the element
if let continuation = self._continuation.withLock({ $0 }) {
self._continuation.withLock { $0 = nil }
continuation.resume(returning: invocation)
} else {
self._buffer.withLock { $0.append(invocation) }
let maybeContinuation = self.lock.withLock { state -> CheckedContinuation<T, any Error>? in
switch state.state {
case .continuation(let continuation):
state.state = .buffer([])
return continuation

case .buffer(var buffer):
buffer.append(invocation)
state.state = .buffer(buffer)
return nil
}
}

maybeContinuation?.resume(returning: invocation)
}

func next() async throws -> T? {

// exit the async for loop if the task is cancelled
guard !Task.isCancelled else {
return nil
}

if let element = await self.popFirst() {
return element
} else {
// we can't return nil if there is nothing to dequeue otherwise the async for loop will stop
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep those two comments. They explain why this is not a regular async iterator. It blocks when the buffer is empty

// wait for an element to be enqueued
return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<T, any Error>) in
// store the continuation for later, when an element is enqueued
self._continuation.withLock {
$0 = continuation
return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<T, any Error>) in
let nextAction = self.lock.withLock { state -> T? in
switch state.state {
case .buffer(var buffer):
if let first = buffer.popFirst() {
state.state = .buffer(buffer)
return first
} else {
state.state = .continuation(continuation)
return nil
}

case .continuation:
fatalError("Concurrent invocations to next(). This is illigal.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo "illegal"

}
}

guard let nextAction else { return }

continuation.resume(returning: nextAction)
}
}

Expand Down
3 changes: 1 addition & 2 deletions Sources/AWSLambdaRuntimeCore/LambdaRuntime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ public final class LambdaRuntime<Handler>: @unchecked Sendable where Handler: St
#if DEBUG
// we're not running on Lambda and we're compiled in DEBUG mode,
// let's start a local server for testing
try await Lambda.withLocalServer(invocationEndpoint: Lambda.env("LOCAL_LAMBDA_SERVER_INVOCATION_ENDPOINT"))
{
try await Lambda.withLocalServer(invocationEndpoint: Lambda.env("LOCAL_LAMBDA_SERVER_INVOCATION_ENDPOINT")) {

try await LambdaRuntimeClient.withRuntimeClient(
configuration: .init(ip: "127.0.0.1", port: 7000),
Expand Down
Loading