Skip to content

Commit 2ecddcf

Browse files
committed
Fixes to Local Lambda Server
1 parent 64b4179 commit 2ecddcf

File tree

2 files changed

+142
-75
lines changed

2 files changed

+142
-75
lines changed

Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift

+141-73
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
#if DEBUG
16+
import DequeModule
1617
import Dispatch
1718
import Logging
1819
import NIOConcurrencyHelpers
@@ -47,24 +48,15 @@ extension Lambda {
4748
/// - note: This API is designed strictly for local testing and is behind a DEBUG flag
4849
static func withLocalServer(
4950
invocationEndpoint: String? = nil,
50-
_ body: @escaping () async throws -> Void
51+
_ body: sending @escaping () async throws -> Void
5152
) async throws {
53+
var logger = Logger(label: "LocalServer")
54+
logger.logLevel = Lambda.env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info
5255

53-
// launch the local server and wait for it to be started before running the body
54-
try await withThrowingTaskGroup(of: Void.self) { group in
55-
// this call will return when the server calls continuation.resume()
56-
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
57-
group.addTask {
58-
do {
59-
try await LambdaHttpServer(invocationEndpoint: invocationEndpoint).start(
60-
continuation: continuation
61-
)
62-
} catch {
63-
continuation.resume(throwing: error)
64-
}
65-
}
66-
}
67-
// now that server is started, run the Lambda function itself
56+
try await LambdaHTTPServer.withLocalServer(
57+
invocationEndpoint: invocationEndpoint,
58+
logger: logger
59+
) {
6860
try await body()
6961
}
7062
}
@@ -84,34 +76,38 @@ extension Lambda {
8476
/// 1. POST /invoke - the client posts the event to the lambda function
8577
///
8678
/// This server passes the data received from /invoke POST request to the lambda function (GET /next) and then forwards the response back to the client.
87-
private struct LambdaHttpServer {
88-
private let logger: Logger
89-
private let group: EventLoopGroup
90-
private let host: String
91-
private let port: Int
79+
private struct LambdaHTTPServer {
9280
private let invocationEndpoint: String
9381

9482
private let invocationPool = Pool<LocalServerInvocation>()
9583
private let responsePool = Pool<LocalServerResponse>()
9684

97-
init(invocationEndpoint: String?) {
98-
var logger = Logger(label: "LocalServer")
99-
logger.logLevel = Lambda.env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info
100-
self.logger = logger
101-
self.group = MultiThreadedEventLoopGroup.singleton
102-
self.host = "127.0.0.1"
103-
self.port = 7000
85+
private init(
86+
invocationEndpoint: String?
87+
) {
10488
self.invocationEndpoint = invocationEndpoint ?? "/invoke"
10589
}
10690

107-
func start(continuation: CheckedContinuation<Void, any Error>) async throws {
108-
let channel = try await ServerBootstrap(group: self.group)
91+
private enum TaskResult<Result: Sendable>: Sendable {
92+
case closureResult(Swift.Result<Result, any Error>)
93+
case serverReturned(Swift.Result<Void, any Error>)
94+
}
95+
96+
static func withLocalServer<Result: Sendable>(
97+
invocationEndpoint: String?,
98+
host: String = "127.0.0.1",
99+
port: Int = 7000,
100+
eventLoopGroup: MultiThreadedEventLoopGroup = .singleton,
101+
logger: Logger,
102+
_ closure: sending @escaping () async throws -> Result
103+
) async throws -> Result {
104+
let channel = try await ServerBootstrap(group: eventLoopGroup)
109105
.serverChannelOption(.backlog, value: 256)
110106
.serverChannelOption(.socketOption(.so_reuseaddr), value: 1)
111107
.childChannelOption(.maxMessagesPerRead, value: 1)
112108
.bind(
113-
host: self.host,
114-
port: self.port
109+
host: host,
110+
port: port
115111
) { channel in
116112
channel.eventLoop.makeCompletedFuture {
117113

@@ -129,8 +125,6 @@ private struct LambdaHttpServer {
129125
}
130126
}
131127

132-
// notify the caller that the server is started
133-
continuation.resume()
134128
logger.info(
135129
"Server started and listening",
136130
metadata: [
@@ -139,30 +133,76 @@ private struct LambdaHttpServer {
139133
]
140134
)
141135

136+
let server = LambdaHTTPServer(invocationEndpoint: invocationEndpoint)
137+
142138
// We are handling each incoming connection in a separate child task. It is important
143139
// to use a discarding task group here which automatically discards finished child tasks.
144140
// A normal task group retains all child tasks and their outputs in memory until they are
145141
// consumed by iterating the group or by exiting the group. Since, we are never consuming
146142
// the results of the group we need the group to automatically discard them; otherwise, this
147143
// would result in a memory leak over time.
148-
try await withThrowingDiscardingTaskGroup { group in
149-
try await channel.executeThenClose { inbound in
150-
for try await connectionChannel in inbound {
151-
152-
group.addTask {
153-
logger.trace("Handling a new connection")
154-
await self.handleConnection(channel: connectionChannel)
155-
logger.trace("Done handling the connection")
144+
let result = await withTaskGroup(of: TaskResult<Result>.self, returning: Swift.Result<Result, any Error>.self) { group in
145+
146+
let c = closure
147+
group.addTask {
148+
do {
149+
150+
let result = try await c()
151+
return .closureResult(.success(result))
152+
} catch {
153+
return .closureResult(.failure(error))
154+
}
155+
}
156+
157+
group.addTask {
158+
do {
159+
try await withThrowingDiscardingTaskGroup { taskGroup in
160+
try await channel.executeThenClose { inbound in
161+
for try await connectionChannel in inbound {
162+
163+
taskGroup.addTask {
164+
logger.trace("Handling a new connection")
165+
await server.handleConnection(channel: connectionChannel, logger: logger)
166+
logger.trace("Done handling the connection")
167+
}
168+
}
169+
}
156170
}
171+
return .serverReturned(.success(()))
172+
} catch {
173+
return .serverReturned(.failure(error))
174+
}
175+
}
176+
177+
let task1 = await group.next()!
178+
group.cancelAll()
179+
let task2 = await group.next()!
180+
181+
switch task1 {
182+
case .closureResult(let result):
183+
return result
184+
185+
case .serverReturned:
186+
switch task2 {
187+
case .closureResult(let result):
188+
return result
189+
190+
case .serverReturned:
191+
fatalError()
157192
}
158193
}
159194
}
195+
160196
logger.info("Server shutting down")
197+
return try result.get()
161198
}
162199

200+
201+
163202
/// This method handles individual TCP connections
164203
private func handleConnection(
165-
channel: NIOAsyncChannel<HTTPServerRequestPart, HTTPServerResponsePart>
204+
channel: NIOAsyncChannel<HTTPServerRequestPart, HTTPServerResponsePart>,
205+
logger: Logger
166206
) async {
167207

168208
var requestHead: HTTPRequestHead!
@@ -186,12 +226,14 @@ private struct LambdaHttpServer {
186226
// process the request
187227
let response = try await self.processRequest(
188228
head: requestHead,
189-
body: requestBody
229+
body: requestBody,
230+
logger: logger
190231
)
191232
// send the responses
192233
try await self.sendResponse(
193234
response: response,
194-
outbound: outbound
235+
outbound: outbound,
236+
logger: logger
195237
)
196238

197239
requestHead = nil
@@ -214,15 +256,15 @@ private struct LambdaHttpServer {
214256
/// - body: the HTTP request body
215257
/// - Throws:
216258
/// - Returns: the response to send back to the client or the Lambda function
217-
private func processRequest(head: HTTPRequestHead, body: ByteBuffer?) async throws -> LocalServerResponse {
259+
private func processRequest(head: HTTPRequestHead, body: ByteBuffer?, logger: Logger) async throws -> LocalServerResponse {
218260

219261
if let body {
220-
self.logger.trace(
262+
logger.trace(
221263
"Processing request",
222264
metadata: ["URI": "\(head.method) \(head.uri)", "Body": "\(String(buffer: body))"]
223265
)
224266
} else {
225-
self.logger.trace("Processing request", metadata: ["URI": "\(head.method) \(head.uri)"])
267+
logger.trace("Processing request", metadata: ["URI": "\(head.method) \(head.uri)"])
226268
}
227269

228270
switch (head.method, head.uri) {
@@ -237,7 +279,9 @@ private struct LambdaHttpServer {
237279
}
238280
// we always accept the /invoke request and push them to the pool
239281
let requestId = "\(DispatchTime.now().uptimeNanoseconds)"
240-
logger.trace("/invoke received invocation", metadata: ["requestId": "\(requestId)"])
282+
var logger = logger
283+
logger[metadataKey: "requestID"] = "\(requestId)"
284+
logger.trace("/invoke received invocation")
241285
await self.invocationPool.push(LocalServerInvocation(requestId: requestId, request: body))
242286

243287
// wait for the lambda function to process the request
@@ -273,9 +317,9 @@ private struct LambdaHttpServer {
273317
case (.GET, let url) where url.hasSuffix(Consts.getNextInvocationURLSuffix):
274318

275319
// pop the tasks from the queue
276-
self.logger.trace("/next waiting for /invoke")
320+
logger.trace("/next waiting for /invoke")
277321
for try await invocation in self.invocationPool {
278-
self.logger.trace("/next retrieved invocation", metadata: ["requestId": "\(invocation.requestId)"])
322+
logger.trace("/next retrieved invocation", metadata: ["requestId": "\(invocation.requestId)"])
279323
// this call also stores the invocation requestId into the response
280324
return invocation.makeResponse(status: .accepted)
281325
}
@@ -322,12 +366,13 @@ private struct LambdaHttpServer {
322366

323367
private func sendResponse(
324368
response: LocalServerResponse,
325-
outbound: NIOAsyncChannelOutboundWriter<HTTPServerResponsePart>
369+
outbound: NIOAsyncChannelOutboundWriter<HTTPServerResponsePart>,
370+
logger: Logger
326371
) async throws {
327372
var headers = HTTPHeaders(response.headers ?? [])
328373
headers.add(name: "Content-Length", value: "\(response.body?.readableBytes ?? 0)")
329374

330-
self.logger.trace("Writing response", metadata: ["requestId": "\(response.requestId ?? "")"])
375+
logger.trace("Writing response", metadata: ["requestId": "\(response.requestId ?? "")"])
331376
try await outbound.write(
332377
HTTPServerResponsePart.head(
333378
HTTPResponseHead(
@@ -350,44 +395,67 @@ private struct LambdaHttpServer {
350395
private final class Pool<T>: AsyncSequence, AsyncIteratorProtocol, Sendable where T: Sendable {
351396
typealias Element = T
352397

353-
private let _buffer = Mutex<CircularBuffer<T>>(.init())
354-
private let _continuation = Mutex<CheckedContinuation<T, any Error>?>(nil)
398+
struct State {
399+
enum State {
400+
case buffer(Deque<T>)
401+
case continuation(CheckedContinuation<T, any Error>?)
402+
}
355403

356-
/// retrieve the first element from the buffer
357-
public func popFirst() async -> T? {
358-
self._buffer.withLock { $0.popFirst() }
404+
var state: State
405+
406+
init() {
407+
self.state = .buffer([])
408+
}
359409
}
360410

411+
private let lock = Mutex<State>(.init())
412+
361413
/// enqueue an element, or give it back immediately to the iterator if it is waiting for an element
362414
public func push(_ invocation: T) async {
363415
// if the iterator is waiting for an element, give it to it
364416
// otherwise, enqueue the element
365-
if let continuation = self._continuation.withLock({ $0 }) {
366-
self._continuation.withLock { $0 = nil }
367-
continuation.resume(returning: invocation)
368-
} else {
369-
self._buffer.withLock { $0.append(invocation) }
417+
let maybeContinuation = self.lock.withLock { state -> CheckedContinuation<T, any Error>? in
418+
switch state.state {
419+
case .continuation(let continuation):
420+
state.state = .buffer([])
421+
return continuation
422+
423+
case .buffer(var buffer):
424+
buffer.append(invocation)
425+
state.state = .buffer(buffer)
426+
return nil
427+
}
370428
}
429+
430+
maybeContinuation?.resume(returning: invocation)
371431
}
372432

373433
func next() async throws -> T? {
374-
375434
// exit the async for loop if the task is cancelled
376435
guard !Task.isCancelled else {
377436
return nil
378437
}
379438

380-
if let element = await self.popFirst() {
381-
return element
382-
} else {
383-
// we can't return nil if there is nothing to dequeue otherwise the async for loop will stop
384-
// wait for an element to be enqueued
385-
return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<T, any Error>) in
386-
// store the continuation for later, when an element is enqueued
387-
self._continuation.withLock {
388-
$0 = continuation
439+
return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<T, any Error>) in
440+
let nextAction = self.lock.withLock { state -> T? in
441+
switch state.state {
442+
case .buffer(var buffer):
443+
if let first = buffer.popFirst() {
444+
state.state = .buffer(buffer)
445+
return first
446+
} else {
447+
state.state = .continuation(continuation)
448+
return nil
449+
}
450+
451+
case .continuation:
452+
fatalError("Concurrent invocations to next(). This is illigal.")
389453
}
390454
}
455+
456+
guard let nextAction else { return }
457+
458+
continuation.resume(returning: nextAction)
391459
}
392460
}
393461

Sources/AWSLambdaRuntimeCore/LambdaRuntime.swift

+1-2
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,7 @@ public final class LambdaRuntime<Handler>: @unchecked Sendable where Handler: St
8585
#if DEBUG
8686
// we're not running on Lambda and we're compiled in DEBUG mode,
8787
// let's start a local server for testing
88-
try await Lambda.withLocalServer(invocationEndpoint: Lambda.env("LOCAL_LAMBDA_SERVER_INVOCATION_ENDPOINT"))
89-
{
88+
try await Lambda.withLocalServer(invocationEndpoint: Lambda.env("LOCAL_LAMBDA_SERVER_INVOCATION_ENDPOINT")) {
9089

9190
try await LambdaRuntimeClient.withRuntimeClient(
9291
configuration: .init(ip: "127.0.0.1", port: 7000),

0 commit comments

Comments
 (0)