Skip to content

Prepare async/await API for public release #531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 14, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,11 @@ extension RequestBodyLength {
init(_ body: HTTPClientRequest.Body?) {
switch body?.mode {
case .none:
self = .fixed(length: 0)
self = .known(0)
case .byteBuffer(let buffer):
self = .fixed(length: buffer.readableBytes)
case .sequence(nil, _, _), .asyncSequence(nil, _):
self = .dynamic
case .sequence(.some(let length), _, _), .asyncSequence(.some(let length), _):
self = .fixed(length: length)
self = .known(buffer.readableBytes)
case .sequence(let length, _, _), .asyncSequence(let length, _):
self = length
}
}
}
Expand Down
69 changes: 48 additions & 21 deletions Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,32 +35,37 @@ struct HTTPClientRequest {
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
extension HTTPClientRequest {
struct Body {
@usableFromInline
internal enum Mode {
case asyncSequence(length: Int?, (ByteBufferAllocator) async throws -> ByteBuffer?)
case sequence(length: Int?, canBeConsumedMultipleTimes: Bool, (ByteBufferAllocator) -> ByteBuffer)
case asyncSequence(length: RequestBodyLength, (ByteBufferAllocator) async throws -> ByteBuffer?)
case sequence(length: RequestBodyLength, canBeConsumedMultipleTimes: Bool, (ByteBufferAllocator) -> ByteBuffer)
case byteBuffer(ByteBuffer)
}

var mode: Mode
@usableFromInline
internal var mode: Mode

private init(_ mode: Mode) {
@inlinable
internal init(_ mode: Mode) {
self.mode = mode
}
}
}

@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
extension HTTPClientRequest.Body {
static func byteBuffer(_ byteBuffer: ByteBuffer) -> Self {
static func bytes(_ byteBuffer: ByteBuffer) -> Self {
self.init(.byteBuffer(byteBuffer))
}

@inlinable
static func bytes<Bytes: Sequence>(
length: Int?,
static func bytes<Bytes: RandomAccessCollection>(
_ bytes: Bytes
) -> Self where Bytes.Element == UInt8 {
self.init(.sequence(length: length, canBeConsumedMultipleTimes: false) { allocator in
self.init(.sequence(
length: .known(bytes.count),
canBeConsumedMultipleTimes: true
) { allocator in
if let buffer = bytes.withContiguousStorageIfAvailable({ allocator.buffer(bytes: $0) }) {
// fastpath
return buffer
Expand All @@ -71,11 +76,14 @@ extension HTTPClientRequest.Body {
}

@inlinable
static func bytes<Bytes: Collection>(
length: Int?,
_ bytes: Bytes
static func bytes<Bytes: Sequence>(
_ bytes: Bytes,
length: Length
) -> Self where Bytes.Element == UInt8 {
self.init(.sequence(length: length, canBeConsumedMultipleTimes: true) { allocator in
self.init(.sequence(
length: length.storage,
canBeConsumedMultipleTimes: false
) { allocator in
if let buffer = bytes.withContiguousStorageIfAvailable({ allocator.buffer(bytes: $0) }) {
// fastpath
return buffer
Expand All @@ -86,10 +94,14 @@ extension HTTPClientRequest.Body {
}

@inlinable
static func bytes<Bytes: RandomAccessCollection>(
_ bytes: Bytes
static func bytes<Bytes: Collection>(
_ bytes: Bytes,
length: Length
) -> Self where Bytes.Element == UInt8 {
self.init(.sequence(length: bytes.count, canBeConsumedMultipleTimes: true) { allocator in
self.init(.sequence(
length: length.storage,
canBeConsumedMultipleTimes: true
) { allocator in
if let buffer = bytes.withContiguousStorageIfAvailable({ allocator.buffer(bytes: $0) }) {
// fastpath
return buffer
Expand All @@ -101,23 +113,23 @@ extension HTTPClientRequest.Body {

@inlinable
static func stream<SequenceOfBytes: AsyncSequence>(
length: Int?,
_ sequenceOfBytes: SequenceOfBytes
_ sequenceOfBytes: SequenceOfBytes,
length: Length
) -> Self where SequenceOfBytes.Element == ByteBuffer {
var iterator = sequenceOfBytes.makeAsyncIterator()
let body = self.init(.asyncSequence(length: length) { _ -> ByteBuffer? in
let body = self.init(.asyncSequence(length: length.storage) { _ -> ByteBuffer? in
try await iterator.next()
})
return body
}

@inlinable
static func stream<Bytes: AsyncSequence>(
length: Int?,
_ bytes: Bytes
_ bytes: Bytes,
length: Length
) -> Self where Bytes.Element == UInt8 {
var iterator = bytes.makeAsyncIterator()
let body = self.init(.asyncSequence(length: length) { allocator -> ByteBuffer? in
let body = self.init(.asyncSequence(length: length.storage) { allocator -> ByteBuffer? in
var buffer = allocator.buffer(capacity: 1024) // TODO: Magic number
while buffer.writableBytes > 0, let byte = try await iterator.next() {
buffer.writeInteger(byte)
Expand All @@ -143,4 +155,19 @@ extension Optional where Wrapped == HTTPClientRequest.Body {
}
}

@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
extension HTTPClientRequest.Body {
struct Length {
/// size of the request body is not known before starting the request
static let unknown: Self = .init(storage: .unknown)
/// size of the request body is fixed and exactly `count` bytes
static func known(_ count: Int) -> Self {
.init(storage: .known(count))
}

@usableFromInline
internal var storage: RequestBodyLength
}
}

#endif
13 changes: 5 additions & 8 deletions Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,25 +48,22 @@ struct HTTPClientResponse {

@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
extension HTTPClientResponse.Body: AsyncSequence {
typealias Element = ByteBuffer
typealias AsyncIterator = Iterator

struct Iterator: AsyncIteratorProtocol {
typealias Element = ByteBuffer
typealias Element = AsyncIterator.Element

struct AsyncIterator: AsyncIteratorProtocol {
private let stream: IteratorStream

fileprivate init(stream: IteratorStream) {
self.stream = stream
}

func next() async throws -> ByteBuffer? {
mutating func next() async throws -> ByteBuffer? {
try await self.stream.next()
}
}

func makeAsyncIterator() -> Iterator {
Iterator(stream: IteratorStream(bag: self.bag))
func makeAsyncIterator() -> AsyncIterator {
AsyncIterator(stream: IteratorStream(bag: self.bag))
}
}

Expand Down
10 changes: 6 additions & 4 deletions Sources/AsyncHTTPClient/ConnectionPool/RequestBodyLength.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
//
//===----------------------------------------------------------------------===//

enum RequestBodyLength: Hashable {
/// - Note: use `HTTPClientRequest.Body.Length` if you want to expose `RequestBodyLength` publicly
@usableFromInline
internal enum RequestBodyLength: Hashable {
/// size of the request body is not known before starting the request
case dynamic
/// size of the request body is fixed and exactly `length` bytes
case fixed(length: Int)
case unknown
/// size of the request body is fixed and exactly `count` bytes
case known(_ count: Int)
}
6 changes: 3 additions & 3 deletions Sources/AsyncHTTPClient/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -690,13 +690,13 @@ internal struct RedirectHandler<ResponseType> {
extension RequestBodyLength {
init(_ body: HTTPClient.Body?) {
guard let body = body else {
self = .fixed(length: 0)
self = .known(0)
return
}
guard let length = body.length else {
self = .dynamic
self = .unknown
return
}
self = .fixed(length: length)
self = .known(length)
}
}
14 changes: 7 additions & 7 deletions Sources/AsyncHTTPClient/RequestValidation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ extension HTTPHeaders {

if case .TRACE = method {
switch bodyLength {
case .fixed(length: 0):
case .known(0):
break
case .dynamic, .fixed:
case .unknown, .known:
// A client MUST NOT send a message body in a TRACE request.
// https://tools.ietf.org/html/rfc7230#section-4.3.8
throw HTTPClientError.traceRequestWithBody
Expand All @@ -37,9 +37,9 @@ extension HTTPHeaders {

let connectionClose = self[canonicalForm: "connection"].lazy.map { $0.lowercased() }.contains("close")
switch bodyLength {
case .dynamic:
case .unknown:
return .init(connectionClose: connectionClose, body: .stream)
case .fixed(let length):
case .known(let length):
return .init(connectionClose: connectionClose, body: .fixedSize(length))
}
}
Expand Down Expand Up @@ -88,7 +88,7 @@ extension HTTPHeaders {
self.remove(name: "Transfer-Encoding")

switch bodyLength {
case .fixed(0):
case .known(0):
// if we don't have a body we might not need to send the Content-Length field
// https://tools.ietf.org/html/rfc7230#section-3.3.2
switch method {
Expand All @@ -103,9 +103,9 @@ extension HTTPHeaders {
// for an enclosed payload body.
self.add(name: "Content-Length", value: "0")
}
case .fixed(let length):
case .known(let length):
self.add(name: "Content-Length", value: String(length))
case .dynamic:
case .unknown:
self.add(name: "Transfer-Encoding", value: "chunked")
}
}
Expand Down
38 changes: 23 additions & 15 deletions Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:))
var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/")
request.method = .POST
request.body = .byteBuffer(ByteBuffer(string: "1234"))
request.body = .bytes(ByteBuffer(string: "1234"))

guard let response = await XCTAssertNoThrowWithResult(
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)
Expand All @@ -115,7 +115,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:))
var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/")
request.method = .POST
request.body = .bytes(length: nil, AnySequence("1234".utf8))
request.body = .bytes(AnySequence("1234".utf8), length: .unknown)

guard let response = await XCTAssertNoThrowWithResult(
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)
Expand All @@ -140,7 +140,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:))
var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/")
request.method = .POST
request.body = .bytes(length: nil, AnyCollection("1234".utf8))
request.body = .bytes(AnyCollection("1234".utf8), length: .unknown)

guard let response = await XCTAssertNoThrowWithResult(
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)
Expand Down Expand Up @@ -190,11 +190,11 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:))
var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/")
request.method = .POST
request.body = .stream(length: nil, [
request.body = .stream([
ByteBuffer(string: "1"),
ByteBuffer(string: "2"),
ByteBuffer(string: "34"),
].asAsyncSequence())
].asAsyncSequence(), length: .unknown)

guard let response = await XCTAssertNoThrowWithResult(
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)
Expand All @@ -219,7 +219,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:))
var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/")
request.method = .POST
request.body = .stream(length: nil, "1234".utf8.asAsyncSequence())
request.body = .stream("1234".utf8.asAsyncSequence(), length: .unknown)

guard let response = await XCTAssertNoThrowWithResult(
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)
Expand All @@ -245,7 +245,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/")
request.method = .POST
let streamWriter = AsyncSequenceWriter<ByteBuffer>()
request.body = .stream(length: nil, streamWriter)
request.body = .stream(streamWriter, length: .unknown)

guard let response = await XCTAssertNoThrowWithResult(
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)
Expand All @@ -257,7 +257,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
ByteBuffer(string: "2"),
ByteBuffer(string: "34"),
]
let bodyIterator = response.body.makeAsyncIterator()
var bodyIterator = response.body.makeAsyncIterator()
for expectedFragment in fragments {
streamWriter.write(expectedFragment)
guard let actualFragment = await XCTAssertNoThrowWithResult(
Expand Down Expand Up @@ -287,7 +287,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/")
request.method = .POST
let streamWriter = AsyncSequenceWriter<ByteBuffer>()
request.body = .stream(length: nil, streamWriter)
request.body = .stream(streamWriter, length: .unknown)

guard let response = await XCTAssertNoThrowWithResult(
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)
Expand All @@ -300,7 +300,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
ByteBuffer(string: String(repeating: "c", count: 4000)),
ByteBuffer(string: String(repeating: "d", count: 4000)),
]
let bodyIterator = response.body.makeAsyncIterator()
var bodyIterator = response.body.makeAsyncIterator()
for expectedFragment in fragments {
streamWriter.write(expectedFragment)
guard let actualFragment = await XCTAssertNoThrowWithResult(
Expand Down Expand Up @@ -330,7 +330,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
var request = HTTPClientRequest(url: "http://localhost:\(bin.port)/offline")
request.method = .POST
let streamWriter = AsyncSequenceWriter<ByteBuffer>()
request.body = .stream(length: nil, streamWriter)
request.body = .stream(streamWriter, length: .unknown)

let task = Task<HTTPClientResponse, Error> { [request] in
try await client.execute(request, deadline: .now() + .seconds(2), logger: logger)
Expand All @@ -357,8 +357,12 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
let task = Task<HTTPClientResponse, Error> { [request] in
try await client.execute(request, deadline: .now() + .milliseconds(100), logger: logger)
}
await XCTAssertThrowsError(try await task.value) {
XCTAssertEqual($0 as? HTTPClientError, HTTPClientError.deadlineExceeded)
await XCTAssertThrowsError(try await task.value) { error in
guard let error = error as? HTTPClientError else {
return XCTFail("unexpected error \(error)")
}
// a race between deadline and connect timer can result in either error
XCTAssertTrue([.deadlineExceeded, .connectTimeout].contains(error))
}
}
#endif
Expand All @@ -378,8 +382,12 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
let task = Task<HTTPClientResponse, Error> { [request] in
try await client.execute(request, deadline: .now(), logger: logger)
}
await XCTAssertThrowsError(try await task.value) {
XCTAssertEqual($0 as? HTTPClientError, HTTPClientError.deadlineExceeded)
await XCTAssertThrowsError(try await task.value) { error in
guard let error = error as? HTTPClientError else {
return XCTFail("unexpected error \(error)")
}
// a race between deadline and connect timer can result in either error
XCTAssertTrue([.deadlineExceeded, .connectTimeout].contains(error))
}
}
#endif
Expand Down
Loading