Skip to content

Commit d67e0b8

Browse files
committed
Merge provides all elements from the subsequences on cancellation
On cancellation, merge currently does not yield all elements. This leads to situations in which the final elements of AsyncStreams are not forwarded to the user. This patch ensures, that only the underlying Task is cancelled and all subsequences' elements are forwarded to the user.
1 parent 07a0c1e commit d67e0b8

File tree

3 files changed

+118
-52
lines changed

3 files changed

+118
-52
lines changed

Sources/AsyncAlgorithms/Merge/MergeStateMachine.swift

+46-44
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ struct MergeStateMachine<
4141
buffer: Deque<Element>,
4242
upstreamContinuations: [UnsafeContinuation<Void, Error>],
4343
upstreamsFinished: Int,
44-
downstreamContinuation: UnsafeContinuation<Element?, Error>?
44+
downstreamContinuation: UnsafeContinuation<Element?, Error>?,
45+
cancelled: Bool
4546
)
4647

4748
/// The state once any of the upstream sequences threw an `Error`.
@@ -100,11 +101,11 @@ struct MergeStateMachine<
100101
// Nothing to do here. No demand was signalled until now
101102
return .none
102103

103-
case .merging(_, _, _, _, .some):
104+
case .merging(_, _, _, _, .some, _):
104105
// An iterator was deinitialized while we have a suspended continuation.
105106
preconditionFailure("Internal inconsistency current state \(self.state) and received iteratorDeinitialized()")
106107

107-
case let .merging(task, _, upstreamContinuations, _, .none):
108+
case let .merging(task, _, upstreamContinuations, _, .none, _):
108109
// The iterator was dropped which signals that the consumer is finished.
109110
// We can transition to finished now and need to clean everything up.
110111
state = .finished
@@ -142,7 +143,8 @@ struct MergeStateMachine<
142143
buffer: .init(),
143144
upstreamContinuations: [], // This should reserve capacity in the variadic generics case
144145
upstreamsFinished: 0,
145-
downstreamContinuation: nil
146+
downstreamContinuation: nil,
147+
cancelled: false
146148
)
147149

148150
case .merging, .upstreamFailure, .finished:
@@ -175,11 +177,11 @@ struct MergeStateMachine<
175177
// Child tasks are only created after we transitioned to `merging`
176178
preconditionFailure("Internal inconsistency current state \(self.state) and received childTaskSuspended()")
177179

178-
case .merging(_, _, _, _, .some):
180+
case .merging(_, _, _, _, .some, _):
179181
// We have outstanding demand so request the next element
180182
return .resumeContinuation(upstreamContinuation: continuation)
181183

182-
case .merging(let task, let buffer, var upstreamContinuations, let upstreamsFinished, .none):
184+
case .merging(let task, let buffer, var upstreamContinuations, let upstreamsFinished, .none, let cancelled):
183185
// There is no outstanding demand from the downstream
184186
// so we are storing the continuation and resume it once there is demand.
185187
state = .modifying
@@ -191,7 +193,8 @@ struct MergeStateMachine<
191193
buffer: buffer,
192194
upstreamContinuations: upstreamContinuations,
193195
upstreamsFinished: upstreamsFinished,
194-
downstreamContinuation: nil
196+
downstreamContinuation: nil,
197+
cancelled: cancelled
195198
)
196199

197200
return .none
@@ -236,7 +239,7 @@ struct MergeStateMachine<
236239
// Child tasks that are producing elements are only created after we transitioned to `merging`
237240
preconditionFailure("Internal inconsistency current state \(self.state) and received elementProduced()")
238241

239-
case let .merging(task, buffer, upstreamContinuations, upstreamsFinished, .some(downstreamContinuation)):
242+
case let .merging(task, buffer, upstreamContinuations, upstreamsFinished, .some(downstreamContinuation), cancelled):
240243
// We produced an element and have an outstanding downstream continuation
241244
// this means we can go right ahead and resume the continuation with that element
242245
precondition(buffer.isEmpty, "We are holding a continuation so the buffer must be empty")
@@ -246,15 +249,16 @@ struct MergeStateMachine<
246249
buffer: buffer,
247250
upstreamContinuations: upstreamContinuations,
248251
upstreamsFinished: upstreamsFinished,
249-
downstreamContinuation: nil
252+
downstreamContinuation: nil,
253+
cancelled: cancelled
250254
)
251255

