Skip to content

Commit c5cb376

Browse files
authored
Merge pull request #8 from de-vri-es/fix-waker-list-growing-indefinitely
Fix the waker lists growing indefinitely.
2 parents 9ec5d4b + 12ca611 commit c5cb376

File tree

6 files changed

+335
-27
lines changed

6 files changed

+335
-27
lines changed

CHANGELOG.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
# main
2+
* Fix bug where the list of wakers to trigger on shutdown or shutdown completion could grow indefinitely.
3+
14
# Version 0.2.1 - 2023-10-08
2-
* Fix `ShutdownManager::wait_shutdown_complete()` never completing if callend when no shutdown was triggered yet and no delay tokens exist.
5+
* Fix `ShutdownManager::wait_shutdown_complete()` never completing if called when no shutdown was triggered yet and no delay tokens exist.
36

47
# Version 0.2.0 - 2023-09-26:
58
* Rename `Shutdown` struct to `ShutdownManager`.

src/lib.rs

+11-11
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,6 @@
169169

170170
use std::future::Future;
171171
use std::sync::{Arc, Mutex};
172-
use std::task::Waker;
173172

174173
mod shutdown_complete;
175174
pub use shutdown_complete::ShutdownComplete;
@@ -178,6 +177,7 @@ mod shutdown_signal;
178177
pub use shutdown_signal::ShutdownSignal;
179178

180179
mod wrap_cancel;
180+
use waker_list::WakerList;
181181
pub use wrap_cancel::WrapCancel;
182182

183183
mod wrap_trigger_shutdown;
@@ -186,6 +186,8 @@ pub use wrap_trigger_shutdown::WrapTriggerShutdown;
186186
mod wrap_delay_shutdown;
187187
pub use wrap_delay_shutdown::WrapDelayShutdown;
188188

