Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,11 @@ extension RequestBodyLength {
init(_ body: HTTPClientRequest.Body?) {
switch body?.mode {
case .none:
self = .fixed(length: 0)
self = .known(0)
case .byteBuffer(let buffer):
self = .fixed(length: buffer.readableBytes)
case .sequence(nil, _, _), .asyncSequence(nil, _):
self = .dynamic
case .sequence(.some(let length), _, _), .asyncSequence(.some(let length), _):
self = .fixed(length: length)
self = .known(buffer.readableBytes)
case .sequence(let length, _, _), .asyncSequence(let length, _):
self = length
}
}
}
Expand Down
69 changes: 48 additions & 21 deletions Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,32 +35,37 @@ struct HTTPClientRequest {
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
extension HTTPClientRequest {
struct Body {
@usableFromInline
internal enum Mode {
case asyncSequence(length: Int?, (ByteBufferAllocator) async throws -> ByteBuffer?)
case sequence(length: Int?, canBeConsumedMultipleTimes: Bool, (ByteBufferAllocator) -> ByteBuffer)
case asyncSequence(length: RequestBodyLength, (ByteBufferAllocator) async throws -> ByteBuffer?)
case sequence(length: RequestBodyLength, canBeConsumedMultipleTimes: Bool, (ByteBufferAllocator) -> ByteBuffer)
case byteBuffer(ByteBuffer)
}

var mode: Mode
@usableFromInline
internal var mode: Mode

private init(_ mode: Mode) {
@inlinable
internal init(_ mode: Mode) {
self.mode = mode
}
}
}

@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
extension HTTPClientRequest.Body {
static func byteBuffer(_ byteBuffer: ByteBuffer) -> Self {
static func bytes(_ byteBuffer: ByteBuffer) -> Self {
self.init(.byteBuffer(byteBuffer))
}

@inlinable
static func bytes<Bytes: Sequence>(
length: Int?,
static func bytes<Bytes: RandomAccessCollection>(
_ bytes: Bytes
) -> Self where Bytes.Element == UInt8 {
self.init(.sequence(length: length, canBeConsumedMultipleTimes: false) { allocator in
self.init(.sequence(
length: .known(bytes.count),
canBeConsumedMultipleTimes: true
) { allocator in
if let buffer = bytes.withContiguousStorageIfAvailable({ allocator.buffer(bytes: $0) }) {
// fastpath
return buffer
Expand All @@ -71,11 +76,14 @@ extension HTTPClientRequest.Body {
}

@inlinable
static func bytes<Bytes: Collection>(
length: Int?,
_ bytes: Bytes
static func bytes<Bytes: Sequence>(
_ bytes: Bytes,
length: Length
) -> Self where Bytes.Element == UInt8 {
self.init(.sequence(length: length, canBeConsumedMultipleTimes: true) { allocator in
self.init(.sequence(
length: length.storage,
canBeConsumedMultipleTimes: false
) { allocator in
if let buffer = bytes.withContiguousStorageIfAvailable({ allocator.buffer(bytes: $0) }) {
// fastpath
return buffer
Expand All @@ -86,10 +94,14 @@ extension HTTPClientRequest.Body {
}

@inlinable
static func bytes<Bytes: RandomAccessCollection>(
_ bytes: Bytes
static func bytes<Bytes: Collection>(
_ bytes: Bytes,
length: Length
) -> Self where Bytes.Element == UInt8 {
self.init(.sequence(length: bytes.count, canBeConsumedMultipleTimes: true) { allocator in
self.init(.sequence(
length: length.storage,
canBeConsumedMultipleTimes: true
) { allocator in
if let buffer = bytes.withContiguousStorageIfAvailable({ allocator.buffer(bytes: $0) }) {
// fastpath
return buffer
Expand All @@ -101,23 +113,23 @@ extension HTTPClientRequest.Body {

@inlinable
static func stream<SequenceOfBytes: AsyncSequence>(
length: Int?,
_ sequenceOfBytes: SequenceOfBytes
_ sequenceOfBytes: SequenceOfBytes,
length: Length
) -> Self where SequenceOfBytes.Element == ByteBuffer {
var iterator = sequenceOfBytes.makeAsyncIterator()
let body = self.init(.asyncSequence(length: length) { _ -> ByteBuffer? in
let body = self.init(.asyncSequence(length: length.storage) { _ -> ByteBuffer? in
try await iterator.next()
})
return body
}

@inlinable
static func stream<Bytes: AsyncSequence>(
length: Int?,
_ bytes: Bytes
_ bytes: Bytes,
length: Length
) -> Self where Bytes.Element == UInt8 {
var iterator = bytes.makeAsyncIterator()
let body = self.init(.asyncSequence(length: length) { allocator -> ByteBuffer? in
let body = self.init(.asyncSequence(length: length.storage) { allocator -> ByteBuffer? in
var buffer = allocator.buffer(capacity: 1024) // TODO: Magic number
while buffer.writableBytes > 0, let byte = try await iterator.next() {
buffer.writeInteger(byte)
Expand All @@ -143,4 +155,19 @@ extension Optional where Wrapped == HTTPClientRequest.Body {
}
}

@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
extension HTTPClientRequest.Body {
struct Length {
/// size of the request body is not known before starting the request
static let unknown: Self = .init(storage: .unknown)
/// size of the request body is fixed and exactly `count` bytes
static func known(_ count: Int) -> Self {
.init(storage: .known(count))
}

@usableFromInline
internal var storage: RequestBodyLength
}
}

#endif
13 changes: 5 additions & 8 deletions Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,25 +48,22 @@ struct HTTPClientResponse {

@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
extension HTTPClientResponse.Body: AsyncSequence {
typealias Element = ByteBuffer
typealias AsyncIterator = Iterator

struct Iterator: AsyncIteratorProtocol {
typealias Element = ByteBuffer
typealias Element = AsyncIterator.Element

struct AsyncIterator: AsyncIteratorProtocol {
private let stream: IteratorStream

fileprivate init(stream: IteratorStream) {
self.stream = stream
}

func next() async throws -> ByteBuffer? {
mutating func next() async throws -> ByteBuffer? {
try await self.stream.next()
}
}

func makeAsyncIterator() -> Iterator {
Iterator(stream: IteratorStream(bag: self.bag))
func makeAsyncIterator() -> AsyncIterator {
AsyncIterator(stream: IteratorStream(bag: self.bag))
}
}

Expand Down
10 changes: 6 additions & 4 deletions Sources/AsyncHTTPClient/ConnectionPool/RequestBodyLength.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
//
//===----------------------------------------------------------------------===//

enum RequestBodyLength: Hashable {
/// - Note: use `HTTPClientRequest.Body.Length` if you want to expose `RequestBodyLength` publicly
@usableFromInline
internal enum RequestBodyLength: Hashable {
/// size of the request body is not known before starting the request
case dynamic
/// size of the request body is fixed and exactly `length` bytes
case fixed(length: Int)
case unknown
/// size of the request body is fixed and exactly `count` bytes
case known(_ count: Int)
}
6 changes: 3 additions & 3 deletions Sources/AsyncHTTPClient/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -690,13 +690,13 @@ internal struct RedirectHandler<ResponseType> {
extension RequestBodyLength {
init(_ body: HTTPClient.Body?) {
guard let body = body else {
self = .fixed(length: 0)
self = .known(0)
return
}
guard let length = body.length else {
self = .dynamic
self = .unknown
return
}
self = .fixed(length: length)
self = .known(length)
}
}
14 changes: 7 additions & 7 deletions Sources/AsyncHTTPClient/RequestValidation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ extension HTTPHeaders {

if case .TRACE = method {
switch bodyLength {
case .fixed(length: 0):
case .known(0):
break
case .dynamic, .fixed:
case .unknown, .known:
// A client MUST NOT send a message body in a TRACE request.
// https://tools.ietf.org/html/rfc7230#section-4.3.8
throw HTTPClientError.traceRequestWithBody
Expand All @@ -37,9 +37,9 @@ extension HTTPHeaders {

let connectionClose = self[canonicalForm: "connection"].lazy.map { $0.lowercased() }.contains("close")
switch bodyLength {
case .dynamic:
case .unknown:
return .init(connectionClose: connectionClose, body: .stream)
case .fixed(let length):
case .known(let length):
return .init(connectionClose: connectionClose, body: .fixedSize(length))
}
}
Expand Down Expand Up @@ -88,7 +88,7 @@ extension HTTPHeaders {
self.remove(name: "Transfer-Encoding")

switch bodyLength {
case .fixed(0):
case .known(0):
// if we don't have a body we might not need to send the Content-Length field
// https://tools.ietf.org/html/rfc7230#section-3.3.2
switch method {
Expand All @@ -103,9 +103,9 @@ extension HTTPHeaders {
// for an enclosed payload body.
self.add(name: "Content-Length", value: "0")
}
case .fixed(let length):
case .known(let length):
self.add(name: "Content-Length", value: String(length))
case .dynamic:
case .unknown:
self.add(name: "Transfer-Encoding", value: "chunked")
}
}
Expand Down
38 changes: 23 additions & 15 deletions Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:))
var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/")
request.method = .POST
request.body = .byteBuffer(ByteBuffer(string: "1234"))
request.body = .bytes(ByteBuffer(string: "1234"))

guard let response = await XCTAssertNoThrowWithResult(
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)
Expand All @@ -115,7 +115,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:))
var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/")
request.method = .POST
request.body = .bytes(length: nil, AnySequence("1234".utf8))
request.body = .bytes(AnySequence("1234".utf8), length: .unknown)

guard let response = await XCTAssertNoThrowWithResult(
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)
Expand All @@ -140,7 +140,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:))
var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/")
request.method = .POST
request.body = .bytes(length: nil, AnyCollection("1234".utf8))
request.body = .bytes(AnyCollection("1234".utf8), length: .unknown)

guard let response = await XCTAssertNoThrowWithResult(
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)
Expand Down Expand Up @@ -190,11 +190,11 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:))
var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/")
request.method = .POST
request.body = .stream(length: nil, [
request.body = .stream([
ByteBuffer(string: "1"),
ByteBuffer(string: "2"),
ByteBuffer(string: "34"),
].asAsyncSequence())
].asAsyncSequence(), length: .unknown)

guard let response = await XCTAssertNoThrowWithResult(
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)
Expand All @@ -219,7 +219,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:))
var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/")
request.method = .POST
request.body = .stream(length: nil, "1234".utf8.asAsyncSequence())
request.body = .stream("1234".utf8.asAsyncSequence(), length: .unknown)

guard let response = await XCTAssertNoThrowWithResult(
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)
Expand All @@ -245,7 +245,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/")
request.method = .POST
let streamWriter = AsyncSequenceWriter<ByteBuffer>()
request.body = .stream(length: nil, streamWriter)
request.body = .stream(streamWriter, length: .unknown)

guard let response = await XCTAssertNoThrowWithResult(
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)
Expand All @@ -257,7 +257,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
ByteBuffer(string: "2"),
ByteBuffer(string: "34"),
]
let bodyIterator = response.body.makeAsyncIterator()
var bodyIterator = response.body.makeAsyncIterator()
for expectedFragment in fragments {
streamWriter.write(expectedFragment)
guard let actualFragment = await XCTAssertNoThrowWithResult(
Expand Down Expand Up @@ -287,7 +287,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/")
request.method = .POST
let streamWriter = AsyncSequenceWriter<ByteBuffer>()
request.body = .stream(length: nil, streamWriter)
request.body = .stream(streamWriter, length: .unknown)

guard let response = await XCTAssertNoThrowWithResult(
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)
Expand All @@ -300,7 +300,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
ByteBuffer(string: String(repeating: "c", count: 4000)),
ByteBuffer(string: String(repeating: "d", count: 4000)),
]
let bodyIterator = response.body.makeAsyncIterator()
var bodyIterator = response.body.makeAsyncIterator()
for expectedFragment in fragments {
streamWriter.write(expectedFragment)
guard let actualFragment = await XCTAssertNoThrowWithResult(
Expand Down Expand Up @@ -330,7 +330,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
var request = HTTPClientRequest(url: "http://localhost:\(bin.port)/offline")
request.method = .POST
let streamWriter = AsyncSequenceWriter<ByteBuffer>()
request.body = .stream(length: nil, streamWriter)
request.body = .stream(streamWriter, length: .unknown)

let task = Task<HTTPClientResponse, Error> { [request] in
try await client.execute(request, deadline: .now() + .seconds(2), logger: logger)
Expand All @@ -357,8 +357,12 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
let task = Task<HTTPClientResponse, Error> { [request] in
try await client.execute(request, deadline: .now() + .milliseconds(100), logger: logger)
}
await XCTAssertThrowsError(try await task.value) {
XCTAssertEqual($0 as? HTTPClientError, HTTPClientError.deadlineExceeded)
await XCTAssertThrowsError(try await task.value) { error in
guard let error = error as? HTTPClientError else {
return XCTFail("unexpected error \(error)")
}
// a race between deadline and connect timer can result in either error
XCTAssertTrue([.deadlineExceeded, .connectTimeout].contains(error))
}
}
#endif
Expand All @@ -378,8 +382,12 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
let task = Task<HTTPClientResponse, Error> { [request] in
try await client.execute(request, deadline: .now(), logger: logger)
}
await XCTAssertThrowsError(try await task.value) {
XCTAssertEqual($0 as? HTTPClientError, HTTPClientError.deadlineExceeded)
await XCTAssertThrowsError(try await task.value) { error in
guard let error = error as? HTTPClientError else {
return XCTFail("unexpected error \(error)")
}
// a race between deadline and connect timer can result in either error
XCTAssertTrue([.deadlineExceeded, .connectTimeout].contains(error))
}
}
#endif
Expand Down
Loading