252256
return .resumeContinuation(
253257
downstreamContinuation: downstreamContinuation,
254258
element: element
255259
)
256260

257-
case .merging(let task, var buffer, let upstreamContinuations, let upstreamsFinished, .none):
261+
case .merging(let task, var buffer, let upstreamContinuations, let upstreamsFinished, .none, let cancelled):
258262
// There is not outstanding downstream continuation so we must buffer the element
259263
// This happens if we race our upstream sequences to produce elements
260264
// and the _losers_ are signalling their produced element
@@ -267,7 +271,8 @@ struct MergeStateMachine<
267271
buffer: buffer,
268272
upstreamContinuations: upstreamContinuations,
269273
upstreamsFinished: upstreamsFinished,
270-
downstreamContinuation: nil
274+
downstreamContinuation: nil,
275+
cancelled: cancelled
271276
)
272277

273278
return .none
@@ -310,7 +315,7 @@ struct MergeStateMachine<
310315
case .initial:
311316
preconditionFailure("Internal inconsistency current state \(self.state) and received upstreamFinished()")
312317

313-
case .merging(let task, let buffer, let upstreamContinuations, var upstreamsFinished, let .some(downstreamContinuation)):
318+
case .merging(let task, let buffer, let upstreamContinuations, var upstreamsFinished, let .some(downstreamContinuation), let cancelled):
314319
// One of the upstreams finished
315320
precondition(buffer.isEmpty, "We are holding a continuation so the buffer must be empty")
316321

@@ -335,13 +340,14 @@ struct MergeStateMachine<
335340
buffer: buffer,
336341
upstreamContinuations: upstreamContinuations,
337342
upstreamsFinished: upstreamsFinished,
338-
downstreamContinuation: downstreamContinuation
343+
downstreamContinuation: downstreamContinuation,
344+
cancelled: cancelled
339345
)
340346

341347
return .none
342348
}
343349

344-
case .merging(let task, let buffer, let upstreamContinuations, var upstreamsFinished, .none):
350+
case .merging(let task, let buffer, let upstreamContinuations, var upstreamsFinished, .none, let cancelled):
345351
// First we increment our counter of finished upstreams
346352
upstreamsFinished += 1
347353

@@ -350,7 +356,8 @@ struct MergeStateMachine<
350356
buffer: buffer,
351357
upstreamContinuations: upstreamContinuations,
352358
upstreamsFinished: upstreamsFinished,
353-
downstreamContinuation: nil
359+
downstreamContinuation: nil,
360+
cancelled: cancelled
354361
)
355362

356363
if upstreamsFinished == self.numberOfUpstreamSequences {
@@ -402,7 +409,7 @@ struct MergeStateMachine<
402409
case .initial:
403410
preconditionFailure("Internal inconsistency current state \(self.state) and received upstreamThrew()")
404411

405-
case let .merging(task, buffer, upstreamContinuations, _, .some(downstreamContinuation)):
412+
case let .merging(task, buffer, upstreamContinuations, _, .some(downstreamContinuation), _):
406413
// An upstream threw an error and we have a downstream continuation.
407414
// We just need to resume the downstream continuation with the error and cancel everything
408415
precondition(buffer.isEmpty, "We are holding a continuation so the buffer must be empty")
@@ -417,7 +424,7 @@ struct MergeStateMachine<
417424
upstreamContinuations: upstreamContinuations
418425
)
419426

420-
case let .merging(task, buffer, upstreamContinuations, _, .none):
427+
case let .merging(task, buffer, upstreamContinuations, _, .none, _):
421428
// An upstream threw an error and we don't have a downstream continuation.
422429
// We need to store the error and wait for the downstream to consume the
423430
// rest of the buffer and the error. However, we can already cancel the task
@@ -454,10 +461,7 @@ struct MergeStateMachine<
454461
upstreamContinuations: [UnsafeContinuation<Void, Error>]
455462
)
456463
/// Indicates that the task and the upstream continuations should be cancelled.
457-
case cancelTaskAndUpstreamContinuations(
458-
task: Task<Void, Never>,
459-
upstreamContinuations: [UnsafeContinuation<Void, Error>]
460-
)
464+
case cancelTask(Task<Void, Never>)
461465
/// Indicates that nothing should be done.
462466
case none
463467
}
@@ -471,26 +475,21 @@ struct MergeStateMachine<
471475

