Skip to content

[core] Add cancellation handling for nextInvocation() #459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions Sources/AWSLambdaRuntimeCore/Lambda.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,31 @@ public enum Lambda {
) async throws where Handler: StreamingLambdaHandler {
var handler = handler

while !Task.isCancelled {
let (invocation, writer) = try await runtimeClient.nextInvocation()
do {
while !Task.isCancelled {
let (invocation, writer) = try await runtimeClient.nextInvocation()

do {
try await handler.handle(
invocation.event,
responseWriter: writer,
context: LambdaContext(
requestID: invocation.metadata.requestID,
traceID: invocation.metadata.traceID,
invokedFunctionARN: invocation.metadata.invokedFunctionARN,
deadline: DispatchWallTime(millisSinceEpoch: invocation.metadata.deadlineInMillisSinceEpoch),
logger: logger
do {
try await handler.handle(
invocation.event,
responseWriter: writer,
context: LambdaContext(
requestID: invocation.metadata.requestID,
traceID: invocation.metadata.traceID,
invokedFunctionARN: invocation.metadata.invokedFunctionARN,
deadline: DispatchWallTime(
millisSinceEpoch: invocation.metadata.deadlineInMillisSinceEpoch
),
logger: logger
)
)
)
} catch {
try await writer.reportError(error)
continue
} catch {
try await writer.reportError(error)
continue
}
}
} catch is CancellationError {
// don't allow cancellation error to propagate further
}
}

Expand Down
40 changes: 26 additions & 14 deletions Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -145,22 +145,28 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
}

func nextInvocation() async throws -> (Invocation, Writer) {
switch self.lambdaState {
case .idle:
self.lambdaState = .waitingForNextInvocation
let handler = try await self.makeOrGetConnection()
let invocation = try await handler.nextInvocation()
guard case .waitingForNextInvocation = self.lambdaState else {
try await withTaskCancellationHandler {
switch self.lambdaState {
case .idle:
self.lambdaState = .waitingForNextInvocation
let handler = try await self.makeOrGetConnection()
let invocation = try await handler.nextInvocation()
guard case .waitingForNextInvocation = self.lambdaState else {
fatalError("Invalid state: \(self.lambdaState)")
}
self.lambdaState = .waitingForResponse(requestID: invocation.metadata.requestID)
return (invocation, Writer(runtimeClient: self))

case .waitingForNextInvocation,
.waitingForResponse,
.sendingResponse,
.sentResponse:
fatalError("Invalid state: \(self.lambdaState)")
}
self.lambdaState = .waitingForResponse(requestID: invocation.metadata.requestID)
return (invocation, Writer(runtimeClient: self))

case .waitingForNextInvocation,
.waitingForResponse,
.sendingResponse,
.sentResponse:
fatalError("Invalid state: \(self.lambdaState)")
} onCancel: {
Task {
await self.close()
}
}
}

Expand Down Expand Up @@ -819,6 +825,12 @@ extension LambdaChannelHandler: ChannelInboundHandler {

func channelInactive(context: ChannelHandlerContext) {
// fail any pending responses with last error or assume peer disconnected
switch self.state {
case .connected(_, .waitingForNextInvocation(let continuation)):
continuation.resume(throwing: self.lastError ?? ChannelError.ioOnClosedChannel)
default:
break
}

// we don't need to forward channelInactive to the delegate, as the delegate observes the
// closeFuture
Expand Down
52 changes: 52 additions & 0 deletions Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,56 @@ struct LambdaRuntimeClientTests {
}
}
}

@Test
func testCancellation() async throws {
struct HappyBehavior: LambdaServerBehavior {
let requestId = UUID().uuidString
let event = "hello"

func getInvocation() -> GetInvocationResult {
.success((self.requestId, self.event))
}

func processResponse(requestId: String, response: String?) -> Result<Void, ProcessResponseError> {
#expect(self.requestId == requestId)
#expect(self.event == response)
return .success(())
}

func processError(requestId: String, error: ErrorResponse) -> Result<Void, ProcessErrorError> {
Issue.record("should not report error")
return .failure(.internalServerError)
}

func processInitError(error: ErrorResponse) -> Result<Void, ProcessErrorError> {
Issue.record("should not report init error")
return .failure(.internalServerError)
}
}

try await withMockServer(behaviour: HappyBehavior()) { port in
try await LambdaRuntimeClient.withRuntimeClient(
configuration: .init(ip: "127.0.0.1", port: port),
eventLoop: NIOSingletons.posixEventLoopGroup.next(),
logger: self.logger
) { runtimeClient in
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
while true {
let (_, writer) = try await runtimeClient.nextInvocation()
// Wrap this is a task so cancellation isn't propagated to the write calls
try await Task {
try await writer.write(ByteBuffer(string: "hello"))
try await writer.finish()
}.value
}
}
// wait a small amount to ensure we are waiting for continuation
try await Task.sleep(for: .milliseconds(100))
group.cancelAll()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should check that the spawned task returns with a cancellation error indeed!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't return cancellation. Currently returns isInClosedChannel. Not sure how to get it to return the cancellation error as I'm just closing the connection and waiting for it to clean up. I could catch the error and then check Task.isCancelled and throw a Cancellation error if it is cancelled but that seems a bit of a hack.

}
}
}
}
}