diff --git a/Sources/AsyncAlgorithms_XCTest/ValidationTest.swift b/Sources/AsyncAlgorithms_XCTest/ValidationTest.swift index 2a7f5667..acabb1a9 100644 --- a/Sources/AsyncAlgorithms_XCTest/ValidationTest.swift +++ b/Sources/AsyncAlgorithms_XCTest/ValidationTest.swift @@ -24,7 +24,7 @@ extension XCTestCase { #endif } - func validate(theme: Theme, expectedFailures: Set, @AsyncSequenceValidationDiagram _ build: (AsyncSequenceValidationDiagram) -> Test, file: StaticString = #file, line: UInt = #line) { + func validate(theme: Theme, expectedFailures: Set, @AsyncSequenceValidationDiagram _ build: (AsyncSequenceValidationDiagram) -> Test, file: StaticString = #filePath, line: UInt = #line) { var expectations = expectedFailures var result: AsyncSequenceValidationDiagram.ExpectationResult? var failures = [AsyncSequenceValidationDiagram.ExpectationFailure]() @@ -62,15 +62,15 @@ extension XCTestCase { } } - func validate(expectedFailures: Set, @AsyncSequenceValidationDiagram _ build: (AsyncSequenceValidationDiagram) -> Test, file: StaticString = #file, line: UInt = #line) { + func validate(expectedFailures: Set, @AsyncSequenceValidationDiagram _ build: (AsyncSequenceValidationDiagram) -> Test, file: StaticString = #filePath, line: UInt = #line) { validate(theme: .ascii, expectedFailures: expectedFailures, build, file: file, line: line) } - public func validate(theme: Theme, @AsyncSequenceValidationDiagram _ build: (AsyncSequenceValidationDiagram) -> Test, file: StaticString = #file, line: UInt = #line) { + public func validate(theme: Theme, @AsyncSequenceValidationDiagram _ build: (AsyncSequenceValidationDiagram) -> Test, file: StaticString = #filePath, line: UInt = #line) { validate(theme: theme, expectedFailures: [], build, file: file, line: line) } - public func validate(@AsyncSequenceValidationDiagram _ build: (AsyncSequenceValidationDiagram) -> Test, file: StaticString = #file, line: UInt = #line) { + public func validate(@AsyncSequenceValidationDiagram _ build: (AsyncSequenceValidationDiagram) -> Test, file: StaticString = #filePath, line: UInt = #line) { validate(theme: .ascii, expectedFailures: [], build, file: file, line: line) } } diff --git a/Sources/AsyncSequenceValidation/Clock.swift b/Sources/AsyncSequenceValidation/Clock.swift index e76f5aa9..fb636c83 100644 --- a/Sources/AsyncSequenceValidation/Clock.swift +++ b/Sources/AsyncSequenceValidation/Clock.swift @@ -115,7 +115,7 @@ extension AsyncSequenceValidationDiagram.Clock { let token = queue.prepare() try await withTaskCancellationHandler { try await withUnsafeThrowingContinuation { continuation in - queue.enqueue(AsyncSequenceValidationDiagram.Context.currentJob, deadline: deadline, continuation: continuation, token: token) + queue.enqueue(AsyncSequenceValidationDiagram.Context.state.withCriticalRegion(\.currentJob), deadline: deadline, continuation: continuation, token: token) } } onCancel: { queue.cancel(token) diff --git a/Sources/AsyncSequenceValidation/Input.swift b/Sources/AsyncSequenceValidation/Input.swift index 26e23da6..7db44d51 100644 --- a/Sources/AsyncSequenceValidation/Input.swift +++ b/Sources/AsyncSequenceValidation/Input.swift @@ -52,7 +52,7 @@ extension AsyncSequenceValidationDiagram { } return try await withTaskCancellationHandler { try await withUnsafeThrowingContinuation { continuation in - queue.enqueue(Context.currentJob, deadline: when, continuation: continuation, results[eventIndex], index: index, token: token) + queue.enqueue(Context.state.withCriticalRegion(\.currentJob), deadline: when, continuation: continuation, results[eventIndex], index: index, token: token) } } onCancel: { [queue] in queue.cancel(token) diff --git a/Sources/AsyncSequenceValidation/TaskDriver.swift b/Sources/AsyncSequenceValidation/TaskDriver.swift index ed128193..153b3960 100644 --- a/Sources/AsyncSequenceValidation/TaskDriver.swift +++ b/Sources/AsyncSequenceValidation/TaskDriver.swift @@ -19,22 +19,8 @@ import Glibc #error("TODO: Port TaskDriver threading to windows") #endif -#if canImport(Darwin) -func start_thread(_ raw: UnsafeMutableRawPointer) -> UnsafeMutableRawPointer? { - Unmanaged.fromOpaque(raw).takeRetainedValue().run() - return nil -} -#elseif canImport(Glibc) -func start_thread(_ raw: UnsafeMutableRawPointer?) -> UnsafeMutableRawPointer? { - Unmanaged.fromOpaque(raw!).takeRetainedValue().run() - return nil -} -#elseif canImport(WinSDK) -#error("TODO: Port TaskDriver threading to windows") -#endif - -final class TaskDriver { - let work: (TaskDriver) -> Void +final class TaskDriver: @unchecked Sendable { + let work: @Sendable (TaskDriver) -> Void let queue: WorkQueue #if canImport(Darwin) var thread: pthread_t? @@ -43,19 +29,37 @@ final class TaskDriver { #elseif canImport(WinSDK) #error("TODO: Port TaskDriver threading to windows") #endif - - init(queue: WorkQueue, _ work: @escaping (TaskDriver) -> Void) { + + private let lock = Lock.allocate() + + init(queue: WorkQueue, _ work: @Sendable @escaping (TaskDriver) -> Void) { self.queue = queue self.work = work } func start() { +#if canImport(Darwin) + func start_thread(_ raw: UnsafeMutableRawPointer) -> UnsafeMutableRawPointer? { + Unmanaged.fromOpaque(raw).takeRetainedValue().run() + return nil + } +#elseif canImport(Glibc) + func start_thread(_ raw: UnsafeMutableRawPointer?) -> UnsafeMutableRawPointer? { + Unmanaged.fromOpaque(raw!).takeRetainedValue().run() + return nil + } +#elseif canImport(WinSDK) +#error("TODO: Port TaskDriver threading to windows") +#endif + + lock.withLockVoid { #if canImport(Darwin) || canImport(Glibc) - pthread_create(&thread, nil, start_thread, - Unmanaged.passRetained(self).toOpaque()) + pthread_create(&thread, nil, start_thread, + Unmanaged.passRetained(self).toOpaque()) #elseif canImport(WinSDK) #error("TODO: Port TaskDriver threading to windows") #endif + } } func run() { @@ -77,11 +81,11 @@ final class TaskDriver { func enqueue(_ job: JobRef) { let job = Job(job) - queue.enqueue(AsyncSequenceValidationDiagram.Context.currentJob) { - let previous = AsyncSequenceValidationDiagram.Context.currentJob - AsyncSequenceValidationDiagram.Context.currentJob = job + queue.enqueue(AsyncSequenceValidationDiagram.Context.state.withCriticalRegion(\.currentJob)) { + let previous = AsyncSequenceValidationDiagram.Context.state.withCriticalRegion(\.currentJob) + AsyncSequenceValidationDiagram.Context.state.withCriticalRegion { $0.currentJob = job } job.execute() - AsyncSequenceValidationDiagram.Context.currentJob = previous + AsyncSequenceValidationDiagram.Context.state.withCriticalRegion { $0.currentJob = previous } } } } diff --git a/Sources/AsyncSequenceValidation/Test.swift b/Sources/AsyncSequenceValidation/Test.swift index 8dc86832..9f3f3cef 100644 --- a/Sources/AsyncSequenceValidation/Test.swift +++ b/Sources/AsyncSequenceValidation/Test.swift @@ -9,7 +9,7 @@ // //===----------------------------------------------------------------------===// -import _CAsyncSequenceValidationSupport +@preconcurrency import _CAsyncSequenceValidationSupport import AsyncAlgorithms @_silgen_name("swift_job_run") @@ -48,17 +48,17 @@ extension AsyncSequenceValidationDiagram { do { if let pastEnd = try await iterator.next(){ let failure = ExpectationFailure( - when: Context.clock!.now, + when: Context.state.withCriticalRegion(\.clock!.now), kind: .specificationViolationGotValueAfterIteration(pastEnd), specification: output) - Context.specificationFailures.append(failure) + Context.state.withCriticalRegion { $0.specificationFailures.append(failure) } } } catch { let failure = ExpectationFailure( - when: Context.clock!.now, + when: Context.state.withCriticalRegion(\.clock!.now), kind: .specificationViolationGotFailureAfterIteration(error), specification: output) - Context.specificationFailures.append(failure) + Context.state.withCriticalRegion { $0.specificationFailures.append(failure) } } } catch { throw error @@ -107,7 +107,7 @@ extension AsyncSequenceValidationDiagram { } } - private static let _executor: AnyObject = { + private static let _executor: any SerialExecutor = { if #available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) { return ClockExecutor_5_9() } else { @@ -116,19 +116,18 @@ extension AsyncSequenceValidationDiagram { }() static var unownedExecutor: UnownedSerialExecutor { - (_executor as! any SerialExecutor).asUnownedSerialExecutor() + _executor.asUnownedSerialExecutor() } #endif - static var clock: Clock? - - - - static var driver: TaskDriver? - - static var currentJob: Job? - - static var specificationFailures = [ExpectationFailure]() + static let state = ManagedCriticalState(State()) + + struct State { + var clock: Clock? + var driver: TaskDriver? + var currentJob: Job? + var specificationFailures = [ExpectationFailure]() + } } enum ActualResult { @@ -158,9 +157,9 @@ extension AsyncSequenceValidationDiagram { actual: [(Clock.Instant, Result)] ) -> (ExpectationResult, [ExpectationFailure]) { let result = ExpectationResult(expected: expected, actual: actual) - var failures = Context.specificationFailures - Context.specificationFailures.removeAll() - + var failures = Context.state.withCriticalRegion(\.specificationFailures) + Context.state.withCriticalRegion { $0.specificationFailures.removeAll() } + let actualTimes = actual.map { when, _ in when } let expectedTimes = expected.map { $0.when } @@ -349,55 +348,58 @@ extension AsyncSequenceValidationDiagram { } let actual = ManagedCriticalState([(Clock.Instant, Result)]()) - Context.clock = clock - Context.specificationFailures.removeAll() - // This all needs to be isolated from potential Tasks (the caller function might be async!) - Context.driver = TaskDriver(queue: diagram.queue) { driver in - swift_task_enqueueGlobal_hook = { job, original in - Context.driver?.enqueue(job) - } - - let runner = Task { - do { - try await test.test(with: clock, activeTicks: activeTicks, output: test.output) { event in + Context.state.withCriticalRegion { state in + state.clock = clock + state.specificationFailures.removeAll() + // This all needs to be isolated from potential Tasks (the caller function might be async!) + state.driver = TaskDriver(queue: diagram.queue) { driver in + swift_task_enqueueGlobal_hook = { job, original in + Context.state.withCriticalRegion(\.driver)?.enqueue(job) + } + + let runner = Task { + do { + try await test.test(with: clock, activeTicks: activeTicks, output: test.output) { event in + actual.withCriticalRegion { values in + values.append((clock.now, .success(event))) + } + } actual.withCriticalRegion { values in - values.append((clock.now, .success(event))) + values.append((clock.now, .success(nil))) + } + } catch { + actual.withCriticalRegion { values in + values.append((clock.now, .failure(error))) } - } - actual.withCriticalRegion { values in - values.append((clock.now, .success(nil))) - } - } catch { - actual.withCriticalRegion { values in - values.append((clock.now, .failure(error))) } } - } - - // Drain off any initial work. Work may spawn additional work to be done. - // If the driver ever becomes blocked on the clock, exit early out of that - // drain, because the drain cant make any forward progress if it is blocked - // by a needed clock advancement. - diagram.queue.drain() - // Next make sure to iterate a decent amount past the end of the maximum - // scheduled things (that way we ensure any reasonable errors are caught) - for _ in 0..<(end.when.rawValue * 2) { - if cancelEvents.contains(diagram.queue.now.advanced(by: .steps(1))) { - runner.cancel() + + // Drain off any initial work. Work may spawn additional work to be done. + // If the driver ever becomes blocked on the clock, exit early out of that + // drain, because the drain cant make any forward progress if it is blocked + // by a needed clock advancement. + diagram.queue.drain() + // Next make sure to iterate a decent amount past the end of the maximum + // scheduled things (that way we ensure any reasonable errors are caught) + for _ in 0..<(end.when.rawValue * 2) { + if cancelEvents.contains(diagram.queue.now.advanced(by: .steps(1))) { + runner.cancel() + } + diagram.queue.advance() } - diagram.queue.advance() + + runner.cancel() + Context.state.withCriticalRegion { $0.clock = nil } + swift_task_enqueueGlobal_hook = nil } - - runner.cancel() - Context.clock = nil - swift_task_enqueueGlobal_hook = nil } - Context.driver?.start() + let driver = Context.state.withCriticalRegion(\.driver) + driver?.start() // This is only valid since we are doing tests here // else wise this would cause QoS inversions - Context.driver?.join() - Context.driver = nil - + driver?.join() + Context.state.withCriticalRegion { $0.driver = nil } + return validate( inputs: test.inputs, output: test.output, diff --git a/Tests/AsyncAlgorithmsTests/Support/Asserts.swift b/Tests/AsyncAlgorithmsTests/Support/Asserts.swift index d891cf91..8a295989 100644 --- a/Tests/AsyncAlgorithmsTests/Support/Asserts.swift +++ b/Tests/AsyncAlgorithmsTests/Support/Asserts.swift @@ -154,7 +154,7 @@ public func XCTAssertEqual(_ expressio @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) internal func XCTAssertThrowsError( _ expression: @autoclosure () async throws -> T, - file: StaticString = #file, + file: StaticString = #filePath, line: UInt = #line, verify: (Error) -> Void = { _ in } ) async {