-
Notifications
You must be signed in to change notification settings - Fork 113
Fixes for Local Lambda Server #486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
//===----------------------------------------------------------------------===// | ||
|
||
#if DEBUG | ||
import DequeModule | ||
import Dispatch | ||
import Logging | ||
import NIOConcurrencyHelpers | ||
|
@@ -47,24 +48,15 @@ extension Lambda { | |
/// - note: This API is designed strictly for local testing and is behind a DEBUG flag | ||
static func withLocalServer( | ||
invocationEndpoint: String? = nil, | ||
_ body: @escaping () async throws -> Void | ||
_ body: sending @escaping () async throws -> Void | ||
) async throws { | ||
var logger = Logger(label: "LocalServer") | ||
logger.logLevel = Lambda.env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info | ||
|
||
// launch the local server and wait for it to be started before running the body | ||
try await withThrowingTaskGroup(of: Void.self) { group in | ||
// this call will return when the server calls continuation.resume() | ||
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in | ||
group.addTask { | ||
do { | ||
try await LambdaHttpServer(invocationEndpoint: invocationEndpoint).start( | ||
continuation: continuation | ||
) | ||
} catch { | ||
continuation.resume(throwing: error) | ||
} | ||
} | ||
} | ||
// now that server is started, run the Lambda function itself | ||
try await LambdaHTTPServer.withLocalServer( | ||
invocationEndpoint: invocationEndpoint, | ||
logger: logger | ||
) { | ||
try await body() | ||
} | ||
} | ||
|
@@ -84,34 +76,46 @@ extension Lambda { | |
/// 1. POST /invoke - the client posts the event to the lambda function | ||
/// | ||
/// This server passes the data received from /invoke POST request to the lambda function (GET /next) and then forwards the response back to the client. | ||
private struct LambdaHttpServer { | ||
private let logger: Logger | ||
private let group: EventLoopGroup | ||
private let host: String | ||
private let port: Int | ||
private struct LambdaHTTPServer { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we update the name of this struct ? This is not the HTTP Server anymore (all the NIO bootstrap and channel creation is done in the But I can't think about a descriptive name :-) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lets keep it for now. We can always change it later. It's internal. |
||
private let invocationEndpoint: String | ||
|
||
private let invocationPool = Pool<LocalServerInvocation>() | ||
private let responsePool = Pool<LocalServerResponse>() | ||
|
||
init(invocationEndpoint: String?) { | ||
var logger = Logger(label: "LocalServer") | ||
logger.logLevel = Lambda.env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info | ||
self.logger = logger | ||
self.group = MultiThreadedEventLoopGroup.singleton | ||
self.host = "127.0.0.1" | ||
self.port = 7000 | ||
private init( | ||
invocationEndpoint: String? | ||
) { | ||
self.invocationEndpoint = invocationEndpoint ?? "/invoke" | ||
} | ||
|
||
func start(continuation: CheckedContinuation<Void, any Error>) async throws { | ||
let channel = try await ServerBootstrap(group: self.group) | ||
private enum TaskResult<Result: Sendable>: Sendable { | ||
case closureResult(Swift.Result<Result, any Error>) | ||
case serverReturned(Swift.Result<Void, any Error>) | ||
} | ||
|
||
struct UnsafeTransferBox<Value>: @unchecked Sendable { | ||
let value: Value | ||
|
||
init(value: sending Value) { | ||
self.value = value | ||
} | ||
} | ||
|
||
static func withLocalServer<Result: Sendable>( | ||
invocationEndpoint: String?, | ||
host: String = "127.0.0.1", | ||
port: Int = 7000, | ||
eventLoopGroup: MultiThreadedEventLoopGroup = .singleton, | ||
logger: Logger, | ||
_ closure: sending @escaping () async throws -> Result | ||
) async throws -> Result { | ||
let channel = try await ServerBootstrap(group: eventLoopGroup) | ||
.serverChannelOption(.backlog, value: 256) | ||
.serverChannelOption(.socketOption(.so_reuseaddr), value: 1) | ||
.childChannelOption(.maxMessagesPerRead, value: 1) | ||
.bind( | ||
host: self.host, | ||
port: self.port | ||
host: host, | ||
port: port | ||
) { channel in | ||
channel.eventLoop.makeCompletedFuture { | ||
|
||
|
@@ -129,8 +133,6 @@ private struct LambdaHttpServer { | |
} | ||
} | ||
|
||
// notify the caller that the server is started | ||
continuation.resume() | ||
logger.info( | ||
"Server started and listening", | ||
metadata: [ | ||
|
@@ -139,30 +141,77 @@ private struct LambdaHttpServer { | |
] | ||
) | ||
|
||
// We are handling each incoming connection in a separate child task. It is important | ||
// to use a discarding task group here which automatically discards finished child tasks. | ||
// A normal task group retains all child tasks and their outputs in memory until they are | ||
// consumed by iterating the group or by exiting the group. Since, we are never consuming | ||
// the results of the group we need the group to automatically discard them; otherwise, this | ||
// would result in a memory leak over time. | ||
try await withThrowingDiscardingTaskGroup { group in | ||
try await channel.executeThenClose { inbound in | ||
for try await connectionChannel in inbound { | ||
|
||
group.addTask { | ||
logger.trace("Handling a new connection") | ||
await self.handleConnection(channel: connectionChannel) | ||
logger.trace("Done handling the connection") | ||
let server = LambdaHTTPServer(invocationEndpoint: invocationEndpoint) | ||
|
||
// Sadly the Swift compiler does not understand that the passed in closure will only be | ||
// invoked once. Because of this we need an unsafe transfer box here. Buuuh! | ||
let closureBox = UnsafeTransferBox(value: closure) | ||
let result = await withTaskGroup(of: TaskResult<Result>.self, returning: Swift.Result<Result, any Error>.self) { group in | ||
group.addTask { | ||
let c = closureBox.value | ||
do { | ||
let result = try await c() | ||
return .closureResult(.success(result)) | ||
} catch { | ||
return .closureResult(.failure(error)) | ||
} | ||
} | ||
|
||
group.addTask { | ||
do { | ||
// We are handling each incoming connection in a separate child task. It is important | ||
// to use a discarding task group here which automatically discards finished child tasks. | ||
// A normal task group retains all child tasks and their outputs in memory until they are | ||
// consumed by iterating the group or by exiting the group. Since, we are never consuming | ||
// the results of the group we need the group to automatically discard them; otherwise, this | ||
// would result in a memory leak over time. | ||
try await withThrowingDiscardingTaskGroup { taskGroup in | ||
try await channel.executeThenClose { inbound in | ||
for try await connectionChannel in inbound { | ||
|
||
taskGroup.addTask { | ||
logger.trace("Handling a new connection") | ||
await server.handleConnection(channel: connectionChannel, logger: logger) | ||
logger.trace("Done handling the connection") | ||
} | ||
} | ||
} | ||
} | ||
return .serverReturned(.success(())) | ||
} catch { | ||
return .serverReturned(.failure(error)) | ||
} | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think 2 lines of comments to explain the logic below wouldn't hurt readability :-)
|
||
let task1 = await group.next()! | ||
group.cancelAll() | ||
let task2 = await group.next()! | ||
|
||
switch task1 { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
case .closureResult(let result): | ||
return result | ||
|
||
case .serverReturned: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we |
||
switch task2 { | ||
case .closureResult(let result): | ||
return result | ||
|
||
case .serverReturned: | ||
fatalError() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add a message to the Also what about renaming |
||
} | ||
} | ||
} | ||
|
||
logger.info("Server shutting down") | ||
return try result.get() | ||
} | ||
|
||
|
||
|
||
/// This method handles individual TCP connections | ||
private func handleConnection( | ||
channel: NIOAsyncChannel<HTTPServerRequestPart, HTTPServerResponsePart> | ||
channel: NIOAsyncChannel<HTTPServerRequestPart, HTTPServerResponsePart>, | ||
logger: Logger | ||
) async { | ||
|
||
var requestHead: HTTPRequestHead! | ||
|
@@ -186,12 +235,14 @@ private struct LambdaHttpServer { | |
// process the request | ||
let response = try await self.processRequest( | ||
head: requestHead, | ||
body: requestBody | ||
body: requestBody, | ||
logger: logger | ||
) | ||
// send the responses | ||
try await self.sendResponse( | ||
response: response, | ||
outbound: outbound | ||
outbound: outbound, | ||
logger: logger | ||
) | ||
|
||
requestHead = nil | ||
|
@@ -214,15 +265,15 @@ private struct LambdaHttpServer { | |
/// - body: the HTTP request body | ||
/// - Throws: | ||
/// - Returns: the response to send back to the client or the Lambda function | ||
private func processRequest(head: HTTPRequestHead, body: ByteBuffer?) async throws -> LocalServerResponse { | ||
private func processRequest(head: HTTPRequestHead, body: ByteBuffer?, logger: Logger) async throws -> LocalServerResponse { | ||
|
||
if let body { | ||
self.logger.trace( | ||
logger.trace( | ||
"Processing request", | ||
metadata: ["URI": "\(head.method) \(head.uri)", "Body": "\(String(buffer: body))"] | ||
) | ||
} else { | ||
self.logger.trace("Processing request", metadata: ["URI": "\(head.method) \(head.uri)"]) | ||
logger.trace("Processing request", metadata: ["URI": "\(head.method) \(head.uri)"]) | ||
} | ||
|
||
switch (head.method, head.uri) { | ||
|
@@ -237,7 +288,9 @@ private struct LambdaHttpServer { | |
} | ||
// we always accept the /invoke request and push them to the pool | ||
let requestId = "\(DispatchTime.now().uptimeNanoseconds)" | ||
logger.trace("/invoke received invocation", metadata: ["requestId": "\(requestId)"]) | ||
var logger = logger | ||
logger[metadataKey: "requestID"] = "\(requestId)" | ||
logger.trace("/invoke received invocation") | ||
await self.invocationPool.push(LocalServerInvocation(requestId: requestId, request: body)) | ||
|
||
// wait for the lambda function to process the request | ||
|
@@ -273,9 +326,9 @@ private struct LambdaHttpServer { | |
case (.GET, let url) where url.hasSuffix(Consts.getNextInvocationURLSuffix): | ||
|
||
// pop the tasks from the queue | ||
self.logger.trace("/next waiting for /invoke") | ||
logger.trace("/next waiting for /invoke") | ||
for try await invocation in self.invocationPool { | ||
self.logger.trace("/next retrieved invocation", metadata: ["requestId": "\(invocation.requestId)"]) | ||
logger.trace("/next retrieved invocation", metadata: ["requestId": "\(invocation.requestId)"]) | ||
// this call also stores the invocation requestId into the response | ||
return invocation.makeResponse(status: .accepted) | ||
} | ||
|
@@ -322,12 +375,13 @@ private struct LambdaHttpServer { | |
|
||
private func sendResponse( | ||
response: LocalServerResponse, | ||
outbound: NIOAsyncChannelOutboundWriter<HTTPServerResponsePart> | ||
outbound: NIOAsyncChannelOutboundWriter<HTTPServerResponsePart>, | ||
logger: Logger | ||
) async throws { | ||
var headers = HTTPHeaders(response.headers ?? []) | ||
headers.add(name: "Content-Length", value: "\(response.body?.readableBytes ?? 0)") | ||
|
||
self.logger.trace("Writing response", metadata: ["requestId": "\(response.requestId ?? "")"]) | ||
logger.trace("Writing response", metadata: ["requestId": "\(response.requestId ?? "")"]) | ||
try await outbound.write( | ||
HTTPServerResponsePart.head( | ||
HTTPResponseHead( | ||
|
@@ -350,44 +404,67 @@ private struct LambdaHttpServer { | |
private final class Pool<T>: AsyncSequence, AsyncIteratorProtocol, Sendable where T: Sendable { | ||
typealias Element = T | ||
|
||
private let _buffer = Mutex<CircularBuffer<T>>(.init()) | ||
private let _continuation = Mutex<CheckedContinuation<T, any Error>?>(nil) | ||
struct State { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we really need a struct here ? Can't we simplify and just define the Mutex's T must be
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reduced it down to an |
||
enum State { | ||
case buffer(Deque<T>) | ||
case continuation(CheckedContinuation<T, any Error>?) | ||
} | ||
|
||
/// retrieve the first element from the buffer | ||
public func popFirst() async -> T? { | ||
self._buffer.withLock { $0.popFirst() } | ||
var state: State | ||
|
||
init() { | ||
self.state = .buffer([]) | ||
} | ||
} | ||
|
||
private let lock = Mutex<State>(.init()) | ||
|
||
/// enqueue an element, or give it back immediately to the iterator if it is waiting for an element | ||
public func push(_ invocation: T) async { | ||
// if the iterator is waiting for an element, give it to it | ||
// otherwise, enqueue the element | ||
if let continuation = self._continuation.withLock({ $0 }) { | ||
self._continuation.withLock { $0 = nil } | ||
continuation.resume(returning: invocation) | ||
} else { | ||
self._buffer.withLock { $0.append(invocation) } | ||
let maybeContinuation = self.lock.withLock { state -> CheckedContinuation<T, any Error>? in | ||
switch state.state { | ||
case .continuation(let continuation): | ||
state.state = .buffer([]) | ||
return continuation | ||
|
||
case .buffer(var buffer): | ||
buffer.append(invocation) | ||
state.state = .buffer(buffer) | ||
return nil | ||
} | ||
} | ||
|
||
maybeContinuation?.resume(returning: invocation) | ||
} | ||
|
||
func next() async throws -> T? { | ||
|
||
// exit the async for loop if the task is cancelled | ||
guard !Task.isCancelled else { | ||
return nil | ||
} | ||
|
||
if let element = await self.popFirst() { | ||
return element | ||
} else { | ||
// we can't return nil if there is nothing to dequeue otherwise the async for loop will stop | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would keep those two comments. They explain why this is not a regular async iterator. It blocks when the buffer is empty |
||
// wait for an element to be enqueued | ||
return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<T, any Error>) in | ||
// store the continuation for later, when an element is enqueued | ||
self._continuation.withLock { | ||
$0 = continuation | ||
return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<T, any Error>) in | ||
let nextAction = self.lock.withLock { state -> T? in | ||
switch state.state { | ||
case .buffer(var buffer): | ||
if let first = buffer.popFirst() { | ||
state.state = .buffer(buffer) | ||
return first | ||
} else { | ||
state.state = .continuation(continuation) | ||
return nil | ||
} | ||
|
||
case .continuation: | ||
fatalError("Concurrent invocations to next(). This is illigal.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo "illegal" |
||
} | ||
} | ||
|
||
guard let nextAction else { return } | ||
|
||
continuation.resume(returning: nextAction) | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need this import now that we don't use
CircularBuffer
anymore ?