Skip to content

Commit

Permalink
Fix race condition in shutdown of background task (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hodkinson authored Jul 13, 2021
1 parent b3926d4 commit ebe4e1f
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 35 deletions.
4 changes: 3 additions & 1 deletion src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ pub async fn main() -> mini_redis::Result<()> {
// Bind a TCP listener
let listener = TcpListener::bind(&format!("127.0.0.1:{}", port)).await?;

server::run(listener, signal::ctrl_c()).await
server::run(listener, signal::ctrl_c()).await;

Ok(())
}

#[derive(StructOpt, Debug)]
Expand Down
68 changes: 47 additions & 21 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@ use tokio::time::{self, Duration, Instant};
use bytes::Bytes;
use std::collections::{BTreeMap, HashMap};
use std::sync::{Arc, Mutex};
use tracing::debug;

/// A wrapper around a `Db` instance. This exists to allow orderly cleanup
/// of the `Db` by signalling the background purge task to shut down when
/// this struct is dropped.
#[derive(Debug)]
pub(crate) struct DbDropGuard {
/// The `Db` instance that will be shut down when this `DbHolder` struct
/// is dropped.
db: Db,
}

/// Server state shared across all connections.
///
Expand Down Expand Up @@ -92,6 +103,27 @@ struct Entry {
expires_at: Option<Instant>,
}

impl DbDropGuard {
/// Create a new `DbHolder`, wrapping a `Db` instance. When this is dropped
/// the `Db`'s purge task will be shut down.
pub(crate) fn new() -> DbDropGuard {
DbDropGuard { db: Db::new() }
}

/// Get the shared database. Internally, this is an
/// `Arc`, so a clone only increments the ref count.
pub(crate) fn db(&self) -> Db {
self.db.clone()
}
}

impl Drop for DbDropGuard {
fn drop(&mut self) {
// Signal the 'Db' instance to shut down the task that purges expired keys
self.db.shutdown_purge_task();
}
}

impl Db {
/// Create a new, empty, `Db` instance. Allocates shared state and spawns a
/// background task to manage key expiration.
Expand Down Expand Up @@ -244,28 +276,20 @@ impl Db {
// subscribers. In this case, return `0`.
.unwrap_or(0)
}
}

impl Drop for Db {
fn drop(&mut self) {
// If this is the last active `Db` instance, the background task must be
// notified to shut down.
//
// First, determine if this is the last `Db` instance. This is done by
// checking `strong_count`. The count will be 2. One for this `Db`
// instance and one for the handle held by the background task.
if Arc::strong_count(&self.shared) == 2 {
// The background task must be signaled to shutdown. This is done by
// setting `State::shutdown` to `true` and signalling the task.
let mut state = self.shared.state.lock().unwrap();
state.shutdown = true;

// Drop the lock before signalling the background task. This helps
// reduce lock contention by ensuring the background task doesn't
// wake up only to be unable to acquire the mutex.
drop(state);
self.shared.background_task.notify_one();
}
/// Signals the purge background task to shut down. This is called by the
/// `DbShutdown`s `Drop` implementation.
fn shutdown_purge_task(&self) {
// The background task must be signaled to shut down. This is done by
// setting `State::shutdown` to `true` and signalling the task.
let mut state = self.shared.state.lock().unwrap();
state.shutdown = true;

// Drop the lock before signalling the background task. This helps
// reduce lock contention by ensuring the background task doesn't
// wake up only to be unable to acquire the mutex.
drop(state);
self.shared.background_task.notify_one();
}
}

Expand Down Expand Up @@ -349,4 +373,6 @@ async fn purge_expired_tasks(shared: Arc<Shared>) {
shared.background_task.notified().await;
}
}

debug!("Purge background task shut down")
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub use frame::Frame;

mod db;
use db::Db;
use db::DbDropGuard;

mod parse;
use parse::{Parse, ParseError};
Expand Down
20 changes: 9 additions & 11 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! Provides an async `run` function that listens for inbound connections,
//! spawning a task per connection.
use crate::{Command, Connection, Db, Shutdown};
use crate::{Command, Connection, Db, DbDropGuard, Shutdown};

use std::future::Future;
use std::sync::Arc;
Expand All @@ -21,9 +21,9 @@ struct Listener {
/// Contains the key / value store as well as the broadcast channels for
/// pub/sub.
///
/// This is a wrapper around an `Arc`. This enables `db` to be cloned and
/// passed into the per connection state (`Handler`).
db: Db,
/// This holds a wrapper around an `Arc`. The internal `Db` can be
/// retrieved and passed into the per connection state (`Handler`).
db_holder: DbDropGuard,

/// TCP listener supplied by the `run` caller.
listener: TcpListener,
Expand Down Expand Up @@ -128,7 +128,7 @@ const MAX_CONNECTIONS: usize = 250;
///
/// `tokio::signal::ctrl_c()` can be used as the `shutdown` argument. This will
/// listen for a SIGINT signal.
pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result<()> {
pub async fn run(listener: TcpListener, shutdown: impl Future) {
// When the provided `shutdown` future completes, we must send a shutdown
// message to all active connections. We use a broadcast channel for this
// purpose. The call below ignores the receiver of the broadcast pair, and when
Expand All @@ -140,7 +140,7 @@ pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result<
// Initialize the listener state
let mut server = Listener {
listener,
db: Db::new(),
db_holder: DbDropGuard::new(),
limit_connections: Arc::new(Semaphore::new(MAX_CONNECTIONS)),
notify_shutdown,
shutdown_complete_tx,
Expand Down Expand Up @@ -193,6 +193,7 @@ pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result<
notify_shutdown,
..
} = server;

// When `notify_shutdown` is dropped, all tasks which have `subscribe`d will
// receive the shutdown signal and can exit
drop(notify_shutdown);
Expand All @@ -204,8 +205,6 @@ pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result<
// `Sender` instances are held by connection handler tasks. When those drop,
// the `mpsc` channel will close and `recv()` will return `None`.
let _ = shutdown_complete_rx.recv().await;

Ok(())
}

impl Listener {
Expand Down Expand Up @@ -250,9 +249,8 @@ impl Listener {

// Create the necessary per-connection handler state.
let mut handler = Handler {
// Get a handle to the shared database. Internally, this is an
// `Arc`, so a clone only increments the ref count.
db: self.db.clone(),
// Get a handle to the shared database.
db: self.db_holder.db(),

// Initialize the connection state. This allocates read/write
// buffers to perform redis protocol frame parsing.
Expand Down
2 changes: 1 addition & 1 deletion tests/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async fn pool_key_value_get_set() {
assert_eq!(b"world", &value[..])
}

async fn start_server() -> (SocketAddr, JoinHandle<mini_redis::Result<()>>) {
async fn start_server() -> (SocketAddr, JoinHandle<()>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();

Expand Down
2 changes: 1 addition & 1 deletion tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async fn unsubscribes_from_channels() {
assert_eq!(subscriber.get_subscribed().len(), 0);
}

async fn start_server() -> (SocketAddr, JoinHandle<mini_redis::Result<()>>) {
async fn start_server() -> (SocketAddr, JoinHandle<()>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();

Expand Down

0 comments on commit ebe4e1f

Please sign in to comment.