-
Notifications
You must be signed in to change notification settings - Fork 160
[WIP] Broadcast algorithm #214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
phausler
wants to merge
2
commits into
apple:main
Choose a base branch
from
phausler:pr/broadcast
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+307
−1
Draft
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This source file is part of the Swift Async Algorithms open source project | ||
// | ||
// Copyright (c) 2022 Apple Inc. and the Swift project authors | ||
// Licensed under Apache License v2.0 with Runtime Library Exception | ||
// | ||
// See https://swift.org/LICENSE.txt for license information | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
import DequeModule | ||
|
||
extension AsyncSequence where Self: Sendable, Element: Sendable { | ||
public func broadcast() -> AsyncBroadcastSequence<Self> { | ||
AsyncBroadcastSequence(self) | ||
} | ||
} | ||
|
||
public struct AsyncBroadcastSequence<Base: AsyncSequence>: Sendable where Base: Sendable, Base.Element: Sendable { | ||
struct State : Sendable { | ||
enum Terminal { | ||
case failure(Error) | ||
case finished | ||
} | ||
|
||
struct Side { | ||
var buffer = Deque<Element>() | ||
var terminal: Terminal? | ||
var continuation: UnsafeContinuation<Result<Element?, Error>, Never>? | ||
|
||
mutating func drain() { | ||
if !buffer.isEmpty, let continuation { | ||
let element = buffer.removeFirst() | ||
continuation.resume(returning: .success(element)) | ||
self.continuation = nil | ||
} else if let terminal, let continuation { | ||
switch terminal { | ||
case .failure(let error): | ||
self.terminal = .finished | ||
continuation.resume(returning: .failure(error)) | ||
case .finished: | ||
continuation.resume(returning: .success(nil)) | ||
} | ||
self.continuation = nil | ||
} | ||
} | ||
|
||
mutating func cancel() { | ||
buffer.removeAll() | ||
terminal = .finished | ||
drain() | ||
} | ||
|
||
mutating func next(_ continuation: UnsafeContinuation<Result<Element?, Error>, Never>) { | ||
assert(self.continuation == nil) // presume that the sides are NOT sendable iterators... | ||
self.continuation = continuation | ||
drain() | ||
} | ||
|
||
mutating func emit(_ result: Result<Element?, Error>) { | ||
switch result { | ||
case .success(let element): | ||
if let element { | ||
buffer.append(element) | ||
} else { | ||
terminal = .finished | ||
} | ||
case .failure(let error): | ||
terminal = .failure(error) | ||
} | ||
drain() | ||
} | ||
} | ||
|
||
var id = 0 | ||
var sides = [Int: Side]() | ||
|
||
init() { } | ||
|
||
mutating func establish() -> Int { | ||
defer { id += 1 } | ||
sides[id] = Side() | ||
return id | ||
} | ||
|
||
static func establish(_ state: ManagedCriticalState<State>) -> Int { | ||
state.withCriticalRegion { $0.establish() } | ||
} | ||
|
||
mutating func cancel(_ id: Int) { | ||
if var side = sides.removeValue(forKey: id) { | ||
side.cancel() | ||
} | ||
} | ||
|
||
static func cancel(_ state: ManagedCriticalState<State>, id: Int) { | ||
state.withCriticalRegion { $0.cancel(id) } | ||
} | ||
|
||
mutating func next(_ id: Int, continuation: UnsafeContinuation<Result<Element?, Error>, Never>) { | ||
sides[id]?.next(continuation) | ||
} | ||
|
||
static func next(_ state: ManagedCriticalState<State>, id: Int) async -> Result<Element?, Error> { | ||
await withUnsafeContinuation { continuation in | ||
state.withCriticalRegion { $0.next(id, continuation: continuation) } | ||
} | ||
} | ||
|
||
mutating func emit(_ result: Result<Element?, Error>) { | ||
for id in sides.keys { | ||
sides[id]?.emit(result) | ||
} | ||
} | ||
|
||
static func emit(_ state: ManagedCriticalState<State>, result: Result<Element?, Error>) { | ||
state.withCriticalRegion { $0.emit(result) } | ||
} | ||
} | ||
|
||
struct Iteration { | ||
enum Status { | ||
case initial(Base) | ||
case iterating(Task<Void, Never>) | ||
case terminal | ||
} | ||
|
||
var status: Status | ||
|
||
init(_ base: Base) { | ||
status = .initial(base) | ||
} | ||
|
||
static func task(_ state: ManagedCriticalState<State>, base: Base) -> Task<Void, Never> { | ||
Task { | ||
do { | ||
for try await element in base { | ||
State.emit(state, result: .success(element)) | ||
} | ||
State.emit(state, result: .success(nil)) | ||
} catch { | ||
State.emit(state, result: .failure(error)) | ||
} | ||
} | ||
} | ||
|
||
mutating func start(_ state: ManagedCriticalState<State>) -> Bool { | ||
switch status { | ||
case .terminal: | ||
return false | ||
case .initial(let base): | ||
status = .iterating(Iteration.task(state, base: base)) | ||
default: | ||
break | ||
} | ||
return true | ||
} | ||
|
||
mutating func cancel() { | ||
switch status { | ||
case .iterating(let task): | ||
task.cancel() | ||
default: | ||
break | ||
} | ||
status = .terminal | ||
} | ||
|
||
static func start(_ iteration: ManagedCriticalState<Iteration>, state: ManagedCriticalState<State>) -> Bool { | ||
iteration.withCriticalRegion { $0.start(state) } | ||
} | ||
|
||
static func cancel(_ iteration: ManagedCriticalState<Iteration>) { | ||
iteration.withCriticalRegion { $0.cancel() } | ||
} | ||
} | ||
|
||
let state: ManagedCriticalState<State> | ||
let iteration: ManagedCriticalState<Iteration> | ||
|
||
init(_ base: Base) { | ||
state = ManagedCriticalState(State()) | ||
iteration = ManagedCriticalState(Iteration(base)) | ||
} | ||
} | ||
|
||
|
||
extension AsyncBroadcastSequence: AsyncSequence { | ||
public typealias Element = Base.Element | ||
|
||
public struct Iterator: AsyncIteratorProtocol { | ||
final class Context { | ||
let state: ManagedCriticalState<State> | ||
var iteration: ManagedCriticalState<Iteration> | ||
let id: Int | ||
|
||
init(_ state: ManagedCriticalState<State>, _ iteration: ManagedCriticalState<Iteration>) { | ||
self.state = state | ||
self.iteration = iteration | ||
self.id = State.establish(state) | ||
} | ||
|
||
deinit { | ||
State.cancel(state, id: id) | ||
if iteration.isKnownUniquelyReferenced() { | ||
Iteration.cancel(iteration) | ||
} | ||
Comment on lines
+206
to
+208
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I might be missing something here, but if the last two iterators enter |
||
} | ||
|
||
func next() async rethrows -> Element? { | ||
guard Iteration.start(iteration, state: state) else { | ||
return nil | ||
} | ||
defer { | ||
if Task.isCancelled && iteration.isKnownUniquelyReferenced() { | ||
Iteration.cancel(iteration) | ||
} | ||
} | ||
return try await withTaskCancellationHandler { | ||
let result = await State.next(state, id: id) | ||
return try result._rethrowGet() | ||
} onCancel: { [state, id] in | ||
State.cancel(state, id: id) | ||
} | ||
} | ||
} | ||
|
||
let context: Context | ||
|
||
init(_ state: ManagedCriticalState<State>, _ iteration: ManagedCriticalState<Iteration>) { | ||
context = Context(state, iteration) | ||
} | ||
|
||
public mutating func next() async rethrows -> Element? { | ||
try await context.next() | ||
} | ||
} | ||
|
||
public func makeAsyncIterator() -> Iterator { | ||
Iterator(state, iteration) | ||
} | ||
} | ||
|
||
@available(*, unavailable) | ||
extension AsyncBroadcastSequence.Iterator: Sendable { } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This source file is part of the Swift Async Algorithms open source project | ||
// | ||
// Copyright (c) 2022 Apple Inc. and the Swift project authors | ||
// Licensed under Apache License v2.0 with Runtime Library Exception | ||
// | ||
// See https://swift.org/LICENSE.txt for license information | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
@preconcurrency import XCTest | ||
import AsyncAlgorithms | ||
|
||
final class TestBroadcast: XCTestCase { | ||
func test_basic_broadcasting() async { | ||
let base = [1, 2, 3, 4].async | ||
let a = base.broadcast() | ||
let b = a | ||
let results = await withTaskGroup(of: [Int].self) { group in | ||
group.addTask { | ||
await Array(a) | ||
} | ||
group.addTask { | ||
await Array(b) | ||
} | ||
return await Array(group) | ||
} | ||
XCTAssertEqual(results[0], results[1]) | ||
} | ||
|
||
func test_basic_broadcasting_from_channel() async { | ||
let base = AsyncChannel<Int>() | ||
let a = base.broadcast() | ||
let b = a | ||
let results = await withTaskGroup(of: [Int].self) { group in | ||
group.addTask { | ||
var sent = [Int]() | ||
for i in 0..<10 { | ||
sent.append(i) | ||
await base.send(i) | ||
} | ||
base.finish() | ||
return sent | ||
} | ||
group.addTask { | ||
await Array(a) | ||
} | ||
group.addTask { | ||
await Array(b) | ||
} | ||
return await Array(group) | ||
} | ||
XCTAssertEqual(results[0], results[1]) | ||
} | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing I just noticed after the discussion in the forums around deferred. This is actually creating multiple upstream iterators if I understand this code correctly. IMO we really need to avoid this otherwise this algorithm is not capable of transforming a unicast
AsyncSequence
into a broadcasted one. @phausler please correct if I am misunderstanding the code hereThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, this is creating 1 singular task to iterate so only 1 upstream iterator is made.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm clearly being dense, because it also looks to me like this would create a new
Task
, at least each timebroadcast
is called. And for eachTask
there would be a separate iterator of the base sequence.Ah, so maybe the intent is that
broadcast
is only called once, and then the broadcast sequence is copied to duplicate it? That's a bit awkward.