diff --git a/Cargo.lock b/Cargo.lock index 59f8430d1..8da4c8f88 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -377,6 +377,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" + [[package]] name = "clap" version = "4.5.1" @@ -548,6 +554,16 @@ dependencies = [ "typenum", ] +[[package]] +name = "ctrlc" +version = "3.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "672465ae37dc1bc6380a6547a8883d5dd397b0f1faaad4f265726cc7042a5345" +dependencies = [ + "nix 0.28.0", + "windows-sys 0.52.0", +] + [[package]] name = "data-encoding" version = "2.5.0" @@ -1323,6 +1339,18 @@ dependencies = [ "libc", ] +[[package]] +name = "nix" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4" +dependencies = [ + "bitflags 2.4.2", + "cfg-if", + "cfg_aliases", + "libc", +] + [[package]] name = "nom" version = "7.1.3" @@ -1644,7 +1672,7 @@ dependencies = [ "inferno", "libc", "log", - "nix", + "nix 0.26.4", "once_cell", "parking_lot", "prost", @@ -1958,6 +1986,7 @@ dependencies = [ "bytes", "clap", "config", + "ctrlc", "flume", "futures-util", "metrics", diff --git a/rumqttd/CHANGELOG.md b/rumqttd/CHANGELOG.md index 090bef2fe..07ee5456d 100644 --- a/rumqttd/CHANGELOG.md +++ b/rumqttd/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Assign random identifier to clients connecting with empty client id. +- Ability to gracefully shut down the broker. ### Changed - Public re-export `Strategy` for shared subscriptions diff --git a/rumqttd/Cargo.toml b/rumqttd/Cargo.toml index 8d1fd604b..6865457b6 100644 --- a/rumqttd/Cargo.toml +++ b/rumqttd/Cargo.toml @@ -16,6 +16,7 @@ tokio = { version = "1.36", features = ["rt", "time", "net", "io-util", "macros" serde = { version = "1.0.196", features = ["derive"] } serde_json = "1.0.113" bytes = { version = "1", features = ["serde"] } +ctrlc = "3.4" flume = { version = "0.11.0", default-features = false, features = ["async"]} slab = "0.4.9" thiserror = "1.0.57" diff --git a/rumqttd/examples/external_auth.rs b/rumqttd/examples/external_auth.rs index f5d7cdfcd..805b6b67f 100644 --- a/rumqttd/examples/external_auth.rs +++ b/rumqttd/examples/external_auth.rs @@ -32,9 +32,15 @@ fn main() { // true // }); - let mut broker = Broker::new(config); + let _guard = Broker::new(config).start().unwrap().drop_guard(); - broker.start().unwrap(); + let (tx, rx) = flume::bounded::<()>(1); + ctrlc::set_handler(move || { + let _ = tx.send(()); + }) + .expect("Error setting Ctrl-C handler"); + + rx.recv().expect("Could not receive Ctrl-C signal"); } async fn auth(_client_id: String, _username: String, _password: String) -> bool { diff --git a/rumqttd/examples/graceful_shutdown.rs b/rumqttd/examples/graceful_shutdown.rs new file mode 100644 index 000000000..ac318688c --- /dev/null +++ b/rumqttd/examples/graceful_shutdown.rs @@ -0,0 +1,36 @@ +use rumqttd::{Broker, Config}; +use tracing::Level; + +use core::time::Duration; +use std::thread; + +fn main() { + let builder = tracing_subscriber::fmt() + .pretty() + .with_max_level(Level::DEBUG) + .with_line_number(false) + .with_file(false) + .with_thread_ids(false) + .with_thread_names(false); + + builder + .try_init() + .expect("initialized subscriber succesfully"); + + // As examples are compiled as seperate binary so this config is current path dependent. Run it + // from root of this crate + let config = config::Config::builder() + .add_source(config::File::with_name("rumqttd.toml")) + .build() + .unwrap(); + + let config: Config = config.try_deserialize().unwrap(); + + dbg!(&config); + + let handler = Broker::new(config).start().unwrap(); + + thread::sleep(Duration::from_secs(1)); + handler.shutdown(); + thread::sleep(Duration::from_secs(1)); +} diff --git a/rumqttd/examples/meters.rs b/rumqttd/examples/meters.rs index ba872cca0..a4a1388fb 100644 --- a/rumqttd/examples/meters.rs +++ b/rumqttd/examples/meters.rs @@ -15,7 +15,7 @@ fn main() { dbg!(&config); - let mut broker = Broker::new(config); + let broker = Broker::new(config); let meters = broker.meters().unwrap(); let (mut link_tx, mut link_rx) = broker.link("consumer").unwrap(); @@ -59,11 +59,7 @@ fn main() { }); } - thread::spawn(move || { - if let Err(e) = broker.start() { - println!("Broker stopped: {e}"); - } - }); + let _guard = broker.start().unwrap().drop_guard(); thread::sleep(Duration::from_secs(2)); loop { diff --git a/rumqttd/examples/singlenode.rs b/rumqttd/examples/singlenode.rs index bcab216c3..50477a568 100644 --- a/rumqttd/examples/singlenode.rs +++ b/rumqttd/examples/singlenode.rs @@ -1,7 +1,5 @@ use rumqttd::{Broker, Config, Notification}; -use std::thread; - fn main() { let builder = tracing_subscriber::fmt() .pretty() @@ -25,11 +23,9 @@ fn main() { dbg!(&config); - let mut broker = Broker::new(config); + let broker = Broker::new(config); let (mut link_tx, mut link_rx) = broker.link("singlenode").unwrap(); - thread::spawn(move || { - broker.start().unwrap(); - }); + let _guard = broker.start().unwrap().drop_guard(); link_tx.subscribe("#").unwrap(); diff --git a/rumqttd/src/lib.rs b/rumqttd/src/lib.rs index 0b62f89f8..6b544617c 100644 --- a/rumqttd/src/lib.rs +++ b/rumqttd/src/lib.rs @@ -23,7 +23,7 @@ pub use link::local; pub use link::meters; pub use router::{Alert, IncomingMeter, Meter, Notification, OutgoingMeter}; use segments::Storage; -pub use server::Broker; +pub use server::{Broker, BrokerHandler, ShutdownDropGuard}; pub use self::router::shared_subs::Strategy; diff --git a/rumqttd/src/link/bridge.rs b/rumqttd/src/link/bridge.rs index 0c6b193be..eac2050dd 100644 --- a/rumqttd/src/link/bridge.rs +++ b/rumqttd/src/link/bridge.rs @@ -12,6 +12,7 @@ use std::{io, net::AddrParseError, time::Duration}; use tokio::{ net::TcpStream, + sync::watch, time::{sleep, sleep_until, Instant}, }; @@ -48,6 +49,7 @@ pub async fn start

