Skip to content

Commit ec68898

Browse files
committed
Add cancellation handling for nextInvocation()
1 parent d778048 commit ec68898

File tree

3 files changed

+84
-20
lines changed

3 files changed

+84
-20
lines changed

Sources/AWSLambdaRuntimeCore/Lambda.swift

+22-16
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,31 @@ public enum Lambda {
3737
) async throws where Handler: StreamingLambdaHandler {
3838
var handler = handler
3939

40-
while !Task.isCancelled {
41-
let (invocation, writer) = try await runtimeClient.nextInvocation()
40+
do {
41+
while !Task.isCancelled {
42+
let (invocation, writer) = try await runtimeClient.nextInvocation()
4243

43-
do {
44-
try await handler.handle(
45-
invocation.event,
46-
responseWriter: writer,
47-
context: LambdaContext(
48-
requestID: invocation.metadata.requestID,
49-
traceID: invocation.metadata.traceID,
50-
invokedFunctionARN: invocation.metadata.invokedFunctionARN,
51-
deadline: DispatchWallTime(millisSinceEpoch: invocation.metadata.deadlineInMillisSinceEpoch),
52-
logger: logger
44+
do {
45+
try await handler.handle(
46+
invocation.event,
47+
responseWriter: writer,
48+
context: LambdaContext(
49+
requestID: invocation.metadata.requestID,
50+
traceID: invocation.metadata.traceID,
51+
invokedFunctionARN: invocation.metadata.invokedFunctionARN,
52+
deadline: DispatchWallTime(
53+
millisSinceEpoch: invocation.metadata.deadlineInMillisSinceEpoch
54+
),
55+
logger: logger
56+
)
5357
)
54-
)
55-
} catch {
56-
try await writer.reportError(error)
57-
continue
58+
} catch {
59+
try await writer.reportError(error)
60+
continue
61+
}
5862
}
63+
} catch is CancellationError {
64+
// don't allow cancellation error to propagate further
5965
}
6066
}
6167

Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift

+42-4
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,17 @@ private protocol LambdaChannelHandlerDelegate {
410410
func connectionErrorHappened(_ error: any Error, channel: any Channel)
411411
}
412412

413+
struct UnsafeContext: @unchecked Sendable {
414+
private let _context: ChannelHandlerContext
415+
var context: ChannelHandlerContext {
416+
self._context.eventLoop.preconditionInEventLoop()
417+
return _context
418+
}
419+
init(_ context: ChannelHandlerContext) {
420+
self._context = context
421+
}
422+
}
423+
413424
private final class LambdaChannelHandler<Delegate: LambdaChannelHandlerDelegate> {
414425
let nextInvocationPath = Consts.invocationURLPrefix + Consts.getNextInvocationURLSuffix
415426

@@ -469,10 +480,37 @@ private final class LambdaChannelHandler<Delegate: LambdaChannelHandlerDelegate>
469480
func nextInvocation(isolation: isolated (any Actor)? = #isolation) async throws -> Invocation {
470481
switch self.state {
471482
case .connected(let context, .idle):
472-
return try await withCheckedThrowingContinuation {
473-
(continuation: CheckedContinuation<Invocation, any Error>) in
474-
self.state = .connected(context, .waitingForNextInvocation(continuation))
475-
self.sendNextRequest(context: context)
483+
return try await withTaskCancellationHandler {
484+
try Task.checkCancellation()
485+
return try await withCheckedThrowingContinuation {
486+
(continuation: CheckedContinuation<Invocation, any Error>) in
487+
self.state = .connected(context, .waitingForNextInvocation(continuation))
488+
489+
let unsafeContext = UnsafeContext(context)
490+
context.eventLoop.execute { [nextInvocationPath, defaultHeaders] in
491+
// Send next request. The function `sendNextRequest` requires `self` which is not
492+
// Sendable so just inlined the code instead
493+
let httpRequest = HTTPRequestHead(
494+
version: .http1_1,
495+
method: .GET,
496+
uri: nextInvocationPath,
497+
headers: defaultHeaders
498+
)
499+
let context = unsafeContext.context
500+
context.write(Self.wrapOutboundOut(.head(httpRequest)), promise: nil)
501+
context.write(Self.wrapOutboundOut(.end(nil)), promise: nil)
502+
context.flush()
503+
}
504+
}
505+
} onCancel: {
506+
switch self.state {
507+
case .connected(_, .waitingForNextInvocation(let continuation)):
508+
continuation.resume(throwing: CancellationError())
509+
case .connected(_, .idle):
510+
break
511+
default:
512+
fatalError("Invalid state: \(self.state)")
513+
}
476514
}
477515

478516
case .connected(_, .sendingResponse),

Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift

+20
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,24 @@ struct LambdaRuntimeClientTests {
8686
}
8787
}
8888
}
89+
90+
@Test
91+
func testCancellation() async throws {
92+
try await LambdaRuntimeClient.withRuntimeClient(
93+
configuration: .init(ip: "127.0.0.1", port: 7000),
94+
eventLoop: NIOSingletons.posixEventLoopGroup.next(),
95+
logger: self.logger
96+
) { runtimeClient in
97+
try await withThrowingTaskGroup(of: Void.self) { group in
98+
group.addTask {
99+
while true {
100+
_ = try await runtimeClient.nextInvocation()
101+
}
102+
}
103+
// wait a small amount to ensure we are waiting for continuation
104+
try await Task.sleep(for: .milliseconds(100))
105+
group.cancelAll()
106+
}
107+
}
108+
}
89109
}

0 commit comments

Comments
 (0)