189+
mod waker_list;
190+
189191
/// Shutdown manager for asynchronous tasks and futures.
190192
///
191193
/// The shutdown manager allows you to:
@@ -244,6 +246,7 @@ impl<T: Clone> ShutdownManager<T> {
244246
pub fn wait_shutdown_triggered(&self) -> ShutdownSignal<T> {
245247
ShutdownSignal {
246248
inner: self.inner.clone(),
249+
waker_token: None,
247250
}
248251
}
249252

@@ -258,6 +261,7 @@ impl<T: Clone> ShutdownManager<T> {
258261
pub fn wait_shutdown_complete(&self) -> ShutdownComplete<T> {
259262
ShutdownComplete {
260263
inner: self.inner.clone(),
264+
waker_token: None,
261265
}
262266
}
263267

@@ -459,19 +463,19 @@ struct ShutdownManagerInner<T> {
459463
delay_tokens: usize,
460464

461465
/// Tasks to wake when a shutdown is triggered.
462-
on_shutdown: Vec<Waker>,
466+
on_shutdown: WakerList,
463467

464468
/// Tasks to wake when the shutdown is complete.
465-
on_shutdown_complete: Vec<Waker>,
469+
on_shutdown_complete: WakerList,
466470
}
467471

468472
impl<T: Clone> ShutdownManagerInner<T> {
469473
fn new() -> Self {
470474
Self {
471475
shutdown_reason: None,
472476
delay_tokens: 0,
473-
on_shutdown_complete: Vec::new(),
474-
on_shutdown: Vec::new(),
477+
on_shutdown_complete: WakerList::new(),
478+
on_shutdown: WakerList::new(),
475479
}
476480
}
477481

@@ -493,9 +497,7 @@ impl<T: Clone> ShutdownManagerInner<T> {
493497
},
494498
None => {
495499
self.shutdown_reason = Some(reason);
496-
for abort in std::mem::take(&mut self.on_shutdown) {
497-
abort.wake()
498-
}
500+
self.on_shutdown.wake_all();
499501
if self.delay_tokens == 0 {
500502
self.notify_shutdown_complete()
501503
}
@@ -505,9 +507,7 @@ impl<T: Clone> ShutdownManagerInner<T> {
505507
}
506508

507509
fn notify_shutdown_complete(&mut self) {
508-
for waiter in std::mem::take(&mut self.on_shutdown_complete) {
509-
waiter.wake()
510-
}
510+
self.on_shutdown_complete.wake_all();
511511
}
512512
}
513513

src/shutdown_complete.rs

+115-8
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,137 @@ use std::pin::Pin;
33
use std::sync::{Arc, Mutex};
44
use std::task::{Context, Poll};
55

6+
use crate::waker_list::WakerToken;
67
use crate::ShutdownManagerInner;
78

89
/// Future to wait for a shutdown to complete.
910
pub struct ShutdownComplete<T: Clone> {
1011
pub(crate) inner: Arc<Mutex<ShutdownManagerInner<T>>>,
12+
pub(crate) waker_token: Option<WakerToken>,
13+
}
14+
15+
impl<T: Clone> Clone for ShutdownComplete<T> {
16+
fn clone(&self) -> Self {
17+
// Clone only the reference to the shutdown manager, not the waker token.
18+
// The waker token is personal to each future.
19+
Self {
20+
inner: self.inner.clone(),
21+
waker_token: None,
22+
}
23+
}
24+
}
25+
26+
impl<T: Clone> Drop for ShutdownComplete<T> {
27+
fn drop(&mut self) {
28+
if let Some(token) = self.waker_token.take() {
29+
let mut inner = self.inner.lock().unwrap();
30+
inner.on_shutdown_complete.deregister(token);
31+
}
32+
}
1133
}
1234

1335
impl<T: Clone> Future for ShutdownComplete<T> {
1436
type Output = T;
1537

1638
#[inline]
1739
fn poll(self: Pin<&mut Self>, context: &mut Context) -> Poll<Self::Output> {
18-
let me = self.as_ref();
40+
let me = self.get_mut();
1941
let mut inner = me.inner.lock().unwrap();
42+
43+
// We're being polled, so we should deregister the waker (if any).
44+
if let Some(token) = me.waker_token.take() {
45+
inner.on_shutdown_complete.deregister(token);
46+
}
47+
48+
// Check if the shutdown is completed.
2049
if inner.delay_tokens == 0 {
2150
if let Some(reason) = inner.shutdown_reason.clone() {
22-
Poll::Ready(reason)
23-
} else {
24-
inner.on_shutdown_complete.push(context.waker().clone());
25-
Poll::Pending
51+
return Poll::Ready(reason);
2652
}
27-
} else {
28-
inner.on_shutdown_complete.push(context.waker().clone());
29-
Poll::Pending
53+
}
54+
55+
// We're not ready, so register the waker to wake us on shutdown completion.
56+
me.waker_token = Some(inner.on_shutdown_complete.register(context.waker().clone()));
57+
58+
Poll::Pending
59+
}
60+
}
61+
62+
#[cfg(test)]
63+
mod test {
64+
use assert2::assert;
65+
use std::future::Future;
66+
use std::pin::Pin;
67+
use std::task::Poll;
68+
69+
/// Wrapper around a future to poll it only once.
70+
struct PollOnce<'a, F>(&'a mut F);
71+
72+
impl<'a, F: std::marker::Unpin + Future> Future for PollOnce<'a, F> {
73+
type Output = Poll<F::Output>;
74+
75+
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
76+
Poll::Ready(Pin::new(&mut self.get_mut().0).poll(cx))
77+
}
78+
}
79+
80+
/// Poll a future once.
81+
async fn poll_once<F: Future + Unpin>(future: &mut F) -> Poll<F::Output> {
82+
PollOnce(future).await
83+
}
84+
85+
#[tokio::test]
86+
async fn waker_list_doesnt_grow_infinitely() {
87+
let shutdown = crate::ShutdownManager::<()>::new();
88+
for i in 0..100_000 {
89+
let mut wait_shutdown_complete = shutdown.wait_shutdown_complete();
90+
let task = tokio::spawn(async move {
91+
assert!(let Poll::Pending = poll_once(&mut wait_shutdown_complete).await);
92+
});
93+
assert!(let Ok(()) = task.await, "task = {i}");
94+
}
95+
96+
// Since we wait for each task to complete before spawning another,
97+
// the total amount of waker slots used should be only 1.
98+
let inner = shutdown.inner.lock().unwrap();
99+
assert!(inner.on_shutdown_complete.total_slots() == 1);
100+
assert!(inner.on_shutdown_complete.empty_slots() == 1);
101+
}
102+
103+
#[tokio::test]
104+
async fn cloning_does_not_clone_waker_token() {
105+
let shutdown = crate::ShutdownManager::<()>::new();
106+
107+
let mut signal = shutdown.wait_shutdown_complete();
108+
assert!(let None = &signal.waker_token);
109+
110+
assert!(let Poll::Pending = poll_once(&mut signal).await);
111+
assert!(let Some(_) = &signal.waker_token);
112+
113+
let mut cloned = signal.clone();
114+
assert!(let None = &cloned.waker_token);
115+
assert!(let Some(_) = &signal.waker_token);
116+
117+
assert!(let Poll::Pending = poll_once(&mut cloned).await);
118+
assert!(let Some(_) = &cloned.waker_token);
119+
assert!(let Some(_) = &signal.waker_token);
120+
121+
{
122+
let inner = shutdown.inner.lock().unwrap();
123+
assert!(inner.on_shutdown_complete.total_slots() == 2);
124+
assert!(inner.on_shutdown_complete.empty_slots() == 0);
125+
}
126+
127+
{
128+
drop(signal);
129+
let inner = shutdown.inner.lock().unwrap();
130+
assert!(inner.on_shutdown_complete.empty_slots() == 1);
131+
}
132+
133+
{
134+
drop(cloned);
135+
let inner = shutdown.inner.lock().unwrap();
136+
assert!(inner.on_shutdown_complete.empty_slots() == 2);
30137
}
31138
}
32139
}

src/shutdown_signal.rs

+110-3
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,37 @@ use std::pin::Pin;
33
use std::sync::{Arc, Mutex};
44
use std::task::{Context, Poll};
55

6+
use crate::waker_list::WakerToken;
67
use crate::{WrapCancel, ShutdownManagerInner};
78

89
/// A future to wait for a shutdown signal.
910
///
1011
/// The future completes when the associated [`ShutdownManager`][crate::ShutdownManager] triggers a shutdown.
1112
///
1213
/// The shutdown signal can be cloned and sent between threads freely.
13-
#[derive(Clone)]
1414
pub struct ShutdownSignal<T: Clone> {
1515
pub(crate) inner: Arc<Mutex<ShutdownManagerInner<T>>>,
16+
pub(crate) waker_token: Option<WakerToken>,
17+
}
18+
19+
impl<T: Clone> Clone for ShutdownSignal<T> {
20+
fn clone(&self) -> Self {
21+
// Clone only the reference to the shutdown manager, not the waker token.
22+
// The waker token is personal to each future.
23+
Self {
24+
inner: self.inner.clone(),
25+
waker_token: None,
26+
}
27+
}
28+
}
29+
30+
impl<T: Clone> Drop for ShutdownSignal<T> {
31+
fn drop(&mut self) {
32+
if let Some(token) = self.waker_token.take() {
33+
let mut inner = self.inner.lock().unwrap();
34+
inner.on_shutdown.deregister(token);
35+
}
36+
}
1637
}
1738

1839
impl<T: Clone> ShutdownSignal<T> {
@@ -36,13 +57,99 @@ impl<T: Clone> Future for ShutdownSignal<T> {
3657

3758
#[inline]
3859
fn poll(self: Pin<&mut Self>, context: &mut Context) -> Poll<Self::Output> {
39-
let me = self.as_ref();
60+
let me = self.get_mut();
4061
let mut inner = me.inner.lock().unwrap();
62+
63+
// We're being polled, so we should deregister the waker (if any).
64+
if let Some(token) = me.waker_token.take() {
65+
inner.on_shutdown.deregister(token);
66+
}
67+
4168
if let Some(reason) = inner.shutdown_reason.clone() {
69+
// Shutdown started, so we're ready.
4270
Poll::Ready(reason)
4371
} else {
44-
inner.on_shutdown.push(context.waker().clone());
72+
// We're not ready, so register the waker to wake us on shutdown start.
73+
me.waker_token = Some(inner.on_shutdown.register(context.waker().clone()));
4574
Poll::Pending
4675
}
4776
}
4877
}
78+
79+
#[cfg(test)]
80+
mod test {
81+
use assert2::assert;
82+
use std::future::Future;
83+
use std::pin::Pin;
84+
use std::task::Poll;
85+
86+
/// Wrapper around a future to poll it only once.
87+
struct PollOnce<'a, F>(&'a mut F);
88+
89+
impl<'a, F: std::marker::Unpin + Future> Future for PollOnce<'a, F> {
90+
type Output = Poll<F::Output>;
91+
92+
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
93+
Poll::Ready(Pin::new(&mut self.get_mut().0).poll(cx))
94+
}
95+
}
96+
97+
/// Poll a future once.
98+
async fn poll_once<F: Future + Unpin>(future: &mut F) -> Poll<F::Output> {
99+
PollOnce(future).await
100+
}
101+
102+
#[tokio::test]
103+
async fn waker_list_doesnt_grow_infinitely() {
104+
let shutdown = crate::ShutdownManager::<()>::new();
105+
for i in 0..100_000 {
106+
let task = tokio::spawn(shutdown.wrap_cancel(async move {
107+
tokio::task::yield_now().await;
108+
}));
109+
assert!(let Ok(Ok(())) = task.await, "task = {i}");
110+
}
111+
112+
// Since we wait for each task to complete before spawning another,
113+
// the total amount of waker slots used should be only 1.
114+
let inner = shutdown.inner.lock().unwrap();
115+
assert!(inner.on_shutdown.total_slots() == 1);
116+
assert!(inner.on_shutdown.empty_slots() == 1);
117+
}
118+
119+
#[tokio::test]
120+
async fn cloning_does_not_clone_waker_token() {
121+
let shutdown = crate::ShutdownManager::<()>::new();
122+
123+
let mut signal = shutdown.wait_shutdown_triggered();
124+
assert!(let None = &signal.waker_token);
125+
126+
assert!(let Poll::Pending = poll_once(&mut signal).await);
127+
assert!(let Some(_) = &signal.waker_token);
128+
129+
let mut cloned = signal.clone();
130+
assert!(let None = &cloned.waker_token);
131+
assert!(let Some(_) = &signal.waker_token);
132+
133+
assert!(let Poll::Pending = poll_once(&mut cloned).await);
134+
assert!(let Some(_) = &cloned.waker_token);
135+
assert!(let Some(_) = &signal.waker_token);
136+
137+
{
138+
let inner = shutdown.inner.lock().unwrap();
139+
assert!(inner.on_shutdown.total_slots() == 2);
140+
assert!(inner.on_shutdown.empty_slots() == 0);
141+
}
142+
143+
{
144+
drop(signal);
145+
let inner = shutdown.inner.lock().unwrap();
146+
assert!(inner.on_shutdown.empty_slots() == 1);
147+
}
148+
149+
{
150+
drop(cloned);
151+
let inner = shutdown.inner.lock().unwrap();
152+
assert!(inner.on_shutdown.empty_slots() == 2);
153+
}
154+
}
155+
}

0 commit comments

Comments
 (0)