( config: BridgeConfig, router_tx: Sender<(ConnectionId, Event)>, protocol: P, + mut shutdown_rx: watch::Receiver<()>, ) -> Result<(), BridgeError> where P: Protocol + Clone + Send + 'static, @@ -154,6 +156,10 @@ where // resetting timeout because tokio::select! consumes the old timeout future timeout = sleep_until(ping_time + Duration::from_secs(config.ping_delay)); } + _ = shutdown_rx.changed() => { + debug!("Shutting down bridge"); + break 'outer Ok(()); + } } } } diff --git a/rumqttd/src/link/console.rs b/rumqttd/src/link/console.rs index 9f03f8a1e..d566e1896 100644 --- a/rumqttd/src/link/console.rs +++ b/rumqttd/src/link/console.rs @@ -9,8 +9,8 @@ use axum::Json; use axum::{routing::get, Router}; use flume::Sender; use std::sync::Arc; -use tokio::net::TcpListener; -use tracing::info; +use tokio::{net::TcpListener, sync::watch}; +use tracing::{debug, info}; #[derive(Debug)] pub struct ConsoleLink { @@ -39,7 +39,7 @@ impl ConsoleLink { } #[tracing::instrument] -pub async fn start(console: Arc) { +pub async fn start(console: Arc, mut shutdown_rx: watch::Receiver<()>) { let listener = TcpListener::bind(console.config.listen.clone()) .await .unwrap(); @@ -56,7 +56,13 @@ pub async fn start(console: Arc) { .route("/logs", post(logs)) .with_state(console); - axum::serve(listener, app).await.unwrap(); + axum::serve(listener, app) + .with_graceful_shutdown(async move { + debug!("Shutting down console"); + let _ = shutdown_rx.changed().await; + }) + .await + .unwrap(); } async fn root(State(console): State>) -> impl IntoResponse { diff --git a/rumqttd/src/link/timer.rs b/rumqttd/src/link/timer.rs index fcc3bf00c..e72d48d25 100644 --- a/rumqttd/src/link/timer.rs +++ b/rumqttd/src/link/timer.rs @@ -5,7 +5,8 @@ use crate::{router::Event, MetricType}; use crate::{ConnectionId, MetricSettings}; use flume::{SendError, Sender}; use tokio::select; -use tracing::error; +use tokio::sync::watch; +use tracing::{debug, error}; #[derive(Debug, thiserror::Error)] pub enum Error { @@ -18,6 +19,7 @@ pub enum Error { pub async fn start( config: HashMap, router_tx: Sender<(ConnectionId, Event)>, + mut shutdown_rx: watch::Receiver<()>, ) { let span = tracing::info_span!("metrics_timer"); let _guard = span.enter(); @@ -42,6 +44,10 @@ pub async fn start( error!("Failed to push alerts: {e}"); } } + _ = shutdown_rx.changed() => { + debug!("Shutting down metrics timer"); + break; + } } } } diff --git a/rumqttd/src/main.rs b/rumqttd/src/main.rs index 8b22805f8..d8ac9e747 100644 --- a/rumqttd/src/main.rs +++ b/rumqttd/src/main.rs @@ -84,10 +84,15 @@ fn main() { validate_config(&configs); - // println!("{:#?}", configs); + let _guard = Broker::new(configs).start().unwrap().drop_guard(); - let mut broker = Broker::new(configs); - broker.start().unwrap(); + let (tx, rx) = flume::bounded::<()>(1); + ctrlc::set_handler(move || { + let _ = tx.send(()); + }) + .expect("Error setting Ctrl-C handler"); + + rx.recv().expect("Could not receive Ctrl-C signal"); } // Do any extra validation that needs to be done before starting the broker here. diff --git a/rumqttd/src/server/broker.rs b/rumqttd/src/server/broker.rs index 9886541c9..34ad68bdb 100644 --- a/rumqttd/src/server/broker.rs +++ b/rumqttd/src/server/broker.rs @@ -11,10 +11,10 @@ use crate::protocol::{Packet, Protocol}; use crate::server::tls::{self, TLSAcceptor}; use crate::{meters, ConnectionSettings, Meter}; use flume::{RecvError, SendError, Sender}; -use std::collections::HashMap; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::{Arc, Mutex}; -use tracing::{error, field, info, warn, Instrument}; +use std::{collections::HashMap, thread::JoinHandle}; +use tracing::{debug, error, field, info, warn, Instrument}; use uuid::Uuid; #[cfg(feature = "websocket")] @@ -39,6 +39,7 @@ use crate::router::{Event, Router}; use crate::{Config, ConnectionId, ServerSettings}; use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::watch; use tokio::time::error::Elapsed; use tokio::{task, time}; @@ -155,8 +156,12 @@ impl Broker { Ok((link_tx, link_rx)) } + /// Starts the MQTT broker server and spawns various components like the router, cluster, metrics, and console. + /// + /// This function sets up the necessary components for the MQTT broker server to run. The function returns a + /// [`BrokerHandler`] that can be used to gracefully shut down the server. #[tracing::instrument(skip(self))] - pub fn start(&mut self) -> Result<(), Error> { + pub fn start(self) -> Result { if self.config.v4.is_none() && self.config.v5.is_none() && (cfg!(not(feature = "websocket")) || self.config.ws.is_none()) @@ -171,16 +176,18 @@ impl Broker { // we don't know which servers (v4/v5/ws) user will spawn // so we collect handles for all of the spawned servers let mut server_thread_handles = Vec::new(); + let (shutdown_tx, shutdown_rx) = watch::channel(()); if let Some(metrics_config) = self.config.metrics.clone() { let timer_thread = thread::Builder::new().name("timer".to_owned()); let router_tx = self.router_tx.clone(); + let shutdown_rx = shutdown_rx.clone(); timer_thread.spawn(move || { let mut runtime = tokio::runtime::Builder::new_current_thread(); let runtime = runtime.enable_all().build().unwrap(); runtime.block_on(async move { - timer::start(metrics_config, router_tx).await; + timer::start(metrics_config, router_tx, shutdown_rx).await; }); })?; } @@ -189,12 +196,13 @@ impl Broker { if let Some(bridge_config) = self.config.bridge.clone() { let bridge_thread = thread::Builder::new().name(bridge_config.name.clone()); let router_tx = self.router_tx.clone(); + let shutdown_rx = shutdown_rx.clone(); bridge_thread.spawn(move || { let mut runtime = tokio::runtime::Builder::new_current_thread(); let runtime = runtime.enable_all().build().unwrap(); runtime.block_on(async move { - if let Err(e) = bridge::start(bridge_config, router_tx, V4).await { + if let Err(e) = bridge::start(bridge_config, router_tx, V4, shutdown_rx).await { error!(error=?e, "Bridge Link error"); }; }); @@ -206,13 +214,16 @@ impl Broker { for (_, config) in v4_config.clone() { let server_thread = thread::Builder::new().name(config.name.clone()); let mut server = Server::new(config, self.router_tx.clone(), V4); + let shutdown_rx = shutdown_rx.clone(); let handle = server_thread.spawn(move || { let mut runtime = tokio::runtime::Builder::new_current_thread(); let runtime = runtime.enable_all().build().unwrap(); runtime.block_on(async { - if let Err(e) = server.start(LinkType::Remote).await { + if let Err(e) = server.start(LinkType::Remote, shutdown_rx).await { error!(error=?e, "Server error - V4"); + } else { + debug!("Shutting down v4 server"); } }); })?; @@ -224,13 +235,16 @@ impl Broker { for (_, config) in v5_config.clone() { let server_thread = thread::Builder::new().name(config.name.clone()); let mut server = Server::new(config, self.router_tx.clone(), V5); + let shutdown_rx = shutdown_rx.clone(); let handle = server_thread.spawn(move || { let mut runtime = tokio::runtime::Builder::new_current_thread(); let runtime = runtime.enable_all().build().unwrap(); runtime.block_on(async { - if let Err(e) = server.start(LinkType::Remote).await { + if let Err(e) = server.start(LinkType::Remote, shutdown_rx).await { error!(error=?e, "Server error - V5"); + } else { + debug!("Shutting down v5 server"); } }); })?; @@ -249,13 +263,16 @@ impl Broker { let server_thread = thread::Builder::new().name(config.name.clone()); //TODO: Add support for V5 procotol with websockets. Registered in config or on ServerSettings let mut server = Server::new(config, self.router_tx.clone(), V4); + let shutdown_rx = shutdown_rx.clone(); let handle = server_thread.spawn(move || { let mut runtime = tokio::runtime::Builder::new_current_thread(); let runtime = runtime.enable_all().build().unwrap(); runtime.block_on(async { - if let Err(e) = server.start(LinkType::Websocket).await { + if let Err(e) = server.start(LinkType::Websocket, shutdown_rx).await { error!(error=?e, "Server error - WS"); + } else { + debug!("Shutting down websocket server"); } }); })?; @@ -280,6 +297,7 @@ impl Broker { }; let metrics_thread = thread::Builder::new().name("Metrics".to_owned()); let meter_link = self.meters().unwrap(); + let shutdown_rx = shutdown_rx.clone(); metrics_thread.spawn(move || { let builder = PrometheusBuilder::new().with_http_listener(addr); builder.install().unwrap(); @@ -301,6 +319,11 @@ impl Broker { } } + if shutdown_rx.has_changed().is_ok_and(|flag| flag) { + debug!("Shutting down metrics"); + break; + } + std::thread::sleep(Duration::from_secs(timeout)); } })?; @@ -311,23 +334,21 @@ impl Broker { let console_link = Arc::new(console_link); let console_thread = thread::Builder::new().name("Console".to_string()); + let shutdown_rx = shutdown_rx.clone(); console_thread.spawn(move || { let mut runtime = tokio::runtime::Builder::new_current_thread(); let runtime = runtime.enable_all().build().unwrap(); - runtime.block_on(console::start(console_link)); + runtime.block_on(console::start(console_link, shutdown_rx)); })?; } - // in ideal case, where server doesn't crash, join() will never resolve - // we still try to join threads so that we don't return from function - // unless everything crashes. - server_thread_handles.into_iter().for_each(|handle| { - // join() might panic in case the thread panics - // we just ignore it - let _ = handle.join(); - }); - - Ok(()) + Ok(BrokerHandler { + inner: InnerHandler { + threads: server_thread_handles, + shutdown_tx, + } + .into(), + }) } } @@ -379,7 +400,11 @@ impl Server