472476
return .none
473477

474-
case let .merging(task, _, upstreamContinuations, _, .some(downstreamContinuation)):
475-
// The downstream Task got cancelled so we need to cancel our upstream Task
476-
// and resume all continuations. We can also transition to finished.
477-
state = .finished
478+
case let .merging(task, buffer, upstreamContinuations, upstreamFinished, downstreamContinuation, cancelled):
479+
guard !cancelled else {
480+
return .none
481+
}
478482

479-
return .resumeDownstreamContinuationWithNilAndCancelTaskAndUpstreamContinuations(
480-
downstreamContinuation: downstreamContinuation,
483+
self.state = .merging(
481484
task: task,
482-
upstreamContinuations: upstreamContinuations
485+
buffer: buffer,
486+
upstreamContinuations: upstreamContinuations,
487+
upstreamsFinished: upstreamFinished,
488+
downstreamContinuation: downstreamContinuation,
489+
cancelled: true
483490
)
484491

485-
case let .merging(task, _, upstreamContinuations, _, .none):
486-
// The downstream Task got cancelled so we need to cancel our upstream Task
487-
// and resume all continuations. We can also transition to finished.
488-
state = .finished
489-
490-
return .cancelTaskAndUpstreamContinuations(
491-
task: task,
492-
upstreamContinuations: upstreamContinuations
493-
)
492+
return .cancelTask(task)
494493

495494
case .upstreamFailure:
496495
// An upstream already threw and we cancelled everything already.
@@ -531,11 +530,11 @@ struct MergeStateMachine<
531530
// We are transitioning to merging in the taskStarted method.
532531
return .startTaskAndSuspendDownstreamTask(base1, base2, base3)
533532

534-
case .merging(_, _, _, _, .some):
533+
case .merging(_, _, _, _, .some, _):
535534
// We have multiple AsyncIterators iterating the sequence
536535
preconditionFailure("Internal inconsistency current state \(self.state) and received next()")
537536

538-
case .merging(let task, var buffer, let upstreamContinuations, let upstreamsFinished, .none):
537+
case .merging(let task, var buffer, let upstreamContinuations, let upstreamsFinished, .none, let cancelled):
539538
state = .modifying
540539

541540
if let element = buffer.popFirst() {
@@ -545,7 +544,8 @@ struct MergeStateMachine<
545544
buffer: buffer,
546545
upstreamContinuations: upstreamContinuations,
547546
upstreamsFinished: upstreamsFinished,
548-
downstreamContinuation: nil
547+
downstreamContinuation: nil,
548+
cancelled: cancelled
549549
)
550550

551551
return .returnElement(.success(element))
@@ -556,7 +556,8 @@ struct MergeStateMachine<
556556
buffer: buffer,
557557
upstreamContinuations: upstreamContinuations,
558558
upstreamsFinished: upstreamsFinished,
559-
downstreamContinuation: nil
559+
downstreamContinuation: nil,
560+
cancelled: cancelled
560561
)
561562

562563
return .suspendDownstreamTask
@@ -601,21 +602,22 @@ struct MergeStateMachine<
601602
mutating func next(for continuation: UnsafeContinuation<Element?, Error>) -> NextForAction {
602603
switch state {
603604
case .initial,
604-
.merging(_, _, _, _, .some),
605+
.merging(_, _, _, _, .some, _),
605606
.upstreamFailure,
606607
.finished:
607608
// All other states are handled by `next` already so we should never get in here with
608609
// any of those
609610
preconditionFailure("Internal inconsistency current state \(self.state) and received next(for:)")
610611

611-
case let .merging(task, buffer, upstreamContinuations, upstreamsFinished, .none):
612+
case let .merging(task, buffer, upstreamContinuations, upstreamsFinished, .none, cancelled):
612613
// We suspended the task and need signal the upstreams
613614
state = .merging(
614615
task: task,
615616
buffer: buffer,
616617
upstreamContinuations: [], // TODO: don't alloc new array here
617618
upstreamsFinished: upstreamsFinished,
618-
downstreamContinuation: continuation
619+
downstreamContinuation: continuation,
620+
cancelled: cancelled
619621
)
620622

621623
return .resumeUpstreamContinuations(

Sources/AsyncAlgorithms/Merge/MergeStorage.swift

+3-8
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,7 @@ final class MergeStorage<
128128

129129
downstreamContinuation.resume(returning: nil)
130130

131-
case let .cancelTaskAndUpstreamContinuations(
132-
task,
133-
upstreamContinuations
134-
):
135-
upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) }
136-
131+
case let .cancelTask(task):
137132
task.cancel()
138133

139134
case .none:
@@ -262,8 +257,8 @@ final class MergeStorage<
262257
task,
263258
upstreamContinuations
264259
):
265-
upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) }
266260
task.cancel()
261+
upstreamContinuations.forEach { $0.resume() }
267262

