Skip to content

Commit 61cd5d5

Browse files
adam-fowlerfabianfettsebsto
authored
[core] Add cancellation handling for nextInvocation() (#459)
Allow `LambdaChannelHandler.nextInvocation` to be cancelled. ### Motivation: If we want to use ServiceLifecycle with the lambda runtime the lambda runtime needs to be cancellable either via a ServiceLifecycle graceful shutdown or via Task cancellation. To avoid bringing in the ServiceLifecycle dependency this PR adds cancellation via Task cancellation handler. ### Modifications: Add `withTaskCancellationHandler` to nextInvocation which calls close on cancel. In `LambdaChannelHandler.channelInactive` resume continuation if state is `waitingForNextInvocation` Added `LambdaRuntimeClientTests.testCancellation` ### Result: You can now cancel the runtime while it is waiting for the next invocation. --------- Co-authored-by: Fabian Fett <[email protected]> Co-authored-by: Sébastien Stormacq <[email protected]>
1 parent 5de00c9 commit 61cd5d5

File tree

3 files changed

+100
-30
lines changed

3 files changed

+100
-30
lines changed

Sources/AWSLambdaRuntimeCore/Lambda.swift

+22-16
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,31 @@ public enum Lambda {
3737
) async throws where Handler: StreamingLambdaHandler {
3838
var handler = handler
3939

40-
while !Task.isCancelled {
41-
let (invocation, writer) = try await runtimeClient.nextInvocation()
40+
do {
41+
while !Task.isCancelled {
42+
let (invocation, writer) = try await runtimeClient.nextInvocation()
4243

43-
do {
44-
try await handler.handle(
45-
invocation.event,
46-
responseWriter: writer,
47-
context: LambdaContext(
48-
requestID: invocation.metadata.requestID,
49-
traceID: invocation.metadata.traceID,
50-
invokedFunctionARN: invocation.metadata.invokedFunctionARN,
51-
deadline: DispatchWallTime(millisSinceEpoch: invocation.metadata.deadlineInMillisSinceEpoch),
52-
logger: logger
44+
do {
45+
try await handler.handle(
46+
invocation.event,
47+
responseWriter: writer,
48+
context: LambdaContext(
49+
requestID: invocation.metadata.requestID,
50+
traceID: invocation.metadata.traceID,
51+
invokedFunctionARN: invocation.metadata.invokedFunctionARN,
52+
deadline: DispatchWallTime(
53+
millisSinceEpoch: invocation.metadata.deadlineInMillisSinceEpoch
54+
),
55+
logger: logger
56+
)
5357
)
54-
)
55-
} catch {
56-
try await writer.reportError(error)
57-
continue
58+
} catch {
59+
try await writer.reportError(error)
60+
continue
61+
}
5862
}
63+
} catch is CancellationError {
64+
// don't allow cancellation error to propagate further
5965
}
6066
}
6167

Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift

+26-14
Original file line numberDiff line numberDiff line change
@@ -145,22 +145,28 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
145145
}
146146

147147
func nextInvocation() async throws -> (Invocation, Writer) {
148-
switch self.lambdaState {
149-
case .idle:
150-
self.lambdaState = .waitingForNextInvocation
151-
let handler = try await self.makeOrGetConnection()
152-
let invocation = try await handler.nextInvocation()
153-
guard case .waitingForNextInvocation = self.lambdaState else {
148+
try await withTaskCancellationHandler {
149+
switch self.lambdaState {
150+
case .idle:
151+
self.lambdaState = .waitingForNextInvocation
152+
let handler = try await self.makeOrGetConnection()
153+
let invocation = try await handler.nextInvocation()
154+
guard case .waitingForNextInvocation = self.lambdaState else {
155+
fatalError("Invalid state: \(self.lambdaState)")
156+
}
157+
self.lambdaState = .waitingForResponse(requestID: invocation.metadata.requestID)
158+
return (invocation, Writer(runtimeClient: self))
159+
160+
case .waitingForNextInvocation,
161+
.waitingForResponse,
162+
.sendingResponse,
163+
.sentResponse:
154164
fatalError("Invalid state: \(self.lambdaState)")
155165
}
156-
self.lambdaState = .waitingForResponse(requestID: invocation.metadata.requestID)
157-
return (invocation, Writer(runtimeClient: self))
158-
159-
case .waitingForNextInvocation,
160-
.waitingForResponse,
161-
.sendingResponse,
162-
.sentResponse:
163-
fatalError("Invalid state: \(self.lambdaState)")
166+
} onCancel: {
167+
Task {
168+
await self.close()
169+
}
164170
}
165171
}
166172

@@ -819,6 +825,12 @@ extension LambdaChannelHandler: ChannelInboundHandler {
819825

820826
func channelInactive(context: ChannelHandlerContext) {
821827
// fail any pending responses with last error or assume peer disconnected
828+
switch self.state {
829+
case .connected(_, .waitingForNextInvocation(let continuation)):
830+
continuation.resume(throwing: self.lastError ?? ChannelError.ioOnClosedChannel)
831+
default:
832+
break
833+
}
822834

823835
// we don't need to forward channelInactive to the delegate, as the delegate observes the
824836
// closeFuture

Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift

+52
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,56 @@ struct LambdaRuntimeClientTests {
8686
}
8787
}
8888
}
89+
90+
@Test
91+
func testCancellation() async throws {
92+
struct HappyBehavior: LambdaServerBehavior {
93+
let requestId = UUID().uuidString
94+
let event = "hello"
95+
96+
func getInvocation() -> GetInvocationResult {
97+
.success((self.requestId, self.event))
98+
}
99+
100+
func processResponse(requestId: String, response: String?) -> Result<Void, ProcessResponseError> {
101+
#expect(self.requestId == requestId)
102+
#expect(self.event == response)
103+
return .success(())
104+
}
105+
106+
func processError(requestId: String, error: ErrorResponse) -> Result<Void, ProcessErrorError> {
107+
Issue.record("should not report error")
108+
return .failure(.internalServerError)
109+
}
110+
111+
func processInitError(error: ErrorResponse) -> Result<Void, ProcessErrorError> {
112+
Issue.record("should not report init error")
113+
return .failure(.internalServerError)
114+
}
115+
}
116+
117+
try await withMockServer(behaviour: HappyBehavior()) { port in
118+
try await LambdaRuntimeClient.withRuntimeClient(
119+
configuration: .init(ip: "127.0.0.1", port: port),
120+
eventLoop: NIOSingletons.posixEventLoopGroup.next(),
121+
logger: self.logger
122+
) { runtimeClient in
123+
try await withThrowingTaskGroup(of: Void.self) { group in
124+
group.addTask {
125+
while true {
126+
let (_, writer) = try await runtimeClient.nextInvocation()
127+
// Wrap this is a task so cancellation isn't propagated to the write calls
128+
try await Task {
129+
try await writer.write(ByteBuffer(string: "hello"))
130+
try await writer.finish()
131+
}.value
132+
}
133+
}
134+
// wait a small amount to ensure we are waiting for continuation
135+
try await Task.sleep(for: .milliseconds(100))
136+
group.cancelAll()
137+
}
138+
}
139+
}
140+
}
89141
}

0 commit comments

Comments
 (0)