{ Ok((Box::new(stream), None)) } - async fn start(&mut self, link_type: LinkType) -> Result<(), Error> { + async fn start( + &mut self, + link_type: LinkType, + mut shutdown_rx: watch::Receiver<()>, + ) -> Result<(), Error> { let listener = TcpListener::bind(&self.config.listen).await?; let delay = Duration::from_millis(self.config.next_connection_delay_ms); let mut count: usize = 0; @@ -392,19 +417,33 @@ impl Server

{ ); loop { // Await new network connection. - let (stream, addr) = match listener.accept().await { - Ok((s, r)) => (s, r), - Err(e) => { - error!(error=?e, "Unable to accept socket."); - continue; + let (stream, addr) = tokio::select! { + accept = listener.accept() => { + match accept { + Ok((s, r)) => (s, r), + Err(e) => { + error!(error=?e, "Unable to accept socket."); + continue; + } + } + } + _ = shutdown_rx.changed() => { + return Ok(()); } }; - let (network, tenant_id) = match self.tls_accept(stream).await { - Ok(o) => o, - Err(e) => { - error!(error=?e, "Tls accept error"); - continue; + let (network, tenant_id) = tokio::select! { + accept = self.tls_accept(stream) => { + match accept { + Ok(o) => o, + Err(e) => { + error!(error=?e, "Tls accept error"); + continue; + } + } + } + _ = shutdown_rx.changed() => { + return Ok(()); } }; @@ -420,11 +459,18 @@ impl Server

{ match link_type { #[cfg(feature = "websocket")] LinkType::Websocket => { - let stream = match accept_hdr_async(network, WSCallback).await { - Ok(s) => Box::new(WsStream::new(s)), - Err(e) => { - error!(error=?e, "Websocket failed handshake"); - continue; + let stream = tokio::select! { + hdr_accept = accept_hdr_async(network, WSCallback) => { + match hdr_accept { + Ok(s) => Box::new(WsStream::new(s)), + Err(e) => { + error!(error=?e, "Websocket failed handshake"); + continue; + } + } + } + _ = shutdown_rx.changed() => { + return Ok(()); } }; task::spawn( @@ -461,7 +507,12 @@ impl Server

{ ), }; - time::sleep(delay).await; + tokio::select! { + _ = time::sleep(delay) => {} + _ = shutdown_rx.changed() => { + return Ok(()); + } + }; } } } @@ -627,3 +678,100 @@ async fn remote( router_tx.send((connection_id, message)).ok(); } } + +/// An internal handler that manages the shutdown process for a broker. +/// +/// The `InnerHandler` struct is responsible for coordinating the shutdown of a broker by maintaining a list of running +/// threads and a channel for sending a shutdown signal. It is used internally by the `ShutdownHandler` to manage the +/// shutdown process. +#[derive(Debug)] +struct InnerHandler { + threads: Vec>, + shutdown_tx: watch::Sender<()>, +} + +/// A struct that handles the shutdown process for a broker. +/// +/// The `ShutdownHandler` struct is responsible for coordinating the shutdown of a broker by sending a shutdown signal to +/// all running threads and waiting for them to join. +#[derive(Debug)] +pub struct BrokerHandler { + inner: Option, +} + +impl BrokerHandler { + /// Shuts down the server by sending a shutdown signal to all running threads and waiting for them to join. + /// + /// This method is responsible for coordinating the shutdown of the broker by taking ownership of the `BrokerHandler` + /// and sending a shutdown signal to all running threads. It then waits for each thread to join before returning. + /// This ensures that all broker resources are properly cleaned up and the broker can be safely shut down. + pub fn shutdown(mut self) { + if let Some(handler) = self.inner.take() { + let _ = handler.shutdown_tx.send(()); + for thread in handler.threads { + let _ = thread.join(); + } + } + } + + /// Joins all running threads associated with the `BrokerHandler`. + /// + /// This method takes ownership of the `BrokerHandler` and waits for all running threads to join. This ensures that all + /// broker resources are properly cleaned up and the broker can be safely shut down. + pub fn join(mut self) { + if let Some(handler) = self.inner.take() { + for thread in handler.threads { + let _ = thread.join(); + } + } + } + + /// Creates a `ShutdownDropGuard` that will automatically handle the shutdown process when the guard is dropped. + /// + /// The `drop_guard()` method takes ownership of the `BrokerHandler` and returns a [`ShutdownDropGuard`] that will + /// automatically send a shutdown signal to all running threads and wait for them to join when the guard is dropped. + /// This is useful for ensuring that the server is properly shut down, even in the event of an unexpected error or + /// early exit from the program. + #[inline] + pub fn drop_guard(mut self) -> ShutdownDropGuard { + ShutdownDropGuard { + inner: self.inner.take(), + } + } +} + +/// A guard that automatically handles the shutdown process for a broker when dropped. +/// +/// The `ShutdownDropGuard` struct is responsible for coordinating the shutdown of a broker by sending a shutdown signal to +/// all running threads and waiting for them to join when the guard is dropped. This ensures that the broker is properly +/// shut down, even in the event of an unexpected error or early exit from the program. +#[derive(Debug)] +pub struct ShutdownDropGuard { + inner: Option, +} + +impl ShutdownDropGuard { + /// Disarms the `ShutdownDropGuard` and returns a [`BrokerHandler`] that can be used to manually shut down the broker. + /// + /// This method takes ownership of the `ShutdownDropGuard` and returns a new [`BrokerHandler`] that contains the same + /// internal state as the `ShutdownDropGuard`. This allows the caller to take control of the shutdown process and + /// manually shut down the server when needed, rather than relying on the automatic shutdown when the `ShutdownDropGuard` + /// is dropped. + #[inline] + pub fn disarm(mut self) -> BrokerHandler { + BrokerHandler { + inner: self.inner.take(), + } + } +} + +impl Drop for ShutdownDropGuard { + fn drop(&mut self) { + if let Some(handler) = self.inner.take() { + let _ = handler.shutdown_tx.send(()); + for thread in handler.threads { + let _ = thread.join(); + } + } + } +} diff --git a/rumqttd/src/server/mod.rs b/rumqttd/src/server/mod.rs index fb151b1bf..4fcfc9a00 100644 --- a/rumqttd/src/server/mod.rs +++ b/rumqttd/src/server/mod.rs @@ -4,7 +4,7 @@ mod broker; #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))] mod tls; -pub use broker::Broker; +pub use broker::{Broker, BrokerHandler, ShutdownDropGuard}; // pub trait IO: AsyncRead + AsyncWrite + Send + Sync + Unpin {} // impl IO for T {}