|
1 |
| -use broadcaster::BroadcastChannel; |
2 |
| - |
3 |
| -use crate::sync::Mutex; |
| 1 | +use crate::sync::{Condvar,Mutex}; |
4 | 2 |
|
5 | 3 | /// A barrier enables multiple tasks to synchronize the beginning
|
6 | 4 | /// of some computation.
|
@@ -36,14 +34,13 @@ use crate::sync::Mutex;
|
36 | 34 | #[derive(Debug)]
|
37 | 35 | pub struct Barrier {
|
38 | 36 | state: Mutex<BarrierState>,
|
39 |
| - wait: BroadcastChannel<(usize, usize)>, |
40 |
| - n: usize, |
| 37 | + cvar: Condvar, |
| 38 | + num_tasks: usize, |
41 | 39 | }
|
42 | 40 |
|
43 | 41 | // The inner state of a double barrier
|
44 | 42 | #[derive(Debug)]
|
45 | 43 | struct BarrierState {
|
46 |
| - waker: BroadcastChannel<(usize, usize)>, |
47 | 44 | count: usize,
|
48 | 45 | generation_id: usize,
|
49 | 46 | }
|
@@ -81,25 +78,14 @@ impl Barrier {
|
81 | 78 | ///
|
82 | 79 | /// let barrier = Barrier::new(10);
|
83 | 80 | /// ```
|
84 |
| - pub fn new(mut n: usize) -> Barrier { |
85 |
| - let waker = BroadcastChannel::new(); |
86 |
| - let wait = waker.clone(); |
87 |
| - |
88 |
| - if n == 0 { |
89 |
| - // if n is 0, it's not clear what behavior the user wants. |
90 |
| - // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every |
91 |
| - // .wait() immediately unblocks, so we adopt that here as well. |
92 |
| - n = 1; |
93 |
| - } |
94 |
| - |
| 81 | + pub fn new(n: usize) -> Barrier { |
95 | 82 | Barrier {
|
96 | 83 | state: Mutex::new(BarrierState {
|
97 |
| - waker, |
98 | 84 | count: 0,
|
99 | 85 | generation_id: 1,
|
100 | 86 | }),
|
101 |
| - n, |
102 |
| - wait, |
| 87 | + cvar: Condvar::new(), |
| 88 | + num_tasks: n, |
103 | 89 | }
|
104 | 90 | }
|
105 | 91 |
|
@@ -143,35 +129,20 @@ impl Barrier {
|
143 | 129 | /// # });
|
144 | 130 | /// ```
|
145 | 131 | pub async fn wait(&self) -> BarrierWaitResult {
|
146 |
| - let mut lock = self.state.lock().await; |
147 |
| - let local_gen = lock.generation_id; |
148 |
| - |
149 |
| - lock.count += 1; |
| 132 | + let mut state = self.state.lock().await; |
| 133 | + let local_gen = state.generation_id; |
| 134 | + state.count += 1; |
150 | 135 |
|
151 |
| - if lock.count < self.n { |
152 |
| - let mut wait = self.wait.clone(); |
153 |
| - |
154 |
| - let mut generation_id = lock.generation_id; |
155 |
| - let mut count = lock.count; |
156 |
| - |
157 |
| - drop(lock); |
158 |
| - |
159 |
| - while local_gen == generation_id && count < self.n { |
160 |
| - let (g, c) = wait.recv().await.expect("sender has not been closed"); |
161 |
| - generation_id = g; |
162 |
| - count = c; |
| 136 | + if state.count < self.num_tasks { |
| 137 | + while local_gen == state.generation_id && state.count < self.num_tasks { |
| 138 | + state = self.cvar.wait(state).await; |
163 | 139 | }
|
164 | 140 |
|
165 | 141 | BarrierWaitResult(false)
|
166 | 142 | } else {
|
167 |
| - lock.count = 0; |
168 |
| - lock.generation_id = lock.generation_id.wrapping_add(1); |
169 |
| - |
170 |
| - lock.waker |
171 |
| - .send(&(lock.generation_id, lock.count)) |
172 |
| - .await |
173 |
| - .expect("there should be at least one receiver"); |
174 |
| - |
| 143 | + state.count = 0; |
| 144 | + state.generation_id = state.generation_id.wrapping_add(1); |
| 145 | + self.cvar.notify_all(); |
175 | 146 | BarrierWaitResult(true)
|
176 | 147 | }
|
177 | 148 | }
|
|
0 commit comments