Skip to content

Commit ebe4e1f

Browse files
authored
Fix race condition in shutdown of background task (#81)
1 parent b3926d4 commit ebe4e1f

File tree

6 files changed

+62
-35
lines changed

6 files changed

+62
-35
lines changed

src/bin/server.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ pub async fn main() -> mini_redis::Result<()> {
2424
// Bind a TCP listener
2525
let listener = TcpListener::bind(&format!("127.0.0.1:{}", port)).await?;
2626

27-
server::run(listener, signal::ctrl_c()).await
27+
server::run(listener, signal::ctrl_c()).await;
28+
29+
Ok(())
2830
}
2931

3032
#[derive(StructOpt, Debug)]

src/db.rs

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,17 @@ use tokio::time::{self, Duration, Instant};
44
use bytes::Bytes;
55
use std::collections::{BTreeMap, HashMap};
66
use std::sync::{Arc, Mutex};
7+
use tracing::debug;
8+
9+
/// A wrapper around a `Db` instance. This exists to allow orderly cleanup
10+
/// of the `Db` by signalling the background purge task to shut down when
11+
/// this struct is dropped.
12+
#[derive(Debug)]
13+
pub(crate) struct DbDropGuard {
14+
/// The `Db` instance that will be shut down when this `DbHolder` struct
15+
/// is dropped.
16+
db: Db,
17+
}
718

819
/// Server state shared across all connections.
920
///
@@ -92,6 +103,27 @@ struct Entry {
92103
expires_at: Option<Instant>,
93104
}
94105

106+
impl DbDropGuard {
107+
/// Create a new `DbHolder`, wrapping a `Db` instance. When this is dropped
108+
/// the `Db`'s purge task will be shut down.
109+
pub(crate) fn new() -> DbDropGuard {
110+
DbDropGuard { db: Db::new() }
111+
}
112+
113+
/// Get the shared database. Internally, this is an
114+
/// `Arc`, so a clone only increments the ref count.
115+
pub(crate) fn db(&self) -> Db {
116+
self.db.clone()
117+
}
118+
}
119+
120+
impl Drop for DbDropGuard {
121+
fn drop(&mut self) {
122+
// Signal the 'Db' instance to shut down the task that purges expired keys
123+
self.db.shutdown_purge_task();
124+
}
125+
}
126+
95127
impl Db {
96128
/// Create a new, empty, `Db` instance. Allocates shared state and spawns a
97129
/// background task to manage key expiration.
@@ -244,28 +276,20 @@ impl Db {
244276
// subscribers. In this case, return `0`.
245277
.unwrap_or(0)
246278
}
247-
}
248279

249-
impl Drop for Db {
250-
fn drop(&mut self) {
251-
// If this is the last active `Db` instance, the background task must be
252-
// notified to shut down.
253-
//
254-
// First, determine if this is the last `Db` instance. This is done by
255-
// checking `strong_count`. The count will be 2. One for this `Db`
256-
// instance and one for the handle held by the background task.
257-
if Arc::strong_count(&self.shared) == 2 {
258-
// The background task must be signaled to shutdown. This is done by
259-
// setting `State::shutdown` to `true` and signalling the task.
260-
let mut state = self.shared.state.lock().unwrap();
261-
state.shutdown = true;
262-
263-
// Drop the lock before signalling the background task. This helps
264-
// reduce lock contention by ensuring the background task doesn't
265-
// wake up only to be unable to acquire the mutex.
266-
drop(state);
267-
self.shared.background_task.notify_one();
268-
}
280+
/// Signals the purge background task to shut down. This is called by the
281+
/// `DbShutdown`s `Drop` implementation.
282+
fn shutdown_purge_task(&self) {
283+
// The background task must be signaled to shut down. This is done by
284+
// setting `State::shutdown` to `true` and signalling the task.
285+
let mut state = self.shared.state.lock().unwrap();
286+
state.shutdown = true;
287+
288+
// Drop the lock before signalling the background task. This helps
289+
// reduce lock contention by ensuring the background task doesn't
290+
// wake up only to be unable to acquire the mutex.
291+
drop(state);
292+
self.shared.background_task.notify_one();
269293
}
270294
}
271295

@@ -349,4 +373,6 @@ async fn purge_expired_tasks(shared: Arc<Shared>) {
349373
shared.background_task.notified().await;
350374
}
351375
}
376+
377+
debug!("Purge background task shut down")
352378
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ pub use frame::Frame;
3939

4040
mod db;
4141
use db::Db;
42+
use db::DbDropGuard;
4243

4344
mod parse;
4445
use parse::{Parse, ParseError};

src/server.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
//! Provides an async `run` function that listens for inbound connections,
44
//! spawning a task per connection.
55
6-
use crate::{Command, Connection, Db, Shutdown};
6+
use crate::{Command, Connection, Db, DbDropGuard, Shutdown};
77

88
use std::future::Future;
99
use std::sync::Arc;
@@ -21,9 +21,9 @@ struct Listener {
2121
/// Contains the key / value store as well as the broadcast channels for
2222
/// pub/sub.
2323
///
24-
/// This is a wrapper around an `Arc`. This enables `db` to be cloned and
25-
/// passed into the per connection state (`Handler`).
26-
db: Db,
24+
/// This holds a wrapper around an `Arc`. The internal `Db` can be
25+
/// retrieved and passed into the per connection state (`Handler`).
26+
db_holder: DbDropGuard,
2727

2828
/// TCP listener supplied by the `run` caller.
2929
listener: TcpListener,
@@ -128,7 +128,7 @@ const MAX_CONNECTIONS: usize = 250;
128128
///
129129
/// `tokio::signal::ctrl_c()` can be used as the `shutdown` argument. This will
130130
/// listen for a SIGINT signal.
131-
pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result<()> {
131+
pub async fn run(listener: TcpListener, shutdown: impl Future) {
132132
// When the provided `shutdown` future completes, we must send a shutdown
133133
// message to all active connections. We use a broadcast channel for this
134134
// purpose. The call below ignores the receiver of the broadcast pair, and when
@@ -140,7 +140,7 @@ pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result<
140140
// Initialize the listener state
141141
let mut server = Listener {
142142
listener,
143-
db: Db::new(),
143+
db_holder: DbDropGuard::new(),
144144
limit_connections: Arc::new(Semaphore::new(MAX_CONNECTIONS)),
145145
notify_shutdown,
146146
shutdown_complete_tx,
@@ -193,6 +193,7 @@ pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result<
193193
notify_shutdown,
194194
..
195195
} = server;
196+
196197
// When `notify_shutdown` is dropped, all tasks which have `subscribe`d will
197198
// receive the shutdown signal and can exit
198199
drop(notify_shutdown);
@@ -204,8 +205,6 @@ pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result<
204205
// `Sender` instances are held by connection handler tasks. When those drop,
205206
// the `mpsc` channel will close and `recv()` will return `None`.
206207
let _ = shutdown_complete_rx.recv().await;
207-
208-
Ok(())
209208
}
210209

211210
impl Listener {
@@ -250,9 +249,8 @@ impl Listener {
250249

251250
// Create the necessary per-connection handler state.
252251
let mut handler = Handler {
253-
// Get a handle to the shared database. Internally, this is an
254-
// `Arc`, so a clone only increments the ref count.
255-
db: self.db.clone(),
252+
// Get a handle to the shared database.
253+
db: self.db_holder.db(),
256254

257255
// Initialize the connection state. This allocates read/write
258256
// buffers to perform redis protocol frame parsing.

tests/buffer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ async fn pool_key_value_get_set() {
2020
assert_eq!(b"world", &value[..])
2121
}
2222

23-
async fn start_server() -> (SocketAddr, JoinHandle<mini_redis::Result<()>>) {
23+
async fn start_server() -> (SocketAddr, JoinHandle<()>) {
2424
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2525
let addr = listener.local_addr().unwrap();
2626

tests/client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ async fn unsubscribes_from_channels() {
8282
assert_eq!(subscriber.get_subscribed().len(), 0);
8383
}
8484

85-
async fn start_server() -> (SocketAddr, JoinHandle<mini_redis::Result<()>>) {
85+
async fn start_server() -> (SocketAddr, JoinHandle<()>) {
8686
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
8787
let addr = listener.local_addr().unwrap();
8888

0 commit comments

Comments
 (0)