268263
downstreamContinuation.resume(returning: nil)
269264

@@ -273,8 +268,8 @@ final class MergeStorage<
273268
task,
274269
upstreamContinuations
275270
):
276-
upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) }
277271
task.cancel()
272+
upstreamContinuations.forEach { $0.resume() }
278273

279274
break loop
280275
case .none:

Tests/AsyncAlgorithmsTests/TestMerge.swift

+69
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,38 @@ final class TestMerge2: XCTestCase {
201201
}
202202
t.cancel()
203203
}
204+
205+
func testAsyncStreamElementsThatAreInjectedOnCancellationAreDelivered() async {
206+
let (stream1, continuation1) = AsyncStream.makeStream(of: Int.self)
207+
let (stream2, continuation2) = AsyncStream.makeStream(of: Int.self)
208+
continuation1.onTermination = { reason in
209+
XCTAssertEqual(reason, .cancelled)
210+
continuation1.yield(1)
211+
}
212+
continuation2.onTermination = { reason in
213+
XCTAssertEqual(reason, .cancelled)
214+
continuation2.yield(2)
215+
}
216+
continuation1.yield(0) // initial
217+
let merge = merge(stream1, stream2)
218+
let finished = expectation(description: "finished")
219+
let iterated = expectation(description: "iterated")
220+
let task = Task {
221+
var count = 0
222+
for await _ in merge {
223+
if count == 0 { iterated.fulfill() }
224+
count += 1
225+
}
226+
finished.fulfill()
227+
XCTAssertEqual(count, 3)
228+
}
229+
// ensure the other task actually starts
230+
await fulfillment(of: [iterated], timeout: 1.0)
231+
// cancellation should ensure the loop finishes
232+
// without regards to the remaining underlying sequence
233+
task.cancel()
234+
await fulfillment(of: [finished], timeout: 1.0)
235+
}
204236
}
205237

206238
final class TestMerge3: XCTestCase {
@@ -555,4 +587,41 @@ final class TestMerge3: XCTestCase {
555587

556588
iterator = nil
557589
}
590+
591+
func testAsyncStreamElementsThatAreInjectedOnCancellationAreDelivered() async {
592+
let (stream1, continuation1) = AsyncStream.makeStream(of: Int.self)
593+
let (stream2, continuation2) = AsyncStream.makeStream(of: Int.self)
594+
let (stream3, continuation3) = AsyncStream.makeStream(of: Int.self)
595+
continuation1.onTermination = { reason in
596+
XCTAssertEqual(reason, .cancelled)
597+
continuation1.yield(1)
598+
}
599+
continuation2.onTermination = { reason in
600+
XCTAssertEqual(reason, .cancelled)
601+
continuation2.yield(2)
602+
}
603+
continuation3.onTermination = { reason in
604+
XCTAssertEqual(reason, .cancelled)
605+
continuation3.yield(3)
606+
}
607+
continuation1.yield(0) // initial
608+
let merge = merge(stream1, stream2, stream3)
609+
let finished = expectation(description: "finished")
610+
let iterated = expectation(description: "iterated")
611+
let task = Task {
612+
var count = 0
613+
for await _ in merge {
614+
if count == 0 { iterated.fulfill() }
615+
count += 1
616+
}
617+
finished.fulfill()
618+
XCTAssertEqual(count, 4)
619+
}
620+
// ensure the other task actually starts
621+
await fulfillment(of: [iterated], timeout: 1.0)
622+
// cancellation should ensure the loop finishes
623+
// without regards to the remaining underlying sequence
624+
task.cancel()
625+
await fulfillment(of: [finished], timeout: 1.0)
626+
}
558627
}

0 commit comments

Comments
 (0)