Skip to content

Fix sendability issues in tests #841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,36 @@

import PackageDescription

let strictConcurrencyDevelopment = false

let strictConcurrencySettings: [SwiftSetting] = {
var initialSettings: [SwiftSetting] = []
initialSettings.append(contentsOf: [
.enableUpcomingFeature("StrictConcurrency"),
.enableUpcomingFeature("InferSendableFromCaptures"),
])

if strictConcurrencyDevelopment {
// -warnings-as-errors here is a workaround so that IDE-based development can
// get tripped up on -require-explicit-sendable.
initialSettings.append(.unsafeFlags(["-Xfrontend", "-require-explicit-sendable", "-warnings-as-errors"]))
}

return initialSettings
}()

let package = Package(
name: "async-http-client",
products: [
.library(name: "AsyncHTTPClient", targets: ["AsyncHTTPClient"])
],
dependencies: [
.package(url: "https://github.com/apple/swift-nio.git", from: "2.78.0"),
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.27.1"),
.package(url: "https://github.com/apple/swift-nio-http2.git", from: "1.19.0"),
.package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.13.0"),
.package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.19.0"),
.package(url: "https://github.com/apple/swift-log.git", from: "1.4.4"),
.package(url: "https://github.com/apple/swift-nio.git", from: "2.81.0"),
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.30.0"),
.package(url: "https://github.com/apple/swift-nio-http2.git", from: "1.36.0"),
.package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.26.0"),
.package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.24.0"),
.package(url: "https://github.com/apple/swift-log.git", from: "1.6.0"),
.package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"),
.package(url: "https://github.com/apple/swift-algorithms.git", from: "1.0.0"),
],
Expand Down Expand Up @@ -55,7 +73,8 @@ let package = Package(
.product(name: "Logging", package: "swift-log"),
.product(name: "Atomics", package: "swift-atomics"),
.product(name: "Algorithms", package: "swift-algorithms"),
]
],
swiftSettings: strictConcurrencySettings
),
.testTarget(
name: "AsyncHTTPClientTests",
Expand All @@ -79,7 +98,8 @@ let package = Package(
.copy("Resources/self_signed_key.pem"),
.copy("Resources/example.com.cert.pem"),
.copy("Resources/example.com.private-key.pem"),
]
],
swiftSettings: strictConcurrencySettings
),
]
)
Expand Down
8 changes: 5 additions & 3 deletions Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,8 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
}

