Skip to content

Commit 82719ed

Browse files
committed
introduce new PreparedRequest to make AsyncRequestBag.init non-throwing
1 parent 962e12d commit 82719ed

File tree

4 files changed

+97
-90
lines changed

4 files changed

+97
-90
lines changed

Sources/AsyncHTTPClient/AsyncAwait/AsyncRequest+Validation.swift

+10-8
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,18 @@
1515
import struct Foundation.URL
1616
import NIOHTTP1
1717

18+
@available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *)
19+
struct PreparedRequest {
20+
let poolKey: ConnectionPool.Key
21+
let requestFramingMetadata: RequestFramingMetadata
22+
let head: HTTPRequestHead
23+
let body: HTTPClientRequest.Body?
24+
}
25+
1826
@available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *)
1927
extension HTTPClientRequest {
2028

21-
struct ValidationResult {
22-
let requestFramingMetadata: RequestFramingMetadata
23-
let poolKey: ConnectionPool.Key
24-
let head: HTTPRequestHead
25-
}
26-
27-
func validate() throws -> ValidationResult {
29+
func prepareForExecution() throws -> PreparedRequest {
2830

2931
guard let url = URL(string: self.url) else {
3032
throw HTTPClientError.invalidURL
@@ -58,7 +60,7 @@ extension HTTPClientRequest {
5860
head.headers.add(name: "host", value: urlHost)
5961
}
6062

61-
return .init(requestFramingMetadata: metadata, poolKey: poolKey, head: head)
63+
return .init(poolKey: poolKey, requestFramingMetadata: metadata, head: head, body: body)
6264
}
6365
}
6466

Sources/AsyncHTTPClient/AsyncAwait/AsyncRequestBag.swift

+10-12
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ import NIOHTTP1
2020

2121
@available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *)
2222
class AsyncRequestBag {
23-
// TODO: We should drop the request after sending to free up resource ASAP
24-
let request: HTTPClientRequest
23+
2524

2625
let logger: Logger
27-
26+
// TODO: We should drop the request after sending to free up resource ASAP
2827
let requestHead: HTTPRequestHead
28+
let requestBody: HTTPClientRequest.Body?
2929
let requestOptions: RequestOptions
3030
let requestFramingMetadata: RequestFramingMetadata
3131

@@ -37,24 +37,22 @@ class AsyncRequestBag {
3737
private var state: StateMachine = .init()
3838

3939
init(
40-
request: HTTPClientRequest,
40+
request: PreparedRequest,
4141
requestOptions: RequestOptions,
4242
logger: Logger,
4343
connectionDeadline: NIODeadline,
4444
preferredEventLoop: EventLoop,
4545
responseContinuation: UnsafeContinuation<HTTPClientResponse, Error>
46-
) throws {
47-
self.request = request
46+
) {
47+
self.poolKey = request.poolKey
48+
self.requestHead = request.head
49+
self.requestBody = request.body
50+
self.requestFramingMetadata = request.requestFramingMetadata
4851
self.requestOptions = requestOptions
4952
self.logger = logger
5053
self.connectionDeadline = connectionDeadline
5154
self.preferredEventLoop = preferredEventLoop
5255

53-
let validatedRequest = try request.validate()
54-
self.poolKey = validatedRequest.poolKey
55-
self.requestHead = validatedRequest.head
56-
self.requestFramingMetadata = validatedRequest.requestFramingMetadata
57-
5856
self.state.registerContinuation(responseContinuation)
5957
}
6058

@@ -121,7 +119,7 @@ class AsyncRequestBag {
121119
case .none:
122120
break
123121
case .resumeStream(let allocator):
124-
switch self.request.body?.mode {
122+
switch self.requestBody?.mode {
125123
case .asyncSequence(_, let next):
126124
// it is safe to call this async here. it dispatches...
127125
self.continueRequestBodyStream(allocator, next: next)

Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift

+3-3
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ extension HTTPClient {
226226
}
227227
}
228228
}
229+
let preparedRequest = try request.prepareForExecution()
229230

230231
let cancelHandler = SwiftCancellationHandlingSucksAsFuck()
231232

@@ -234,9 +235,8 @@ extension HTTPClient {
234235
return try await withTaskCancellationHandler(operation: { () async throws -> HTTPClientResponse in
235236
try await withUnsafeThrowingContinuation{
236237
(continuation: UnsafeContinuation<HTTPClientResponse, Swift.Error>) -> Void in
237-
238-
let bag = try! AsyncRequestBag(
239-
request: request,
238+
let bag = AsyncRequestBag(
239+
request: preparedRequest,
240240
requestOptions: .init(idleReadTimeout: nil),
241241
logger: logger,
242242
connectionDeadline: .now() + .seconds(10),

Tests/AsyncHTTPClientTests/AsyncRequestTests.swift

+74-67
Original file line numberDiff line numberDiff line change
@@ -21,38 +21,28 @@ import XCTest
2121
@available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *)
2222
final class AsyncRequestTests: XCTestCase {
2323
func testCancelAsyncRequest() async {
24-
let logger = Logger(label: "test")
2524
let embeddedEventLoop = EmbeddedEventLoop()
2625
defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) }
2726

2827
var request = HTTPClientRequest(url: "https://localhost/")
2928
request.method = .GET
30-
31-
var maybeRequestBag: AsyncRequestBag?
32-
var requestBagPromise = embeddedEventLoop.makePromise(of: AsyncRequestBag.self)
33-
async let result = withUnsafeContinuation { (continuation: UnsafeContinuation<HTTPClientResponse, Error>) in
34-
let requestBag = try AsyncRequestBag(
35-
request: request,
36-
requestOptions: .forTests(),
37-
logger: logger,
38-
connectionDeadline: .distantFuture,
39-
preferredEventLoop: embeddedEventLoop
40-
)
41-
requestBagPromise.succeed(requestBag)
29+
var maybePreparedRequest: PreparedRequest?
30+
XCTAssertNoThrow(maybePreparedRequest = try request.prepareForExecution())
31+
guard let preparedRequest = maybePreparedRequest else {
32+
return
4233
}
43-
44-
45-
46-
// guard let requestBag = maybeRequestBag else { return XCTFail("unexpectedly found nil") }
47-
34+
let (requestBag, responseTask) = AsyncRequestBag.makeWithResultTask(
35+
request: preparedRequest,
36+
preferredEventLoop: embeddedEventLoop
37+
)
4838

4939
Task.detached {
5040
try await Task.sleep(nanoseconds: 5 * 1000 * 1000)
5141
requestBag.cancel()
5242
}
5343

5444
do {
55-
_ = try await result
45+
_ = try await responseTask.result.get()
5646
XCTFail("Expected to throw error")
5747
} catch {
5848
XCTAssertEqual(error as? HTTPClientError, .cancelled)
@@ -62,24 +52,19 @@ final class AsyncRequestTests: XCTestCase {
6252
func testResponseStreamingWorks() async {
6353
let embeddedEventLoop = EmbeddedEventLoop()
6454
defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) }
65-
let logger = Logger(label: "test")
6655

6756
var request = HTTPClientRequest(url: "https://localhost/")
6857
request.method = .GET
69-
70-
var maybeRequestBag: AsyncRequestBag?
71-
XCTAssertNoThrow(maybeRequestBag = try AsyncRequestBag(
72-
request: request,
73-
requestOptions: .forTests(),
74-
logger: logger,
75-
connectionDeadline: .distantFuture,
76-
preferredEventLoop: embeddedEventLoop
77-
))
78-
guard let requestBag = maybeRequestBag else { return XCTFail("unexpectedly found nil") }
7958

80-
81-
async let awaitableResponse = requestBag.result()
82-
await Task.yield()
59+
var maybePreparedRequest: PreparedRequest?
60+
XCTAssertNoThrow(maybePreparedRequest = try request.prepareForExecution())
61+
guard let preparedRequest = maybePreparedRequest else {
62+
return
63+
}
64+
let (requestBag, responseTask) = AsyncRequestBag.makeWithResultTask(
65+
request: preparedRequest,
66+
preferredEventLoop: embeddedEventLoop
67+
)
8368

8469
let executor = MockRequestExecutor(
8570
pauseRequestBodyPartStreamAfterASingleWrite: true,
@@ -94,7 +79,7 @@ final class AsyncRequestTests: XCTestCase {
9479
requestBag.receiveResponseHead(responseHead)
9580

9681
do {
97-
let response = try await awaitableResponse
82+
let response = try await responseTask.result.get()
9883
XCTAssertEqual(response.status, responseHead.status)
9984
XCTAssertEqual(response.headers, responseHead.headers)
10085
XCTAssertEqual(response.version, responseHead.version)
@@ -133,7 +118,6 @@ final class AsyncRequestTests: XCTestCase {
133118
func testWriteBackpressureWorks() async {
134119
let embeddedEventLoop = EmbeddedEventLoop()
135120
defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) }
136-
let logger = Logger(label: "test")
137121

138122
let streamWriter = AsyncSequenceWriter()
139123
if await streamWriter.hasDemand { XCTFail("Did not expect to have a demand at this point") }
@@ -142,20 +126,15 @@ final class AsyncRequestTests: XCTestCase {
142126
request.method = .POST
143127
request.body = .stream(streamWriter)
144128

145-
var maybeRequestBag: AsyncRequestBag?
146-
XCTAssertNoThrow(maybeRequestBag = try AsyncRequestBag(
147-
request: request,
148-
requestOptions: .forTests(),
149-
logger: logger,
150-
connectionDeadline: .distantFuture,
129+
var maybePreparedRequest: PreparedRequest?
130+
XCTAssertNoThrow(maybePreparedRequest = try request.prepareForExecution())
131+
guard let preparedRequest = maybePreparedRequest else {
132+
return
133+
}
134+
let (requestBag, responseTask) = AsyncRequestBag.makeWithResultTask(
135+
request: preparedRequest,
151136
preferredEventLoop: embeddedEventLoop
152-
))
153-
guard let requestBag = maybeRequestBag else { return XCTFail("unexpectedly found nil") }
154-
155-
async let awaitableResponse = requestBag.result()
156-
157-
// we need to yield here, to ensure the continuation can be build up
158-
await Task.yield()
137+
)
159138

160139
let executor = MockRequestExecutor(eventLoop: embeddedEventLoop)
161140

@@ -201,7 +180,7 @@ final class AsyncRequestTests: XCTestCase {
201180
XCTAssertFalse(executor.signalledDemandForResponseBody)
202181
requestBag.receiveResponseHead(responseHead)
203182

204-
let response = try await awaitableResponse
183+
let response = try await responseTask.result.get()
205184
XCTAssertEqual(response.status, responseHead.status)
206185
XCTAssertEqual(response.headers, responseHead.headers)
207186
XCTAssertEqual(response.version, responseHead.version)
@@ -245,22 +224,21 @@ final class AsyncRequestTests: XCTestCase {
245224
var request = HTTPClientRequest(url: "https://localhost:\(httpBin.port)/")
246225
request.headers = ["host": "localhost:\(httpBin.port)"]
247226

248-
let requestBag = try AsyncRequestBag(
249-
request: request,
250-
requestOptions: .forTests(),
251-
logger: Logger(label: "test"),
252-
connectionDeadline: .distantFuture,
227+
var maybePreparedRequest: PreparedRequest?
228+
XCTAssertNoThrow(maybePreparedRequest = try request.prepareForExecution())
229+
guard let preparedRequest = maybePreparedRequest else {
230+
return
231+
}
232+
let (requestBag, responseTask) = AsyncRequestBag.makeWithResultTask(
233+
request: preparedRequest,
253234
preferredEventLoop: eventLoopGroup.next()
254235
)
255236

256-
async let awaitableResponse = requestBag.result()
257-
await Task.yield()
258-
259237
http2Connection.executeRequest(requestBag)
260238

261239
XCTAssertEqual(delegate.hitStreamClosed, 0)
262240

263-
let response = try await awaitableResponse
241+
let response = try await responseTask.result.get()
264242

265243
XCTAssertEqual(response.status, .ok)
266244
XCTAssertEqual(response.version, .http2)
@@ -306,22 +284,21 @@ final class AsyncRequestTests: XCTestCase {
306284
request.headers = ["host": "localhost:\(httpBin.port)"]
307285
request.body = .stream(length: 800, streamWriter)
308286

309-
let requestBag = try AsyncRequestBag(
310-
request: request,
311-
requestOptions: .forTests(),
312-
logger: Logger(label: "test"),
313-
connectionDeadline: .distantFuture,
287+
var maybePreparedRequest: PreparedRequest?
288+
XCTAssertNoThrow(maybePreparedRequest = try request.prepareForExecution())
289+
guard let preparedRequest = maybePreparedRequest else {
290+
return
291+
}
292+
let (requestBag, responseTask) = AsyncRequestBag.makeWithResultTask(
293+
request: preparedRequest,
314294
preferredEventLoop: eventLoopGroup.next()
315295
)
316296

317-
async let awaitableResponse = requestBag.result()
318-
await Task.yield() // yield is used here to ensure register continuation is executed here.
319-
320297
http2Connection.executeRequest(requestBag)
321298

322299
XCTAssertEqual(delegate.hitStreamClosed, 0)
323300

324-
let response = try await awaitableResponse
301+
let response = try await responseTask.result.get()
325302

326303
XCTAssertEqual(response.status, .ok)
327304
XCTAssertEqual(response.version, .http2)
@@ -506,3 +483,33 @@ actor AsyncSequenceWriter: AsyncSequence {
506483
}
507484
}
508485
}
486+
487+
@available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *)
488+
extension AsyncRequestBag {
489+
fileprivate static func makeWithResultTask(
490+
request: PreparedRequest,
491+
requestOptions: RequestOptions = .forTests(),
492+
logger: Logger = Logger(label: "test"),
493+
connectionDeadline: NIODeadline = .distantFuture,
494+
preferredEventLoop: EventLoop
495+
) -> (AsyncRequestBag, _Concurrency.Task<HTTPClientResponse, Error>) {
496+
let requestBagPromise = preferredEventLoop.makePromise(of: AsyncRequestBag.self)
497+
let result = Task {
498+
try await withUnsafeThrowingContinuation { (continuation: UnsafeContinuation<HTTPClientResponse, Error>) in
499+
let requestBag = AsyncRequestBag(
500+
request: request,
501+
requestOptions: requestOptions,
502+
logger: logger,
503+
connectionDeadline: connectionDeadline,
504+
preferredEventLoop: preferredEventLoop,
505+
responseContinuation: continuation
506+
)
507+
requestBagPromise.succeed(requestBag)
508+
}
509+
}
510+
// the promise can never fail and it is therefore safe to force unwrap
511+
let requestBag = try! requestBagPromise.futureResult.wait()
512+
513+
return (requestBag, result)
514+
}
515+
}

0 commit comments

Comments
 (0)