Skip to content

Commit c271eaf

Browse files
committed
Merge remote-tracking branch 'upstream/cancel-next-invocation' into sebsto/servicelifecycle
2 parents d1247fa + 7d5257c commit c271eaf

File tree

3 files changed

+84
-26
lines changed

3 files changed

+84
-26
lines changed

Sources/AWSLambdaRuntimeCore/Lambda.swift

+22-22
Original file line numberDiff line numberDiff line change
@@ -47,31 +47,31 @@ public enum Lambda {
4747
) async throws where Handler: StreamingLambdaHandler {
4848
var handler = handler
4949

50-
var cancelled: Bool = Lambda.cancelled.withLock { $0 }
51-
while !Task.isCancelled && !cancelled {
52-
logger.trace("Waiting for next invocation")
53-
let (invocation, writer) = try await runtimeClient.nextInvocation()
50+
do {
51+
while !Task.isCancelled {
52+
let (invocation, writer) = try await runtimeClient.nextInvocation()
5453

55-
logger.trace("Received invocation : \(invocation.metadata.requestID)")
56-
do {
57-
try await handler.handle(
58-
invocation.event,
59-
responseWriter: writer,
60-
context: LambdaContext(
61-
requestID: invocation.metadata.requestID,
62-
traceID: invocation.metadata.traceID,
63-
invokedFunctionARN: invocation.metadata.invokedFunctionARN,
64-
deadline: DispatchWallTime(millisSinceEpoch: invocation.metadata.deadlineInMillisSinceEpoch),
65-
logger: logger
54+
do {
55+
try await handler.handle(
56+
invocation.event,
57+
responseWriter: writer,
58+
context: LambdaContext(
59+
requestID: invocation.metadata.requestID,
60+
traceID: invocation.metadata.traceID,
61+
invokedFunctionARN: invocation.metadata.invokedFunctionARN,
62+
deadline: DispatchWallTime(
63+
millisSinceEpoch: invocation.metadata.deadlineInMillisSinceEpoch
64+
),
65+
logger: logger
66+
)
6667
)
67-
)
68-
} catch {
69-
try await writer.reportError(error)
70-
continue
68+
} catch {
69+
try await writer.reportError(error)
70+
continue
71+
}
7172
}
72-
73-
logger.trace("Completed invocation : \(invocation.metadata.requestID)")
74-
cancelled = Lambda.cancelled.withLock { $0 }
73+
} catch is CancellationError {
74+
// don't allow cancellation error to propagate further
7575
}
7676
logger.trace("Lambda runLoop() \(cancelled ? "cancelled" : "completed")")
7777
}

Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift

+42-4
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,17 @@ private protocol LambdaChannelHandlerDelegate {
406406
func connectionErrorHappened(_ error: any Error, channel: any Channel)
407407
}
408408

409+
struct UnsafeContext: @unchecked Sendable {
410+
private let _context: ChannelHandlerContext
411+
var context: ChannelHandlerContext {
412+
self._context.eventLoop.preconditionInEventLoop()
413+
return _context
414+
}
415+
init(_ context: ChannelHandlerContext) {
416+
self._context = context
417+
}
418+
}
419+
409420
private final class LambdaChannelHandler<Delegate: LambdaChannelHandlerDelegate> {
410421
let nextInvocationPath = Consts.invocationURLPrefix + Consts.getNextInvocationURLSuffix
411422

@@ -465,10 +476,37 @@ private final class LambdaChannelHandler<Delegate: LambdaChannelHandlerDelegate>
465476
func nextInvocation(isolation: isolated (any Actor)? = #isolation) async throws -> Invocation {
466477
switch self.state {
467478
case .connected(let context, .idle):
468-
return try await withCheckedThrowingContinuation {
469-
(continuation: CheckedContinuation<Invocation, any Error>) in
470-
self.state = .connected(context, .waitingForNextInvocation(continuation))
471-
self.sendNextRequest(context: context)
479+
return try await withTaskCancellationHandler {
480+
try Task.checkCancellation()
481+
return try await withCheckedThrowingContinuation {
482+
(continuation: CheckedContinuation<Invocation, any Error>) in
483+
self.state = .connected(context, .waitingForNextInvocation(continuation))
484+
485+
let unsafeContext = UnsafeContext(context)
486+
context.eventLoop.execute { [nextInvocationPath, defaultHeaders] in
487+
// Send next request. The function `sendNextRequest` requires `self` which is not
488+
// Sendable so just inlined the code instead
489+
let httpRequest = HTTPRequestHead(
490+
version: .http1_1,
491+
method: .GET,
492+
uri: nextInvocationPath,
493+
headers: defaultHeaders
494+
)
495+
let context = unsafeContext.context
496+
context.write(Self.wrapOutboundOut(.head(httpRequest)), promise: nil)
497+
context.write(Self.wrapOutboundOut(.end(nil)), promise: nil)
498+
context.flush()
499+
}
500+
}
501+
} onCancel: {
502+
switch self.state {
503+
case .connected(_, .waitingForNextInvocation(let continuation)):
504+
continuation.resume(throwing: CancellationError())
505+
case .connected(_, .idle):
506+
break
507+
default:
508+
fatalError("Invalid state: \(self.state)")
509+
}
472510
}
473511

474512
case .connected(_, .sendingResponse),

Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift

+20
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,24 @@ struct LambdaRuntimeClientTests {
8686
}
8787
}
8888
}
89+
90+
@Test
91+
func testCancellation() async throws {
92+
try await LambdaRuntimeClient.withRuntimeClient(
93+
configuration: .init(ip: "127.0.0.1", port: 7000),
94+
eventLoop: NIOSingletons.posixEventLoopGroup.next(),
95+
logger: self.logger
96+
) { runtimeClient in
97+
try await withThrowingTaskGroup(of: Void.self) { group in
98+
group.addTask {
99+
while true {
100+
_ = try await runtimeClient.nextInvocation()
101+
}
102+
}
103+
// wait a small amount to ensure we are waiting for continuation
104+
try await Task.sleep(for: .milliseconds(100))
105+
group.cancelAll()
106+
}
107+
}
108+
}
89109
}

0 commit comments

Comments
 (0)