func testConnectTimeout() {
let serverGroup = self.serverGroup!
let clientGroup = self.clientGroup!
XCTAsyncTest(timeout: 60) {
#if os(Linux)
// 198.51.100.254 is reserved for documentation only and therefore should not accept any TCP connection
Expand All @@ -542,7 +544,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
XCTAssertNoThrow(try group.syncShutdownGracefully())
}

let serverChannel = try await ServerBootstrap(group: self.serverGroup)
let serverChannel = try await ServerBootstrap(group: serverGroup)
.serverChannelOption(ChannelOptions.backlog, value: 1)
.serverChannelOption(ChannelOptions.autoRead, value: false)
.bind(host: "127.0.0.1", port: 0)
Expand All @@ -551,7 +553,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
XCTAssertNoThrow(try serverChannel.close().wait())
}
let port = serverChannel.localAddress!.port!
let firstClientChannel = try await ClientBootstrap(group: self.serverGroup)
let firstClientChannel = try await ClientBootstrap(group: serverGroup)
.connect(host: "127.0.0.1", port: port)
.get()
defer {
Expand All @@ -561,7 +563,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
#endif

let httpClient = HTTPClient(
eventLoopGroupProvider: .shared(self.clientGroup),
eventLoopGroupProvider: .shared(clientGroup),
configuration: .init(timeout: .init(connect: .milliseconds(100), read: .milliseconds(150)))
)

Expand Down
2 changes: 1 addition & 1 deletion Tests/AsyncHTTPClientTests/AsyncTestHelpers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import NIOCore

/// ``AsyncSequenceWriter`` is `Sendable` because its state is protected by a Lock
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
final class AsyncSequenceWriter<Element>: AsyncSequence, @unchecked Sendable {
final class AsyncSequenceWriter<Element: Sendable>: AsyncSequence, @unchecked Sendable {
typealias AsyncIterator = Iterator

struct Iterator: AsyncIteratorProtocol {
Expand Down
70 changes: 40 additions & 30 deletions Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//

import Logging
import NIOConcurrencyHelpers
import NIOCore
import NIOEmbedded
import NIOHTTP1
Expand Down Expand Up @@ -833,10 +834,11 @@ class HTTP1ClientChannelHandlerTests: XCTestCase {
)
try channel.connect(to: .init(ipAddress: "127.0.0.1", port: 80)).wait()

let request = MockHTTPExecutableRequest()
// non empty body is important to trigger this bug as we otherwise finish the request in a single flush
request.requestFramingMetadata.body = .fixedSize(1)
request.raiseErrorIfUnimplementedMethodIsCalled = false
let request = MockHTTPExecutableRequest(
framingMetadata: RequestFramingMetadata(connectionClose: false, body: .fixedSize(1)),
raiseErrorIfUnimplementedMethodIsCalled: false
)
channel.writeAndFlush(request, promise: nil)
XCTAssertEqual(request.events.map(\.kind), [.willExecuteRequest, .requestHeadSent])
}
Expand Down Expand Up @@ -897,34 +899,43 @@ class HTTP1ClientChannelHandlerTests: XCTestCase {
}
}

class TestBackpressureWriter {
final class TestBackpressureWriter: Sendable {
let eventLoop: EventLoop

let parts: Int

var finishFuture: EventLoopFuture<Void> { self.finishPromise.futureResult }
private let finishPromise: EventLoopPromise<Void>
private(set) var written: Int = 0

private var channelIsWritable: Bool = false
private struct State {
var written = 0
var channelIsWritable = false
}

var written: Int {
self.state.value.written
}

private let state: NIOLoopBoundBox<State>

init(eventLoop: EventLoop, parts: Int) {
self.eventLoop = eventLoop
self.parts = parts

self.state = .makeBoxSendingValue(State(), eventLoop: eventLoop)
self.finishPromise = eventLoop.makePromise(of: Void.self)
}

func start(writer: HTTPClient.Body.StreamWriter, expectedErrors: [HTTPClientError] = []) -> EventLoopFuture<Void> {
@Sendable
func recursive() {
XCTAssert(self.eventLoop.inEventLoop)
XCTAssert(self.channelIsWritable)
if self.written == self.parts {
XCTAssert(self.state.value.channelIsWritable)
if self.state.value.written == self.parts {
self.finishPromise.succeed(())
} else {
self.eventLoop.execute {
let future = writer.write(.byteBuffer(.init(bytes: [0, 1])))
self.written += 1
self.state.value.written += 1
future.whenComplete { result in
switch result {
case .success:
Expand All @@ -951,36 +962,35 @@ class TestBackpressureWriter {
}

func writabilityChanged(_ newValue: Bool) {
self.channelIsWritable = newValue
self.state.value.channelIsWritable = newValue
}
}

class ResponseBackpressureDelegate: HTTPClientResponseDelegate {
final class ResponseBackpressureDelegate: HTTPClientResponseDelegate {
typealias Response = Void

enum State {
enum State: Sendable {
case consuming(EventLoopPromise<Void>)
case waitingForRemote(CircularBuffer<EventLoopPromise<ByteBuffer?>>)
case buffering((ByteBuffer?, EventLoopPromise<Void>)?)
case done
}

let eventLoop: EventLoop
private var state: State = .buffering(nil)
private let state: NIOLoopBoundBox<State>

init(eventLoop: EventLoop) {
self.eventLoop = eventLoop

self.state = .consuming(self.eventLoop.makePromise(of: Void.self))
self.state = .makeBoxSendingValue(.consuming(eventLoop.makePromise(of: Void.self)), eventLoop: eventLoop)
}

func next() -> EventLoopFuture<ByteBuffer?> {
switch self.state {
switch self.state.value {
case .consuming(let backpressurePromise):
var promiseBuffer = CircularBuffer<EventLoopPromise<ByteBuffer?>>()
let newPromise = self.eventLoop.makePromise(of: ByteBuffer?.self)
promiseBuffer.append(newPromise)
self.state = .waitingForRemote(promiseBuffer)
self.state.value = .waitingForRemote(promiseBuffer)
backpressurePromise.succeed(())
return newPromise.futureResult

Expand All @@ -991,18 +1001,18 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate {
)
let promise = self.eventLoop.makePromise(of: ByteBuffer?.self)
promiseBuffer.append(promise)
self.state = .waitingForRemote(promiseBuffer)
self.state.value = .waitingForRemote(promiseBuffer)
return promise.futureResult

case .buffering(.none):
var promiseBuffer = CircularBuffer<EventLoopPromise<ByteBuffer?>>()
let promise = self.eventLoop.makePromise(of: ByteBuffer?.self)
promiseBuffer.append(promise)
self.state = .waitingForRemote(promiseBuffer)
self.state.value = .waitingForRemote(promiseBuffer)
return promise.futureResult

case .buffering(.some((let buffer, let promise))):
self.state = .buffering(nil)
self.state.value = .buffering(nil)
promise.succeed(())
return self.eventLoop.makeSucceededFuture(buffer)

Expand All @@ -1012,7 +1022,7 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate {
}

func didReceiveHead(task: HTTPClient.Task<Void>, _ head: HTTPResponseHead) -> EventLoopFuture<Void> {
switch self.state {
switch self.state.value {
case .consuming(let backpressurePromise):
return backpressurePromise.futureResult

Expand All @@ -1025,7 +1035,7 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate {
}

func didReceiveBodyPart(task: HTTPClient.Task<Void>, _ buffer: ByteBuffer) -> EventLoopFuture<Void> {
switch self.state {
switch self.state.value {
case .waitingForRemote(var promiseBuffer):
assert(
!promiseBuffer.isEmpty,
Expand All @@ -1034,18 +1044,18 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate {
let promise = promiseBuffer.removeFirst()
if promiseBuffer.isEmpty {
let newBackpressurePromise = self.eventLoop.makePromise(of: Void.self)
self.state = .consuming(newBackpressurePromise)
self.state.value = .consuming(newBackpressurePromise)
promise.succeed(buffer)
return newBackpressurePromise.futureResult
} else {
self.state = .waitingForRemote(promiseBuffer)
self.state.value = .waitingForRemote(promiseBuffer)
promise.succeed(buffer)
return self.eventLoop.makeSucceededVoidFuture()
}

case .buffering(.none):
let promise = self.eventLoop.makePromise(of: Void.self)
self.state = .buffering((buffer, promise))
self.state.value = .buffering((buffer, promise))
return promise.futureResult

case .buffering(.some):
Expand All @@ -1059,15 +1069,15 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate {
}

func didFinishRequest(task: HTTPClient.Task<Void>) throws {
switch self.state {
switch self.state.value {
case .waitingForRemote(let promiseBuffer):
for promise in promiseBuffer {
promise.succeed(.none)
}
self.state = .done
self.state.value = .done

case .buffering(.none):
self.state = .done
self.state.value = .done

case .done, .consuming:
preconditionFailure("Invalid state: \(self.state)")
Expand All @@ -1093,7 +1103,7 @@ class ReadEventHitHandler: ChannelOutboundHandler {
}
}

final class FailEndHandler: ChannelOutboundHandler {
final class FailEndHandler: ChannelOutboundHandler, Sendable {
typealias OutboundIn = HTTPClientRequestPart
typealias OutboundOut = HTTPClientRequestPart

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -568,10 +568,11 @@ class HTTP2ClientRequestHandlerTests: XCTestCase {
)
try channel.connect(to: .init(ipAddress: "127.0.0.1", port: 80)).wait()

let request = MockHTTPExecutableRequest()
// non empty body is important to trigger this bug as we otherwise finish the request in a single flush
request.requestFramingMetadata.body = .fixedSize(1)
request.raiseErrorIfUnimplementedMethodIsCalled = false
let request = MockHTTPExecutableRequest(
framingMetadata: RequestFramingMetadata(connectionClose: false, body: .fixedSize(1)),
raiseErrorIfUnimplementedMethodIsCalled: false
)
channel.writeAndFlush(request, promise: nil)
XCTAssertEqual(request.events.map(\.kind), [.willExecuteRequest, .requestHeadSent])
}
Expand Down
20 changes: 12 additions & 8 deletions Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import AsyncHTTPClient // NOT @testable - tests that really need @testable go into HTTP2ClientInternalTests.swift
import Logging
import NIOConcurrencyHelpers
import NIOCore
import NIOFoundationCompat
import NIOHTTP1
Expand Down Expand Up @@ -283,15 +284,16 @@ class HTTP2ClientTests: XCTestCase {
XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(bin.port)"))
guard let request = maybeRequest else { return }

var task: HTTPClient.Task<Void>!
let taskBox = NIOLockedValueBox<HTTPClient.Task<Void>?>(nil)
let delegate = HeadReceivedCallback { _ in
// request is definitely running because we just received a head from the server
task.cancel()
taskBox.withLockedValue { $0 }!.cancel()
}
task = client.execute(
let task = client.execute(
request: request,
delegate: delegate
)
taskBox.withLockedValue { $0 = task }

XCTAssertThrowsError(try task.futureResult.timeout(after: .seconds(2)).wait()) {
XCTAssertEqualTypeAndValue($0, HTTPClientError.cancelled)
Expand Down Expand Up @@ -360,18 +362,20 @@ class HTTP2ClientTests: XCTestCase {
guard let request = maybeRequest else { return }

let tasks = (0..<100).map { _ -> HTTPClient.Task<TestHTTPDelegate.Response> in
var task: HTTPClient.Task<Void>!
let taskBox = NIOLockedValueBox<HTTPClient.Task<Void>?>(nil)

let delegate = HeadReceivedCallback { _ in
// request is definitely running because we just received a head from the server
cancelPool.next().execute {
// canceling from a different thread
task.cancel()
taskBox.withLockedValue { $0 }!.cancel()
}
}
task = client.execute(
let task = client.execute(
request: request,
delegate: delegate
)
taskBox.withLockedValue { $0 = task }
return task
}

Expand Down Expand Up @@ -547,8 +551,8 @@ class HTTP2ClientTests: XCTestCase {

private final class HeadReceivedCallback: HTTPClientResponseDelegate {
typealias Response = Void
private let didReceiveHeadCallback: (HTTPResponseHead) -> Void
init(didReceiveHead: @escaping (HTTPResponseHead) -> Void) {
private let didReceiveHeadCallback: @Sendable (HTTPResponseHead) -> Void
init(didReceiveHead: @escaping @Sendable (HTTPResponseHead) -> Void) {
self.didReceiveHeadCallback = didReceiveHead
}

Expand Down
Loading
Loading