Skip to content

Commit 28eb2ac

Browse files
authored
AsyncChannel: Provide throwing finish method on test stream (#2493)
# Motivation We recently provided testing utilities for the `NIOAsyncChannelInboundStream`. On thing that was missing is a way to finish the stream with an error. # Modification This PR provides a `finish()` method that takes an error which is thrown from the inbound stream. # Result Better way to test code relying on the AsyncChannel work.
1 parent 1ec71be commit 28eb2ac

File tree

2 files changed

+38
-12
lines changed

2 files changed

+38
-12
lines changed

Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ public struct NIOAsyncChannelInboundStream<Inbound: Sendable>: Sendable {
2424
/// A source used for driving a ``NIOAsyncChannelInboundStream`` during tests.
2525
public struct TestSource {
2626
@usableFromInline
27-
internal let continuation: AsyncStream<Inbound>.Continuation
27+
internal let continuation: AsyncThrowingStream<Inbound, Error>.Continuation
2828

2929
@inlinable
30-
init(continuation: AsyncStream<Inbound>.Continuation) {
30+
init(continuation: AsyncThrowingStream<Inbound, Error>.Continuation) {
3131
self.continuation = continuation
3232
}
3333

@@ -40,23 +40,25 @@ public struct NIOAsyncChannelInboundStream<Inbound: Sendable>: Sendable {
4040
}
4141

4242
/// Finished the inbound stream.
43+
///
44+
/// - Parameter error: The error to throw, or nil, to finish normally.
4345
@inlinable
44-
public func finish() {
45-
self.continuation.finish()
46+
public func finish(throwing error: Error? = nil) {
47+
self.continuation.finish(throwing: error)
4648
}
4749
}
4850

4951
#if swift(>=5.7)
5052
@usableFromInline
5153
enum _Backing: Sendable {
52-
case asyncStream(AsyncStream<Inbound>)
54+
case asyncStream(AsyncThrowingStream<Inbound, Error>)
5355
case producer(Producer)
5456
}
5557
#else
5658
// AsyncStream wasn't marked as `Sendable` in 5.6
5759
@usableFromInline
5860
enum _Backing: @unchecked Sendable {
59-
case asyncStream(AsyncStream<Inbound>)
61+
case asyncStream(AsyncThrowingStream<Inbound, Error>)
6062
case producer(Producer)
6163
}
6264
#endif
@@ -72,15 +74,15 @@ public struct NIOAsyncChannelInboundStream<Inbound: Sendable>: Sendable {
7274
/// - Returns: A tuple containing the input stream and a test source to drive it.
7375
@inlinable
7476
public static func makeTestingStream() -> (Self, TestSource) {
75-
var continuation: AsyncStream<Inbound>.Continuation!
76-
let stream = AsyncStream<Inbound> { continuation = $0 }
77+
var continuation: AsyncThrowingStream<Inbound, Error>.Continuation!
78+
let stream = AsyncThrowingStream<Inbound, Error> { continuation = $0 }
7779
let source = TestSource(continuation: continuation)
7880
let inputStream = Self(stream: stream)
7981
return (inputStream, source)
8082
}
8183

8284
@inlinable
83-
init(stream: AsyncStream<Inbound>) {
85+
init(stream: AsyncThrowingStream<Inbound, Error>) {
8486
self._backing = .asyncStream(stream)
8587
}
8688

@@ -163,7 +165,7 @@ extension NIOAsyncChannelInboundStream: AsyncSequence {
163165
public struct AsyncIterator: AsyncIteratorProtocol {
164166
@usableFromInline
165167
enum _Backing {
166-
case asyncStream(AsyncStream<Inbound>.Iterator)
168+
case asyncStream(AsyncThrowingStream<Inbound, Error>.Iterator)
167169
case producer(Producer.AsyncIterator)
168170
}
169171

@@ -183,8 +185,10 @@ extension NIOAsyncChannelInboundStream: AsyncSequence {
183185
public mutating func next() async throws -> Element? {
184186
switch self._backing {
185187
case .asyncStream(var iterator):
186-
let value = await iterator.next()
187-
self._backing = .asyncStream(iterator)
188+
defer {
189+
self._backing = .asyncStream(iterator)
190+
}
191+
let value = try await iterator.next()
188192
return value
189193

190194
case .producer(let iterator):

Tests/NIOCoreTests/AsyncChannel/AsyncChannelInboundStreamTests.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,26 @@ final class AsyncChannelInboundStreamTests: XCTestCase {
3737
XCTAssertEqual(result, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
3838
}
3939
}
40+
41+
func testTestingStream_whenThrowing() async throws {
42+
let (stream, source) = NIOAsyncChannelInboundStream<Int>.makeTestingStream()
43+
44+
await withThrowingTaskGroup(of: [Int].self) { group in
45+
group.addTask {
46+
var elements = [Int]()
47+
for try await element in stream {
48+
elements.append(element)
49+
}
50+
return elements
51+
}
52+
source.finish(throwing: ChannelError.alreadyClosed)
53+
54+
do {
55+
_ = try await group.next()
56+
XCTFail("Expected an error to be thrown")
57+
} catch {
58+
XCTAssertEqual(error as? ChannelError, .alreadyClosed)
59+
}
60+
}
61+
}
4062
}

0 commit comments

Comments
 (0)