13
13
//===----------------------------------------------------------------------===//
14
14
15
15
#if DEBUG
16
+ import DequeModule
16
17
import Dispatch
17
18
import Logging
18
19
import NIOConcurrencyHelpers
@@ -47,24 +48,15 @@ extension Lambda {
47
48
/// - note: This API is designed strictly for local testing and is behind a DEBUG flag
48
49
static func withLocalServer(
49
50
invocationEndpoint: String ? = nil ,
50
- _ body: @escaping ( ) async throws -> Void
51
+ _ body: sending @escaping ( ) async throws -> Void
51
52
) async throws {
53
+ var logger = Logger ( label: " LocalServer " )
54
+ logger. logLevel = Lambda . env ( " LOG_LEVEL " ) . flatMap ( Logger . Level. init) ?? . info
52
55
53
- // launch the local server and wait for it to be started before running the body
54
- try await withThrowingTaskGroup ( of: Void . self) { group in
55
- // this call will return when the server calls continuation.resume()
56
- try await withCheckedThrowingContinuation { ( continuation: CheckedContinuation < Void , any Error > ) in
57
- group. addTask {
58
- do {
59
- try await LambdaHttpServer ( invocationEndpoint: invocationEndpoint) . start (
60
- continuation: continuation
61
- )
62
- } catch {
63
- continuation. resume ( throwing: error)
64
- }
65
- }
66
- }
67
- // now that server is started, run the Lambda function itself
56
+ try await LambdaHTTPServer . withLocalServer (
57
+ invocationEndpoint: invocationEndpoint,
58
+ logger: logger
59
+ ) {
68
60
try await body ( )
69
61
}
70
62
}
@@ -84,34 +76,38 @@ extension Lambda {
84
76
/// 1. POST /invoke - the client posts the event to the lambda function
85
77
///
86
78
/// This server passes the data received from /invoke POST request to the lambda function (GET /next) and then forwards the response back to the client.
87
- private struct LambdaHttpServer {
88
- private let logger : Logger
89
- private let group : EventLoopGroup
90
- private let host : String
91
- private let port : Int
79
+ private struct LambdaHTTPServer {
92
80
private let invocationEndpoint : String
93
81
94
82
private let invocationPool = Pool < LocalServerInvocation > ( )
95
83
private let responsePool = Pool < LocalServerResponse > ( )
96
84
97
- init ( invocationEndpoint: String ? ) {
98
- var logger = Logger ( label: " LocalServer " )
99
- logger. logLevel = Lambda . env ( " LOG_LEVEL " ) . flatMap ( Logger . Level. init) ?? . info
100
- self . logger = logger
101
- self . group = MultiThreadedEventLoopGroup . singleton
102
- self . host = " 127.0.0.1 "
103
- self . port = 7000
85
+ private init (
86
+ invocationEndpoint: String ?
87
+ ) {
104
88
self . invocationEndpoint = invocationEndpoint ?? " /invoke "
105
89
}
106
90
107
- func start( continuation: CheckedContinuation < Void , any Error > ) async throws {
108
- let channel = try await ServerBootstrap ( group: self . group)
91
+ private enum TaskResult < Result: Sendable > : Sendable {
92
+ case closureResult( Swift . Result < Result , any Error > )
93
+ case serverReturned( Swift . Result < Void , any Error > )
94
+ }
95
+
96
+ static func withLocalServer< Result: Sendable > (
97
+ invocationEndpoint: String ? ,
98
+ host: String = " 127.0.0.1 " ,
99
+ port: Int = 7000 ,
100
+ eventLoopGroup: MultiThreadedEventLoopGroup = . singleton,
101
+ logger: Logger ,
102
+ _ closure: sending @escaping ( ) async throws -> Result
103
+ ) async throws -> Result {
104
+ let channel = try await ServerBootstrap ( group: eventLoopGroup)
109
105
. serverChannelOption ( . backlog, value: 256 )
110
106
. serverChannelOption ( . socketOption( . so_reuseaddr) , value: 1 )
111
107
. childChannelOption ( . maxMessagesPerRead, value: 1 )
112
108
. bind (
113
- host: self . host,
114
- port: self . port
109
+ host: host,
110
+ port: port
115
111
) { channel in
116
112
channel. eventLoop. makeCompletedFuture {
117
113
@@ -129,8 +125,6 @@ private struct LambdaHttpServer {
129
125
}
130
126
}
131
127
132
- // notify the caller that the server is started
133
- continuation. resume ( )
134
128
logger. info (
135
129
" Server started and listening " ,
136
130
metadata: [
@@ -139,30 +133,76 @@ private struct LambdaHttpServer {
139
133
]
140
134
)
141
135
136
+ let server = LambdaHTTPServer ( invocationEndpoint: invocationEndpoint)
137
+
142
138
// We are handling each incoming connection in a separate child task. It is important
143
139
// to use a discarding task group here which automatically discards finished child tasks.
144
140
// A normal task group retains all child tasks and their outputs in memory until they are
145
141
// consumed by iterating the group or by exiting the group. Since, we are never consuming
146
142
// the results of the group we need the group to automatically discard them; otherwise, this
147
143
// would result in a memory leak over time.
148
- try await withThrowingDiscardingTaskGroup { group in
149
- try await channel. executeThenClose { inbound in
150
- for try await connectionChannel in inbound {
151
-
152
- group. addTask {
153
- logger. trace ( " Handling a new connection " )
154
- await self . handleConnection ( channel: connectionChannel)
155
- logger. trace ( " Done handling the connection " )
144
+ let result = await withTaskGroup ( of: TaskResult< Result> . self , returning: Swift . Result < Result , any Error > . self) { group in
145
+
146
+ let c = closure
147
+ group. addTask {
148
+ do {
149
+
150
+ let result = try await c ( )
151
+ return . closureResult( . success( result) )
152
+ } catch {
153
+ return . closureResult( . failure( error) )
154
+ }
155
+ }
156
+
157
+ group. addTask {
158
+ do {
159
+ try await withThrowingDiscardingTaskGroup { taskGroup in
160
+ try await channel. executeThenClose { inbound in
161
+ for try await connectionChannel in inbound {
162
+
163
+ taskGroup. addTask {
164
+ logger. trace ( " Handling a new connection " )
165
+ await server. handleConnection ( channel: connectionChannel, logger: logger)
166
+ logger. trace ( " Done handling the connection " )
167
+ }
168
+ }
169
+ }
156
170
}
171
+ return . serverReturned( . success( ( ) ) )
172
+ } catch {
173
+ return . serverReturned( . failure( error) )
174
+ }
175
+ }
176
+
177
+ let task1 = await group. next ( ) !
178
+ group. cancelAll ( )
179
+ let task2 = await group. next ( ) !
180
+
181
+ switch task1 {
182
+ case . closureResult( let result) :
183
+ return result
184
+
185
+ case . serverReturned:
186
+ switch task2 {
187
+ case . closureResult( let result) :
188
+ return result
189
+
190
+ case . serverReturned:
191
+ fatalError ( )
157
192
}
158
193
}
159
194
}
195
+
160
196
logger. info ( " Server shutting down " )
197
+ return try result. get ( )
161
198
}
162
199
200
+
201
+
163
202
/// This method handles individual TCP connections
164
203
private func handleConnection(
165
- channel: NIOAsyncChannel < HTTPServerRequestPart , HTTPServerResponsePart >
204
+ channel: NIOAsyncChannel < HTTPServerRequestPart , HTTPServerResponsePart > ,
205
+ logger: Logger
166
206
) async {
167
207
168
208
var requestHead : HTTPRequestHead !
@@ -186,12 +226,14 @@ private struct LambdaHttpServer {
186
226
// process the request
187
227
let response = try await self . processRequest (
188
228
head: requestHead,
189
- body: requestBody
229
+ body: requestBody,
230
+ logger: logger
190
231
)
191
232
// send the responses
192
233
try await self . sendResponse (
193
234
response: response,
194
- outbound: outbound
235
+ outbound: outbound,
236
+ logger: logger
195
237
)
196
238
197
239
requestHead = nil
@@ -214,15 +256,15 @@ private struct LambdaHttpServer {
214
256
/// - body: the HTTP request body
215
257
/// - Throws:
216
258
/// - Returns: the response to send back to the client or the Lambda function
217
- private func processRequest( head: HTTPRequestHead , body: ByteBuffer ? ) async throws -> LocalServerResponse {
259
+ private func processRequest( head: HTTPRequestHead , body: ByteBuffer ? , logger : Logger ) async throws -> LocalServerResponse {
218
260
219
261
if let body {
220
- self . logger. trace (
262
+ logger. trace (
221
263
" Processing request " ,
222
264
metadata: [ " URI " : " \( head. method) \( head. uri) " , " Body " : " \( String ( buffer: body) ) " ]
223
265
)
224
266
} else {
225
- self . logger. trace ( " Processing request " , metadata: [ " URI " : " \( head. method) \( head. uri) " ] )
267
+ logger. trace ( " Processing request " , metadata: [ " URI " : " \( head. method) \( head. uri) " ] )
226
268
}
227
269
228
270
switch ( head. method, head. uri) {
@@ -237,7 +279,9 @@ private struct LambdaHttpServer {
237
279
}
238
280
// we always accept the /invoke request and push them to the pool
239
281
let requestId = " \( DispatchTime . now ( ) . uptimeNanoseconds) "
240
- logger. trace ( " /invoke received invocation " , metadata: [ " requestId " : " \( requestId) " ] )
282
+ var logger = logger
283
+ logger [ metadataKey: " requestID " ] = " \( requestId) "
284
+ logger. trace ( " /invoke received invocation " )
241
285
await self . invocationPool. push ( LocalServerInvocation ( requestId: requestId, request: body) )
242
286
243
287
// wait for the lambda function to process the request
@@ -273,9 +317,9 @@ private struct LambdaHttpServer {
273
317
case ( . GET, let url) where url. hasSuffix ( Consts . getNextInvocationURLSuffix) :
274
318
275
319
// pop the tasks from the queue
276
- self . logger. trace ( " /next waiting for /invoke " )
320
+ logger. trace ( " /next waiting for /invoke " )
277
321
for try await invocation in self . invocationPool {
278
- self . logger. trace ( " /next retrieved invocation " , metadata: [ " requestId " : " \( invocation. requestId) " ] )
322
+ logger. trace ( " /next retrieved invocation " , metadata: [ " requestId " : " \( invocation. requestId) " ] )
279
323
// this call also stores the invocation requestId into the response
280
324
return invocation. makeResponse ( status: . accepted)
281
325
}
@@ -322,12 +366,13 @@ private struct LambdaHttpServer {
322
366
323
367
private func sendResponse(
324
368
response: LocalServerResponse ,
325
- outbound: NIOAsyncChannelOutboundWriter < HTTPServerResponsePart >
369
+ outbound: NIOAsyncChannelOutboundWriter < HTTPServerResponsePart > ,
370
+ logger: Logger
326
371
) async throws {
327
372
var headers = HTTPHeaders ( response. headers ?? [ ] )
328
373
headers. add ( name: " Content-Length " , value: " \( response. body? . readableBytes ?? 0 ) " )
329
374
330
- self . logger. trace ( " Writing response " , metadata: [ " requestId " : " \( response. requestId ?? " " ) " ] )
375
+ logger. trace ( " Writing response " , metadata: [ " requestId " : " \( response. requestId ?? " " ) " ] )
331
376
try await outbound. write (
332
377
HTTPServerResponsePart . head (
333
378
HTTPResponseHead (
@@ -350,44 +395,67 @@ private struct LambdaHttpServer {
350
395
private final class Pool < T> : AsyncSequence , AsyncIteratorProtocol , Sendable where T: Sendable {
351
396
typealias Element = T
352
397
353
- private let _buffer = Mutex < CircularBuffer < T > > ( . init( ) )
354
- private let _continuation = Mutex < CheckedContinuation < T , any Error > ? > ( nil )
398
+ struct State {
399
+ enum State {
400
+ case buffer( Deque < T > )
401
+ case continuation( CheckedContinuation < T , any Error > ? )
402
+ }
355
403
356
- /// retrieve the first element from the buffer
357
- public func popFirst( ) async -> T ? {
358
- self . _buffer. withLock { $0. popFirst ( ) }
404
+ var state : State
405
+
406
+ init ( ) {
407
+ self . state = . buffer( [ ] )
408
+ }
359
409
}
360
410
411
+ private let lock = Mutex < State > ( . init( ) )
412
+
361
413
/// enqueue an element, or give it back immediately to the iterator if it is waiting for an element
362
414
public func push( _ invocation: T ) async {
363
415
// if the iterator is waiting for an element, give it to it
364
416
// otherwise, enqueue the element
365
- if let continuation = self . _continuation. withLock ( { $0 } ) {
366
- self . _continuation. withLock { $0 = nil }
367
- continuation. resume ( returning: invocation)
368
- } else {
369
- self . _buffer. withLock { $0. append ( invocation) }
417
+ let maybeContinuation = self . lock. withLock { state -> CheckedContinuation < T , any Error > ? in
418
+ switch state. state {
419
+ case . continuation( let continuation) :
420
+ state. state = . buffer( [ ] )
421
+ return continuation
422
+
423
+ case . buffer( var buffer) :
424
+ buffer. append ( invocation)
425
+ state. state = . buffer( buffer)
426
+ return nil
427
+ }
370
428
}
429
+
430
+ maybeContinuation? . resume ( returning: invocation)
371
431
}
372
432
373
433
func next( ) async throws -> T ? {
374
-
375
434
// exit the async for loop if the task is cancelled
376
435
guard !Task. isCancelled else {
377
436
return nil
378
437
}
379
438
380
- if let element = await self . popFirst ( ) {
381
- return element
382
- } else {
383
- // we can't return nil if there is nothing to dequeue otherwise the async for loop will stop
384
- // wait for an element to be enqueued
385
- return try await withCheckedThrowingContinuation { ( continuation: CheckedContinuation < T , any Error > ) in
386
- // store the continuation for later, when an element is enqueued
387
- self . _continuation. withLock {
388
- $0 = continuation
439
+ return try await withCheckedThrowingContinuation { ( continuation: CheckedContinuation < T , any Error > ) in
440
+ let nextAction = self . lock. withLock { state -> T ? in
441
+ switch state. state {
442
+ case . buffer( var buffer) :
443
+ if let first = buffer. popFirst ( ) {
444
+ state. state = . buffer( buffer)
445
+ return first
446
+ } else {
447
+ state. state = . continuation( continuation)
448
+ return nil
449
+ }
450
+
451
+ case . continuation:
452
+ fatalError ( " Concurrent invocations to next(). This is illigal. " )
389
453
}
390
454
}
455
+
456
+ guard let nextAction else { return }
457
+
458
+ continuation. resume ( returning: nextAction)
391
459
}
392
460
}
393
461
0 commit comments